Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYlARK_LINEARL2_REGRESSION_SOLVER_KRYLOV_HPP 00002 #define SKYLARK_LINEARL2_REGRESSION_SOLVER_KRYLOV_HPP 00003 00004 #include "../../base/query.hpp" 00005 #include "../../utility/typer.hpp" 00006 #include "../Krylov/LSQR.hpp" 00007 00008 namespace skylark { namespace algorithms { 00009 00015 template <typename MatrixType, 00016 typename RhsType, 00017 typename SolType, 00018 typename KrylovMethod> 00019 struct regression_solver_t< 00020 regression_problem_t<MatrixType, linear_tag, l2_tag, no_reg_tag>, 00021 RhsType, 00022 SolType, 00023 iterative_l2_solver_tag<KrylovMethod> > { 00024 00025 typedef typename utility::typer_t<MatrixType>::value_type value_type; 00026 00027 typedef MatrixType matrix_type; 00028 typedef RhsType rhs_type; 00029 typedef SolType sol_type; 00030 00031 typedef regression_problem_t<matrix_type, 00032 linear_tag, l2_tag, no_reg_tag> problem_type; 00033 00034 private: 00035 const nla::id_precond_t<sol_type> _id_precond_obj; 00036 00037 public: 00038 const int m; 00039 const int n; 00040 const matrix_type& A; 00041 const nla::precond_t<sol_type>& R; 00042 const nla::iter_params_t iter_params; 00043 00044 regression_solver_t (const problem_type& problem, 00045 nla::iter_params_t iter_params = nla::iter_params_t()) : 00046 m(problem.m), n(problem.n), A(problem.input_matrix), 00047 R(_id_precond_obj), iter_params(iter_params) 00048 { /* Check if m<n? */ } 00049 00050 regression_solver_t (const problem_type& problem, 00051 const nla::precond_t<sol_type>& R, 00052 nla::iter_params_t iter_params = nla::iter_params_t()) : 00053 m(problem.m), n(problem.n), A(problem.input_matrix), R(R), 00054 iter_params(iter_params) 00055 { /* Check if m<n? */ } 00056 00057 00058 00059 int solve(const rhs_type& b, sol_type& x) { 00060 00061 if (m != base::Height(b)) { /* error */ return -1; } 00062 if (n != base::Height(x)) { /* error */ return -1; } 00063 if (base::Width(b) != base::Width(x)) { /* error */ return -1; } 00064 00072 return solve_impl (b, x, KrylovMethod()); 00073 } 00074 00075 private: 00077 int solve_impl (const rhs_type& b, sol_type& x, lsqr_tag) { 00078 00079 return LSQR(A, b, x, iter_params, R); 00080 } 00081 }; 00082 00083 } } // namespace skylark::algorithms 00084 00085 #endif // SKYLARK_LINEARL2_REGRESSION_SOLVER_KRYLOV_HPP