Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/examples/regression.cpp
Go to the documentation of this file.
00001 #include <iostream>
00002 
00003 #include <elemental.hpp>
00004 #include <boost/mpi.hpp>
00005 #include <boost/format.hpp>
00006 #include <skylark.hpp>
00007 
00008 /*******************************************/
00009 namespace bmpi =  boost::mpi;
00010 namespace skybase = skylark::base;
00011 namespace skysk =  skylark::sketch;
00012 namespace skynla = skylark::nla;
00013 namespace skyalg = skylark::algorithms;
00014 namespace skyutil = skylark::utility;
00015 /*******************************************/
00016 
00017 // Parameters
00018 #if 0
00019 const int m = 2000;
00020 const int n = 10;
00021 const int t = 500;
00022 #else
00023 const int m = 30000;
00024 const int n = 500;
00025 const int t = 2000;
00026 #endif
00027 
00028 typedef elem::DistMatrix<double, elem::VC, elem::STAR> matrix_type;
00029 typedef elem::DistMatrix<double, elem::VC, elem::STAR> rhs_type;
00030 typedef elem::DistMatrix<double, elem::STAR, elem::STAR> sol_type;
00031 typedef elem::DistMatrix<double, elem::STAR, elem::STAR> sketch_type;
00032 typedef elem::DistMatrix<double, elem::STAR, elem::STAR> precond_type;
00033 
00034 typedef skyalg::regression_problem_t<matrix_type,
00035                                      skyalg::linear_tag,
00036                                      skyalg::l2_tag,
00037                                      skyalg::no_reg_tag> regression_problem_type;
00038 
00039 template<typename AlgTag>
00040 struct exact_solver_type :
00041     public skyalg::regression_solver_t<
00042     regression_problem_type, rhs_type, sol_type, AlgTag> {
00043 
00044     typedef skyalg::regression_solver_t<
00045         regression_problem_type, rhs_type, sol_type, AlgTag> base_type;
00046 
00047     exact_solver_type(const regression_problem_type& problem) :
00048         base_type(problem) {
00049 
00050     }
00051 };
00052 
00053 
00054 // Just a temporary small example on using "computed matrices"
00055 class cmatrix : public skybase::computed_matrix_t<matrix_type> {
00056     const matrix_type &_A;
00057 
00058 public:
00059     cmatrix(const matrix_type& A) : _A(A) { };
00060 
00061     int height() const { return _A.Height(); }
00062     int width() const { return _A.Width(); }
00063 
00064     void materialize(matrix_type& Z) const { Z = _A; }
00065     matrix_type materialize() const { matrix_type Z(_A); return Z; }
00066 };
00067 
00068 typedef skyalg::regression_problem_t<skybase::computed_matrix_t<matrix_type>,
00069                                      skyalg::linear_tag,
00070                                      skyalg::l2_tag,
00071                                      skyalg::no_reg_tag> regression_problem_type1;
00072 
00073 template<typename AlgTag>
00074 struct exact_solver_type1 :
00075     public skyalg::regression_solver_t<
00076     regression_problem_type1, rhs_type, sol_type, AlgTag> {
00077 
00078     typedef skyalg::regression_solver_t<
00079         regression_problem_type1, rhs_type, sol_type, AlgTag> base_type;
00080 
00081     exact_solver_type1(const regression_problem_type1& problem) :
00082         base_type(problem) {
00083 
00084     }
00085 };
00086 
00087 
00088 template<template <typename, typename> class TransformType >
00089 struct accelerated_exact_solver_type_sb :
00090     public skyalg::accelerated_regression_solver_t<
00091     regression_problem_type, rhs_type, sol_type,
00092     skyalg::simplified_blendenpik_tag<TransformType,
00093                                       skyalg::svd_precond_tag> > {
00094 
00095     typedef  skyalg::accelerated_regression_solver_t<
00096         regression_problem_type, rhs_type, sol_type,
00097         skyalg::simplified_blendenpik_tag<TransformType,
00098                                           skyalg::svd_precond_tag > > base_type;
00099 
00100     accelerated_exact_solver_type_sb(const regression_problem_type& problem,
00101         skybase::context_t& context) :
00102         base_type(problem, context) {
00103 
00104     }
00105 };
00106 
00107 struct accelerated_exact_solver_type_blendenpik :
00108     public skyalg::accelerated_regression_solver_t<
00109     regression_problem_type, rhs_type, sol_type,
00110     skyalg::blendenpik_tag<skyalg::qr_precond_tag> > {
00111 
00112     typedef  skyalg::accelerated_regression_solver_t<
00113         regression_problem_type, rhs_type, sol_type,
00114         skyalg::blendenpik_tag<skyalg::qr_precond_tag > > base_type;
00115 
00116     accelerated_exact_solver_type_blendenpik(const regression_problem_type& problem,
00117         skybase::context_t& context) :
00118         base_type(problem, context) {
00119 
00120     }
00121 };
00122 
00123 
00124 struct accelerated_exact_solver_type_lsrn :
00125     public skyalg::accelerated_regression_solver_t<
00126     regression_problem_type, rhs_type, sol_type,
00127     skyalg::lsrn_tag<skyalg::svd_precond_tag> > {
00128 
00129     typedef  skyalg::accelerated_regression_solver_t<
00130         regression_problem_type, rhs_type, sol_type,
00131         skyalg::lsrn_tag<skyalg::svd_precond_tag > > base_type;
00132 
00133     accelerated_exact_solver_type_lsrn(const regression_problem_type& problem,
00134         skybase::context_t& context) :
00135         base_type(problem, context) {
00136 
00137     }
00138 };
00139 
00140 template<>
00141 template<typename KT>
00142 struct exact_solver_type< skyalg::iterative_l2_solver_tag<KT> >:
00143     public skyalg::regression_solver_t<
00144     regression_problem_type, rhs_type, sol_type,
00145     skyalg::iterative_l2_solver_tag<KT> > {
00146 
00147     typedef skyalg::regression_solver_t<
00148         regression_problem_type, rhs_type, sol_type,
00149         skyalg::iterative_l2_solver_tag<KT> > base_type;
00150 
00151     exact_solver_type(const regression_problem_type& problem,
00152         skynla::iter_params_t iter_params) :
00153         base_type(problem, iter_params) {
00154 
00155     }
00156 
00157     exact_solver_type(const regression_problem_type& problem,
00158         const skynla::precond_t<sol_type>& R,
00159         skynla::iter_params_t iter_params) :
00160         base_type(problem, R, iter_params) {
00161 
00162     }
00163 
00164 
00165 };
00166 
00167 template<template <typename, typename> class TransformType >
00168 struct sketched_solver_type :
00169     public skyalg::sketched_regression_solver_t<
00170     regression_problem_type, matrix_type, sol_type,
00171     skyalg::linear_tag,
00172     sketch_type,
00173     sketch_type,
00174     TransformType,
00175     skyalg::qr_l2_solver_tag> {
00176 
00177     typedef skyalg::sketched_regression_solver_t<
00178         regression_problem_type, matrix_type, sol_type,
00179         skyalg::linear_tag,
00180         sketch_type,
00181         sketch_type,
00182         TransformType,
00183         skyalg::qr_l2_solver_tag> base_type;
00184 
00185     sketched_solver_type(const regression_problem_type& problem,
00186         int sketch_size,
00187         skybase::context_t& context) :
00188         base_type(problem, sketch_size, context) {
00189 
00190     }
00191 
00192 };
00193 
00194 template<typename ProblemType, typename RhsType, typename SolType>
00195 void check_solution(const ProblemType &pr, const RhsType &b, const SolType &x, 
00196     const RhsType &r0,
00197     double &res, double &resAtr, double &resFac) {
00198     RhsType r(b);
00199     skybase::Gemv(elem::NORMAL, -1.0, pr.input_matrix, x, 1.0, r);
00200     res = skybase::Nrm2(r);
00201 
00202     SolType Atr(x.Height(), x.Width(), x.Grid());
00203     skybase::Gemv(elem::TRANSPOSE, 1.0, pr.input_matrix, r, 0.0, Atr);
00204     resAtr = skybase::Nrm2(Atr);
00205 
00206     skybase::Axpy(-1.0, r0, r);
00207     RhsType dr(b);
00208     skybase::Axpy(-1.0, r0, dr);
00209     resFac = skybase::Nrm2(r) / skybase::Nrm2(dr);
00210 }
00211 
00212 int main(int argc, char** argv) {
00213     double res, resAtr, resFac;
00214 
00215     elem::Initialize(argc, argv);
00216 
00217     bmpi::communicator world;
00218     int rank = world.rank();
00219 
00220     skybase::context_t context(23234);
00221 
00222     // Setup problem and righthand side
00223     // Using Skylark's uniform generator (as opposed to Elemental's)
00224     // will insure the same A and b are generated regardless of the number
00225     // of processors.
00226     matrix_type A =
00227         skyutil::uniform_matrix_t<matrix_type>::generate(m,
00228             n, elem::DefaultGrid(), context);
00229     matrix_type b =
00230         skyutil::uniform_matrix_t<matrix_type>::generate(m,
00231             1, elem::DefaultGrid(), context);
00232 
00233     regression_problem_type problem(m, n, A);
00234 
00235     boost::mpi::timer timer;
00236     double telp;
00237 
00238     sol_type x(n,1);
00239 
00240     rhs_type r(b);
00241 
00242     // Using QR
00243     timer.restart();
00244     exact_solver_type<skyalg::qr_l2_solver_tag> exact_solver(problem);
00245     exact_solver.solve(b, x);
00246     telp = timer.elapsed();
00247     check_solution(problem, b, x, r, res, resAtr, resFac);
00248     if (rank == 0)
00249         std::cout << "Exact (QR):\t\t\t||r||_2 =  "
00250                   << boost::format("%.2f") % res
00251                   << "\t\t\t\t\t\t\t||A' * r||_2 = " << boost::format("%.2e") % resAtr
00252                   << "\t\tTime: " << boost::format("%.2e") % telp << " sec"
00253                   << std::endl;
00254     double res_opt = res;
00255 
00256     skybase::Gemv(elem::NORMAL, -1.0, problem.input_matrix, x, 1.0, r);
00257 
00258     // Using SNE (semi-normal equations)
00259     timer.restart();
00260     exact_solver_type<skyalg::sne_l2_solver_tag>(problem).solve(b, x);
00261     telp = timer.elapsed();
00262     check_solution(problem, b, x, r, res, resAtr, resFac);
00263     if (rank == 0)
00264         std::cout << "Exact (SNE):\t\t\t||r||_2 =  "
00265                   << boost::format("%.2f") % res
00266                   << "\t\t\t\t\t\t\t||A' * r||_2 = " << boost::format("%.2e") % resAtr
00267                   << "\t\tTime: " << boost::format("%.2e") % telp << " sec"
00268                   << std::endl;
00269     res_opt = res;
00270 
00271     // Again, using SNE, only with the computed interface (example; to be removed.)
00272     cmatrix CA(A);
00273     regression_problem_type1 problem1(m, n, CA);
00274     timer.restart();
00275     exact_solver_type1<skyalg::sne_l2_solver_tag>(problem1).solve(b, x);
00276     telp = timer.elapsed();
00277     check_solution(problem, b, x, r, res, resAtr, resFac);
00278     if (rank == 0)
00279         std::cout << "Exact (SNE) (COMPUTED):\t\t\t||r||_2 =  "
00280                   << boost::format("%.2f") % res
00281                   << "\t\t\t\t\t\t\t||A' * r||_2 = " << boost::format("%.2e") % resAtr
00282                   << "\t\tTime: " << boost::format("%.2e") % telp << " sec"
00283                   << std::endl;
00284     res_opt = res;
00285 
00286     // Using SVD
00287     timer.restart();
00288     exact_solver_type<skyalg::svd_l2_solver_tag>(problem).solve(b, x);
00289     telp = timer.elapsed();
00290     check_solution(problem, b, x, r, res, resAtr, resFac);
00291     if (rank == 0)
00292         std::cout << "Exact (SVD):\t\t\t||r||_2 =  "
00293                   << boost::format("%.2f") % res
00294                   << "\t\t\t\t\t\t\t||A' * r||_2 = " << boost::format("%.2e") % resAtr
00295                   << "\t\tTime: " << boost::format("%.2e") % telp << " sec"
00296                   << std::endl;
00297     res_opt = res;
00298 
00299     // Using LSQR
00300     skynla::iter_params_t lsqrparams;
00301     lsqrparams.am_i_printing = rank == 0;
00302     lsqrparams.log_level = 0;
00303     timer.restart();
00304     exact_solver_type<
00305         skyalg::iterative_l2_solver_tag<
00306             skyalg::lsqr_tag > >(problem, lsqrparams)
00307         .solve(b, x);
00308     telp = timer.elapsed();
00309     check_solution(problem, b, x, r, res, resAtr, resFac);
00310     if (rank == 0)
00311         std::cout << "Exact (LSQR):\t\t\t||r||_2 =  "
00312                   << boost::format("%.2f") % res
00313                   << "\t\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac
00314                   << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr
00315                   << "\t\tTime: " << boost::format("%.2e") % telp << " sec"
00316                   << std::endl;
00317 
00318     // Using sketch-and-solve
00319 
00320 #if 0 
00321     timer.restart();
00322     sketched_solver_type<skysk::JLT_t>(problem, t, context).solve(b, x);
00323     telp = timer.elapsed();
00324     check_solution(problem, b, x, r, res, resAtr, resFac);
00325     if (rank == 0)
00326         std::cout << "Sketch-and-Solve (JLT):\t\t||r||_2 =  "
00327                   << boost::format("%.2f") % res
00328                   << " (x " << boost::format("%.5f") % (res / res_opt) << ")"
00329                   << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac
00330                   << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr
00331                   << "\t\tTime: " << boost::format("%.2e") % telp << " sec"
00332                   << std::endl;
00333 #endif
00334 
00335     timer.restart();
00336     sketched_solver_type<skysk::CWT_t>(problem, t, context).solve(b, x);
00337     telp = timer.elapsed();
00338     check_solution(problem, b, x, r, res, resAtr, resFac);
00339     if (rank == 0)
00340         std::cout << "Sketch-and-Solve (CWT):\t\t||r||_2 =  "
00341                   << boost::format("%.2f") % res
00342                   << " (x " << boost::format("%.5f") % (res / res_opt) << ")"
00343                   << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac
00344                   << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr
00345                   << "\t\tTime: " << boost::format("%.2e") % telp << " sec"
00346                   << std::endl;
00347 
00348     timer.restart();
00349     sketched_solver_type<skysk::FJLT_t>(problem, t, context).solve(b, x);
00350     telp = timer.elapsed();
00351     check_solution(problem, b, x, r, res, resAtr, resFac);
00352     if (rank == 0)
00353         std::cout << "Sketch-and-Solve (FJLT):\t||r||_2 =  "
00354                   << boost::format("%.2f") % res
00355                   << " (x " << boost::format("%.5f") % (res / res_opt) << ")"
00356                   << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac
00357                   << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr
00358                   << "\t\tTime: " << boost::format("%.2e") % telp << " sec"
00359                   << std::endl;
00360 
00361     // Accelerate-using-sketching
00362 #if 0
00363     timer.restart();
00364     accelerated_exact_solver_type_sb<skysk::JLT_t>(problem, context).solve(b, x);
00365     telp = timer.elapsed();
00366     check_solution(problem, b, x, r, res, resAtr, resFac);
00367     if (rank == 0)
00368         std::cout << "Simplified Blendenpik (JLT):\t||r||_2 =  "
00369                   << boost::format("%.2f") % res
00370                   << " (x " << boost::format("%.5f") % (res / res_opt) << ")"
00371                   << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac
00372                   << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr
00373                   << "\t\tTime: " << boost::format("%.2e") % telp << " sec"
00374                   << std::endl;
00375 #endif
00376 
00377     timer.restart();
00378     accelerated_exact_solver_type_sb<skysk::FJLT_t>(problem, context).solve(b, x);
00379     telp = timer.elapsed();
00380     check_solution(problem, b, x, r, res, resAtr, resFac);
00381     if (rank == 0)
00382         std::cout << "Simplified Blendenpik (FJLT):\t||r||_2 =  "
00383                   << boost::format("%.2f") % res
00384                   << " (x " << boost::format("%.5f") % (res / res_opt) << ")"
00385                   << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac
00386                   << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr
00387                   << "\t\tTime: " << boost::format("%.2e") % telp << " sec"
00388                   << std::endl;
00389 
00390     timer.restart();
00391     accelerated_exact_solver_type_sb<skysk::CWT_t>(problem, context).solve(b, x);
00392     telp = timer.elapsed();
00393     check_solution(problem, b, x, r, res, resAtr, resFac);
00394     if (rank == 0)
00395         std::cout << "Simplified Blendenpik (CWT):\t||r||_2 =  "
00396                   << boost::format("%.2f") % res
00397                   << " (x " << boost::format("%.5f") % (res / res_opt) << ")"
00398                   << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac
00399                   << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr
00400                   << "\t\tTime: " << boost::format("%.2e") % telp << " sec"
00401                   << std::endl;
00402 
00403     timer.restart();
00404     accelerated_exact_solver_type_blendenpik(problem, context).solve(b, x);
00405     telp = timer.elapsed();
00406     check_solution(problem, b, x, r, res, resAtr, resFac);
00407     if (rank == 0)
00408         std::cout << "Blendenpik:\t\t\t||r||_2 =  "
00409                   << boost::format("%.2f") % res
00410                   << " (x " << boost::format("%.5f") % (res / res_opt) << ")"
00411                   << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac
00412                   << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr
00413                   << "\t\tTime: " << boost::format("%.2e") % telp << " sec"
00414                   << std::endl;
00415 
00416     timer.restart();
00417     accelerated_exact_solver_type_lsrn(problem, context).solve(b, x);
00418     telp = timer.elapsed();
00419     check_solution(problem, b, x, r, res, resAtr, resFac);
00420     if (rank == 0)
00421         std::cout << "LSRN:\t\t\t\t||r||_2 =  "
00422                   << boost::format("%.2f") % res
00423                   << " (x " << boost::format("%.5f") % (res / res_opt) << ")"
00424                   << "\t||r - r*||_2 / ||b - r*||_2 = " << boost::format("%.2e") % resFac
00425                   << "\t||A' * r||_2 = " << boost::format("%.2e") % resAtr
00426                   << "\t\tTime: " << boost::format("%.2e") % telp << " sec"
00427                   << std::endl;
00428 
00429     return 0;
00430 }