Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_DENSE_TRANSFORM_ELEMENTAL_STAR_ROWDIST_STAR_STAR_HPP 00002 #define SKYLARK_DENSE_TRANSFORM_ELEMENTAL_STAR_ROWDIST_STAR_STAR_HPP 00003 00004 #include "../base/base.hpp" 00005 00006 #include "transforms.hpp" 00007 #include "dense_transform_data.hpp" 00008 #include "../utility/comm.hpp" 00009 #include "../utility/get_communicator.hpp" 00010 00011 #include "sketch_params.hpp" 00012 #include "dense_transform_Elemental_star_rowdist.hpp" 00013 00014 namespace skylark { namespace sketch { 00018 template <typename ValueType, 00019 elem::Distribution RowDist, 00020 template <typename> class ValueDistribution> 00021 struct dense_transform_t < 00022 elem::DistMatrix<ValueType, elem::STAR, RowDist>, 00023 elem::DistMatrix<ValueType, elem::STAR, elem::STAR>, 00024 ValueDistribution > : 00025 public dense_transform_data_t<ValueDistribution> { 00026 // Typedef matrix and distribution types so that we can use them regularly 00027 typedef ValueType value_type; 00028 typedef elem::DistMatrix<value_type, elem::STAR, RowDist> matrix_type; 00029 typedef elem::DistMatrix<value_type, elem::STAR, elem::STAR> 00030 output_matrix_type; 00031 typedef ValueDistribution<value_type> value_distribution_type; 00032 typedef dense_transform_data_t<ValueDistribution> data_type; 00033 00037 dense_transform_t (int N, int S, double scale, base::context_t& context) 00038 : data_type (N, S, scale, context) { 00039 00040 } 00041 00042 00046 dense_transform_t (dense_transform_t<matrix_type, 00047 output_matrix_type, 00048 ValueDistribution>& other) 00049 : data_type(other) {} 00050 00054 dense_transform_t(const data_type& other_data) 00055 : data_type(other_data) {} 00056 00057 00061 template <typename Dimension> 00062 void apply (const matrix_type& A, 00063 output_matrix_type& sketch_of_A, 00064 Dimension dimension) const { 00065 00066 switch(RowDist) { 00067 case elem::VR: 00068 case elem::VC: 00069 try { 00070 apply_impl_vdist (A, sketch_of_A, dimension); 00071 } catch (std::logic_error e) { 00072 SKYLARK_THROW_EXCEPTION ( 00073 base::elemental_exception() 00074 << base::error_msg(e.what()) ); 00075 } catch(boost::mpi::exception e) { 00076 SKYLARK_THROW_EXCEPTION ( 00077 base::mpi_exception() 00078 << base::error_msg(e.what()) ); 00079 } 00080 00081 break; 00082 00083 default: 00084 SKYLARK_THROW_EXCEPTION ( 00085 base::unsupported_matrix_distribution() ); 00086 } 00087 } 00088 00089 int get_N() const { return this->_N; } 00090 int get_S() const { return this->_S; } 00092 const sketch_transform_data_t* get_data() const { return this; } 00093 00094 private: 00095 00100 void apply_impl_vdist(const matrix_type& A, 00101 output_matrix_type& sketch_of_A, 00102 skylark::sketch::rowwise_tag tag) const { 00103 00104 00105 matrix_type sketch_of_A_STAR_RD(A.Height(), 00106 data_type::_S); 00107 00108 dense_transform_t<matrix_type, matrix_type, ValueDistribution> 00109 transform(*this); 00110 00111 transform.apply(A, sketch_of_A_STAR_RD, tag); 00112 00113 sketch_of_A = sketch_of_A_STAR_RD; 00114 } 00115 00116 00117 void apply_impl_vdist(const matrix_type& A, 00118 output_matrix_type& sketch_of_A, 00119 skylark::sketch::columnwise_tag tag) const { 00120 00121 matrix_type sketch_of_A_STAR_RD(data_type::_S, 00122 A.Width()); 00123 00124 dense_transform_t<matrix_type, matrix_type, ValueDistribution> 00125 transform(*this); 00126 00127 transform.apply(A, sketch_of_A_STAR_RD, tag); 00128 00129 sketch_of_A = sketch_of_A_STAR_RD; 00130 } 00131 }; 00132 00133 } } 00135 #endif // SKYLARK_DENSE_TRANSFORM_ELEMENTAL_STAR_ROWDIST_STAR_STAR_HPP