Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_DENSE_TRANSFORM_ELEMENTAL_COLDIST_STAR_LOCAL_HPP 00002 #define SKYLARK_DENSE_TRANSFORM_ELEMENTAL_COLDIST_STAR_LOCAL_HPP 00003 00004 #include "../base/base.hpp" 00005 00006 #include "transforms.hpp" 00007 #include "dense_transform_data.hpp" 00008 #include "../utility/comm.hpp" 00009 #include "../utility/get_communicator.hpp" 00010 00011 #include "sketch_params.hpp" 00012 #include "dense_transform_Elemental_coldist_star.hpp" 00013 00014 namespace skylark { namespace sketch { 00018 template <typename ValueType, 00019 elem::Distribution ColDist, 00020 template <typename> class ValueDistribution> 00021 struct dense_transform_t < 00022 elem::DistMatrix<ValueType, ColDist, elem::STAR>, 00023 elem::Matrix<ValueType>, 00024 ValueDistribution > : 00025 public dense_transform_data_t<ValueDistribution> { 00026 // Typedef matrix and distribution types so that we can use them regularly 00027 typedef ValueType value_type; 00028 typedef elem::DistMatrix<value_type, ColDist, elem::STAR> matrix_type; 00029 typedef elem::Matrix<value_type> output_matrix_type; 00030 typedef ValueDistribution<value_type> value_distribution_type; 00031 typedef dense_transform_data_t<ValueDistribution> data_type; 00032 00036 dense_transform_t (int N, int S, double scale, base::context_t& context) 00037 : data_type (N, S, scale, context) { 00038 00039 } 00040 00041 00045 dense_transform_t (dense_transform_t<matrix_type, 00046 output_matrix_type, 00047 ValueDistribution>& other) 00048 : data_type(other) {} 00049 00053 dense_transform_t(const data_type& other_data) 00054 : data_type(other_data) {} 00055 00056 00060 template <typename Dimension> 00061 void apply (const matrix_type& A, 00062 output_matrix_type& sketch_of_A, 00063 Dimension dimension) const { 00064 00065 switch(ColDist) { 00066 case elem::VR: 00067 case elem::VC: 00068 try { 00069 apply_impl_vdist (A, sketch_of_A, dimension); 00070 } catch (std::logic_error e) { 00071 SKYLARK_THROW_EXCEPTION ( 00072 base::elemental_exception() 00073 << base::error_msg(e.what()) ); 00074 } catch(boost::mpi::exception e) { 00075 SKYLARK_THROW_EXCEPTION ( 00076 base::mpi_exception() 00077 << base::error_msg(e.what()) ); 00078 } 00079 00080 break; 00081 00082 default: 00083 SKYLARK_THROW_EXCEPTION ( 00084 base::unsupported_matrix_distribution() ); 00085 } 00086 } 00087 00088 int get_N() const { return this->_N; } 00089 int get_S() const { return this->_S; } 00091 const sketch_transform_data_t* get_data() const { return this; } 00092 00093 private: 00094 00099 void apply_impl_vdist(const matrix_type& A, 00100 output_matrix_type& sketch_of_A, 00101 skylark::sketch::rowwise_tag tag) const { 00102 00103 typedef elem::DistMatrix<value_type, elem::CIRC, elem::CIRC> 00104 intermediate_matrix_type; 00105 00106 matrix_type sketch_of_A_CD_STAR(A.Height(), 00107 data_type::_S); 00108 intermediate_matrix_type sketch_of_A_CIRC_CIRC(A.Height(), 00109 data_type::_S); 00110 00111 dense_transform_t<matrix_type, matrix_type, ValueDistribution> 00112 transform(*this); 00113 00114 transform.apply(A, sketch_of_A_CD_STAR, tag); 00115 00116 sketch_of_A_CIRC_CIRC = sketch_of_A_CD_STAR; 00117 00118 boost::mpi::communicator world; 00119 MPI_Comm mpi_world(world); 00120 elem::Grid grid(mpi_world); 00121 int rank = world.rank(); 00122 if (rank == 0) { 00123 sketch_of_A = sketch_of_A_CIRC_CIRC.Matrix(); 00124 } 00125 } 00126 00127 00128 void apply_impl_vdist(const matrix_type& A, 00129 output_matrix_type& sketch_of_A, 00130 skylark::sketch::columnwise_tag tag) const { 00131 00132 typedef elem::DistMatrix<value_type, elem::CIRC, elem::CIRC> 00133 intermediate_matrix_type; 00134 00135 matrix_type sketch_of_A_CD_STAR(data_type::_S, 00136 A.Width()); 00137 intermediate_matrix_type sketch_of_A_CIRC_CIRC(data_type::_S, 00138 A.Width()); 00139 00140 dense_transform_t<matrix_type, matrix_type, ValueDistribution> 00141 transform(*this); 00142 00143 transform.apply(A, sketch_of_A_CD_STAR, tag); 00144 00145 sketch_of_A_CIRC_CIRC = sketch_of_A_CD_STAR; 00146 00147 boost::mpi::communicator world; 00148 MPI_Comm mpi_world(world); 00149 elem::Grid grid(mpi_world); 00150 int rank = world.rank(); 00151 if (rank == 0) { 00152 sketch_of_A = sketch_of_A_CIRC_CIRC.Matrix(); 00153 } 00154 } 00155 00156 }; 00157 00158 } } 00160 #endif // SKYLARK_DENSE_TRANSFORM_ELEMENTAL_COLDIST_STAR_LOCAL_HPP