Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_DENSE_TRANSFORM_ELEMENTAL_MC_MR_LOCAL_HPP 00002 #define SKYLARK_DENSE_TRANSFORM_ELEMENTAL_MC_MR_LOCAL_HPP 00003 00004 #include "../base/base.hpp" 00005 00006 #include "transforms.hpp" 00007 #include "dense_transform_data.hpp" 00008 #include "../utility/comm.hpp" 00009 #include "../utility/get_communicator.hpp" 00010 00011 #include "sketch_params.hpp" 00012 #include "dense_transform_Elemental_mc_mr.hpp" 00013 00014 namespace skylark { namespace sketch { 00018 template <typename ValueType, 00019 template <typename> class ValueDistribution> 00020 struct dense_transform_t < 00021 elem::DistMatrix<ValueType>, 00022 elem::Matrix<ValueType>, 00023 ValueDistribution> : 00024 public dense_transform_data_t<ValueDistribution> { 00025 00026 // Typedef matrix and distribution types so that we can use them regularly 00027 typedef ValueType value_type; 00028 typedef elem::DistMatrix<value_type> matrix_type; 00029 typedef elem::Matrix<value_type> output_matrix_type; 00030 typedef ValueDistribution<value_type> value_distribution_type; 00031 typedef dense_transform_data_t<ValueDistribution> data_type; 00032 00036 dense_transform_t (int N, int S, double scale, base::context_t& context) 00037 : data_type (N, S, scale, context) { 00038 00039 } 00040 00044 dense_transform_t (dense_transform_t<matrix_type, 00045 output_matrix_type, 00046 ValueDistribution>& other) 00047 : data_type(other) {} 00048 00052 dense_transform_t(const data_type& other_data) 00053 : data_type(other_data) {} 00054 00055 00059 template <typename Dimension> 00060 void apply (const matrix_type& A, 00061 output_matrix_type& sketch_of_A, 00062 Dimension dimension) const { 00063 try { 00064 apply_impl_dist(A, sketch_of_A, dimension); 00065 } catch (std::logic_error e) { 00066 SKYLARK_THROW_EXCEPTION ( 00067 base::elemental_exception() 00068 << base::error_msg(e.what()) ); 00069 } catch(boost::mpi::exception e) { 00070 SKYLARK_THROW_EXCEPTION ( 00071 base::mpi_exception() 00072 << base::error_msg(e.what()) ); 00073 } 00074 } 00075 00076 int get_N() const { return this->_N; } 00077 int get_S() const { return this->_S; } 00079 const sketch_transform_data_t* get_data() const { return this; } 00080 00081 private: 00082 00087 void apply_impl_dist(const matrix_type& A, 00088 output_matrix_type& sketch_of_A, 00089 skylark::sketch::rowwise_tag tag) const { 00090 00091 typedef elem::DistMatrix<value_type, elem::CIRC, elem::CIRC> 00092 intermediate_matrix_type; 00093 00094 matrix_type sketch_of_A_MC_MR(A.Height(), 00095 data_type::_S); 00096 intermediate_matrix_type sketch_of_A_CIRC_CIRC(A.Height(), 00097 data_type::_S); 00098 00099 dense_transform_t<matrix_type, matrix_type, ValueDistribution> 00100 transform(*this); 00101 00102 transform.apply(A, sketch_of_A_MC_MR, tag); 00103 00104 sketch_of_A_CIRC_CIRC = sketch_of_A_MC_MR; 00105 00106 boost::mpi::communicator world; 00107 MPI_Comm mpi_world(world); 00108 elem::Grid grid(mpi_world); 00109 int rank = world.rank(); 00110 if (rank == 0) { 00111 sketch_of_A = sketch_of_A_CIRC_CIRC.Matrix(); 00112 } 00113 } 00114 00115 00116 void apply_impl_dist(const matrix_type& A, 00117 output_matrix_type& sketch_of_A, 00118 skylark::sketch::columnwise_tag tag) const { 00119 00120 typedef elem::DistMatrix<value_type, elem::CIRC, elem::CIRC> 00121 intermediate_matrix_type; 00122 00123 matrix_type sketch_of_A_MC_MR(data_type::_S, 00124 A.Width()); 00125 intermediate_matrix_type sketch_of_A_CIRC_CIRC(data_type::_S, 00126 A.Width()); 00127 00128 dense_transform_t<matrix_type, matrix_type, ValueDistribution> 00129 transform(*this); 00130 00131 transform.apply(A, sketch_of_A_MC_MR, tag); 00132 00133 sketch_of_A_CIRC_CIRC = sketch_of_A_MC_MR; 00134 00135 boost::mpi::communicator world; 00136 MPI_Comm mpi_world(world); 00137 elem::Grid grid(mpi_world); 00138 int rank = world.rank(); 00139 if (rank == 0) { 00140 sketch_of_A = sketch_of_A_CIRC_CIRC.Matrix(); 00141 } 00142 } 00143 00144 }; 00145 00146 } } 00148 #endif // SKYLARK_DENSE_TRANSFORM_ELEMENTAL_MC_MR_LOCAL_HPP