Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_LEAST_SQUARES_HPP 00002 #define SKYLARK_LEAST_SQUARES_HPP 00003 00004 #include <elemental.hpp> 00005 #include "../algorithms/regression/regression.hpp" 00006 #include "../base/exception.hpp" 00007 00008 namespace skylark { namespace nla { 00009 00010 template<typename T> 00011 void ApproximateLeastSquares(elem::Orientation orientation, 00012 const elem::Matrix<T>& A, const elem::Matrix<T>& B, elem::Matrix<T>& X, 00013 base::context_t& context, int sketch_size = -1) { 00014 00015 if (orientation != elem::NORMAL) 00016 SKYLARK_THROW_EXCEPTION ( 00017 base::sketch_exception() 00018 << base::error_msg( 00019 "Only NORMAL orientation is supported for ApproximateLeastSquares")); 00020 00021 if (sketch_size == -1) 00022 sketch_size = 4 * base::Width(A); 00023 00024 typedef algorithms::regression_problem_t<elem::Matrix<double>, 00025 algorithms::linear_tag, 00026 algorithms::l2_tag, 00027 algorithms::no_reg_tag> ptype; 00028 ptype problem(base::Height(A), base::Width(A), A); 00029 00030 algorithms::sketched_regression_solver_t< 00031 ptype, elem::Matrix<double>, elem::Matrix<double>, 00032 algorithms::linear_tag, 00033 elem::Matrix<double>, 00034 elem::Matrix<double>, 00035 sketch::FJLT_t, 00036 algorithms::qr_l2_solver_tag> solver(problem, sketch_size, context); 00037 00038 solver.solve(B, X); 00039 } 00040 00041 template<typename T, elem::Distribution CA, elem::Distribution RA, 00042 elem::Distribution CB, elem::Distribution RB, elem::Distribution CX, 00043 elem::Distribution RX> 00044 void ApproximateLeastSquares(elem::Orientation orientation, 00045 const elem::DistMatrix<T, CA, RA>& A, const elem::DistMatrix<T, CB, RB>& B, 00046 elem::DistMatrix<T, CX, RX>& X, base::context_t& context, 00047 int sketch_size = -1) { 00048 00049 if (orientation != elem::NORMAL) 00050 SKYLARK_THROW_EXCEPTION ( 00051 base::sketch_exception() 00052 << base::error_msg( 00053 "Only NORMAL orientation is supported for ApproximateLeastSquares")); 00054 00055 if (sketch_size == -1) 00056 sketch_size = 4 * base::Width(A); 00057 00058 typedef algorithms::regression_problem_t<elem::DistMatrix<T, CA, RA>, 00059 algorithms::linear_tag, 00060 algorithms::l2_tag, 00061 algorithms::no_reg_tag> ptype; 00062 ptype problem(base::Height(A), base::Width(A), A); 00063 00064 algorithms::sketched_regression_solver_t< 00065 ptype, 00066 elem::DistMatrix<T, CB, RB>, 00067 elem::DistMatrix<T, CX, RX>, 00068 algorithms::linear_tag, 00069 elem::DistMatrix<T, elem::STAR, elem::STAR>, 00070 elem::DistMatrix<T, elem::STAR, elem::STAR>, 00071 sketch::FJLT_t, 00072 algorithms::qr_l2_solver_tag> solver(problem, sketch_size, context); 00073 00074 solver.solve(B, X); 00075 } 00076 00077 /* 00078 template<typename AT, typename BT, typename XT> 00079 void ApproximateLeastSquares(elem::Orientation orientation, const AT& A, const BT& B, 00080 XT& X, base::context_t& context, int sketch_size = -1) { 00081 00082 if (orientation != elem::NORMAL) 00083 SKYLARK_THROW_EXCEPTION ( 00084 base::sketch_exception() 00085 << base::error_msg( 00086 "Only NORMAL orientation is supported for ApproximateLeastSquares")); 00087 00088 if (sketch_size == -1) 00089 sketch_size = 4 * base::Width(A); 00090 00091 typedef algorithms::regression_problem_t<AT, 00092 algorithms::linear_tag, 00093 algorithms::l2_tag, 00094 algorithms::no_reg_tag> ptype; 00095 ptype problem(base::Height(A), base::Width(A), A); 00096 00097 algorithms::sketched_regression_solver_t< 00098 ptype, BT, XT, 00099 algorithms::linear_tag, 00100 elem::DistMatrix<double, elem::STAR, elem::STAR>, 00101 elem::DistMatrix<double, elem::STAR, elem::STAR>, 00102 sketch::FJLT_t, 00103 algorithms::qr_l2_solver_tag> solver(problem, sketch_size, context); 00104 00105 solver.solve(B, X); 00106 } 00107 */ 00108 00109 template<typename AT, typename BT, typename XT> 00110 void FastLeastSquares(elem::Orientation orientation, const AT& A, const BT& B, 00111 XT& X, base::context_t& context) { 00112 00113 if (orientation != elem::NORMAL) 00114 SKYLARK_THROW_EXCEPTION ( 00115 base::sketch_exception() 00116 << base::error_msg( 00117 "Only NORMAL orientation is supported for FastLeastSquares")); 00118 00119 typedef algorithms::regression_problem_t<AT, 00120 algorithms::linear_tag, 00121 algorithms::l2_tag, 00122 algorithms::no_reg_tag> ptype; 00123 ptype problem(base::Height(A), base::Width(A), A); 00124 00125 algorithms::accelerated_regression_solver_t<ptype, BT, XT, 00126 algorithms::blendenpik_tag< 00127 algorithms::qr_precond_tag> > 00128 solver(problem, context); 00129 solver.solve(B, X); 00130 } 00131 00132 00133 } } 00134 #endif