Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYlARK_LINEARL2_REGRESSION_SOLVER_ELEMENTAL_HPP 00002 #define SKYLARK_LINEARL2_REGRESSION_SOLVER_ELEMENTAL_HPP 00003 00004 #include <elemental.hpp> 00005 #include "../../base/base.hpp" 00006 00007 #include "regression_problem.hpp" 00008 00009 namespace skylark { 00010 namespace algorithms { 00011 00022 template <typename ValueType> 00023 class regression_solver_t< 00024 regression_problem_t<elem::Matrix<ValueType>, 00025 linear_tag, l2_tag, no_reg_tag>, 00026 elem::Matrix<ValueType>, 00027 elem::Matrix<ValueType>, 00028 qr_l2_solver_tag> { 00029 00030 public: 00031 00032 typedef ValueType value_type; 00033 00034 typedef elem::Matrix<ValueType> matrix_type; 00035 typedef elem::Matrix<ValueType> rhs_type; 00036 typedef elem::Matrix<ValueType> sol_type; 00037 00038 typedef regression_problem_t< 00039 elem::Matrix<ValueType>, linear_tag, l2_tag, no_reg_tag> problem_type; 00040 00041 private: 00042 const int _m; 00043 const int _n; 00044 matrix_type _QR; 00045 matrix_type _t; 00046 matrix_type _R; 00047 00048 public: 00054 regression_solver_t(const problem_type& problem) : 00055 _m(problem.m), _n(problem.n) { 00056 // TODO n < m 00057 _QR = problem.input_matrix; 00058 elem::QR(_QR, _t); 00059 elem::LockedView(_R, _QR, 0, 0, _n, _n); 00060 } 00061 00068 void solve(const rhs_type& B, rhs_type& X) const { 00069 // TODO error checking 00070 X = B; 00071 elem::qr::ApplyQ(elem::LEFT, elem::ADJOINT, _QR, _t, X); 00072 X.Resize(_n, B.Width()); 00073 elem::Trsm(elem::LEFT, elem::UPPER, elem::NORMAL, elem::NON_UNIT, 00074 1.0, _R, X, true); 00075 } 00076 }; 00077 00088 template <typename ValueType> 00089 class regression_solver_t< 00090 regression_problem_t<elem::DistMatrix<ValueType, elem::STAR, elem::STAR>, 00091 linear_tag, l2_tag, no_reg_tag>, 00092 elem::DistMatrix<ValueType, elem::STAR, elem::STAR>, 00093 elem::DistMatrix<ValueType, elem::STAR, elem::STAR>, 00094 qr_l2_solver_tag> { 00095 00096 public: 00097 00098 typedef ValueType value_type; 00099 00100 typedef elem::DistMatrix<ValueType, elem::STAR, elem::STAR> matrix_type; 00101 typedef elem::DistMatrix<ValueType, elem::STAR, elem::STAR> rhs_type; 00102 typedef elem::DistMatrix<ValueType, elem::STAR, elem::STAR> sol_type; 00103 00104 typedef regression_problem_t< 00105 elem::DistMatrix<ValueType, elem::STAR, elem::STAR>, 00106 linear_tag, l2_tag, no_reg_tag> problem_type; 00107 00108 private: 00109 const int _m; 00110 const int _n; 00111 matrix_type _QR; 00112 matrix_type _t; 00113 matrix_type _R; 00114 00115 public: 00121 regression_solver_t(const problem_type& problem) : 00122 _m(problem.m), _n(problem.n) { 00123 // TODO n < m 00124 _QR = problem.input_matrix; 00125 elem::QR(_QR.Matrix(), _t.Matrix()); 00126 elem::LockedView(_R.Matrix(), _QR.Matrix(), 0, 0, _n, _n); 00127 } 00128 00135 void solve(const rhs_type& B, rhs_type& X) const { 00136 // TODO error checking 00137 X = B; 00138 elem::qr::ApplyQ(elem::LEFT, elem::ADJOINT, 00139 _QR.LockedMatrix(), _t.LockedMatrix(), X.Matrix()); 00140 X.Resize(_n, B.Width()); 00141 elem::Trsm(elem::LEFT, elem::UPPER, elem::NORMAL, elem::NON_UNIT, 00142 1.0, _R.LockedMatrix(), X.Matrix(), true); 00143 } 00144 }; 00145 00156 template <typename ValueType> 00157 class regression_solver_t< 00158 regression_problem_t<elem::DistMatrix<ValueType>, 00159 linear_tag, l2_tag, no_reg_tag>, 00160 elem::DistMatrix<ValueType>, 00161 elem::DistMatrix<ValueType>, 00162 qr_l2_solver_tag> { 00163 00164 public: 00165 00166 typedef ValueType value_type; 00167 00168 typedef elem::DistMatrix<ValueType> matrix_type; 00169 typedef elem::DistMatrix<ValueType> rhs_type; 00170 typedef elem::DistMatrix<ValueType> sol_type; 00171 00172 typedef regression_problem_t<matrix_type, 00173 linear_tag, l2_tag, no_reg_tag> problem_type; 00174 00175 private: 00176 const int _m; 00177 const int _n; 00178 matrix_type _QR; 00179 matrix_type _R; 00180 elem::DistMatrix<ValueType, elem::MD, elem::STAR> _t; 00181 00182 public: 00188 regression_solver_t(const problem_type& problem) : 00189 _m(problem.m), _n(problem.n), 00190 _QR(problem.input_matrix.Grid()), _R(problem.input_matrix.Grid()), 00191 _t(problem.input_matrix.Grid()) { 00192 // TODO n < m 00193 _QR = problem.input_matrix; 00194 elem::QR(_QR, _t); 00195 elem::LockedView(_R, _QR, 0, 0, _n, _n); 00196 } 00197 00204 void solve (const rhs_type& B, sol_type& X) const { 00205 // TODO error checking 00206 X = B; 00207 elem::qr::ApplyQ(elem::LEFT, elem::ADJOINT, _QR, _t, X); 00208 X.Resize(_n, B.Width()); 00209 elem::Trsm(elem::LEFT, elem::UPPER, elem::NORMAL, elem::NON_UNIT, 00210 1.0, _R, X); 00211 } 00212 }; 00213 00225 template <typename ValueType, elem::Distribution VD> 00226 class regression_solver_t< 00227 regression_problem_t<elem::DistMatrix<ValueType, VD, elem::STAR>, 00228 linear_tag, l2_tag, no_reg_tag>, 00229 elem::DistMatrix<ValueType, VD, elem::STAR>, 00230 elem::DistMatrix<ValueType, elem::STAR, elem::STAR>, 00231 qr_l2_solver_tag> { 00232 00233 public: 00234 00235 typedef ValueType value_type; 00236 00237 typedef elem::DistMatrix<ValueType, VD, elem::STAR> matrix_type; 00238 typedef elem::DistMatrix<ValueType, VD, elem::STAR> rhs_type; 00239 typedef elem::DistMatrix<ValueType, elem::STAR, elem::STAR> sol_type; 00240 00241 typedef regression_problem_t<matrix_type, 00242 linear_tag, l2_tag, no_reg_tag> problem_type; 00243 00244 private: 00245 const int _m; 00246 const int _n; 00247 matrix_type _Q; 00248 sol_type _R; 00249 00250 public: 00256 regression_solver_t(const problem_type& problem) : 00257 _m(problem.m), _n(problem.n), 00258 _Q(problem.input_matrix.Grid()), _R(problem.input_matrix.Grid()) { 00259 // TODO n < m ??? 00260 _Q = problem.input_matrix; 00261 elem::qr::ExplicitTS(_Q, _R); 00262 } 00263 00270 void solve (const rhs_type& B, sol_type& X) const { 00271 // TODO error checking 00272 00273 base::Gemm(elem::ADJOINT, elem::NORMAL, 1.0, _Q, B, X); 00274 base::Trsm(elem::LEFT, elem::UPPER, elem::NORMAL, elem::NON_UNIT, 00275 1.0, _R, X); 00276 } 00277 00278 }; 00279 00291 template <typename ValueType, elem::Distribution VD> 00292 class regression_solver_t< 00293 regression_problem_t<elem::DistMatrix<ValueType, VD, elem::STAR>, 00294 linear_tag, l2_tag, no_reg_tag>, 00295 elem::DistMatrix<ValueType, VD, elem::STAR>, 00296 elem::DistMatrix<ValueType, elem::STAR, elem::STAR>, 00297 svd_l2_solver_tag> { 00298 00299 public: 00300 00301 typedef ValueType value_type; 00302 00303 typedef elem::DistMatrix<ValueType, VD, elem::STAR> matrix_type; 00304 typedef elem::DistMatrix<ValueType, VD, elem::STAR> rhs_type; 00305 typedef elem::DistMatrix<ValueType, elem::STAR, elem::STAR> sol_type; 00306 00307 typedef regression_problem_t<matrix_type, 00308 linear_tag, l2_tag, no_reg_tag> problem_type; 00309 00310 private: 00311 const int _m; 00312 const int _n; 00313 matrix_type _U; 00314 sol_type _S, _V; 00315 00316 public: 00322 regression_solver_t(const problem_type& problem) : 00323 _m(problem.m), _n(problem.n), 00324 _U(problem.input_matrix.Grid()), _S(problem.input_matrix.Grid()), 00325 _V(problem.input_matrix.Grid()) { 00326 // TODO n < m ??? 00327 _U = problem.input_matrix; 00328 base::SVD(_U, _S, _V); 00329 for(int i = 0; i < _S.Height(); i++) 00330 _S.Set(i, 0, 1 / _S.Get(i, 0)); // TODO handle rank deficiency 00331 } 00332 00339 void solve (const rhs_type& B, sol_type& X) const { 00340 // TODO error checking 00341 sol_type UB(X); // Not copying -- just taking grid and size. 00342 base::Gemm(elem::ADJOINT, elem::NORMAL, 1.0, _U, B, UB); 00343 elem::DiagonalScale(elem::LEFT, elem::NORMAL, _S, UB); 00344 base::Gemm(elem::NORMAL, elem::NORMAL, 1.0, _V, UB, X); 00345 } 00346 00347 }; 00348 00363 template <typename ValueType, elem::Distribution VD> 00364 class regression_solver_t< 00365 regression_problem_t<elem::DistMatrix<ValueType, VD, elem::STAR>, 00366 linear_tag, l2_tag, no_reg_tag>, 00367 elem::DistMatrix<ValueType, VD, elem::STAR>, 00368 elem::DistMatrix<ValueType, elem::STAR, elem::STAR>, 00369 sne_l2_solver_tag> { 00370 00371 public: 00372 00373 typedef ValueType value_type; 00374 00375 typedef elem::DistMatrix<ValueType, VD, elem::STAR> matrix_type; 00376 typedef elem::DistMatrix<ValueType, VD, elem::STAR> rhs_type; 00377 typedef elem::DistMatrix<ValueType, elem::STAR, elem::STAR> sol_type; 00378 00379 typedef regression_problem_t<matrix_type, 00380 linear_tag, l2_tag, no_reg_tag> problem_type; 00381 00382 private: 00383 const int _m; 00384 const int _n; 00385 const matrix_type& _A; 00386 sol_type _R; 00387 00388 public: 00394 regression_solver_t(const problem_type& problem) : 00395 _m(problem.m), _n(problem.n), 00396 _A(problem.input_matrix), _R(problem.input_matrix.Grid()) { 00397 // TODO n < m ??? 00398 matrix_type _Q = problem.input_matrix; 00399 elem::qr::ExplicitTS(_Q, _R); 00400 } 00401 00408 void solve (const rhs_type& B, sol_type& X) const { 00409 // TODO error checking 00410 00411 base::Gemm(elem::ADJOINT, elem::NORMAL, 1.0, _A, B, X); 00412 base::Trsm(elem::LEFT, elem::UPPER, elem::ADJOINT, elem::NON_UNIT, 00413 1.0, _R, X); 00414 base::Trsm(elem::LEFT, elem::UPPER, elem::NORMAL, elem::NON_UNIT, 00415 1.0, _R, X); 00416 } 00417 00418 }; 00419 00435 template <typename ValueType, elem::Distribution VD> 00436 class regression_solver_t< 00437 regression_problem_t< 00438 base::computed_matrix_t< elem::DistMatrix<ValueType, VD, elem::STAR> >, 00439 linear_tag, l2_tag, no_reg_tag>, 00440 elem::DistMatrix<ValueType, VD, elem::STAR>, 00441 elem::DistMatrix<ValueType, elem::STAR, elem::STAR>, 00442 sne_l2_solver_tag> { 00443 00444 public: 00445 00446 typedef ValueType value_type; 00447 00448 typedef base::computed_matrix_t< elem::DistMatrix<ValueType, VD, elem::STAR> > 00449 matrix_type; 00450 typedef elem::DistMatrix<ValueType, VD, elem::STAR> rhs_type; 00451 typedef elem::DistMatrix<ValueType, elem::STAR, elem::STAR> sol_type; 00452 00453 typedef regression_problem_t<matrix_type, 00454 linear_tag, l2_tag, no_reg_tag> problem_type; 00455 00456 private: 00457 const int _m; 00458 const int _n; 00459 const matrix_type& _A; 00460 sol_type _R; 00461 00462 public: 00468 regression_solver_t(const problem_type& problem) : 00469 _m(problem.m), _n(problem.n),_A(problem.input_matrix) { 00470 // TODO n < m ??? 00471 elem::DistMatrix<ValueType, VD, elem::STAR> _Q = 00472 problem.input_matrix.materialize(); 00473 elem::qr::ExplicitTS(_Q, _R); 00474 } 00475 00482 void solve (const rhs_type& B, sol_type& X) const { 00483 // TODO error checking 00484 00485 base::Gemm(elem::ADJOINT, elem::NORMAL, 1.0, _A, B, X); 00486 base::Trsm(elem::LEFT, elem::UPPER, elem::ADJOINT, elem::NON_UNIT, 00487 1.0, _R, X); 00488 base::Trsm(elem::LEFT, elem::UPPER, elem::NORMAL, elem::NON_UNIT, 00489 1.0, _R, X); 00490 } 00491 00492 }; 00493 00494 } } 00496 #endif // SKYLARK_LINEARL2_REGRESSION_SOLVER_ELEMENTAL_HPP