Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/sketch/dense_transform_Elemental_star_rowdist.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_DENSE_TRANSFORM_ELEMENTAL_STAR_ROWDIST_HPP
00002 #define SKYLARK_DENSE_TRANSFORM_ELEMENTAL_STAR_ROWDIST_HPP
00003 
00004 #include "../base/base.hpp"
00005 
00006 #include "transforms.hpp"
00007 #include "dense_transform_data.hpp"
00008 #include "../utility/comm.hpp"
00009 #include "../utility/get_communicator.hpp"
00010 
00011 #include "sketch_params.hpp"
00012 
00013 namespace skylark { namespace sketch {
00017 template <typename ValueType,
00018           elem::Distribution RowDist,
00019           template <typename> class ValueDistribution>
00020 struct dense_transform_t <
00021     elem::DistMatrix<ValueType, elem::STAR, RowDist>,
00022     elem::DistMatrix<ValueType, elem::STAR, RowDist>,
00023     ValueDistribution> :
00024         public dense_transform_data_t<ValueDistribution> {
00025     // Typedef matrix and distribution types so that we can use them regularly
00026     typedef ValueType value_type;
00027     typedef elem::DistMatrix<value_type, elem::STAR, RowDist> matrix_type;
00028     typedef elem::DistMatrix<value_type, elem::STAR, RowDist>
00029     output_matrix_type;
00030     typedef ValueDistribution<value_type> value_distribution_type;
00031     typedef dense_transform_data_t<ValueDistribution> data_type;
00032 
00036     dense_transform_t (int N, int S, double scale, base::context_t& context)
00037         : data_type (N, S, scale, context) {
00038 
00039     }
00040 
00041 
00045     dense_transform_t (dense_transform_t<matrix_type,
00046                                          output_matrix_type,
00047                                          ValueDistribution>& other)
00048         : data_type(other) {}
00049 
00053     dense_transform_t(const data_type& other_data)
00054         : data_type(other_data) {}
00055 
00056 
00060     template <typename Dimension>
00061     void apply (const matrix_type& A,
00062                 output_matrix_type& sketch_of_A,
00063                 Dimension dimension) const {
00064 
00065         switch(RowDist) {
00066         case elem::VR:
00067         case elem::VC:
00068             try {
00069                 apply_impl_vdist (A, sketch_of_A, dimension);
00070             } catch (std::logic_error e) {
00071                 SKYLARK_THROW_EXCEPTION (
00072                     base::elemental_exception()
00073                         << base::error_msg(e.what()) );
00074             } catch(boost::mpi::exception e) {
00075                 SKYLARK_THROW_EXCEPTION (
00076                     base::mpi_exception()
00077                         << base::error_msg(e.what()) );
00078             }
00079 
00080             break;
00081 
00082         default:
00083             SKYLARK_THROW_EXCEPTION (
00084                 base::unsupported_matrix_distribution() );
00085         }
00086     }
00087 
00088     int get_N() const { return this->_N; } 
00089     int get_S() const { return this->_S; } 
00091     const sketch_transform_data_t* get_data() const { return this; }
00092 
00093 private:
00094 
00095     // Communication demanding scenario: Memory-oblivious mode
00096     // TODO: Block-by-block mode
00097     void inner_panel_gemm(const matrix_type& A,
00098                           output_matrix_type& sketch_of_A,
00099                           skylark::sketch::rowwise_tag) const {
00100 
00101         const elem::Grid& grid = A.Grid();
00102 
00103         elem::DistMatrix<value_type, elem::STAR, RowDist> R(grid);
00104         elem::DistMatrix<value_type, elem::STAR, elem::STAR>
00105             sketch_of_A_STAR_STAR(grid);
00106 
00107         data_type::realize_matrix_view(R);
00108 
00109         // TODO: is alignment necessary?
00110 
00111         // Global size of the result of the Local Gemm that follows
00112         sketch_of_A_STAR_STAR.Resize(A.Height(),
00113                                      R.Height());
00114 
00115         // Local Gemm
00116         base::Gemm(elem::NORMAL,
00117                    elem::TRANSPOSE,
00118                    value_type(1),
00119                    A.LockedMatrix(),
00120                    R.LockedMatrix(),
00121                    sketch_of_A_STAR_STAR.Matrix());
00122 
00123         // Reduce-scatter within process grid
00124         sketch_of_A.SumScatterFrom(sketch_of_A_STAR_STAR);
00125     }
00126 
00127 
00131     void inner_panel_gemm(const matrix_type& A,
00132                           output_matrix_type& sketch_of_A,
00133                           skylark::sketch::columnwise_tag) const {
00134 
00135         const elem::Grid& grid = A.Grid();
00136 
00137         elem::DistMatrix<value_type, elem::STAR, elem::STAR> R1(grid);
00138         elem::DistMatrix<value_type, elem::STAR, RowDist>
00139             A_Left(grid),
00140             A_Right(grid),
00141             A0(grid),
00142             A1(grid),
00143             A2(grid);
00144         elem::DistMatrix<value_type, elem::STAR, RowDist>
00145             sketch_of_A_Top(grid),
00146             sketch_of_A_Bottom(grid),
00147             sketch_of_A0(grid),
00148             sketch_of_A1(grid),
00149             sketch_of_A2(grid),
00150             sketch_of_A1_Left(grid),
00151             sketch_of_A1_Right(grid),
00152             sketch_of_A10(grid),
00153             sketch_of_A11(grid),
00154             sketch_of_A12(grid);
00155 
00156         // TODO: are alignments necessary?
00157 
00158         elem::PartitionDown
00159         ( sketch_of_A,
00160           sketch_of_A_Top, sketch_of_A_Bottom, 0 );
00161 
00162         // TODO: Allow for different blocksizes in "down" and "right" directions
00163         int blocksize = get_blocksize();
00164         if (blocksize == 0) {
00165             blocksize = std::min(sketch_of_A.Height(), sketch_of_A.Width());
00166         }
00167         int base = 0;
00168         while (sketch_of_A_Bottom.Height() > 0) {
00169 
00170             int b = std::min(sketch_of_A_Bottom.Height(), blocksize);
00171             data_type::realize_matrix_view(R1, base, 0,
00172                                                b,    A.Height());
00173 
00174             elem::RepartitionDown
00175             ( sketch_of_A_Top,    sketch_of_A0,
00176                               
00177                                   sketch_of_A1,
00178               sketch_of_A_Bottom, sketch_of_A2, b );
00179 
00180 
00181             elem::LockedPartitionRight
00182             ( A,
00183               A_Left, A_Right, 0 );
00184 
00185             elem::PartitionRight
00186             ( sketch_of_A1,
00187               sketch_of_A1_Left, sketch_of_A1_Right, 0 );
00188 
00189             while(A_Right.Width() > 0) {
00190 
00191                 elem::LockedRepartitionRight
00192                  ( A_Left,      A_Right,
00193                    A0,      A1, A2,      b );
00194 
00195                 elem::RepartitionRight
00196                 ( sketch_of_A1_Left,                sketch_of_A1_Right,
00197                   sketch_of_A10,      sketch_of_A11, sketch_of_A12,     b );
00198 
00199                 // Local Gemm
00200                 base::Gemm(elem::NORMAL,
00201                            elem::NORMAL,
00202                            value_type(1),
00203                            R1.LockedMatrix(),
00204                            A1.LockedMatrix(),
00205                            sketch_of_A11.Matrix());
00206 
00207                 elem::SlideLockedPartitionRight
00208                 ( A_Left,      A_Right,
00209                   A0,     A1,  A2 );
00210 
00211                 elem::SlidePartitionRight
00212                 ( sketch_of_A1_Left,                 sketch_of_A1_Right,
00213                   sketch_of_A10,     sketch_of_A11,  sketch_of_A12 );
00214             }
00215 
00216             base = base + b;
00217 
00218             elem::SlidePartitionDown
00219             ( sketch_of_A_Top,    sketch_of_A0,
00220                                   sketch_of_A1,
00221                               
00222               sketch_of_A_Bottom, sketch_of_A2 );
00223         }
00224     }
00225 
00226 
00231     // Communication demanding scenario: Memory-oblivious mode
00232     // TODO: Block-by-block mode
00233     void outer_panel_gemm(const matrix_type& A,
00234                           output_matrix_type& sketch_of_A,
00235                           skylark::sketch::rowwise_tag) const {
00236 
00237         const elem::Grid& grid = A.Grid();
00238 
00239         elem::DistMatrix<value_type, RowDist, elem::STAR> R(grid);
00240         elem::DistMatrix<value_type, elem::STAR, elem::STAR>
00241             A_STAR_STAR(grid);
00242 
00243         // TODO: are alignments necessary?
00244         R.AlignWith(sketch_of_A);
00245         A_STAR_STAR.AlignWith(sketch_of_A);
00246 
00247         data_type::realize_matrix_view(R);
00248 
00249         // Allgather within process grid
00250         A_STAR_STAR = A;
00251 
00252         // Zero sketch_of_A
00253         elem::Zero(sketch_of_A);
00254 
00255         // Local Gemm
00256         base::Gemm(elem::NORMAL,
00257                    elem::TRANSPOSE,
00258                    value_type(1),
00259                    A_STAR_STAR.LockedMatrix(),
00260                    R.LockedMatrix(),
00261                    value_type(1),
00262                    sketch_of_A.Matrix());
00263 
00264     }
00265 
00266 
00270     void outer_panel_gemm(const matrix_type& A,
00271                           output_matrix_type& sketch_of_A,
00272                           skylark::sketch::columnwise_tag) const {
00273 
00274         const elem::Grid& grid = A.Grid();
00275 
00276         elem::DistMatrix<value_type, elem::STAR, elem::STAR> R1(grid);
00277         elem::DistMatrix<value_type, elem::STAR, RowDist>
00278             A_Top(grid),
00279             A_Bottom(grid),
00280             A0(grid),
00281             A1(grid),
00282             A2(grid);
00283 
00284         // Zero sketch_of_A
00285         elem::Zero(sketch_of_A);
00286 
00287         // TODO: is alignment necessary?
00288         R1.AlignWith(sketch_of_A);
00289 
00290         elem::LockedPartitionDown
00291         ( A,
00292           A_Top, A_Bottom, 0 );
00293 
00294         int blocksize = get_blocksize();
00295         if (blocksize == 0) {
00296             blocksize = A_Bottom.Height();
00297         }
00298         int base = 0;
00299         while (A_Bottom.Height() > 0) {
00300 
00301             int b = std::min(A_Bottom.Height(), blocksize);
00302             data_type::realize_matrix_view(R1, 0,                   base,
00303                                                sketch_of_A.Height(), b);
00304 
00305             elem::RepartitionDown
00306             ( A_Top,    A0,
00307                     
00308                         A1,
00309               A_Bottom, A2, b );
00310 
00311             // Local Gemm
00312             base::Gemm(elem::NORMAL,
00313                        elem::NORMAL,
00314                        value_type(1),
00315                        R1.LockedMatrix(),
00316                        A1.LockedMatrix(),
00317                        value_type(1),
00318                        sketch_of_A.Matrix());
00319 
00320             base = base + b;
00321 
00322             elem::SlidePartitionDown
00323             ( A_Top,    A0,
00324                         A1,
00325                     
00326               A_Bottom, A2 );
00327         }
00328     }
00329 
00330 
00335     // Communication demanding scenario: Memory-oblivious mode
00336     // TODO: Block-by-block mode
00337     void matrix_panel_gemm(const matrix_type& A,
00338                           output_matrix_type& sketch_of_A,
00339                           skylark::sketch::rowwise_tag) const {
00340 
00341         const elem::Grid& grid = A.Grid();
00342 
00343         elem::DistMatrix<value_type, elem::STAR, RowDist> R(grid);
00344         elem::DistMatrix<value_type, elem::STAR, elem::STAR>
00345             sketch_of_A_STAR_STAR(grid);
00346 
00347         // TODO: are alignments necessary?
00348         R.AlignWith(sketch_of_A);
00349         sketch_of_A_STAR_STAR.AlignWith(sketch_of_A);
00350 
00351         data_type::realize_matrix_view(R);
00352 
00353         // Global size of the result of the Local Gemm that follows
00354         sketch_of_A_STAR_STAR.Resize(sketch_of_A.Height(),
00355                                       sketch_of_A.Width());
00356 
00357         // Local Gemm
00358         base::Gemm(elem::NORMAL,
00359                    elem::TRANSPOSE,
00360                    value_type(1),
00361                    A.LockedMatrix(),
00362                    R.LockedMatrix(),
00363                    sketch_of_A_STAR_STAR.Matrix());
00364 
00365         // Reduce-scatter within process grid
00366         sketch_of_A.SumScatterFrom(sketch_of_A_STAR_STAR);
00367 
00368     }
00369 
00370 
00374     void panel_matrix_gemm(const matrix_type& A,
00375                           output_matrix_type& sketch_of_A,
00376                           skylark::sketch::columnwise_tag) const {
00377 
00378         const elem::Grid& grid = A.Grid();
00379 
00380         elem::DistMatrix<value_type, elem::STAR, elem::STAR> R1(grid);
00381         elem::DistMatrix<value_type, elem::STAR, RowDist>
00382             sketch_of_A_Top(grid),
00383             sketch_of_A_Bottom(grid),
00384             sketch_of_A0(grid),
00385             sketch_of_A1(grid),
00386             sketch_of_A2(grid);
00387 
00388         // TODO: is alignment necessary?
00389         R1.AlignWith(A);
00390 
00391         elem::PartitionDown
00392         ( sketch_of_A,
00393           sketch_of_A_Top, sketch_of_A_Bottom, 0 );
00394 
00395         int blocksize = get_blocksize();
00396         if (blocksize == 0) {
00397             blocksize = sketch_of_A_Bottom.Height();
00398         }
00399         int base = 0;
00400         while (sketch_of_A_Bottom.Height() > 0) {
00401 
00402             int b = std::min(sketch_of_A_Bottom.Height(), blocksize);
00403             data_type::realize_matrix_view(R1, base, 0,
00404                                                b,    A.Height());
00405 
00406             elem::RepartitionDown
00407             ( sketch_of_A_Top,    sketch_of_A0,
00408                               
00409                                   sketch_of_A1,
00410               sketch_of_A_Bottom, sketch_of_A2, b );
00411 
00412             // Local Gemm
00413             base::Gemm(elem::NORMAL,
00414                        elem::NORMAL,
00415                        value_type(1),
00416                        R1.LockedMatrix(),
00417                        A.LockedMatrix(),
00418                        sketch_of_A1.Matrix());
00419 
00420             base = base + b;
00421 
00422             elem::SlidePartitionDown
00423             ( sketch_of_A_Top,    sketch_of_A0,
00424                                   sketch_of_A1,
00425                               
00426               sketch_of_A_Bottom, sketch_of_A2 );
00427         }
00428     }
00429 
00430 
00431     void sketch_gemm(const matrix_type& A,
00432         output_matrix_type& sketch_of_A,
00433         skylark::sketch::rowwise_tag tag) const {
00434 
00435         const int sketch_height = sketch_of_A.Height();
00436         const int sketch_width  = sketch_of_A.Width();
00437         const int width         = A.Width();
00438 
00439         const double factor = get_factor();
00440 
00441         if((sketch_height * factor <= width) &&
00442             (sketch_width * factor <= width))
00443             inner_panel_gemm(A, sketch_of_A, tag);
00444         else if((sketch_height >= width * factor) &&
00445             (sketch_width >= width * factor))
00446             outer_panel_gemm(A, sketch_of_A, tag);
00447         else
00448             matrix_panel_gemm(A, sketch_of_A, tag);
00449     }
00450 
00451 
00452     void sketch_gemm(const matrix_type& A,
00453         output_matrix_type& sketch_of_A,
00454         skylark::sketch::columnwise_tag tag) const {
00455 
00456         const int sketch_height = sketch_of_A.Height();
00457         const int sketch_width  = sketch_of_A.Width();
00458         const int height        = A.Height();
00459 
00460         const double factor = get_factor();
00461 
00462         if((sketch_height * factor <= height) &&
00463             (sketch_width * factor <= height))
00464             inner_panel_gemm(A, sketch_of_A, tag);
00465         else if((sketch_height >= height * factor) &&
00466             (sketch_width >= height * factor))
00467             outer_panel_gemm(A, sketch_of_A, tag);
00468         else
00469             panel_matrix_gemm(A, sketch_of_A, tag);
00470     }
00471 
00472 
00473     void apply_impl_vdist (const matrix_type& A,
00474                           output_matrix_type& sketch_of_A,
00475                           skylark::sketch::rowwise_tag tag) const {
00476 
00477         sketch_gemm(A, sketch_of_A, tag);
00478     }
00479 
00480 
00481     void apply_impl_vdist (const matrix_type& A,
00482                           output_matrix_type& sketch_of_A,
00483                           skylark::sketch::columnwise_tag tag) const {
00484 
00485         sketch_gemm(A, sketch_of_A, tag);
00486     }
00487 
00488 };
00489 
00490 } } 
00492 #endif // SKYLARK_DENSE_TRANSFORM_ELEMENTAL_STAR_ROWDIST_HPP