Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/algorithms/regression/sketched_regression_solver_Elemental.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_SKETCHED_REGRESSION_SOLVER_ELEMENTAL_HPP
00002 #define SKYLARK_SKETCHED_REGRESSION_SOLVER_ELEMENTAL_HPP
00003 
00004 #include <boost/mpi.hpp>
00005 #include <elemental.hpp>
00006 
00007 #include "../../base/context.hpp"
00008 #include "regression_problem.hpp"
00009 #include "../../sketch/sketch.hpp"
00010 #include "../../utility/typer.hpp"
00011 
00012 namespace skylark {
00013 namespace algorithms {
00014 
00019 template <
00020     typename RegressionType,
00021     typename PenaltyType,
00022     typename RegularizationType,
00023     typename InputType,
00024     typename RhsType,
00025     typename SolType,
00026     typename SketchedRegressionType,
00027     template <typename, typename> class TransformType,
00028     typename ExactAlgTag>
00029 class sketched_regression_solver_t<
00030     regression_problem_t<InputType,
00031                          RegressionType, PenaltyType, RegularizationType>,
00032     RhsType,
00033     SolType,
00034     SketchedRegressionType,
00035     elem::Matrix<
00036         typename utility::typer_t<InputType>::value_type >,
00037     elem::Matrix<
00038         typename utility::typer_t<InputType>::value_type >,
00039     TransformType,
00040     ExactAlgTag> {
00041 
00042 public:
00043 
00044     typedef typename utility::typer_t<InputType>::value_type value_type;
00045 
00046     typedef elem::Matrix<value_type> sketch_type;
00047     typedef elem::Matrix<value_type> sketch_rhs_type;
00048     typedef InputType matrix_type;
00049     typedef RhsType rhs_type;
00050     typedef SolType sol_type;
00051 
00052     typedef RegressionType regression_type;
00053     typedef PenaltyType penalty_type;
00054     typedef RegularizationType regularization_type;
00055     typedef SketchedRegressionType sketched_regression_type;
00056 
00057     typedef regression_problem_t<matrix_type,
00058                                  regression_type, penalty_type,
00059                                  regularization_type> problem_type;
00060     typedef regression_problem_t<sketch_type,
00061                                  sketched_regression_type, penalty_type,
00062                                  regularization_type> sketched_problem_type;
00063 
00064 
00065     typedef regression_solver_t<sketched_problem_type,
00066                               sketch_rhs_type,
00067                               sol_type,
00068                               ExactAlgTag> underlying_solver_type;
00069 
00070 private:
00071     typedef typename TransformType<matrix_type, sketch_type>::data_type
00072     transform_data_type;
00073 
00074     const int _my_rank;
00075     const int _sketch_size;
00076     const transform_data_type _sketch;
00077     const underlying_solver_type  *_underlying_solver;
00078 
00079 public:
00080     sketched_regression_solver_t(const problem_type& problem, int sketch_size,
00081         base::context_t& context) :
00082         _my_rank(utility::get_communicator(problem.input_matrix)),
00083         _sketch_size(sketch_size),
00084         _sketch(problem.m, sketch_size, context) {
00085 
00086         // TODO m < n
00087         TransformType<matrix_type, sketch_type> S(_sketch);
00088         // TODO For DistMatrix this will allocate on DefaultGrid...
00089         sketch_type sketch(sketch_size, problem.n);
00090         S.apply(problem.input_matrix, sketch, sketch::columnwise_tag());
00091         sketched_problem_type sketched_problem(sketch_size, problem.n, sketch);
00092         _underlying_solver = new underlying_solver_type(sketched_problem);
00093     }
00094 
00095     ~sketched_regression_solver_t() {
00096         delete _underlying_solver;
00097     }
00098 
00099     void solve(const rhs_type& b, sol_type& x) {
00100         TransformType<rhs_type, sketch_type> S(_sketch);
00101         sketch_type Sb(_sketch_size, 1);
00102         S.apply(b, Sb, sketch::columnwise_tag());
00103         if (_my_rank == 0)
00104             _underlying_solver->solve(Sb, x);
00105     }
00106 
00107     void solve_mulitple(const rhs_type& B, sol_type& X) {
00108         TransformType<rhs_type, sketch_type> S(_sketch);
00109         sketch_type SB(_sketch_size, B.Width());
00110         S.apply(SB, SB, sketch::columnwise_tag());
00111         if (_my_rank == 0)
00112             _underlying_solver->solve_mulitple(SB, X);
00113     }
00114 };
00115 
00119 template <
00120     typename RegressionType,
00121     typename PenaltyType,
00122     typename RegularizationType,
00123     typename InputType,
00124     typename RhsType,
00125     typename SolType,
00126     typename SketchedRegressionType,
00127     elem::Distribution CD, elem::Distribution RD,
00128     template <typename, typename> class TransformType,
00129     typename ExactAlgTag>
00130 class sketched_regression_solver_t<
00131     regression_problem_t<InputType,
00132                          RegressionType, PenaltyType, RegularizationType>,
00133     RhsType,
00134     SolType,
00135     SketchedRegressionType,
00136     elem::DistMatrix<
00137         typename utility::typer_t<InputType>::value_type,
00138         CD, RD >,
00139    elem::DistMatrix<
00140         typename utility::typer_t<InputType>::value_type,
00141         CD, RD >,
00142     TransformType,
00143     ExactAlgTag> {
00144 
00145 public:
00146 
00147     typedef typename utility::typer_t<InputType>::value_type value_type;
00148 
00149     typedef elem::DistMatrix<value_type, CD, RD> sketch_type;
00150     typedef elem::DistMatrix<value_type, CD, RD> sketch_rhs_type;
00151     typedef InputType matrix_type;
00152     typedef RhsType rhs_type;
00153     typedef SolType sol_type;
00154 
00155     typedef RegressionType regression_type;
00156     typedef PenaltyType penalty_type;
00157     typedef RegularizationType regularization_type;
00158     typedef SketchedRegressionType sketched_regression_type;
00159 
00160     typedef regression_problem_t<matrix_type,
00161                                  regression_type, penalty_type,
00162                                  regularization_type> problem_type;
00163     typedef regression_problem_t<sketch_type,
00164                                  sketched_regression_type, penalty_type,
00165                                  regularization_type> sketched_problem_type;
00166 
00167     typedef regression_solver_t<sketched_problem_type,
00168                                 sketch_rhs_type,
00169                                 sol_type,
00170                                 ExactAlgTag> underlying_solver_type;
00171 
00172 private:
00173     typedef typename TransformType<matrix_type, sketch_type>::data_type
00174     transform_data_type;
00175 
00176     const int _sketch_size;
00177     const transform_data_type _sketch;
00178     const underlying_solver_type  *_underlying_solver;
00179 
00180 public:
00181     sketched_regression_solver_t(const problem_type& problem, int sketch_size,
00182         base::context_t& context) :
00183         _sketch_size(sketch_size),
00184         _sketch(problem.m, sketch_size, context) {
00185 
00186         // TODO m < n
00187         TransformType<matrix_type, sketch_type> S(_sketch);
00188         // TODO For DistMatrix this will allocate on DefaultGrid...
00189         sketch_type sketch(sketch_size, problem.n);
00190         S.apply(problem.input_matrix, sketch, sketch::columnwise_tag());
00191         sketched_problem_type sketched_problem(sketch_size, problem.n, sketch);
00192         _underlying_solver = new underlying_solver_type(sketched_problem);
00193     }
00194 
00195     ~sketched_regression_solver_t() {
00196         delete _underlying_solver;
00197     }
00198 
00199     void solve(const rhs_type& b, sol_type& x) {
00200         // TODO For DistMatrix this will allocate on DefaultGrid
00201         //      MIGHT BE VERY WRONG (grid is different).
00202         TransformType<rhs_type, sketch_type> S(_sketch);
00203         sketch_type Sb(_sketch_size, 1);
00204         S.apply(b, Sb, sketch::columnwise_tag());
00205         _underlying_solver->solve(Sb, x);
00206     }
00207 
00208     void solve_mulitple(const rhs_type& B, sol_type& X) {
00209         // TODO For DistMatrix this will allocate on DefaultGrid...
00210         //      MIGHT BE VERY WRONG (grid is different).
00211         TransformType<rhs_type, sketch_type> S(_sketch);
00212         sketch_type SB(_sketch_size, B.Width());
00213         S.apply(SB, SB, sketch::columnwise_tag());
00214         _underlying_solver->solve_mulitple(SB, X);
00215     }
00216 };
00217 
00218 } } // namespace skylark::algorithms
00219 
00220 #endif // SKYLARK_SKETCHED_REGRESSION_SOLVER_ELEMENTAL_HPP