Skylark (Sketching Library)
0.1
|
00001 #include <iostream> 00002 00003 #include <elemental.hpp> 00004 #include <boost/mpi.hpp> 00005 #include <boost/format.hpp> 00006 #include <skylark.hpp> 00007 00008 /*******************************************/ 00009 namespace bmpi = boost::mpi; 00010 namespace skybase = skylark::base; 00011 namespace skysk = skylark::sketch; 00012 namespace skynla = skylark::nla; 00013 namespace skyalg = skylark::algorithms; 00014 namespace skyutil = skylark::utility; 00015 /*******************************************/ 00016 00017 // Parameters 00018 #if 0 00019 const int m = 2000; 00020 const int n = 10; 00021 const int t = 500; 00022 #else 00023 const int m = 30000; 00024 const int n = 500; 00025 const int t = 2000; 00026 #endif 00027 00028 typedef elem::DistMatrix<double, elem::VC, elem::STAR> matrix_type; 00029 typedef elem::DistMatrix<double, elem::VC, elem::STAR> rhs_type; 00030 typedef elem::DistMatrix<double, elem::STAR, elem::STAR> sol_type; 00031 typedef elem::DistMatrix<double, elem::STAR, elem::STAR> sketch_type; 00032 typedef elem::DistMatrix<double, elem::STAR, elem::STAR> precond_type; 00033 00034 typedef skyalg::regression_problem_t<matrix_type, 00035 skyalg::linear_tag, 00036 skyalg::l2_tag, 00037 skyalg::no_reg_tag> regression_problem_type; 00038 00039 template<typename AlgTag> 00040 struct exact_solver_type : 00041 public skyalg::regression_solver_t< 00042 regression_problem_type, rhs_type, sol_type, AlgTag> { 00043 00044 typedef skyalg::regression_solver_t< 00045 regression_problem_type, rhs_type, sol_type, AlgTag> base_type; 00046 00047 exact_solver_type(const regression_problem_type& problem) : 00048 base_type(problem) { 00049 00050 } 00051 }; 00052 00053 00054 // Just a temporary small example on using "computed matrices" 00055 class cmatrix : public skybase::computed_matrix_t<matrix_type> { 00056 const matrix_type &_A; 00057 00058 public: 00059 cmatrix(const matrix_type& A) : _A(A) { }; 00060 00061 int height() const { return _A.Height(); } 00062 int width() const { return _A.Width(); } 00063 00064 void materialize(matrix_type& Z) const { Z = _A; } 00065 matrix_type materialize() const { matrix_type Z(_A); return Z; } 00066 }; 00067 00068 typedef skyalg::regression_problem_t<skybase::computed_matrix_t<matrix_type>, 00069 skyalg::linear_tag, 00070 skyalg::l2_tag, 00071 skyalg::no_reg_tag> regression_problem_type1; 00072 00073 template<typename AlgTag> 00074 struct exact_solver_type1 : 00075 public skyalg::regression_solver_t< 00076 regression_problem_type1, rhs_type, sol_type, AlgTag> { 00077 00078 typedef skyalg::regression_solver_t< 00079 regression_problem_type1, rhs_type, sol_type, AlgTag> base_type; 00080 00081 exact_solver_type1(const regression_problem_type1& problem) : 00082 base_type(problem) { 00083 00084 } 00085 }; 00086 00087 00088 template<template <typename, typename> class TransformType > 00089 struct accelerated_exact_solver_type_sb : 00090 public skyalg::accelerated_regression_solver_t< 00091 regression_problem_type, rhs_type, sol_type, 00092 skyalg::simplified_blendenpik_tag<TransformType, 00093 skyalg::svd_precond_tag> > { 00094 00095 typedef skyalg::accelerated_regression_solver_t< 00096 regression_problem_type, rhs_type, sol_type, 00097 skyalg::simplified_blendenpik_tag<TransformType, 00098 skyalg::svd_precond_tag > > base_type; 00099 00100 accelerated_exact_solver_type_sb(const regression_problem_type& problem, 00101 skybase::context_t& context) : 00102 base_type(problem, context) { 00103 00104 } 00105 }; 00106 00107 struct accelerated_exact_solver_type_blendenpik : 00108 public skyalg::accelerated_regression_solver_t< 00109 regression_problem_type, rhs_type, sol_type, 00110 skyalg::blendenpik_tag<skyalg::qr_precond_tag> > { 00111 00112 typedef skyalg::accelerated_regression_solver_t< 00113 regression_problem_type, rhs_type, sol_type, 00114 skyalg::blendenpik_tag<skyalg::qr_precond_tag > > base_type; 00115 00116 accelerated_exact_solver_type_blendenpik(const regression_problem_type& problem, 00117 skybase::context_t& context) : 00118 base_type(problem, context) { 00119 00120 } 00121 }; 00122 00123 00124 struct accelerated_exact_solver_type_lsrn : 00125 public skyalg::accelerated_regression_solver_t< 00126 regression_problem_type, rhs_type, sol_type, 00127 skyalg::lsrn_tag<skyalg::svd_precond_tag> > { 00128 00129 typedef skyalg::accelerated_regression_solver_t< 00130 regression_problem_type, rhs_type, sol_type, 00131 skyalg::lsrn_tag<skyalg::svd_precond_tag > > base_type; 00132 00133 accelerated_exact_solver_type_lsrn(const regression_problem_type& problem, 00134 skybase::context_t& context) : 00135 base_type(problem, context) { 00136 00137 } 00138 }; 00139 00140 template<> 00141 template<typename KT> 00142 struct exact_solver_type< skyalg::iterative_l2_solver_tag<KT> >: 00143 public skyalg::regression_solver_t< 00144 regression_problem_type, rhs_type, sol_type, 00145 skyalg::iterative_l2_solver_tag<KT> > { 00146 00147 typedef skyalg::regression_solver_t< 00148 regression_problem_type, rhs_type, sol_type, 00149 skyalg::iterative_l2_solver_tag<KT> > base_type; 00150 00151 exact_solver_type(const regression_problem_type& problem, 00152 skynla::iter_params_t iter_params) : 00153 base_type(problem, iter_params) { 00154 00155 } 00156 00157 exact_solver_type(const regression_problem_type& problem, 00158 const skynla::precond_t<sol_type>& R, 00159 skynla::iter_params_t iter_params) : 00160 base_type(problem, R, iter_params) { 00161 00162 } 00163 00164 00165 }; 00166 00167 template<template <typename, typename> class TransformType > 00168 struct sketched_solver_type : 00169 public skyalg::sketched_regression_solver_t< 00170 regression_problem_type, matrix_type, sol_type, 00171 skyalg::linear_tag, 00172 sketch_type, 00173 sketch_type, 00174 TransformType, 00175 skyalg::qr_l2_solver_tag> { 00176 00177 typedef skyalg::sketched_regression_solver_t< 00178 regression_problem_type, matrix_type, sol_type, 00179 skyalg::linear_tag, 00180 sketch_type, 00181 sketch_type, 00182 TransformType, 00183 skyalg::qr_l2_solver_tag> base_type; 00184 00185 sketched_solver_type(const regression_problem_type& problem, 00186 int sketch_size, 00187 skybase::context_t& context) : 00188 base_type(problem, sketch_size, context) { 00189 00190 } 00191 00192 }; 00193 00194 template<typename ProblemType, typename RhsType, typename SolType> 00195 void check_solution(const ProblemType &pr, const RhsType &b, const SolType &x, 00196 const RhsType &r0, 00197 double &res, double &resAtr, double &resFac) { 00198 RhsType r(b); 00199 skybase::Gemv(elem::NORMAL, -1.0, pr.input_matrix, x, 1.0, r); 00200 res = skybase::Nrm2(r); 00201 00202 SolType Atr(x.Height(), x.Width(), x.Grid()); 00203 skybase::Gemv(elem::TRANSPOSE, 1.0, pr.input_matrix, r, 0.0, Atr); 00204 resAtr = skybase::Nrm2(Atr); 00205 00206 skybase::Axpy(-1.0, r0, r); 00207 RhsType dr(b); 00208 skybase::Axpy(-1.0, r0, dr); 00209 resFac = skybase::Nrm2(r) / skybase::Nrm2(dr); 00210 } 00211 00212 int main(int argc, char** argv) { 00213 double res, resAtr, resFac; 00214 00215 elem::Initialize(argc, argv); 00216 00217 bmpi::communicator world; 00218 int rank = world.rank(); 00219 00220 skybase::context_t context(23234); 00221 00222 // Setup problem and righthand side 00223 // Using Skylark's uniform generator (as opposed to Elemental's) 00224 // will insure the same A and b are generated regardless of the number 00225 // of processors. 00226 matrix_type A = 00227 skyutil::uniform_matrix_t<matrix_type>::generate(m, 00228 n, elem::DefaultGrid(), context); 00229 matrix_type b = 00230 skyutil::uniform_matrix_t<matrix_type>::generate(m, 00231 1, elem::DefaultGrid(), context); 00232 00233 regression_problem_type problem(m, n, A); 00234 00235 boost::mpi::timer timer; 00236 double telp; 00237 00238 sol_type x(n,1); 00239 00240 rhs_type r(b); 00241 00242 // Using QR 00243 timer.restart(); 00244 exact_solver_type<skyalg::qr_l2_solver_tag> exact_solver(problem); 00245 exact_solver.solve(b, x); 00246 telp = timer.elapsed(); 00247 check_solution(problem, b, x, r, res, resAtr, resFac); 00248 if (rank == 0) 00249 std::cout << "Exact (QR):\t\t\t||r||_2 = " 00250 << boost::format("%.2f") % res 00251 << "\t\t\t\t\t\t\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00252 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00253 << std::endl; 00254 double res_opt = res; 00255 00256 skybase::Gemv(elem::NORMAL, -1.0, problem.input_matrix, x, 1.0, r); 00257 00258 // Using SNE (semi-normal equations) 00259 timer.restart(); 00260 exact_solver_type<skyalg::sne_l2_solver_tag>(problem).solve(b, x); 00261 telp = timer.elapsed(); 00262 check_solution(problem, b, x, r, res, resAtr, resFac); 00263 if (rank == 0) 00264 std::cout << "Exact (SNE):\t\t\t||r||_2 = " 00265 << boost::format("%.2f") % res 00266 << "\t\t\t\t\t\t\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00267 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00268 << std::endl; 00269 res_opt = res; 00270 00271 // Again, using SNE, only with the computed interface (example; to be removed.) 00272 cmatrix CA(A); 00273 regression_problem_type1 problem1(m, n, CA); 00274 timer.restart(); 00275 exact_solver_type1<skyalg::sne_l2_solver_tag>(problem1).solve(b, x); 00276 telp = timer.elapsed(); 00277 check_solution(problem, b, x, r, res, resAtr, resFac); 00278 if (rank == 0) 00279 std::cout << "Exact (SNE) (COMPUTED):\t\t\t||r||_2 = " 00280 << boost::format("%.2f") % res 00281 << "\t\t\t\t\t\t\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00282 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00283 << std::endl; 00284 res_opt = res; 00285 00286 // Using SVD 00287 timer.restart(); 00288 exact_solver_type<skyalg::svd_l2_solver_tag>(problem).solve(b, x); 00289 telp = timer.elapsed(); 00290 check_solution(problem, b, x, r, res, resAtr, resFac); 00291 if (rank == 0) 00292 std::cout << "Exact (SVD):\t\t\t||r||_2 = " 00293 << boost::format("%.2f") % res 00294 << "\t\t\t\t\t\t\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00295 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00296 << std::endl; 00297 res_opt = res; 00298 00299 // Using LSQR 00300 skynla::iter_params_t lsqrparams; 00301 lsqrparams.am_i_printing = rank == 0; 00302 lsqrparams.log_level = 0; 00303 timer.restart(); 00304 exact_solver_type< 00305 skyalg::iterative_l2_solver_tag< 00306 skyalg::lsqr_tag > >(problem, lsqrparams) 00307 .solve(b, x); 00308 telp = timer.elapsed(); 00309 check_solution(problem, b, x, r, res, resAtr, resFac); 00310 if (rank == 0) 00311 std::cout << "Exact (LSQR):\t\t\t||r||_2 = " 00312 << boost::format("%.2f") % res 00313 << "\t\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac 00314 << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00315 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00316 << std::endl; 00317 00318 // Using sketch-and-solve 00319 00320 #if 0 00321 timer.restart(); 00322 sketched_solver_type<skysk::JLT_t>(problem, t, context).solve(b, x); 00323 telp = timer.elapsed(); 00324 check_solution(problem, b, x, r, res, resAtr, resFac); 00325 if (rank == 0) 00326 std::cout << "Sketch-and-Solve (JLT):\t\t||r||_2 = " 00327 << boost::format("%.2f") % res 00328 << " (x " << boost::format("%.5f") % (res / res_opt) << ")" 00329 << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac 00330 << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00331 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00332 << std::endl; 00333 #endif 00334 00335 timer.restart(); 00336 sketched_solver_type<skysk::CWT_t>(problem, t, context).solve(b, x); 00337 telp = timer.elapsed(); 00338 check_solution(problem, b, x, r, res, resAtr, resFac); 00339 if (rank == 0) 00340 std::cout << "Sketch-and-Solve (CWT):\t\t||r||_2 = " 00341 << boost::format("%.2f") % res 00342 << " (x " << boost::format("%.5f") % (res / res_opt) << ")" 00343 << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac 00344 << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00345 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00346 << std::endl; 00347 00348 timer.restart(); 00349 sketched_solver_type<skysk::FJLT_t>(problem, t, context).solve(b, x); 00350 telp = timer.elapsed(); 00351 check_solution(problem, b, x, r, res, resAtr, resFac); 00352 if (rank == 0) 00353 std::cout << "Sketch-and-Solve (FJLT):\t||r||_2 = " 00354 << boost::format("%.2f") % res 00355 << " (x " << boost::format("%.5f") % (res / res_opt) << ")" 00356 << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac 00357 << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00358 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00359 << std::endl; 00360 00361 // Accelerate-using-sketching 00362 #if 0 00363 timer.restart(); 00364 accelerated_exact_solver_type_sb<skysk::JLT_t>(problem, context).solve(b, x); 00365 telp = timer.elapsed(); 00366 check_solution(problem, b, x, r, res, resAtr, resFac); 00367 if (rank == 0) 00368 std::cout << "Simplified Blendenpik (JLT):\t||r||_2 = " 00369 << boost::format("%.2f") % res 00370 << " (x " << boost::format("%.5f") % (res / res_opt) << ")" 00371 << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac 00372 << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00373 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00374 << std::endl; 00375 #endif 00376 00377 timer.restart(); 00378 accelerated_exact_solver_type_sb<skysk::FJLT_t>(problem, context).solve(b, x); 00379 telp = timer.elapsed(); 00380 check_solution(problem, b, x, r, res, resAtr, resFac); 00381 if (rank == 0) 00382 std::cout << "Simplified Blendenpik (FJLT):\t||r||_2 = " 00383 << boost::format("%.2f") % res 00384 << " (x " << boost::format("%.5f") % (res / res_opt) << ")" 00385 << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac 00386 << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00387 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00388 << std::endl; 00389 00390 timer.restart(); 00391 accelerated_exact_solver_type_sb<skysk::CWT_t>(problem, context).solve(b, x); 00392 telp = timer.elapsed(); 00393 check_solution(problem, b, x, r, res, resAtr, resFac); 00394 if (rank == 0) 00395 std::cout << "Simplified Blendenpik (CWT):\t||r||_2 = " 00396 << boost::format("%.2f") % res 00397 << " (x " << boost::format("%.5f") % (res / res_opt) << ")" 00398 << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac 00399 << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00400 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00401 << std::endl; 00402 00403 timer.restart(); 00404 accelerated_exact_solver_type_blendenpik(problem, context).solve(b, x); 00405 telp = timer.elapsed(); 00406 check_solution(problem, b, x, r, res, resAtr, resFac); 00407 if (rank == 0) 00408 std::cout << "Blendenpik:\t\t\t||r||_2 = " 00409 << boost::format("%.2f") % res 00410 << " (x " << boost::format("%.5f") % (res / res_opt) << ")" 00411 << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac 00412 << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00413 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00414 << std::endl; 00415 00416 timer.restart(); 00417 accelerated_exact_solver_type_lsrn(problem, context).solve(b, x); 00418 telp = timer.elapsed(); 00419 check_solution(problem, b, x, r, res, resAtr, resFac); 00420 if (rank == 0) 00421 std::cout << "LSRN:\t\t\t\t||r||_2 = " 00422 << boost::format("%.2f") % res 00423 << " (x " << boost::format("%.5f") % (res / res_opt) << ")" 00424 << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac 00425 << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr 00426 << "\t\tTime: " << boost::format("%.2e") % telp << " sec" 00427 << std::endl; 00428 00429 return 0; 00430 }