Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/nla/least_squares.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_LEAST_SQUARES_HPP
00002 #define SKYLARK_LEAST_SQUARES_HPP
00003 
00004 #include <elemental.hpp>
00005 #include "../algorithms/regression/regression.hpp"
00006 #include "../base/exception.hpp"
00007 
00008 namespace skylark { namespace nla {
00009 
00010 template<typename T>
00011 void ApproximateLeastSquares(elem::Orientation orientation,
00012     const elem::Matrix<T>& A, const elem::Matrix<T>& B, elem::Matrix<T>& X, 
00013     base::context_t& context, int sketch_size = -1) {
00014 
00015     if (orientation != elem::NORMAL)
00016         SKYLARK_THROW_EXCEPTION (
00017           base::sketch_exception()
00018               << base::error_msg(
00019                  "Only NORMAL orientation is supported for ApproximateLeastSquares"));
00020 
00021     if (sketch_size == -1)
00022         sketch_size = 4 * base::Width(A);
00023 
00024     typedef algorithms::regression_problem_t<elem::Matrix<double>,
00025                                              algorithms::linear_tag,
00026                                              algorithms::l2_tag,
00027                                              algorithms::no_reg_tag> ptype;
00028     ptype problem(base::Height(A), base::Width(A), A);
00029 
00030     algorithms::sketched_regression_solver_t<
00031         ptype, elem::Matrix<double>, elem::Matrix<double>,
00032         algorithms::linear_tag,
00033         elem::Matrix<double>,
00034         elem::Matrix<double>,
00035         sketch::FJLT_t,
00036         algorithms::qr_l2_solver_tag> solver(problem, sketch_size, context);
00037 
00038     solver.solve(B, X);
00039 }
00040 
00041 template<typename T, elem::Distribution CA, elem::Distribution RA,
00042          elem::Distribution CB, elem::Distribution RB, elem::Distribution CX, 
00043          elem::Distribution RX>
00044 void ApproximateLeastSquares(elem::Orientation orientation,
00045     const elem::DistMatrix<T, CA, RA>& A, const elem::DistMatrix<T, CB, RB>& B,
00046     elem::DistMatrix<T, CX, RX>& X, base::context_t& context,
00047     int sketch_size = -1) {
00048 
00049     if (orientation != elem::NORMAL)
00050         SKYLARK_THROW_EXCEPTION (
00051           base::sketch_exception()
00052               << base::error_msg(
00053                  "Only NORMAL orientation is supported for ApproximateLeastSquares"));
00054 
00055     if (sketch_size == -1)
00056         sketch_size = 4 * base::Width(A);
00057 
00058     typedef algorithms::regression_problem_t<elem::DistMatrix<T, CA, RA>,
00059                                              algorithms::linear_tag,
00060                                              algorithms::l2_tag,
00061                                              algorithms::no_reg_tag> ptype;
00062     ptype problem(base::Height(A), base::Width(A), A);
00063 
00064     algorithms::sketched_regression_solver_t<
00065         ptype,
00066         elem::DistMatrix<T, CB, RB>,
00067         elem::DistMatrix<T, CX, RX>,
00068         algorithms::linear_tag,
00069         elem::DistMatrix<T, elem::STAR, elem::STAR>,
00070         elem::DistMatrix<T, elem::STAR, elem::STAR>,
00071         sketch::FJLT_t,
00072         algorithms::qr_l2_solver_tag> solver(problem, sketch_size, context);
00073 
00074     solver.solve(B, X);
00075 }
00076 
00077 /*
00078 template<typename AT, typename BT, typename XT>
00079 void ApproximateLeastSquares(elem::Orientation orientation, const AT& A, const BT& B,
00080     XT& X, base::context_t& context, int sketch_size = -1) {
00081 
00082     if (orientation != elem::NORMAL) 
00083         SKYLARK_THROW_EXCEPTION (
00084           base::sketch_exception()
00085               << base::error_msg(
00086                  "Only NORMAL orientation is supported for ApproximateLeastSquares"));
00087 
00088     if (sketch_size == -1)
00089         sketch_size = 4 * base::Width(A);
00090 
00091     typedef algorithms::regression_problem_t<AT,
00092                                              algorithms::linear_tag,
00093                                              algorithms::l2_tag,
00094                                              algorithms::no_reg_tag> ptype;
00095     ptype problem(base::Height(A), base::Width(A), A);
00096 
00097     algorithms::sketched_regression_solver_t<
00098         ptype, BT, XT,
00099         algorithms::linear_tag,
00100         elem::DistMatrix<double, elem::STAR, elem::STAR>,
00101         elem::DistMatrix<double, elem::STAR, elem::STAR>,
00102         sketch::FJLT_t,
00103         algorithms::qr_l2_solver_tag> solver(problem, sketch_size, context);
00104 
00105     solver.solve(B, X);
00106 }
00107 */
00108 
00109 template<typename AT, typename BT, typename XT>
00110 void FastLeastSquares(elem::Orientation orientation, const AT& A, const BT& B,
00111     XT& X, base::context_t& context) {
00112 
00113     if (orientation != elem::NORMAL)
00114         SKYLARK_THROW_EXCEPTION (
00115           base::sketch_exception()
00116               << base::error_msg(
00117                  "Only NORMAL orientation is supported for FastLeastSquares"));
00118 
00119     typedef algorithms::regression_problem_t<AT,
00120                                              algorithms::linear_tag,
00121                                              algorithms::l2_tag,
00122                                              algorithms::no_reg_tag> ptype;
00123     ptype problem(base::Height(A), base::Width(A), A);
00124 
00125     algorithms::accelerated_regression_solver_t<ptype, BT, XT,
00126                                     algorithms::blendenpik_tag<
00127                                         algorithms::qr_precond_tag> >
00128         solver(problem, context);
00129     solver.solve(B, X);
00130 }
00131 
00132 
00133 } }
00134 #endif