Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/sketch/dense_transform_Elemental_coldist_star_local.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_DENSE_TRANSFORM_ELEMENTAL_COLDIST_STAR_LOCAL_HPP
00002 #define SKYLARK_DENSE_TRANSFORM_ELEMENTAL_COLDIST_STAR_LOCAL_HPP
00003 
00004 #include "../base/base.hpp"
00005 
00006 #include "transforms.hpp"
00007 #include "dense_transform_data.hpp"
00008 #include "../utility/comm.hpp"
00009 #include "../utility/get_communicator.hpp"
00010 
00011 #include "sketch_params.hpp"
00012 #include "dense_transform_Elemental_coldist_star.hpp"
00013 
00014 namespace skylark { namespace sketch {
00018 template <typename ValueType,
00019           elem::Distribution ColDist,
00020           template <typename> class ValueDistribution>
00021 struct dense_transform_t <
00022     elem::DistMatrix<ValueType, ColDist, elem::STAR>,
00023     elem::Matrix<ValueType>,
00024     ValueDistribution > :
00025         public dense_transform_data_t<ValueDistribution> {
00026     // Typedef matrix and distribution types so that we can use them regularly
00027     typedef ValueType value_type;
00028     typedef elem::DistMatrix<value_type, ColDist, elem::STAR> matrix_type;
00029     typedef elem::Matrix<value_type> output_matrix_type;
00030     typedef ValueDistribution<value_type> value_distribution_type;
00031     typedef dense_transform_data_t<ValueDistribution> data_type;
00032 
00036     dense_transform_t (int N, int S, double scale, base::context_t& context)
00037         : data_type (N, S, scale, context) {
00038 
00039     }
00040 
00041 
00045     dense_transform_t (dense_transform_t<matrix_type,
00046                                          output_matrix_type,
00047                                          ValueDistribution>& other)
00048         : data_type(other) {}
00049 
00053     dense_transform_t(const data_type& other_data)
00054         : data_type(other_data) {}
00055 
00056 
00060     template <typename Dimension>
00061     void apply (const matrix_type& A,
00062                 output_matrix_type& sketch_of_A,
00063                 Dimension dimension) const {
00064 
00065         switch(ColDist) {
00066         case elem::VR:
00067         case elem::VC:
00068             try {
00069                 apply_impl_vdist (A, sketch_of_A, dimension);
00070             } catch (std::logic_error e) {
00071                 SKYLARK_THROW_EXCEPTION (
00072                     base::elemental_exception()
00073                         << base::error_msg(e.what()) );
00074             } catch(boost::mpi::exception e) {
00075                 SKYLARK_THROW_EXCEPTION (
00076                     base::mpi_exception()
00077                         << base::error_msg(e.what()) );
00078             }
00079 
00080             break;
00081 
00082         default:
00083             SKYLARK_THROW_EXCEPTION (
00084                base::unsupported_matrix_distribution() );
00085         }
00086     }
00087 
00088     int get_N() const { return this->_N; } 
00089     int get_S() const { return this->_S; } 
00091     const sketch_transform_data_t* get_data() const { return this; }
00092 
00093 private:
00094 
00099     void apply_impl_vdist(const matrix_type& A,
00100                          output_matrix_type& sketch_of_A,
00101                          skylark::sketch::rowwise_tag tag) const {
00102 
00103         typedef elem::DistMatrix<value_type, elem::CIRC, elem::CIRC>
00104             intermediate_matrix_type;
00105 
00106         matrix_type sketch_of_A_CD_STAR(A.Height(),
00107                                         data_type::_S);
00108         intermediate_matrix_type sketch_of_A_CIRC_CIRC(A.Height(),
00109                                         data_type::_S);
00110 
00111         dense_transform_t<matrix_type, matrix_type, ValueDistribution>
00112             transform(*this);
00113 
00114         transform.apply(A, sketch_of_A_CD_STAR, tag);
00115 
00116         sketch_of_A_CIRC_CIRC = sketch_of_A_CD_STAR;
00117 
00118         boost::mpi::communicator world;
00119         MPI_Comm mpi_world(world);
00120         elem::Grid grid(mpi_world);
00121         int rank = world.rank();
00122         if (rank == 0) {
00123             sketch_of_A = sketch_of_A_CIRC_CIRC.Matrix();
00124         }
00125     }
00126 
00127 
00128     void apply_impl_vdist(const matrix_type& A,
00129                          output_matrix_type& sketch_of_A,
00130                          skylark::sketch::columnwise_tag tag) const {
00131 
00132         typedef elem::DistMatrix<value_type, elem::CIRC, elem::CIRC>
00133             intermediate_matrix_type;
00134 
00135         matrix_type sketch_of_A_CD_STAR(data_type::_S,
00136                                         A.Width());
00137         intermediate_matrix_type sketch_of_A_CIRC_CIRC(data_type::_S,
00138                                         A.Width());
00139 
00140         dense_transform_t<matrix_type, matrix_type, ValueDistribution>
00141             transform(*this);
00142 
00143         transform.apply(A, sketch_of_A_CD_STAR, tag);
00144 
00145         sketch_of_A_CIRC_CIRC = sketch_of_A_CD_STAR;
00146 
00147         boost::mpi::communicator world;
00148         MPI_Comm mpi_world(world);
00149         elem::Grid grid(mpi_world);
00150         int rank = world.rank();
00151         if (rank == 0) {
00152             sketch_of_A = sketch_of_A_CIRC_CIRC.Matrix();
00153         }
00154     }
00155 
00156 };
00157 
00158 } } 
00160 #endif // SKYLARK_DENSE_TRANSFORM_ELEMENTAL_COLDIST_STAR_LOCAL_HPP