Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/algorithms/regression/linearl2_regression_solver_Elemental.hpp
Go to the documentation of this file.
00001 #ifndef SKYlARK_LINEARL2_REGRESSION_SOLVER_ELEMENTAL_HPP
00002 #define SKYLARK_LINEARL2_REGRESSION_SOLVER_ELEMENTAL_HPP
00003 
00004 #include <elemental.hpp>
00005 #include "../../base/base.hpp"
00006 
00007 #include "regression_problem.hpp"
00008 
00009 namespace skylark {
00010 namespace algorithms {
00011 
00022 template <typename ValueType>
00023 class regression_solver_t<
00024     regression_problem_t<elem::Matrix<ValueType>,
00025                          linear_tag, l2_tag, no_reg_tag>,
00026     elem::Matrix<ValueType>,
00027     elem::Matrix<ValueType>,
00028     qr_l2_solver_tag> {
00029 
00030 public:
00031 
00032     typedef ValueType value_type;
00033 
00034     typedef elem::Matrix<ValueType> matrix_type;
00035     typedef elem::Matrix<ValueType> rhs_type;
00036     typedef elem::Matrix<ValueType> sol_type;
00037 
00038     typedef regression_problem_t<
00039         elem::Matrix<ValueType>, linear_tag, l2_tag, no_reg_tag> problem_type;
00040 
00041 private:
00042     const int _m;
00043     const int _n;
00044     matrix_type _QR;
00045     matrix_type _t;
00046     matrix_type _R;
00047 
00048 public:
00054     regression_solver_t(const problem_type& problem) :
00055         _m(problem.m), _n(problem.n) {
00056         // TODO n < m
00057         _QR = problem.input_matrix;
00058         elem::QR(_QR, _t);
00059         elem::LockedView(_R, _QR, 0, 0, _n, _n);
00060     }
00061 
00068     void solve(const rhs_type& B, rhs_type& X) const {
00069         // TODO error checking
00070         X = B;
00071         elem::qr::ApplyQ(elem::LEFT, elem::ADJOINT, _QR, _t, X);
00072         X.Resize(_n, B.Width());
00073         elem::Trsm(elem::LEFT, elem::UPPER, elem::NORMAL, elem::NON_UNIT,
00074             1.0, _R, X, true);
00075     }
00076 };
00077 
00088 template <typename ValueType>
00089 class regression_solver_t<
00090     regression_problem_t<elem::DistMatrix<ValueType, elem::STAR, elem::STAR>,
00091                          linear_tag, l2_tag, no_reg_tag>,
00092     elem::DistMatrix<ValueType, elem::STAR, elem::STAR>,
00093     elem::DistMatrix<ValueType, elem::STAR, elem::STAR>,
00094     qr_l2_solver_tag> {
00095 
00096 public:
00097 
00098     typedef ValueType value_type;
00099 
00100     typedef elem::DistMatrix<ValueType, elem::STAR, elem::STAR> matrix_type;
00101     typedef elem::DistMatrix<ValueType, elem::STAR, elem::STAR> rhs_type;
00102     typedef elem::DistMatrix<ValueType, elem::STAR, elem::STAR> sol_type;
00103 
00104     typedef regression_problem_t<
00105         elem::DistMatrix<ValueType, elem::STAR, elem::STAR>,
00106         linear_tag, l2_tag, no_reg_tag> problem_type;
00107 
00108 private:
00109     const int _m;
00110     const int _n;
00111     matrix_type _QR;
00112     matrix_type _t;
00113     matrix_type _R;
00114 
00115 public:
00121     regression_solver_t(const problem_type& problem) :
00122         _m(problem.m), _n(problem.n) {
00123         // TODO n < m
00124         _QR = problem.input_matrix;
00125         elem::QR(_QR.Matrix(), _t.Matrix());
00126         elem::LockedView(_R.Matrix(), _QR.Matrix(), 0, 0, _n, _n);
00127     }
00128 
00135     void solve(const rhs_type& B, rhs_type& X) const {
00136         // TODO error checking
00137         X = B;
00138         elem::qr::ApplyQ(elem::LEFT, elem::ADJOINT,
00139             _QR.LockedMatrix(), _t.LockedMatrix(), X.Matrix());
00140         X.Resize(_n, B.Width());
00141         elem::Trsm(elem::LEFT, elem::UPPER, elem::NORMAL, elem::NON_UNIT,
00142             1.0, _R.LockedMatrix(), X.Matrix(), true);
00143     }
00144 };
00145 
00156 template <typename ValueType>
00157 class regression_solver_t<
00158     regression_problem_t<elem::DistMatrix<ValueType>,
00159                          linear_tag, l2_tag, no_reg_tag>,
00160     elem::DistMatrix<ValueType>,
00161     elem::DistMatrix<ValueType>,
00162     qr_l2_solver_tag> {
00163 
00164 public:
00165 
00166     typedef ValueType value_type;
00167 
00168     typedef elem::DistMatrix<ValueType> matrix_type;
00169     typedef elem::DistMatrix<ValueType> rhs_type;
00170     typedef elem::DistMatrix<ValueType> sol_type;
00171 
00172     typedef regression_problem_t<matrix_type,
00173                                  linear_tag, l2_tag, no_reg_tag> problem_type;
00174 
00175 private:
00176     const int _m;
00177     const int _n;
00178     matrix_type _QR;
00179     matrix_type _R;
00180     elem::DistMatrix<ValueType, elem::MD, elem::STAR> _t;
00181 
00182 public:
00188     regression_solver_t(const problem_type& problem) :
00189         _m(problem.m), _n(problem.n),
00190         _QR(problem.input_matrix.Grid()), _R(problem.input_matrix.Grid()),
00191         _t(problem.input_matrix.Grid()) {
00192         // TODO n < m
00193         _QR = problem.input_matrix;
00194         elem::QR(_QR, _t);
00195         elem::LockedView(_R, _QR, 0, 0, _n, _n);
00196     }
00197 
00204     void solve (const rhs_type& B, sol_type& X) const {
00205         // TODO error checking
00206         X = B;
00207         elem::qr::ApplyQ(elem::LEFT, elem::ADJOINT, _QR, _t, X);
00208         X.Resize(_n, B.Width());
00209         elem::Trsm(elem::LEFT, elem::UPPER, elem::NORMAL, elem::NON_UNIT,
00210             1.0, _R, X);
00211     }
00212 };
00213 
00225 template <typename ValueType, elem::Distribution VD>
00226 class regression_solver_t<
00227     regression_problem_t<elem::DistMatrix<ValueType, VD, elem::STAR>,
00228                          linear_tag, l2_tag, no_reg_tag>,
00229     elem::DistMatrix<ValueType, VD, elem::STAR>,
00230     elem::DistMatrix<ValueType, elem::STAR, elem::STAR>,
00231     qr_l2_solver_tag> {
00232 
00233 public:
00234 
00235     typedef ValueType value_type;
00236 
00237     typedef elem::DistMatrix<ValueType, VD, elem::STAR> matrix_type;
00238     typedef elem::DistMatrix<ValueType, VD, elem::STAR> rhs_type;
00239     typedef elem::DistMatrix<ValueType, elem::STAR, elem::STAR> sol_type;
00240 
00241     typedef regression_problem_t<matrix_type,
00242                                  linear_tag, l2_tag, no_reg_tag> problem_type;
00243 
00244 private:
00245     const int _m;
00246     const int _n;
00247     matrix_type _Q;
00248     sol_type _R;
00249 
00250 public:
00256     regression_solver_t(const problem_type& problem) :
00257         _m(problem.m), _n(problem.n),
00258         _Q(problem.input_matrix.Grid()), _R(problem.input_matrix.Grid()) {
00259         // TODO n < m ???
00260         _Q = problem.input_matrix;
00261         elem::qr::ExplicitTS(_Q, _R);
00262     }
00263 
00270     void solve (const rhs_type& B, sol_type& X) const {
00271         // TODO error checking
00272 
00273         base::Gemm(elem::ADJOINT, elem::NORMAL, 1.0, _Q, B, X);
00274         base::Trsm(elem::LEFT, elem::UPPER, elem::NORMAL, elem::NON_UNIT,
00275             1.0, _R, X);
00276     }
00277 
00278 };
00279 
00291 template <typename ValueType, elem::Distribution VD>
00292 class regression_solver_t<
00293     regression_problem_t<elem::DistMatrix<ValueType, VD, elem::STAR>,
00294                          linear_tag, l2_tag, no_reg_tag>,
00295     elem::DistMatrix<ValueType, VD, elem::STAR>,
00296     elem::DistMatrix<ValueType, elem::STAR, elem::STAR>,
00297     svd_l2_solver_tag> {
00298 
00299 public:
00300 
00301     typedef ValueType value_type;
00302 
00303     typedef elem::DistMatrix<ValueType, VD, elem::STAR> matrix_type;
00304     typedef elem::DistMatrix<ValueType, VD, elem::STAR> rhs_type;
00305     typedef elem::DistMatrix<ValueType, elem::STAR, elem::STAR> sol_type;
00306 
00307     typedef regression_problem_t<matrix_type,
00308                                  linear_tag, l2_tag, no_reg_tag> problem_type;
00309 
00310 private:
00311     const int _m;
00312     const int _n;
00313     matrix_type _U;
00314     sol_type _S, _V;
00315 
00316 public:
00322     regression_solver_t(const problem_type& problem) :
00323         _m(problem.m), _n(problem.n),
00324         _U(problem.input_matrix.Grid()), _S(problem.input_matrix.Grid()),
00325         _V(problem.input_matrix.Grid()) {
00326         // TODO n < m ???
00327         _U = problem.input_matrix;
00328         base::SVD(_U, _S, _V);
00329         for(int i = 0; i < _S.Height(); i++)
00330             _S.Set(i, 0, 1 / _S.Get(i, 0));   // TODO handle rank deficiency
00331     }
00332 
00339     void solve (const rhs_type& B, sol_type& X) const {
00340         // TODO error checking
00341         sol_type UB(X); // Not copying -- just taking grid and size.
00342         base::Gemm(elem::ADJOINT, elem::NORMAL, 1.0, _U, B, UB);
00343         elem::DiagonalScale(elem::LEFT, elem::NORMAL, _S, UB);
00344         base::Gemm(elem::NORMAL, elem::NORMAL, 1.0, _V, UB, X);
00345     }
00346 
00347 };
00348 
00363 template <typename ValueType, elem::Distribution VD>
00364 class regression_solver_t<
00365     regression_problem_t<elem::DistMatrix<ValueType, VD, elem::STAR>,
00366                          linear_tag, l2_tag, no_reg_tag>,
00367     elem::DistMatrix<ValueType, VD, elem::STAR>,
00368     elem::DistMatrix<ValueType, elem::STAR, elem::STAR>,
00369     sne_l2_solver_tag> {
00370 
00371 public:
00372 
00373     typedef ValueType value_type;
00374 
00375     typedef elem::DistMatrix<ValueType, VD, elem::STAR> matrix_type;
00376     typedef elem::DistMatrix<ValueType, VD, elem::STAR> rhs_type;
00377     typedef elem::DistMatrix<ValueType, elem::STAR, elem::STAR> sol_type;
00378 
00379     typedef regression_problem_t<matrix_type,
00380                                  linear_tag, l2_tag, no_reg_tag> problem_type;
00381 
00382 private:
00383     const int _m;
00384     const int _n;
00385     const matrix_type& _A;
00386     sol_type _R;
00387 
00388 public:
00394     regression_solver_t(const problem_type& problem) :
00395         _m(problem.m), _n(problem.n),
00396         _A(problem.input_matrix), _R(problem.input_matrix.Grid()) {
00397         // TODO n < m ???
00398         matrix_type _Q = problem.input_matrix;
00399         elem::qr::ExplicitTS(_Q, _R);
00400     }
00401 
00408     void solve (const rhs_type& B, sol_type& X) const {
00409         // TODO error checking
00410 
00411         base::Gemm(elem::ADJOINT, elem::NORMAL, 1.0, _A, B, X);
00412         base::Trsm(elem::LEFT, elem::UPPER, elem::ADJOINT, elem::NON_UNIT,
00413             1.0, _R, X);
00414         base::Trsm(elem::LEFT, elem::UPPER, elem::NORMAL, elem::NON_UNIT,
00415             1.0, _R, X);
00416     }
00417 
00418 };
00419 
00435 template <typename ValueType, elem::Distribution VD>
00436 class regression_solver_t<
00437     regression_problem_t<
00438         base::computed_matrix_t< elem::DistMatrix<ValueType, VD, elem::STAR> >,
00439         linear_tag, l2_tag, no_reg_tag>,
00440     elem::DistMatrix<ValueType, VD, elem::STAR>,
00441     elem::DistMatrix<ValueType, elem::STAR, elem::STAR>,
00442     sne_l2_solver_tag> {
00443 
00444 public:
00445 
00446     typedef ValueType value_type;
00447 
00448     typedef base::computed_matrix_t< elem::DistMatrix<ValueType, VD, elem::STAR> >
00449     matrix_type;
00450     typedef elem::DistMatrix<ValueType, VD, elem::STAR> rhs_type;
00451     typedef elem::DistMatrix<ValueType, elem::STAR, elem::STAR> sol_type;
00452 
00453     typedef regression_problem_t<matrix_type,
00454                                  linear_tag, l2_tag, no_reg_tag> problem_type;
00455 
00456 private:
00457     const int _m;
00458     const int _n;
00459     const matrix_type& _A;
00460     sol_type _R;
00461 
00462 public:
00468     regression_solver_t(const problem_type& problem) :
00469         _m(problem.m), _n(problem.n),_A(problem.input_matrix)  {
00470         // TODO n < m ???
00471         elem::DistMatrix<ValueType, VD, elem::STAR> _Q =
00472             problem.input_matrix.materialize();
00473         elem::qr::ExplicitTS(_Q, _R);
00474     }
00475 
00482     void solve (const rhs_type& B, sol_type& X) const {
00483         // TODO error checking
00484 
00485         base::Gemm(elem::ADJOINT, elem::NORMAL, 1.0, _A, B, X);
00486         base::Trsm(elem::LEFT, elem::UPPER, elem::ADJOINT, elem::NON_UNIT,
00487             1.0, _R, X);
00488         base::Trsm(elem::LEFT, elem::UPPER, elem::NORMAL, elem::NON_UNIT,
00489             1.0, _R, X);
00490     }
00491 
00492 };
00493 
00494 } } 
00496 #endif // SKYLARK_LINEARL2_REGRESSION_SOLVER_ELEMENTAL_HPP