Skylark (Sketching Library)
0.1
|
00001 /* 00002 * predict.cpp 00003 * 00004 * Created on: Feb 7, 2014 00005 * Author: vikas 00006 */ 00007 00008 #include "hilbert.hpp" 00009 #include <boost/mpi.hpp> 00010 #include "../base/context.hpp" 00011 00012 00013 #define DEBUG std::cout << "error " << std::endl; 00014 00015 int main (int argc, char** argv) { 00016 00017 int provided; 00018 MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); 00019 00020 std::string testfile = argv[1]; 00021 std::string modelfile = argv[2]; 00022 boost::mpi::environment env (argc, argv); 00023 00024 // get communicator 00025 boost::mpi::communicator comm; 00026 int rank = comm.rank(); 00027 00028 skylark::base::context_t context (12345); 00029 elem::Initialize (argc, argv); 00030 MPI_Comm mpi_world(comm); 00031 DistInputMatrixType X, Y; 00032 int m,n; 00033 elem::Matrix<double> W; 00034 00035 00036 if (rank == 0) { 00037 00038 read_model_file(modelfile, W); 00039 m = W.Height(); 00040 n = W.Width(); 00041 } 00042 boost::mpi::broadcast(comm, m, 0); 00043 boost::mpi::broadcast(comm, n, 0); 00044 if (rank != 0) { 00045 W.Resize(m,n); 00046 } 00047 00048 boost::mpi::broadcast(comm, W.Buffer(), m*n, 0); 00049 00050 read_libsvm_dense(context, testfile, X, Y); 00051 00052 elem::DistMatrix<double, elem::VC, elem::STAR> O(X.Height(), W.Width()); 00053 elem::MakeZeros(O); 00054 00055 elem::Gemm(elem::NORMAL,elem::NORMAL,1.0, X.Matrix(), W, 0.0, O.Matrix()); 00056 00057 int correct = 0; 00058 double o, o1; 00059 int pred; 00060 for(int i=0; i<O.LocalHeight(); i++) { 00061 o = O.GetLocal(i,0); 00062 pred = 0; 00063 for(int j=1; j<O.Width(); j++) { 00064 o1 = O.GetLocal(i,j); 00065 if ( o1 > o) { 00066 o = o1; 00067 pred = j; 00068 } 00069 } 00070 if(pred== (int) Y.GetLocal(i,0)) 00071 correct++; 00072 } 00073 00074 comm.barrier(); 00075 00076 //if(rank ==0) { 00077 int totalcorrect; 00078 boost::mpi::reduce(comm, correct, totalcorrect, std::plus<double>(), 0); 00079 if(rank ==0) 00080 std::cout << "Accuracy = " << totalcorrect*100.0/X.Height() << " %" << std::endl; 00081 //} 00082 }