Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/nla/RandSVD.hpp
Go to the documentation of this file.
00001 
00002 #ifndef SKYLARK_RAND_SVD_HPP
00003 #define SKYLARK_RAND_SVD_HPP
00004 
00005 #include "../config.h"
00006 #include "../base/exception.hpp"
00007 #include "../base/svd.hpp"
00008 #include "../base/QR.hpp"
00009 #include "../base/Gemm.hpp"
00010 #include "../sketch/capi/sketchc.hpp"
00011 
00012 
00013 #if SKYLARK_HAVE_ELEMENTAL
00014 #include <elemental.hpp>
00015 #endif
00016 
00017 namespace skylark { namespace nla {
00018 
00019 
00020 struct rand_svd_params_t {
00021 
00022     int oversampling;
00023     int num_iterations;
00024     sketch::c::transform_type_t transform;
00025     bool skip_qr;
00026 
00027     rand_svd_params_t(int oversampling, sketch::c::transform_type_t transform = sketch::c::transform_type_t::JLT,
00028         int num_iterations = 0, bool skip_qr = 0) : oversampling(oversampling), transform(transform),
00029                                   num_iterations(num_iterations), skip_qr(skip_qr) {};
00030 };
00031 
00032 
00033 #if SKYLARK_HAVE_ELEMENTAL
00034 
00035 template < template <typename, typename> class SketchTransform >
00036 struct randsvd_t {
00037 
00038 template <typename InputMatrixType,
00039           typename UMatrixType,
00040           typename SingularValuesMatrixType,
00041           typename VMatrixType>
00042 void operator()(InputMatrixType &A,
00043              int target_rank,
00044              UMatrixType &U,
00045              SingularValuesMatrixType &SV,
00046              VMatrixType &V,
00047              rand_svd_params_t params,
00048              skylark::base::context_t& context) {
00049 
00050 
00051     // TODO: input matrix should provide Height() and Width()
00052     int input_height = A.Height();
00053     int input_width  = A.Width();
00054     int sketch_size = target_rank + params.oversampling;
00055 
00056 
00064     if ((target_rank > std::min(input_height, input_width)) ||
00065                     (sketch_size > input_width) ||
00066                     (sketch_size < target_rank)) {
00067             std::ostringstream msg;
00068             msg << "Incompatible matrix dimensions and target rank\n";
00069             SKYLARK_THROW_EXCEPTION(
00070                             skylark::base::skylark_exception()
00071                             << skylark::base::error_msg(msg.str()));
00072     }
00073 
00074 
00076     UMatrixType Q(input_height, sketch_size);
00077 
00078     typedef typename SketchTransform<InputMatrixType, UMatrixType>::data_type sketch_data_type;
00079     sketch_data_type sketch_data(input_width, sketch_size, context);
00080     //typedef typename SketchTransform<InputMatrixType, UMatrixType> sketch_transform_type;
00081     SketchTransform<InputMatrixType, UMatrixType> sketch_transform(sketch_data);
00082     sketch_transform.apply(A, Q, skylark::sketch::rowwise_tag());
00083 
00084 #if 0
00085     if (params.transform == sketch::c::transform_type_t::JLT)
00086         {
00087           typedef typename skylark::sketch::JLT_t<InputMatrixType, UMatrixType>::data_type sketch_data_type;
00088           sketch_data_type sketch_data(input_width, sketch_size, context);
00089           typedef typename skylark::sketch::JLT_t<InputMatrixType, UMatrixType> sketch_transform_type;
00090           sketch_transform_type sketch_transform(sketch_data);
00091           sketch_transform.apply(A, Q, skylark::sketch::rowwise_tag());
00092         }
00093     else if (params.transform == sketch::c::transform_type_t::FJLT)
00094         {
00095           typedef typename skylark::sketch::FJLT_t<InputMatrixType, UMatrixType>::data_type sketch_data_type;
00096           sketch_data_type sketch_data(input_width, sketch_size, context);
00097           typedef typename skylark::sketch::FJLT_t<InputMatrixType, UMatrixType> sketch_transform_type;
00098           sketch_transform_type sketch_transform(sketch_data);
00099           sketch_transform.apply(A, Q, skylark::sketch::rowwise_tag());
00100         }
00101     else if (params.transform == sketch::c::transform_type_t::CWT)
00102         {
00103           /*typedef typename skylark::sketch::CWT_t<InputMatrixType, UMatrixType>::data_type sketch_data_type;
00104           sketch_data_type sketch_data(input_width, sketch_size, context);
00105           typedef typename skylark::sketch::CWT_t<InputMatrixType, UMatrixType> sketch_transform_type;
00106           sketch_transform_type sketch_transform(sketch_data);
00107           sketch_transform.apply(A, Q, skylark::sketch::rowwise_tag());*/
00108         }
00109     else
00110         {
00111             std::ostringstream msg;
00112             msg << "Unknown sketch transform type\n";
00113             SKYLARK_THROW_EXCEPTION(
00114                             skylark::base::skylark_exception()
00115                             << skylark::base::error_msg(msg.str()));
00116 
00117         }
00118 #endif
00119 
00127     UMatrixType Y;
00128 
00130     skylark::base::qr::Explicit(Q);
00131 
00133     for(int step = 0; step < params.num_iterations; step++) {
00135             skylark::base::Gemm(elem::ADJOINT, elem::NORMAL,
00136                             double(1), A, Q, Y);
00137             skylark::base::qr::Explicit(Y);
00138 
00139             skylark::base::Gemm(elem::NORMAL, elem::NORMAL,
00140                             double(1), A,Y, Q);
00141             if (!params.skip_qr)
00142                 skylark::base::qr::Explicit(Q);
00143     }
00144 
00145 
00147     UMatrixType B;
00148     skylark::base::Gemm(elem::ADJOINT, elem::NORMAL,
00149                     double(1), Q, A, B);
00150     skylark::base::SVD(B, SV, V);
00151     skylark::base::Gemm(elem::NORMAL, elem::NORMAL,
00152                     double(1), Q, B, U);
00153 }
00154 };
00155 
00156 #endif
00157 
00158 } } 
00161 #endif