Skylark (Sketching Library)
0.1
|
00001 #include <iostream> 00002 00003 #include <elemental.hpp> 00004 #include <boost/mpi.hpp> 00005 #include <boost/format.hpp> 00006 #include <skylark.hpp> 00007 00008 /*******************************************/ 00009 namespace bmpi = boost::mpi; 00010 namespace skybase = skylark::base; 00011 namespace skysk = skylark::sketch; 00012 namespace skynla = skylark::nla; 00013 namespace skyalg = skylark::algorithms; 00014 namespace skyutil = skylark::utility; 00015 /*******************************************/ 00016 00017 const int m = 50000; 00018 const int n = 500; 00019 00020 typedef elem::DistMatrix<double, elem::VC, elem::STAR> matrix_type; 00021 typedef elem::DistMatrix<double, elem::VC, elem::STAR> rhs_type; 00022 typedef elem::DistMatrix<double, elem::STAR, elem::STAR> sol_type; 00023 00024 template<typename MatrixType, typename RhsType, typename SolType> 00025 void check_solution(const MatrixType &A, const RhsType &b, const SolType &x, 00026 const RhsType &r0, 00027 double &res, double &resAtr, double &resFac) { 00028 RhsType r(b); 00029 skybase::Gemv(elem::NORMAL, -1.0, A, x, 1.0, r); 00030 res = skybase::Nrm2(r); 00031 00032 SolType Atr(x.Height(), x.Width(), x.Grid()); 00033 skybase::Gemv(elem::TRANSPOSE, 1.0, A, r, 0.0, Atr); 00034 resAtr = skybase::Nrm2(Atr); 00035 00036 skybase::Axpy(-1.0, r0, r); 00037 RhsType dr(b); 00038 skybase::Axpy(-1.0, r0, dr); 00039 resFac = skybase::Nrm2(r) / skybase::Nrm2(dr); 00040 } 00041 00042 int main(int argc, char** argv) { 00043 double res, resAtr, resFac; 00044 00045 elem::Initialize(argc, argv); 00046 00047 bmpi::communicator world; 00048 int rank = world.rank(); 00049 00050 skybase::context_t context(23234); 00051 00052 // Setup problem and righthand side 00053 // Using Skylark's uniform generator (as opposed to Elemental's) 00054 // will insure the same A and b are generated regardless of the number 00055 // of processors. 00056 matrix_type A = 00057 skyutil::uniform_matrix_t<matrix_type>::generate(m, 00058 n, elem::DefaultGrid(), context); 00059 matrix_type b = 00060 skyutil::uniform_matrix_t<matrix_type>::generate(m, 00061 1, elem::DefaultGrid(), context); 00062 00063 sol_type x(n,1); 00064 rhs_type r(b); 00065 00066 boost::mpi::timer timer; 00067 double telp; 00068 00069 // Solve using Elemental. Note: Elemental only supports [MC,MR]... 00070 elem::DistMatrix<double> A1 = A, b1 = b, x1; 00071 timer.restart(); 00072 elem::LeastSquares(elem::NORMAL, A1, b1, x1); 00073 telp = timer.elapsed(); 00074 x = x1; 00075 check_solution(A, b, x, r, res, resAtr, resFac); 00076 if (rank == 0) 00077 std::cout << "Elemental:\t\t\t||r||_2 = " 00078 << boost::format("%.2f") % res 00079 << "\t\t\t\t\t\t\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00080 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00081 << std::endl; 00082 double res_opt = res; 00083 00084 skybase::Gemv(elem::NORMAL, -1.0, A, x, 1.0, r); 00085 00086 // Solve using Sylark 00087 timer.restart(); 00088 skynla::FastLeastSquares(elem::NORMAL, A, b, x, context); 00089 telp = timer.elapsed(); 00090 check_solution(A, b, x, r, res, resAtr, resFac); 00091 if (rank == 0) 00092 std::cout << "Skylark:\t\t\t||r||_2 = " 00093 << boost::format("%.2f") % res 00094 << " (x " << boost::format("%.5f") % (res / res_opt) << ")" 00095 << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac 00096 << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00097 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00098 << std::endl; 00099 00100 // Approximately solve using Sylark 00101 timer.restart(); 00102 skynla::ApproximateLeastSquares(elem::NORMAL, A, b, x, context); 00103 telp = timer.elapsed(); 00104 check_solution(A, b, x, r, res, resAtr, resFac); 00105 if (rank == 0) 00106 std::cout << "Skylark (approximate):\t\t||r||_2 = " 00107 << boost::format("%.2f") % res 00108 << " (x " << boost::format("%.5f") % (res / res_opt) << ")" 00109 << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac 00110 << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00111 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00112 << std::endl; 00113 00114 return 0; 00115 }