Skylark (Sketching Library)
0.1
|
00001 00002 #ifndef SKYLARK_RAND_SVD_HPP 00003 #define SKYLARK_RAND_SVD_HPP 00004 00005 #include "../config.h" 00006 #include "../base/exception.hpp" 00007 #include "../base/svd.hpp" 00008 #include "../base/QR.hpp" 00009 #include "../base/Gemm.hpp" 00010 #include "../sketch/capi/sketchc.hpp" 00011 00012 00013 #if SKYLARK_HAVE_ELEMENTAL 00014 #include <elemental.hpp> 00015 #endif 00016 00017 namespace skylark { namespace nla { 00018 00019 00020 struct rand_svd_params_t { 00021 00022 int oversampling; 00023 int num_iterations; 00024 sketch::c::transform_type_t transform; 00025 bool skip_qr; 00026 00027 rand_svd_params_t(int oversampling, sketch::c::transform_type_t transform = sketch::c::transform_type_t::JLT, 00028 int num_iterations = 0, bool skip_qr = 0) : oversampling(oversampling), transform(transform), 00029 num_iterations(num_iterations), skip_qr(skip_qr) {}; 00030 }; 00031 00032 00033 #if SKYLARK_HAVE_ELEMENTAL 00034 00035 template < template <typename, typename> class SketchTransform > 00036 struct randsvd_t { 00037 00038 template <typename InputMatrixType, 00039 typename UMatrixType, 00040 typename SingularValuesMatrixType, 00041 typename VMatrixType> 00042 void operator()(InputMatrixType &A, 00043 int target_rank, 00044 UMatrixType &U, 00045 SingularValuesMatrixType &SV, 00046 VMatrixType &V, 00047 rand_svd_params_t params, 00048 skylark::base::context_t& context) { 00049 00050 00051 // TODO: input matrix should provide Height() and Width() 00052 int input_height = A.Height(); 00053 int input_width = A.Width(); 00054 int sketch_size = target_rank + params.oversampling; 00055 00056 00064 if ((target_rank > std::min(input_height, input_width)) || 00065 (sketch_size > input_width) || 00066 (sketch_size < target_rank)) { 00067 std::ostringstream msg; 00068 msg << "Incompatible matrix dimensions and target rank\n"; 00069 SKYLARK_THROW_EXCEPTION( 00070 skylark::base::skylark_exception() 00071 << skylark::base::error_msg(msg.str())); 00072 } 00073 00074 00076 UMatrixType Q(input_height, sketch_size); 00077 00078 typedef typename SketchTransform<InputMatrixType, UMatrixType>::data_type sketch_data_type; 00079 sketch_data_type sketch_data(input_width, sketch_size, context); 00080 //typedef typename SketchTransform<InputMatrixType, UMatrixType> sketch_transform_type; 00081 SketchTransform<InputMatrixType, UMatrixType> sketch_transform(sketch_data); 00082 sketch_transform.apply(A, Q, skylark::sketch::rowwise_tag()); 00083 00084 #if 0 00085 if (params.transform == sketch::c::transform_type_t::JLT) 00086 { 00087 typedef typename skylark::sketch::JLT_t<InputMatrixType, UMatrixType>::data_type sketch_data_type; 00088 sketch_data_type sketch_data(input_width, sketch_size, context); 00089 typedef typename skylark::sketch::JLT_t<InputMatrixType, UMatrixType> sketch_transform_type; 00090 sketch_transform_type sketch_transform(sketch_data); 00091 sketch_transform.apply(A, Q, skylark::sketch::rowwise_tag()); 00092 } 00093 else if (params.transform == sketch::c::transform_type_t::FJLT) 00094 { 00095 typedef typename skylark::sketch::FJLT_t<InputMatrixType, UMatrixType>::data_type sketch_data_type; 00096 sketch_data_type sketch_data(input_width, sketch_size, context); 00097 typedef typename skylark::sketch::FJLT_t<InputMatrixType, UMatrixType> sketch_transform_type; 00098 sketch_transform_type sketch_transform(sketch_data); 00099 sketch_transform.apply(A, Q, skylark::sketch::rowwise_tag()); 00100 } 00101 else if (params.transform == sketch::c::transform_type_t::CWT) 00102 { 00103 /*typedef typename skylark::sketch::CWT_t<InputMatrixType, UMatrixType>::data_type sketch_data_type; 00104 sketch_data_type sketch_data(input_width, sketch_size, context); 00105 typedef typename skylark::sketch::CWT_t<InputMatrixType, UMatrixType> sketch_transform_type; 00106 sketch_transform_type sketch_transform(sketch_data); 00107 sketch_transform.apply(A, Q, skylark::sketch::rowwise_tag());*/ 00108 } 00109 else 00110 { 00111 std::ostringstream msg; 00112 msg << "Unknown sketch transform type\n"; 00113 SKYLARK_THROW_EXCEPTION( 00114 skylark::base::skylark_exception() 00115 << skylark::base::error_msg(msg.str())); 00116 00117 } 00118 #endif 00119 00127 UMatrixType Y; 00128 00130 skylark::base::qr::Explicit(Q); 00131 00133 for(int step = 0; step < params.num_iterations; step++) { 00135 skylark::base::Gemm(elem::ADJOINT, elem::NORMAL, 00136 double(1), A, Q, Y); 00137 skylark::base::qr::Explicit(Y); 00138 00139 skylark::base::Gemm(elem::NORMAL, elem::NORMAL, 00140 double(1), A,Y, Q); 00141 if (!params.skip_qr) 00142 skylark::base::qr::Explicit(Q); 00143 } 00144 00145 00147 UMatrixType B; 00148 skylark::base::Gemm(elem::ADJOINT, elem::NORMAL, 00149 double(1), Q, A, B); 00150 skylark::base::SVD(B, SV, V); 00151 skylark::base::Gemm(elem::NORMAL, elem::NORMAL, 00152 double(1), Q, B, U); 00153 } 00154 }; 00155 00156 #endif 00157 00158 } } 00161 #endif