Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/ml/predict.cpp
Go to the documentation of this file.
00001 /*
00002  * predict.cpp
00003  *
00004  *  Created on: Feb 7, 2014
00005  *      Author: vikas
00006  */
00007 
00008 #include "hilbert.hpp"
00009 #include <boost/mpi.hpp>
00010 #include "../base/context.hpp"
00011 
00012 
00013 #define DEBUG std::cout << "error " << std::endl;
00014 
00015 int main (int argc, char** argv) {
00016 
00017     int provided;
00018     MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided);
00019 
00020         std::string testfile = argv[1];
00021         std::string modelfile = argv[2];
00022         boost::mpi::environment env (argc, argv);
00023 
00024         // get communicator
00025         boost::mpi::communicator comm;
00026         int rank = comm.rank();
00027 
00028         skylark::base::context_t context (12345);
00029         elem::Initialize (argc, argv);
00030         MPI_Comm mpi_world(comm);
00031         DistInputMatrixType X, Y;
00032         int m,n;
00033         elem::Matrix<double> W;
00034 
00035 
00036         if (rank == 0) {
00037 
00038                 read_model_file(modelfile, W);
00039                 m = W.Height();
00040                 n = W.Width();
00041         }
00042         boost::mpi::broadcast(comm, m, 0);
00043         boost::mpi::broadcast(comm, n, 0);
00044         if (rank != 0) {
00045                 W.Resize(m,n);
00046         }
00047 
00048         boost::mpi::broadcast(comm, W.Buffer(), m*n, 0);
00049 
00050         read_libsvm_dense(context, testfile, X, Y);
00051 
00052         elem::DistMatrix<double, elem::VC, elem::STAR>  O(X.Height(), W.Width());
00053         elem::MakeZeros(O);
00054 
00055         elem::Gemm(elem::NORMAL,elem::NORMAL,1.0, X.Matrix(), W, 0.0, O.Matrix());
00056 
00057         int correct = 0;
00058         double o, o1;
00059         int pred;
00060         for(int i=0; i<O.LocalHeight(); i++) {
00061                 o = O.GetLocal(i,0);
00062                 pred = 0;
00063                 for(int j=1; j<O.Width(); j++) {
00064                         o1 = O.GetLocal(i,j);
00065                         if ( o1 > o) {
00066                                 o = o1;
00067                                 pred = j;
00068                         }
00069                 }
00070                 if(pred== (int) Y.GetLocal(i,0))
00071                         correct++;
00072         }
00073 
00074         comm.barrier();
00075 
00076         //if(rank ==0) {
00077         int totalcorrect;
00078         boost::mpi::reduce(comm, correct, totalcorrect, std::plus<double>(), 0);
00079         if(rank ==0)
00080             std::cout << "Accuracy = " << totalcorrect*100.0/X.Height() << " %" << std::endl;
00081         //}
00082 }