Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_FJLT_ELEMENTAL_HPP 00002 #define SKYLARK_FJLT_ELEMENTAL_HPP 00003 00004 #include "../utility/get_communicator.hpp" 00005 00006 namespace skylark { namespace sketch { 00010 template <typename ValueType, elem::Distribution ColDist> 00011 struct FJLT_t < 00012 elem::DistMatrix<ValueType, ColDist, elem::STAR>, 00013 elem::Matrix<ValueType> > : 00014 public FJLT_data_t, 00015 virtual public sketch_transform_t<elem::DistMatrix<ValueType, 00016 ColDist, 00017 elem::STAR>, 00018 elem::Matrix<ValueType> > { 00019 // Typedef value, matrix, transform, distribution and transform data types 00020 // so that we can use them regularly and consistently. 00021 typedef ValueType value_type; 00022 typedef elem::DistMatrix<value_type, ColDist, elem::STAR> matrix_type; 00023 typedef elem::Matrix<value_type> output_matrix_type; 00024 typedef elem::DistMatrix<ValueType, 00025 elem::STAR, ColDist> intermediate_type; 00026 typedef fft_futs<double>::DCT_t transform_type; 00027 typedef utility::rademacher_distribution_t<value_type> 00028 underlying_value_distribution_type; 00029 00030 typedef FJLT_data_t data_type; 00031 typedef data_type::params_t params_t; 00032 00033 protected: 00034 typedef RFUT_t<intermediate_type, 00035 transform_type, 00036 underlying_value_distribution_type> underlying_type; 00037 00038 public: 00039 FJLT_t(int N, int S, base::context_t& context) 00040 : data_type (N, S, context) { 00041 00042 } 00043 00044 FJLT_t(int N, int S, const params_t& params, base::context_t& context) 00045 : data_type (N, S, params, context) { 00046 00047 } 00048 00049 FJLT_t(const boost::property_tree::ptree &pt) 00050 : data_type(pt) { 00051 00052 } 00053 00054 template <typename OtherInputMatrixType, 00055 typename OtherOutputMatrixType> 00056 FJLT_t(const FJLT_t<OtherInputMatrixType, OtherOutputMatrixType>& other) 00057 : data_type(other) { 00058 00059 } 00060 00061 FJLT_t(const data_type& other_data) 00062 : data_type(other_data) { 00063 00064 } 00065 00070 void apply (const matrix_type& A, 00071 output_matrix_type& sketch_of_A, 00072 columnwise_tag dimension) const { 00073 switch (ColDist) { 00074 case elem::VR: 00075 case elem::VC: 00076 try { 00077 apply_impl_vdist (A, sketch_of_A, dimension); 00078 } catch (std::logic_error e) { 00079 SKYLARK_THROW_EXCEPTION ( 00080 base::elemental_exception() 00081 << base::error_msg(e.what()) ); 00082 } catch(boost::mpi::exception e) { 00083 SKYLARK_THROW_EXCEPTION ( 00084 base::mpi_exception() 00085 << base::error_msg(e.what()) ); 00086 } 00087 00088 break; 00089 00090 default: 00091 SKYLARK_THROW_EXCEPTION ( 00092 base::unsupported_matrix_distribution() ); 00093 00094 } 00095 } 00096 00097 00102 void apply (const matrix_type& A, 00103 output_matrix_type& sketch_of_A, 00104 rowwise_tag dimension) const { 00105 switch (ColDist) { 00106 case elem::VR: 00107 case elem::VC: 00108 try { 00109 apply_impl_vdist (A, sketch_of_A, dimension); 00110 } catch (std::logic_error e) { 00111 SKYLARK_THROW_EXCEPTION ( 00112 base::elemental_exception() 00113 << base::error_msg(e.what()) ); 00114 } catch(boost::mpi::exception e) { 00115 SKYLARK_THROW_EXCEPTION ( 00116 base::mpi_exception() 00117 << base::error_msg(e.what()) ); 00118 } 00119 00120 break; 00121 00122 default: 00123 SKYLARK_THROW_EXCEPTION ( 00124 base::unsupported_matrix_distribution() ); 00125 00126 } 00127 } 00128 00129 int get_N() const { return this->_N; } 00130 int get_S() const { return this->_S; } 00132 const sketch_transform_data_t* get_data() const { return this; } 00133 00134 private: 00139 void apply_impl_vdist(const matrix_type& A, 00140 output_matrix_type& sketch_A, 00141 skylark::sketch::columnwise_tag) const { 00142 00143 // Rearrange the matrix to fit the underlying transform 00144 intermediate_type inter_A(A.Grid()); 00145 inter_A = A; 00146 00147 // Apply the underlying transform 00148 underlying_type underlying(*data_type::underlying_data); 00149 underlying.apply(inter_A, inter_A, 00150 skylark::sketch::columnwise_tag()); 00151 00152 // Create the sampled and scaled matrix -- still in distributed mode 00153 intermediate_type dist_sketch_A(data_type::_S, 00154 inter_A.Width(), inter_A.Grid()); 00155 double scale = sqrt((double)data_type::_N / (double)data_type::_S); 00156 for (int j = 0; j < inter_A.LocalWidth(); j++) 00157 for (int i = 0; i < data_type::_S; i++) { 00158 int row = data_type::samples[i]; 00159 dist_sketch_A.Matrix().Set(i, j, 00160 scale * inter_A.Matrix().Get(row, j)); 00161 } 00162 00163 // get communicator from matrix 00164 boost::mpi::communicator comm = skylark::utility::get_communicator(A); 00165 skylark::utility::collect_dist_matrix(comm, comm.rank() == 0, 00166 dist_sketch_A, sketch_A); 00167 } 00168 00173 void apply_impl_vdist(const matrix_type& A, 00174 output_matrix_type& sketch_of_A, 00175 skylark::sketch::rowwise_tag) const { 00176 00177 // TODO This is a quick&dirty hack - uses the columnwise implementation. 00178 matrix_type A_t(A.Grid()); 00179 elem::Transpose(A, A_t); 00180 output_matrix_type sketch_of_A_t(sketch_of_A.Width(), 00181 sketch_of_A.Height()); 00182 apply_impl_vdist(A_t, sketch_of_A_t, 00183 skylark::sketch::columnwise_tag()); 00184 elem::Transpose(sketch_of_A_t, sketch_of_A); 00185 } 00186 }; 00187 00191 template <typename ValueType, elem::Distribution ColDist> 00192 struct FJLT_t < 00193 elem::DistMatrix<ValueType, ColDist, elem::STAR>, 00194 elem::DistMatrix<ValueType, elem::STAR, elem::STAR> > : 00195 public FJLT_data_t, 00196 virtual public sketch_transform_t<elem::DistMatrix<ValueType, 00197 ColDist, 00198 elem::STAR>, 00199 elem::DistMatrix<ValueType, 00200 elem::STAR, 00201 elem::STAR> > { 00202 // Typedef value, matrix, transform, distribution and transform data types 00203 // so that we can use them regularly and consistently. 00204 typedef ValueType value_type; 00205 typedef elem::DistMatrix<value_type, ColDist, elem::STAR> matrix_type; 00206 typedef elem::DistMatrix<value_type, elem::STAR, elem::STAR> 00207 output_matrix_type; 00208 typedef elem::DistMatrix<ValueType, 00209 elem::STAR, ColDist> intermediate_type; 00210 typedef fft_futs<double>::DCT_t transform_type; 00211 typedef utility::rademacher_distribution_t<value_type> 00212 underlying_value_distribution_type; 00213 00214 typedef FJLT_data_t data_type; 00215 typedef data_type::params_t params_t; 00216 00217 protected: 00218 typedef RFUT_t<intermediate_type, 00219 transform_type, 00220 underlying_value_distribution_type> underlying_type; 00221 00222 public: 00223 00224 FJLT_t(int N, int S, base::context_t& context) 00225 : data_type (N, S, context) { 00226 00227 } 00228 00229 FJLT_t(int N, int S, const params_t& params, base::context_t& context) 00230 : data_type (N, S, params, context) { 00231 00232 } 00233 00234 template <typename OtherInputMatrixType, 00235 typename OtherOutputMatrixType> 00236 FJLT_t(const FJLT_t<OtherInputMatrixType, OtherOutputMatrixType>& other) 00237 : data_type(other) { 00238 00239 } 00240 00241 FJLT_t(const data_type& other_data) 00242 : data_type(other_data) { 00243 00244 } 00245 00246 FJLT_t(const boost::property_tree::ptree &pt) 00247 : data_type(pt) { 00248 00249 } 00250 00255 void apply (const matrix_type& A, 00256 output_matrix_type& sketch_of_A, 00257 columnwise_tag dimension) const { 00258 switch (ColDist) { 00259 case elem::VR: 00260 case elem::VC: 00261 try { 00262 apply_impl_vdist (A, sketch_of_A, dimension); 00263 } catch (std::logic_error e) { 00264 SKYLARK_THROW_EXCEPTION ( 00265 base::elemental_exception() 00266 << base::error_msg(e.what()) ); 00267 } catch(boost::mpi::exception e) { 00268 SKYLARK_THROW_EXCEPTION ( 00269 base::mpi_exception() 00270 << base::error_msg(e.what()) ); 00271 } 00272 00273 break; 00274 00275 default: 00276 SKYLARK_THROW_EXCEPTION ( 00277 base::unsupported_matrix_distribution() ); 00278 00279 } 00280 } 00281 00282 00287 void apply (const matrix_type& A, 00288 output_matrix_type& sketch_of_A, 00289 rowwise_tag dimension) const { 00290 switch (ColDist) { 00291 case elem::VR: 00292 case elem::VC: 00293 try { 00294 apply_impl_vdist (A, sketch_of_A, dimension); 00295 } catch (std::logic_error e) { 00296 SKYLARK_THROW_EXCEPTION ( 00297 base::elemental_exception() 00298 << base::error_msg(e.what()) ); 00299 } catch(boost::mpi::exception e) { 00300 SKYLARK_THROW_EXCEPTION ( 00301 base::mpi_exception() 00302 << base::error_msg(e.what()) ); 00303 } 00304 00305 break; 00306 00307 default: 00308 SKYLARK_THROW_EXCEPTION ( 00309 base::unsupported_matrix_distribution() ); 00310 00311 } 00312 } 00313 00314 int get_N() const { return this->_N; } 00315 int get_S() const { return this->_S; } 00317 const sketch_transform_data_t* get_data() const { return this; } 00318 00319 private: 00324 void apply_impl_vdist(const matrix_type& A, 00325 output_matrix_type& sketch_A, 00326 skylark::sketch::columnwise_tag) const { 00327 00328 // Rearrange the matrix to fit the underlying transform 00329 intermediate_type inter_A(A.Grid()); 00330 inter_A = A; 00331 00332 // Apply the underlying transform 00333 underlying_type underlying(*data_type::underlying_data); 00334 underlying.apply(inter_A, inter_A, 00335 skylark::sketch::columnwise_tag()); 00336 00337 // Create the sampled and scaled matrix -- still in distributed mode 00338 intermediate_type dist_sketch_A(data_type::_S, 00339 inter_A.Width(), inter_A.Grid()); 00340 double scale = sqrt((double)data_type::_N / (double)data_type::_S); 00341 for (int j = 0; j < inter_A.LocalWidth(); j++) 00342 for (int i = 0; i < data_type::_S; i++) { 00343 int row = data_type::samples[i]; 00344 dist_sketch_A.Matrix().Set(i, j, 00345 scale * inter_A.Matrix().Get(row, j)); 00346 } 00347 00348 sketch_A = dist_sketch_A; 00349 } 00350 00355 void apply_impl_vdist(const matrix_type& A, 00356 output_matrix_type& sketch_of_A, 00357 skylark::sketch::rowwise_tag) const { 00358 00359 // TODO This is a quick&dirty hack - uses the columnwise implementation. 00360 matrix_type A_t(A.Grid()); 00361 elem::Transpose(A, A_t); 00362 output_matrix_type sketch_of_A_t(sketch_of_A.Width(), 00363 sketch_of_A.Height()); 00364 apply_impl_vdist(A_t, sketch_of_A_t, 00365 skylark::sketch::columnwise_tag()); 00366 elem::Transpose(sketch_of_A_t, sketch_of_A); 00367 } 00368 }; 00369 00370 } } 00372 #endif // FJLT_ELEMENTAL_HPP