Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/sketch/dense_transform_Elemental_coldist_star.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_DENSE_TRANSFORM_ELEMENTAL_COLDIST_STAR_HPP
00002 #define SKYLARK_DENSE_TRANSFORM_ELEMENTAL_COLDIST_STAR_HPP
00003 
00004 #include "../base/base.hpp"
00005 
00006 #include "transforms.hpp"
00007 #include "dense_transform_data.hpp"
00008 #include "../utility/comm.hpp"
00009 #include "../utility/get_communicator.hpp"
00010 
00011 #include "sketch_params.hpp"
00012 
00013 namespace skylark { namespace sketch {
00017 template <typename ValueType,
00018           elem::Distribution ColDist,
00019           template <typename> class ValueDistribution>
00020 struct dense_transform_t <
00021     elem::DistMatrix<ValueType, ColDist, elem::STAR>,
00022     elem::DistMatrix<ValueType, ColDist, elem::STAR>,
00023     ValueDistribution> :
00024         public dense_transform_data_t<ValueDistribution> {
00025     // Typedef matrix and distribution types so that we can use them regularly
00026     typedef ValueType value_type;
00027     typedef elem::DistMatrix<value_type, ColDist, elem::STAR> matrix_type;
00028     typedef elem::DistMatrix<value_type, ColDist, elem::STAR>
00029     output_matrix_type;
00030     typedef ValueDistribution<value_type> value_distribution_type;
00031     typedef dense_transform_data_t<ValueDistribution> data_type;
00032 
00036     dense_transform_t (int N, int S, double scale, base::context_t& context)
00037         : data_type (N, S, scale, context) {
00038 
00039     }
00040 
00041 
00045     dense_transform_t (dense_transform_t<matrix_type,
00046                                          output_matrix_type,
00047                                          ValueDistribution>& other)
00048         : data_type(other) {}
00049 
00053     dense_transform_t(const data_type& other_data)
00054         : data_type(other_data) {}
00055 
00056 
00060     template <typename Dimension>
00061     void apply (const matrix_type& A,
00062                 output_matrix_type& sketch_of_A,
00063                 Dimension dimension) const {
00064 
00065         switch(ColDist) {
00066         case elem::VR:
00067         case elem::VC:
00068             try {
00069                 apply_impl_vdist (A, sketch_of_A, dimension);
00070             } catch (std::logic_error e) {
00071                 SKYLARK_THROW_EXCEPTION (
00072                     base::elemental_exception()
00073                         << base::error_msg(e.what()) );
00074             } catch(boost::mpi::exception e) {
00075                 SKYLARK_THROW_EXCEPTION (
00076                     base::mpi_exception()
00077                         << base::error_msg(e.what()) );
00078             }
00079 
00080             break;
00081 
00082         default:
00083             SKYLARK_THROW_EXCEPTION (
00084                 base::unsupported_matrix_distribution() );
00085         }
00086     }
00087 
00088     int get_N() const { return this->_N; } 
00089     int get_S() const { return this->_S; } 
00091     const sketch_transform_data_t* get_data() const { return this; }
00092 
00093 private:
00094 
00098     void inner_panel_gemm(const matrix_type& A,
00099                           output_matrix_type& sketch_of_A,
00100                           skylark::sketch::rowwise_tag) const {
00101 
00102         const elem::Grid& grid = A.Grid();
00103 
00104         elem::DistMatrix<value_type, elem::STAR, elem::STAR> R1(grid);
00105         elem::DistMatrix<value_type, ColDist, elem::STAR>
00106             A_Top(grid),
00107             A_Bottom(grid),
00108             A0(grid),
00109             A1(grid),
00110             A2(grid);
00111         elem::DistMatrix<value_type, ColDist, elem::STAR>
00112             sketch_of_A_Left(grid),
00113             sketch_of_A_Right(grid),
00114             sketch_of_A0(grid),
00115             sketch_of_A1(grid),
00116             sketch_of_A2(grid),
00117             sketch_of_A1_Top(grid),
00118             sketch_of_A1_Bottom(grid),
00119             sketch_of_A10(grid),
00120             sketch_of_A11(grid),
00121             sketch_of_A12(grid);
00122 
00123         // TODO: are alignments necessary?
00124 
00125         elem::PartitionRight
00126         ( sketch_of_A,
00127           sketch_of_A_Left, sketch_of_A_Right, 0 );
00128 
00129         // TODO: Allow for different blocksizes in "down" and "right" directions
00130         int blocksize = get_blocksize();
00131         if (blocksize == 0) {
00132             blocksize = std::min(sketch_of_A.Height(), sketch_of_A.Width());
00133         }
00134         int base = 0;
00135         while (sketch_of_A_Right.Width() > 0) {
00136 
00137             int b = std::min(sketch_of_A_Right.Width(), blocksize);
00138             data_type::realize_matrix_view(R1, base, 0,
00139                                                b,    A.Width());
00140 
00141             elem::RepartitionRight
00142             ( sketch_of_A_Left,                sketch_of_A_Right,
00143               sketch_of_A0,      sketch_of_A1, sketch_of_A2,      b );
00144 
00145             elem::LockedPartitionDown
00146             ( A, A_Top,
00147                  A_Bottom, 0 );
00148 
00149             elem::PartitionDown
00150             ( sketch_of_A1, sketch_of_A1_Top,
00151                             sketch_of_A1_Bottom, 0 );
00152 
00153             while(A_Bottom.Height() > 0) {
00154                 elem::LockedRepartitionDown
00155                 ( A_Top,    A0,
00156                         
00157                             A1,
00158                   A_Bottom, A2, b );
00159 
00160                 elem::RepartitionDown
00161                 ( sketch_of_A1_Top,    sketch_of_A10,
00162                                    
00163                                        sketch_of_A11,
00164                   sketch_of_A1_Bottom, sketch_of_A12, b );
00165 
00166                 // Local Gemm
00167                 base::Gemm(elem::NORMAL,
00168                            elem::TRANSPOSE,
00169                            value_type(1),
00170                            A1.LockedMatrix(),
00171                            R1.LockedMatrix(),
00172                            sketch_of_A11.Matrix());
00173 
00174                 elem::SlideLockedPartitionDown
00175                 ( A_Top,    A0,
00176                             A1,
00177                         
00178                   A_Bottom, A2 );
00179 
00180                 elem::SlidePartitionDown
00181                 ( sketch_of_A1_Top,    sketch_of_A10,
00182                                        sketch_of_A11,
00183                                    
00184                   sketch_of_A1_Bottom, sketch_of_A12 );
00185             }
00186 
00187             base = base + b;
00188 
00189             elem::SlidePartitionRight
00190             ( sketch_of_A_Left,                sketch_of_A_Right,
00191               sketch_of_A0,     sketch_of_A1,  sketch_of_A2 );
00192         }
00193     }
00194 
00195 
00200     // Communication demanding scenario: Memory-oblivious mode
00201     // TODO: Block-by-block mode
00202     void inner_panel_gemm(const matrix_type& A,
00203                           output_matrix_type& sketch_of_A,
00204                           skylark::sketch::columnwise_tag) const {
00205 
00206         const elem::Grid& grid = A.Grid();
00207 
00208         elem::DistMatrix<value_type, elem::STAR, ColDist> R(grid);
00209         elem::DistMatrix<value_type, elem::STAR, elem::STAR>
00210             sketch_of_A_STAR_STAR(grid);
00211 
00212         data_type::realize_matrix_view(R);
00213 
00214         // TODO: is alignment necessary?
00215 
00216         // Global size of the result of the Local Gemm that follows
00217         sketch_of_A_STAR_STAR.Resize(R.Height(),
00218                                      A.Width());
00219 
00220         // Local Gemm
00221         base::Gemm(elem::NORMAL,
00222                    elem::NORMAL,
00223                    value_type(1),
00224                    R.LockedMatrix(),
00225                    A.LockedMatrix(),
00226                    sketch_of_A_STAR_STAR.Matrix());
00227 
00228         // Reduce-scatter within process grid
00229         sketch_of_A.SumScatterFrom(sketch_of_A_STAR_STAR);
00230 
00231     }
00232 
00233 
00237     void outer_panel_gemm(const matrix_type& A,
00238                           output_matrix_type& sketch_of_A,
00239                           skylark::sketch::rowwise_tag) const {
00240 
00241         const elem::Grid& grid = A.Grid();
00242 
00243         elem::DistMatrix<value_type, elem::STAR, elem::STAR> R1(grid);
00244         elem::DistMatrix<value_type, ColDist, elem::STAR>
00245             A_Left(grid),
00246             A_Right(grid),
00247             A0(grid),
00248             A1(grid),
00249             A2(grid);
00250 
00251         // Zero sketch_of_A
00252         elem::Zero(sketch_of_A);
00253 
00254         // TODO: is alignment necessary?
00255         R1.AlignWith(sketch_of_A);
00256 
00257         elem::LockedPartitionRight
00258         ( A,
00259           A_Left, A_Right, 0 );
00260 
00261         int blocksize = get_blocksize();
00262         if (blocksize == 0) {
00263             blocksize = A_Right.Width();
00264         }
00265         int base = 0;
00266         while (A_Right.Width() > 0) {
00267 
00268             int b = std::min(A_Right.Width(), blocksize);
00269             data_type::realize_matrix_view(R1, 0,                   base,
00270                                                sketch_of_A.Width(), b);
00271 
00272             elem::RepartitionRight
00273             ( A_Left,      A_Right,
00274               A0,      A1, A2,      b );
00275 
00276             // Local Gemm
00277             base::Gemm(elem::NORMAL,
00278                        elem::TRANSPOSE,
00279                        value_type(1),
00280                        A1.LockedMatrix(),
00281                        R1.LockedMatrix(),
00282                        value_type(1),
00283                        sketch_of_A.Matrix());
00284 
00285             base = base + b;
00286 
00287             elem::SlidePartitionRight
00288             ( A_Left,      A_Right,
00289               A0,     A1,  A2 );
00290 
00291         }
00292     }
00293 
00294 
00299     // Communication demanding scenario: Memory-oblivious mode
00300     // TODO: Block-by-block mode
00301     void outer_panel_gemm(const matrix_type& A,
00302                           output_matrix_type& sketch_of_A,
00303                           skylark::sketch::columnwise_tag) const {
00304 
00305         const elem::Grid& grid = A.Grid();
00306 
00307         elem::DistMatrix<value_type, ColDist, elem::STAR> R(grid);
00308         elem::DistMatrix<value_type, elem::STAR, elem::STAR>
00309             A_STAR_STAR(grid);
00310 
00311         // TODO: are alignments necessary?
00312         R.AlignWith(sketch_of_A);
00313         A_STAR_STAR.AlignWith(sketch_of_A);
00314 
00315         data_type::realize_matrix_view(R);
00316 
00317         // Zero sketch_of_A
00318         elem::Zero(sketch_of_A);
00319 
00320         // Allgather within process grid
00321         A_STAR_STAR = A;
00322 
00323         // Local Gemm
00324         base::Gemm(elem::NORMAL,
00325                    elem::NORMAL,
00326                    value_type(1),
00327                    R.LockedMatrix(),
00328                    A_STAR_STAR.LockedMatrix(),
00329                    value_type(1),
00330                    sketch_of_A.Matrix());
00331     }
00332 
00333 
00337     void matrix_panel_gemm(const matrix_type& A,
00338                           output_matrix_type& sketch_of_A,
00339                           skylark::sketch::rowwise_tag) const {
00340 
00341         const elem::Grid& grid = A.Grid();
00342 
00343         elem::DistMatrix<value_type, elem::STAR, elem::STAR> R1(grid);
00344         elem::DistMatrix<value_type, ColDist, elem::STAR>
00345             sketch_of_A_Left(grid),
00346             sketch_of_A_Right(grid),
00347             sketch_of_A0(grid),
00348             sketch_of_A1(grid),
00349             sketch_of_A2(grid);
00350 
00351         // TODO: is alignment necessary?
00352         R1.AlignWith(sketch_of_A);
00353 
00354         elem::PartitionRight
00355         ( sketch_of_A,
00356           sketch_of_A_Left, sketch_of_A_Right, 0 );
00357 
00358         int blocksize = get_blocksize();
00359         if (blocksize == 0) {
00360             blocksize = sketch_of_A_Right.Width();
00361         }
00362         int base = 0;
00363         while (sketch_of_A_Right.Width() > 0) {
00364 
00365             int b = std::min(sketch_of_A_Right.Width(), blocksize);
00366             data_type::realize_matrix_view(R1, base, 0,
00367                                                b,    A.Width());
00368 
00369             elem::RepartitionRight
00370             ( sketch_of_A_Left,                sketch_of_A_Right,
00371               sketch_of_A0,      sketch_of_A1, sketch_of_A2,      b );
00372 
00373             // Local Gemm
00374             base::Gemm(elem::NORMAL,
00375                        elem::TRANSPOSE,
00376                        value_type(1),
00377                        A.LockedMatrix(),
00378                        R1.LockedMatrix(),
00379                        sketch_of_A1.Matrix());
00380 
00381             base = base + b;
00382 
00383             elem::SlidePartitionRight
00384             ( sketch_of_A_Left,                sketch_of_A_Right,
00385               sketch_of_A0,     sketch_of_A1,  sketch_of_A2 );
00386 
00387         }
00388     }
00389 
00390 
00395     // Communication demanding scenario: Memory-oblivious mode
00396     // TODO: Block-by-block mode
00397     void panel_matrix_gemm(const matrix_type& A,
00398                           output_matrix_type& sketch_of_A,
00399                           skylark::sketch::columnwise_tag) const {
00400 
00401         const elem::Grid& grid = A.Grid();
00402 
00403         elem::DistMatrix<value_type, elem::STAR, ColDist> R(grid);
00404         elem::DistMatrix<value_type, elem::STAR, elem::STAR>
00405             sketch_of_A_STAR_STAR(grid);
00406 
00407         // TODO: is alignment necessary?
00408         sketch_of_A_STAR_STAR.AlignWith(A);
00409 
00410         data_type::realize_matrix_view(R);
00411 
00412         // Global size of the result of the Local Gemm that follows
00413         sketch_of_A_STAR_STAR.Resize(R.Height(),
00414                                      A.Width());
00415 
00416         // Local Gemm
00417         base::Gemm(elem::NORMAL,
00418                    elem::NORMAL,
00419                    value_type(1),
00420                    R.LockedMatrix(),
00421                    A.LockedMatrix(),
00422                    sketch_of_A_STAR_STAR.Matrix());
00423 
00424         // Reduce-scatter within process grid
00425         sketch_of_A.SumScatterFrom(sketch_of_A_STAR_STAR);
00426     }
00427 
00428 
00429     void sketch_gemm(const matrix_type& A,
00430         output_matrix_type& sketch_of_A,
00431         skylark::sketch::rowwise_tag tag) const {
00432 
00433         const int sketch_height = sketch_of_A.Height();
00434         const int sketch_width  = sketch_of_A.Width();
00435         const int width         = A.Width();
00436 
00437         const double factor = get_factor();
00438 
00439         if((sketch_height * factor <= width) &&
00440             (sketch_width * factor <= width)) {
00441             inner_panel_gemm(A, sketch_of_A, tag);
00442         } else if((sketch_height >= width * factor) &&
00443             (sketch_width >= width * factor))
00444             outer_panel_gemm(A, sketch_of_A, tag);
00445         else
00446             matrix_panel_gemm(A, sketch_of_A, tag);
00447     }
00448 
00449 
00450     void sketch_gemm(const matrix_type& A,
00451         output_matrix_type& sketch_of_A,
00452         skylark::sketch::columnwise_tag tag) const {
00453 
00454         const int sketch_height = sketch_of_A.Height();
00455         const int sketch_width  = sketch_of_A.Width();
00456         const int height         = A.Height();
00457 
00458         const double factor = get_factor();
00459 
00460         if((sketch_height * factor <= height) &&
00461             (sketch_width * factor <= height))
00462             inner_panel_gemm(A, sketch_of_A, tag);
00463         else if((sketch_height >= height * factor) &&
00464             (sketch_width >= height * factor))
00465             outer_panel_gemm(A, sketch_of_A, tag);
00466         else
00467             panel_matrix_gemm(A, sketch_of_A, tag);
00468     }
00469 
00470 
00471     void apply_impl_vdist (const matrix_type& A,
00472                           output_matrix_type& sketch_of_A,
00473                           skylark::sketch::rowwise_tag tag) const {
00474 
00475         sketch_gemm(A, sketch_of_A, tag);
00476     }
00477 
00478 
00479     void apply_impl_vdist (const matrix_type& A,
00480                           output_matrix_type& sketch_of_A,
00481                           skylark::sketch::columnwise_tag tag) const {
00482 
00483         sketch_gemm(A, sketch_of_A, tag);
00484     }
00485 };
00486 
00487 } } 
00489 #endif // SKYLARK_DENSE_TRANSFORM_ELEMENTAL_COLDIST_STAR_HPP