Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_SKETCHED_REGRESSION_SOLVER_ELEMENTAL_HPP 00002 #define SKYLARK_SKETCHED_REGRESSION_SOLVER_ELEMENTAL_HPP 00003 00004 #include <boost/mpi.hpp> 00005 #include <elemental.hpp> 00006 00007 #include "../../base/context.hpp" 00008 #include "regression_problem.hpp" 00009 #include "../../sketch/sketch.hpp" 00010 #include "../../utility/typer.hpp" 00011 00012 namespace skylark { 00013 namespace algorithms { 00014 00019 template < 00020 typename RegressionType, 00021 typename PenaltyType, 00022 typename RegularizationType, 00023 typename InputType, 00024 typename RhsType, 00025 typename SolType, 00026 typename SketchedRegressionType, 00027 template <typename, typename> class TransformType, 00028 typename ExactAlgTag> 00029 class sketched_regression_solver_t< 00030 regression_problem_t<InputType, 00031 RegressionType, PenaltyType, RegularizationType>, 00032 RhsType, 00033 SolType, 00034 SketchedRegressionType, 00035 elem::Matrix< 00036 typename utility::typer_t<InputType>::value_type >, 00037 elem::Matrix< 00038 typename utility::typer_t<InputType>::value_type >, 00039 TransformType, 00040 ExactAlgTag> { 00041 00042 public: 00043 00044 typedef typename utility::typer_t<InputType>::value_type value_type; 00045 00046 typedef elem::Matrix<value_type> sketch_type; 00047 typedef elem::Matrix<value_type> sketch_rhs_type; 00048 typedef InputType matrix_type; 00049 typedef RhsType rhs_type; 00050 typedef SolType sol_type; 00051 00052 typedef RegressionType regression_type; 00053 typedef PenaltyType penalty_type; 00054 typedef RegularizationType regularization_type; 00055 typedef SketchedRegressionType sketched_regression_type; 00056 00057 typedef regression_problem_t<matrix_type, 00058 regression_type, penalty_type, 00059 regularization_type> problem_type; 00060 typedef regression_problem_t<sketch_type, 00061 sketched_regression_type, penalty_type, 00062 regularization_type> sketched_problem_type; 00063 00064 00065 typedef regression_solver_t<sketched_problem_type, 00066 sketch_rhs_type, 00067 sol_type, 00068 ExactAlgTag> underlying_solver_type; 00069 00070 private: 00071 typedef typename TransformType<matrix_type, sketch_type>::data_type 00072 transform_data_type; 00073 00074 const int _my_rank; 00075 const int _sketch_size; 00076 const transform_data_type _sketch; 00077 const underlying_solver_type *_underlying_solver; 00078 00079 public: 00080 sketched_regression_solver_t(const problem_type& problem, int sketch_size, 00081 base::context_t& context) : 00082 _my_rank(utility::get_communicator(problem.input_matrix)), 00083 _sketch_size(sketch_size), 00084 _sketch(problem.m, sketch_size, context) { 00085 00086 // TODO m < n 00087 TransformType<matrix_type, sketch_type> S(_sketch); 00088 // TODO For DistMatrix this will allocate on DefaultGrid... 00089 sketch_type sketch(sketch_size, problem.n); 00090 S.apply(problem.input_matrix, sketch, sketch::columnwise_tag()); 00091 sketched_problem_type sketched_problem(sketch_size, problem.n, sketch); 00092 _underlying_solver = new underlying_solver_type(sketched_problem); 00093 } 00094 00095 ~sketched_regression_solver_t() { 00096 delete _underlying_solver; 00097 } 00098 00099 void solve(const rhs_type& b, sol_type& x) { 00100 TransformType<rhs_type, sketch_type> S(_sketch); 00101 sketch_type Sb(_sketch_size, 1); 00102 S.apply(b, Sb, sketch::columnwise_tag()); 00103 if (_my_rank == 0) 00104 _underlying_solver->solve(Sb, x); 00105 } 00106 00107 void solve_mulitple(const rhs_type& B, sol_type& X) { 00108 TransformType<rhs_type, sketch_type> S(_sketch); 00109 sketch_type SB(_sketch_size, B.Width()); 00110 S.apply(SB, SB, sketch::columnwise_tag()); 00111 if (_my_rank == 0) 00112 _underlying_solver->solve_mulitple(SB, X); 00113 } 00114 }; 00115 00119 template < 00120 typename RegressionType, 00121 typename PenaltyType, 00122 typename RegularizationType, 00123 typename InputType, 00124 typename RhsType, 00125 typename SolType, 00126 typename SketchedRegressionType, 00127 elem::Distribution CD, elem::Distribution RD, 00128 template <typename, typename> class TransformType, 00129 typename ExactAlgTag> 00130 class sketched_regression_solver_t< 00131 regression_problem_t<InputType, 00132 RegressionType, PenaltyType, RegularizationType>, 00133 RhsType, 00134 SolType, 00135 SketchedRegressionType, 00136 elem::DistMatrix< 00137 typename utility::typer_t<InputType>::value_type, 00138 CD, RD >, 00139 elem::DistMatrix< 00140 typename utility::typer_t<InputType>::value_type, 00141 CD, RD >, 00142 TransformType, 00143 ExactAlgTag> { 00144 00145 public: 00146 00147 typedef typename utility::typer_t<InputType>::value_type value_type; 00148 00149 typedef elem::DistMatrix<value_type, CD, RD> sketch_type; 00150 typedef elem::DistMatrix<value_type, CD, RD> sketch_rhs_type; 00151 typedef InputType matrix_type; 00152 typedef RhsType rhs_type; 00153 typedef SolType sol_type; 00154 00155 typedef RegressionType regression_type; 00156 typedef PenaltyType penalty_type; 00157 typedef RegularizationType regularization_type; 00158 typedef SketchedRegressionType sketched_regression_type; 00159 00160 typedef regression_problem_t<matrix_type, 00161 regression_type, penalty_type, 00162 regularization_type> problem_type; 00163 typedef regression_problem_t<sketch_type, 00164 sketched_regression_type, penalty_type, 00165 regularization_type> sketched_problem_type; 00166 00167 typedef regression_solver_t<sketched_problem_type, 00168 sketch_rhs_type, 00169 sol_type, 00170 ExactAlgTag> underlying_solver_type; 00171 00172 private: 00173 typedef typename TransformType<matrix_type, sketch_type>::data_type 00174 transform_data_type; 00175 00176 const int _sketch_size; 00177 const transform_data_type _sketch; 00178 const underlying_solver_type *_underlying_solver; 00179 00180 public: 00181 sketched_regression_solver_t(const problem_type& problem, int sketch_size, 00182 base::context_t& context) : 00183 _sketch_size(sketch_size), 00184 _sketch(problem.m, sketch_size, context) { 00185 00186 // TODO m < n 00187 TransformType<matrix_type, sketch_type> S(_sketch); 00188 // TODO For DistMatrix this will allocate on DefaultGrid... 00189 sketch_type sketch(sketch_size, problem.n); 00190 S.apply(problem.input_matrix, sketch, sketch::columnwise_tag()); 00191 sketched_problem_type sketched_problem(sketch_size, problem.n, sketch); 00192 _underlying_solver = new underlying_solver_type(sketched_problem); 00193 } 00194 00195 ~sketched_regression_solver_t() { 00196 delete _underlying_solver; 00197 } 00198 00199 void solve(const rhs_type& b, sol_type& x) { 00200 // TODO For DistMatrix this will allocate on DefaultGrid 00201 // MIGHT BE VERY WRONG (grid is different). 00202 TransformType<rhs_type, sketch_type> S(_sketch); 00203 sketch_type Sb(_sketch_size, 1); 00204 S.apply(b, Sb, sketch::columnwise_tag()); 00205 _underlying_solver->solve(Sb, x); 00206 } 00207 00208 void solve_mulitple(const rhs_type& B, sol_type& X) { 00209 // TODO For DistMatrix this will allocate on DefaultGrid... 00210 // MIGHT BE VERY WRONG (grid is different). 00211 TransformType<rhs_type, sketch_type> S(_sketch); 00212 sketch_type SB(_sketch_size, B.Width()); 00213 S.apply(SB, SB, sketch::columnwise_tag()); 00214 _underlying_solver->solve_mulitple(SB, X); 00215 } 00216 }; 00217 00218 } } // namespace skylark::algorithms 00219 00220 #endif // SKYLARK_SKETCHED_REGRESSION_SOLVER_ELEMENTAL_HPP