Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_DENSE_TRANSFORM_ELEMENTAL_STAR_ROWDIST_LOCAL_HPP 00002 #define SKYLARK_DENSE_TRANSFORM_ELEMENTAL_STAR_ROWDIST_LOCAL_HPP 00003 00004 #include "../base/base.hpp" 00005 00006 #include "transforms.hpp" 00007 #include "dense_transform_data.hpp" 00008 #include "../utility/comm.hpp" 00009 #include "../utility/get_communicator.hpp" 00010 00011 #include "sketch_params.hpp" 00012 #include "dense_transform_Elemental_star_rowdist.hpp" 00013 00014 namespace skylark { namespace sketch { 00018 template <typename ValueType, 00019 elem::Distribution RowDist, 00020 template <typename> class ValueDistribution> 00021 struct dense_transform_t < 00022 elem::DistMatrix<ValueType, elem::STAR, RowDist>, 00023 elem::Matrix<ValueType>, 00024 ValueDistribution > : 00025 public dense_transform_data_t<ValueDistribution> { 00026 // Typedef matrix and distribution types so that we can use them regularly 00027 typedef ValueType value_type; 00028 typedef elem::DistMatrix<value_type, elem::STAR, RowDist> matrix_type; 00029 typedef elem::Matrix<value_type> output_matrix_type; 00030 typedef ValueDistribution<value_type> value_distribution_type; 00031 typedef dense_transform_data_t<ValueDistribution> data_type; 00032 00036 dense_transform_t (int N, int S, double scale, base::context_t& context) 00037 : data_type (N, S, scale, context) { 00038 00039 } 00040 00041 00045 dense_transform_t (dense_transform_t<matrix_type, 00046 output_matrix_type, 00047 ValueDistribution>& other) 00048 : data_type(other) {} 00049 00050 00054 dense_transform_t(const data_type& other_data) 00055 : data_type(other_data) {} 00056 00057 00061 template <typename Dimension> 00062 void apply (const matrix_type& A, 00063 output_matrix_type& sketch_of_A, 00064 Dimension dimension) const { 00065 00066 switch(RowDist) { 00067 case elem::VR: 00068 case elem::VC: 00069 try { 00070 apply_impl_vdist (A, sketch_of_A, dimension); 00071 } catch (std::logic_error e) { 00072 SKYLARK_THROW_EXCEPTION ( 00073 base::elemental_exception() 00074 << base::error_msg(e.what()) ); 00075 } catch(boost::mpi::exception e) { 00076 SKYLARK_THROW_EXCEPTION ( 00077 base::mpi_exception() 00078 << base::error_msg(e.what()) ); 00079 } 00080 00081 break; 00082 00083 default: 00084 SKYLARK_THROW_EXCEPTION ( 00085 base::unsupported_matrix_distribution() ); 00086 } 00087 } 00088 00089 int get_N() const { return this->_N; } 00090 int get_S() const { return this->_S; } 00092 const sketch_transform_data_t* get_data() const { return this; } 00093 00094 private: 00095 00100 void apply_impl_vdist(const matrix_type& A, 00101 output_matrix_type& sketch_of_A, 00102 skylark::sketch::rowwise_tag tag) const { 00103 00104 typedef elem::DistMatrix<value_type, elem::CIRC, elem::CIRC> 00105 intermediate_matrix_type; 00106 00107 matrix_type sketch_of_A_STAR_RD(A.Height(), 00108 data_type::_S); 00109 intermediate_matrix_type sketch_of_A_CIRC_CIRC(A.Height(), 00110 data_type::_S); 00111 00112 dense_transform_t<matrix_type, matrix_type, ValueDistribution> 00113 transform(*this); 00114 00115 transform.apply(A, sketch_of_A_STAR_RD, tag); 00116 00117 sketch_of_A_CIRC_CIRC = sketch_of_A_STAR_RD; 00118 00119 boost::mpi::communicator world; 00120 MPI_Comm mpi_world(world); 00121 elem::Grid grid(mpi_world); 00122 int rank = world.rank(); 00123 if (rank == 0) { 00124 sketch_of_A = sketch_of_A_CIRC_CIRC.Matrix(); 00125 } 00126 } 00127 00128 00129 void apply_impl_vdist(const matrix_type& A, 00130 output_matrix_type& sketch_of_A, 00131 skylark::sketch::columnwise_tag tag) const { 00132 00133 typedef elem::DistMatrix<value_type, elem::CIRC, elem::CIRC> 00134 intermediate_matrix_type; 00135 00136 matrix_type sketch_of_A_STAR_RD(data_type::_S, 00137 A.Width()); 00138 intermediate_matrix_type sketch_of_A_CIRC_CIRC(data_type::_S, 00139 A.Width()); 00140 00141 dense_transform_t<matrix_type, matrix_type, ValueDistribution> 00142 transform(*this); 00143 00144 transform.apply(A, sketch_of_A_STAR_RD, tag); 00145 00146 sketch_of_A_CIRC_CIRC = sketch_of_A_STAR_RD; 00147 00148 boost::mpi::communicator world; 00149 MPI_Comm mpi_world(world); 00150 elem::Grid grid(mpi_world); 00151 int rank = world.rank(); 00152 if (rank == 0) { 00153 sketch_of_A = sketch_of_A_CIRC_CIRC.Matrix(); 00154 } 00155 } 00156 00157 }; 00158 00159 } } 00161 #endif // SKYLARK_DENSE_TRANSFORM_ELEMENTAL_STAR_ROWDIST_LOCAL_HPP