Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/sketch/dense_transform_Elemental_mc_mr.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_DENSE_TRANSFORM_ELEMENTAL_MC_MR_HPP
00002 #define SKYLARK_DENSE_TRANSFORM_ELEMENTAL_MC_MR_HPP
00003 
00004 #include "../base/base.hpp"
00005 
00006 #include "transforms.hpp"
00007 #include "dense_transform_data.hpp"
00008 #include "../utility/comm.hpp"
00009 #include "../utility/get_communicator.hpp"
00010 
00011 #include "sketch_params.hpp"
00012 
00013 namespace skylark { namespace sketch {
00017 template <typename ValueType,
00018           template <typename> class ValueDistribution>
00019 struct dense_transform_t <
00020     elem::DistMatrix<ValueType>,
00021     elem::DistMatrix<ValueType>,
00022     ValueDistribution> :
00023         public dense_transform_data_t<ValueDistribution> {
00024 
00025     // Typedef matrix and distribution types so that we can use them regularly
00026     typedef ValueType value_type;
00027     typedef elem::DistMatrix<value_type> matrix_type;
00028     typedef elem::DistMatrix<value_type> output_matrix_type;
00029     typedef ValueDistribution<value_type> value_distribution_type;
00030     typedef dense_transform_data_t<ValueDistribution> data_type;
00031 
00032 
00036     dense_transform_t (int N, int S, double scale, base::context_t& context)
00037         : data_type (N, S, scale, context) {
00038 
00039     }
00040 
00044     dense_transform_t (dense_transform_t<matrix_type,
00045                                          output_matrix_type,
00046                                          ValueDistribution>& other)
00047         : data_type(other) {
00048 
00049     }
00050 
00054     dense_transform_t(const data_type& other_data)
00055         : data_type(other_data) {
00056 
00057     }
00058 
00062     template <typename Dimension>
00063     void apply (const matrix_type& A,
00064                 output_matrix_type& sketch_of_A,
00065                 Dimension dimension) const {
00066         try {
00067             apply_impl_dist(A, sketch_of_A, dimension);
00068         } catch (std::logic_error e) {
00069             SKYLARK_THROW_EXCEPTION (
00070                 base::elemental_exception()
00071                     << base::error_msg(e.what()) );
00072         } catch(boost::mpi::exception e) {
00073                 SKYLARK_THROW_EXCEPTION (
00074                     base::mpi_exception()
00075                         << base::error_msg(e.what()) );
00076         }
00077     }
00078 
00079     int get_N() const { return this->_N; } 
00080     int get_S() const { return this->_S; } 
00082     const sketch_transform_data_t* get_data() const { return this; }
00083 
00084 private:
00085 
00089     void inner_panel_gemm(const matrix_type& A,
00090                           output_matrix_type& sketch_of_A,
00091                           skylark::sketch::rowwise_tag) const {
00092 
00093         const elem::Grid& grid = A.Grid();
00094 
00095         elem::DistMatrix<value_type, elem::STAR, elem::VR> R1(grid);
00096         elem::DistMatrix<value_type>
00097             A_Top(grid),
00098             A_Bottom(grid),
00099             A0(grid),
00100             A1(grid),
00101             A2(grid);
00102         elem::DistMatrix<value_type>
00103             sketch_of_A_Left(grid),
00104             sketch_of_A_Right(grid),
00105             sketch_of_A0(grid),
00106             sketch_of_A1(grid),
00107             sketch_of_A2(grid),
00108             sketch_of_A1_Top(grid),
00109             sketch_of_A1_Bottom(grid),
00110             sketch_of_A10(grid),
00111             sketch_of_A11(grid),
00112             sketch_of_A12(grid);
00113         elem::DistMatrix<value_type, elem::STAR, elem::VR>
00114             A1_STAR_VR(grid);
00115         elem::DistMatrix<value_type, elem::STAR, elem::STAR>
00116             sketch_of_A11_STAR_STAR(grid);
00117 
00118         // TODO: are alignments necessary?
00119 
00120         elem::PartitionRight
00121         ( sketch_of_A,
00122           sketch_of_A_Left, sketch_of_A_Right, 0 );
00123 
00124         // TODO: Allow for different blocksizes in "down" and "right" directions
00125         int blocksize = get_blocksize();
00126         if (blocksize == 0) {
00127             blocksize = std::min(sketch_of_A.Height(), sketch_of_A.Width());
00128         }
00129         int base = 0;
00130         while (sketch_of_A_Right.Width() > 0) {
00131 
00132             int b = std::min(sketch_of_A_Right.Width(), blocksize);
00133             data_type::realize_matrix_view(R1, base, 0,
00134                                                b,    A.Width());
00135 
00136             elem::RepartitionRight
00137             ( sketch_of_A_Left,                sketch_of_A_Right,
00138               sketch_of_A0,      sketch_of_A1, sketch_of_A2,      b );
00139 
00140             // TODO: is alignment necessary?
00141             A1_STAR_VR.AlignWith(R1);
00142 
00143             elem::LockedPartitionDown
00144             ( A,
00145               A_Top, A_Bottom, 0 );
00146 
00147             elem::PartitionDown
00148             ( sketch_of_A1,
00149               sketch_of_A1_Top, sketch_of_A1_Bottom, 0 );
00150 
00151             while(A_Bottom.Height() > 0) {
00152 
00153                 elem::LockedRepartitionDown
00154                 ( A_Top,    A0,
00155                         
00156                             A1,
00157                   A_Bottom, A2, b );
00158 
00159                 elem::RepartitionDown
00160                 ( sketch_of_A1_Top,    sketch_of_A10,
00161                                    
00162                                        sketch_of_A11,
00163                   sketch_of_A1_Bottom, sketch_of_A12, b );
00164 
00165                 // Alltoall within process columns
00166                 A1_STAR_VR = A1;
00167 
00168                 // Global size of the result of the Local Gemm that follows
00169                 sketch_of_A11_STAR_STAR.Resize(A1_STAR_VR.Height(),
00170                                                R1.Height());
00171 
00172                 // Local Gemm
00173                 base::Gemm(elem::NORMAL,
00174                            elem::TRANSPOSE,
00175                            value_type(1),
00176                            A1_STAR_VR.LockedMatrix(),
00177                            R1.LockedMatrix(),
00178                            sketch_of_A11_STAR_STAR.Matrix());
00179 
00180                 // Reduce-scatter within process grid
00181                 sketch_of_A11.SumScatterFrom(sketch_of_A11_STAR_STAR);
00182 
00183                 elem::SlideLockedPartitionDown
00184                 ( A_Top,    A0,
00185                             A1,
00186                         
00187                   A_Bottom, A2 );
00188 
00189                 elem::SlidePartitionDown
00190                 ( sketch_of_A1_Top,    sketch_of_A10,
00191                                        sketch_of_A11,
00192                                    
00193                   sketch_of_A1_Bottom, sketch_of_A12 );
00194 
00195             }
00196 
00197             base = base + b;
00198 
00199             elem::SlidePartitionRight
00200             ( sketch_of_A_Left,                sketch_of_A_Right,
00201               sketch_of_A0,     sketch_of_A1,  sketch_of_A2 );
00202 
00203         }
00204     }
00205 
00206 
00210     void inner_panel_gemm(const matrix_type& A,
00211                           output_matrix_type& sketch_of_A,
00212                           skylark::sketch::columnwise_tag) const {
00213 
00214         const elem::Grid& grid = A.Grid();
00215 
00216         elem::DistMatrix<value_type, elem::STAR, elem::VC> R1(grid);
00217         elem::DistMatrix<value_type>
00218             A_Left(grid),
00219             A_Right(grid),
00220             A0(grid),
00221             A1(grid),
00222             A2(grid);
00223         elem::DistMatrix<value_type>
00224             sketch_of_A_Top(grid),
00225             sketch_of_A_Bottom(grid),
00226             sketch_of_A0(grid),
00227             sketch_of_A1(grid),
00228             sketch_of_A2(grid),
00229             sketch_of_A1_Left(grid),
00230             sketch_of_A1_Right(grid),
00231             sketch_of_A10(grid),
00232             sketch_of_A11(grid),
00233             sketch_of_A12(grid);
00234         elem::DistMatrix<value_type, elem::VC, elem::STAR>
00235             A1_VC_STAR(grid);
00236         elem::DistMatrix<value_type, elem::STAR, elem::STAR>
00237             sketch_of_A11_STAR_STAR(grid);
00238 
00239         // TODO: are alignments necessary?
00240 
00241         elem::PartitionDown
00242         ( sketch_of_A,
00243           sketch_of_A_Top, sketch_of_A_Bottom, 0 );
00244 
00245         // TODO: Allow for different blocksizes in "down" and "right" directions
00246         int blocksize = get_blocksize();
00247         if (blocksize == 0) {
00248             blocksize = std::min(sketch_of_A.Height(), sketch_of_A.Width());
00249         }
00250         int base = 0;
00251         while (sketch_of_A_Bottom.Height() > 0) {
00252 
00253             int b = std::min(sketch_of_A_Bottom.Height(), blocksize);
00254             data_type::realize_matrix_view(R1, base, 0,
00255                                                b,    A.Height());
00256 
00257             elem::RepartitionDown
00258             ( sketch_of_A_Top,     sketch_of_A0,
00259                                
00260                                    sketch_of_A1,
00261               sketch_of_A_Bottom, sketch_of_A2,  b );
00262 
00263             // TODO: is alignment necessary?
00264             A1_VC_STAR.AlignWith(R1);
00265 
00266             elem::LockedPartitionRight
00267             ( A,
00268               A_Left, A_Right, 0 );
00269 
00270             elem::PartitionRight
00271             ( sketch_of_A1,
00272               sketch_of_A1_Left, sketch_of_A1_Right, 0 );
00273 
00274             while(A_Right.Width() > 0) {
00275 
00276                 elem::LockedRepartitionRight
00277                 ( A_Left,      A_Right,
00278                   A0,      A1, A2,      b );
00279 
00280                 elem::RepartitionRight
00281                 ( sketch_of_A1_Left,                 sketch_of_A1_Right,
00282                   sketch_of_A10,      sketch_of_A11, sketch_of_A12,      b);
00283 
00284                 // Alltoall within process rows
00285                 A1_VC_STAR = A1;
00286 
00287                 // Global size of the result of the Local Gemm that follows
00288                 sketch_of_A11_STAR_STAR.Resize(R1.Height(),
00289                                                A1_VC_STAR.Width());
00290 
00291                 // Local Gemm
00292                 base::Gemm(elem::NORMAL,
00293                            elem::NORMAL,
00294                            value_type(1),
00295                            R1.LockedMatrix(),
00296                            A1_VC_STAR.LockedMatrix(),
00297                            value_type(0),
00298                            sketch_of_A11_STAR_STAR.Matrix());
00299 
00300                 // Reduce-scatter within process grid
00301                 sketch_of_A11.SumScatterFrom(sketch_of_A11_STAR_STAR);
00302 
00303                 elem::SlideLockedPartitionRight
00304                 ( A_Left,      A_Right,
00305                   A0,     A1,  A2 );
00306 
00307                 elem::SlidePartitionRight
00308                 ( sketch_of_A1_Left,                 sketch_of_A1_Right,
00309                   sketch_of_A10,     sketch_of_A11,  sketch_of_A12 );
00310 
00311             }
00312 
00313             base = base + b;
00314 
00315             elem::SlidePartitionDown
00316             ( sketch_of_A_Top,    sketch_of_A0,
00317                                   sketch_of_A1,
00318                               
00319               sketch_of_A_Bottom, sketch_of_A2 );
00320 
00321         }
00322     }
00323 
00324 
00328     void outer_panel_gemm(const matrix_type& A,
00329                           output_matrix_type& sketch_of_A,
00330                           skylark::sketch::rowwise_tag) const {
00331 
00332 
00333         const elem::Grid& grid = A.Grid();
00334 
00335         elem::DistMatrix<value_type, elem::MR, elem::STAR> R1(grid);
00336         elem::DistMatrix<value_type>
00337             A_Left(grid),
00338             A_Right(grid),
00339             A0(grid),
00340             A1(grid),
00341             A2(grid);
00342         elem::DistMatrix<value_type, elem::MC, elem::STAR>
00343             A1_MC_STAR(grid);
00344 
00345         // Zero sketch_of_A
00346         elem::Zero(sketch_of_A);
00347 
00348         // TODO: are alignments necessary?
00349         R1.AlignWith(sketch_of_A);
00350         A1_MC_STAR.AlignWith(sketch_of_A);
00351 
00352         elem::LockedPartitionRight
00353         ( A,
00354           A_Left, A_Right, 0 );
00355 
00356         int blocksize = get_blocksize();
00357         if (blocksize == 0) {
00358             blocksize = A_Right.Width();
00359         }
00360         int base = 0;
00361         while (A_Right.Width() > 0) {
00362 
00363             int b = std::min(A_Right.Width(), blocksize);
00364             data_type::realize_matrix_view(R1, 0,                   base,
00365                                                sketch_of_A.Width(), b);
00366 
00367             elem::RepartitionRight
00368             ( A_Left,       A_Right,
00369               A0,      A1,  A2,      b );
00370 
00371             // Allgather within process rows
00372             A1_MC_STAR = A1;
00373 
00374             // Local Gemm
00375             base::Gemm(elem::NORMAL,
00376                        elem::TRANSPOSE,
00377                        value_type(1),
00378                        A1_MC_STAR.LockedMatrix(),
00379                        R1.LockedMatrix(),
00380                        value_type(1),
00381                        sketch_of_A.Matrix());
00382 
00383             base = base + b;
00384 
00385             elem::SlidePartitionRight
00386             ( A_Left,      A_Right,
00387               A0,     A1,  A2 );
00388 
00389         }
00390     }
00391 
00392 
00396     void outer_panel_gemm(const matrix_type& A,
00397                           output_matrix_type& sketch_of_A,
00398                           skylark::sketch::columnwise_tag) const {
00399 
00400         const elem::Grid& grid = A.Grid();
00401 
00402         elem::DistMatrix<value_type, elem::MC, elem::STAR> R1(grid);
00403         elem::DistMatrix<value_type>
00404             A_Top(grid),
00405             A_Bottom(grid),
00406             A0(grid),
00407             A1(grid),
00408             A2(grid);
00409         elem::DistMatrix<value_type, elem::MR, elem::STAR>
00410             A1Trans_MR_STAR(grid);
00411 
00412         // Zero sketch_of_A
00413         elem::Zero(sketch_of_A);
00414 
00415         // TODO: are alignments necessary?
00416         R1.AlignWith(sketch_of_A);
00417         A1Trans_MR_STAR.AlignWith(sketch_of_A);
00418 
00419         elem::LockedPartitionDown
00420         ( A,
00421           A_Top, A_Bottom, 0 );
00422 
00423         int blocksize = get_blocksize();
00424         if (blocksize == 0) {
00425             blocksize = A_Bottom.Height();
00426         }
00427         int base = 0;
00428         while (A_Bottom.Height() > 0) {
00429 
00430             int b = std::min(A_Bottom.Height(), blocksize);
00431             data_type::realize_matrix_view(R1, 0,                   base,
00432                                                sketch_of_A.Height(), b);
00433 
00434             elem::RepartitionDown
00435              ( A_Top,    A0,
00436                      
00437                          A1,
00438                A_Bottom, A2, b );
00439 
00440 
00441             // Global size of the target of Allgather that follows
00442             A1Trans_MR_STAR.Resize(A1.Width(),
00443                                    A1.Height());
00444 
00445             // Allgather within process columns
00446             // TODO: Describe cache benefits from transposition:
00447             //       why not simply use A1[STAR, MR]?
00448             A1.TransposeColAllGather(A1Trans_MR_STAR);
00449 
00450             // Local Gemm
00451             base::Gemm(elem::NORMAL,
00452                        elem::TRANSPOSE,
00453                        value_type(1),
00454                        R1.LockedMatrix(),
00455                        A1Trans_MR_STAR.LockedMatrix(),
00456                        value_type(1),
00457                        sketch_of_A.Matrix());
00458 
00459             base = base + b;
00460 
00461             elem::SlidePartitionDown
00462             ( A_Top,    A0,
00463                         A1,
00464                     
00465               A_Bottom, A2 );
00466         }
00467     }
00468 
00469 
00473     void matrix_panel_gemm(const matrix_type& A,
00474                           output_matrix_type& sketch_of_A,
00475                           skylark::sketch::rowwise_tag) const {
00476 
00477         const elem::Grid& grid = A.Grid();
00478 
00479         elem::DistMatrix<value_type, elem::STAR, elem::MR> R1(grid);
00480         elem::DistMatrix<value_type>
00481             sketch_of_A_Left(grid),
00482             sketch_of_A_Right(grid),
00483             sketch_of_A0(grid),
00484             sketch_of_A1(grid),
00485             sketch_of_A2(grid);
00486         elem::DistMatrix<value_type, elem::MC, elem::STAR>
00487             sketch_of_A_temp(grid);
00488 
00489         // TODO: are alignments necessary?
00490         R1.AlignWith(sketch_of_A);
00491         sketch_of_A_temp.AlignWith(sketch_of_A);
00492 
00493         elem::PartitionRight
00494         ( sketch_of_A,
00495           sketch_of_A_Left, sketch_of_A_Right, 0 );
00496 
00497         int blocksize = get_blocksize();
00498         if (blocksize == 0) {
00499             blocksize = sketch_of_A_Right.Width();
00500         }
00501         int base = 0;
00502         while (sketch_of_A_Right.Width() > 0) {
00503 
00504             int b = std::min(sketch_of_A_Right.Width(), blocksize);
00505             data_type::realize_matrix_view(R1, base, 0,
00506                                                b,    A.Width());
00507 
00508             elem::RepartitionRight
00509             ( sketch_of_A_Left,                sketch_of_A_Right,
00510               sketch_of_A0,      sketch_of_A1, sketch_of_A2,      b );
00511 
00512             // Global size of the result of the Local Gemm that follows
00513             sketch_of_A_temp.Resize(sketch_of_A.Height(),
00514                                     R1.Height());
00515 
00516             // Local Gemm
00517             base::Gemm(elem::NORMAL,
00518                        elem::TRANSPOSE,
00519                        value_type(1),
00520                        A.LockedMatrix(),
00521                        R1.LockedMatrix(),
00522                        sketch_of_A_temp.Matrix());
00523 
00524             // Reduce-scatter within row communicators
00525             sketch_of_A1.RowSumScatterFrom(sketch_of_A_temp);
00526 
00527             base = base + b;
00528 
00529             elem::SlidePartitionRight
00530             ( sketch_of_A_Left,                sketch_of_A_Right,
00531               sketch_of_A0,     sketch_of_A1,  sketch_of_A2 );
00532 
00533         }
00534     }
00535 
00536 
00540     void panel_matrix_gemm(const matrix_type& A,
00541                           output_matrix_type& sketch_of_A,
00542                           skylark::sketch::columnwise_tag) const {
00543 
00544         const elem::Grid& grid = A.Grid();
00545 
00546         elem::DistMatrix<value_type, elem::STAR, elem::MC> R1(grid);
00547         elem::DistMatrix<value_type>
00548             sketch_of_A_Top(grid),
00549             sketch_of_A_Bottom(grid),
00550             sketch_of_A0(grid),
00551             sketch_of_A1(grid),
00552             sketch_of_A2(grid);
00553         elem::DistMatrix<value_type, elem::STAR, elem::MR>
00554             sketch_of_A_temp(grid);
00555 
00556         // TODO: are alignments necessary?
00557         R1.AlignWith(A);
00558         sketch_of_A_temp.AlignWith(A);
00559 
00560         elem::PartitionDown
00561         ( sketch_of_A,
00562           sketch_of_A_Top, sketch_of_A_Bottom, 0 );
00563 
00564         int blocksize = get_blocksize();
00565         if (blocksize == 0) {
00566             blocksize = sketch_of_A_Bottom.Height();
00567         }
00568         int base = 0;
00569         while (sketch_of_A_Bottom.Height() > 0) {
00570 
00571             int b = std::min(sketch_of_A_Bottom.Height(), blocksize);
00572 
00573             data_type::realize_matrix_view(R1, base, 0,
00574                                                b,    A.Height());
00575 
00576             elem::RepartitionDown
00577             ( sketch_of_A_Top,     sketch_of_A0,
00578                                
00579                                    sketch_of_A1,
00580               sketch_of_A_Bottom, sketch_of_A2, b );
00581 
00582             // Global size of the result of the Local Gemm that follows
00583             sketch_of_A_temp.Resize(R1.Height(),
00584                                     A.Width());
00585 
00586             // Local Gemm
00587             base::Gemm(elem::NORMAL,
00588                        elem::NORMAL,
00589                        value_type(1),
00590                        R1.LockedMatrix(),
00591                        A.LockedMatrix(),
00592                        sketch_of_A_temp.Matrix());
00593 
00594             // Reduce-scatter within column communicators
00595             sketch_of_A1.ColSumScatterFrom(sketch_of_A_temp);
00596 
00597             // Reduce-scatter within column communicators
00598             // sketch_of_A1.ColSumScatterUpdate(value_type(1),
00599             //    sketch_of_A_temp);
00600 
00601             base = base + b;
00602 
00603             elem::SlidePartitionDown
00604             ( sketch_of_A_Top,    sketch_of_A0,
00605                                   sketch_of_A1,
00606                               
00607               sketch_of_A_Bottom, sketch_of_A2 );
00608         }
00609     }
00610 
00611 
00612     void sketch_gemm(const matrix_type& A,
00613         output_matrix_type& sketch_of_A,
00614         skylark::sketch::rowwise_tag tag) const {
00615 
00616         const int sketch_height = sketch_of_A.Height();
00617         const int sketch_width  = sketch_of_A.Width();
00618         const int width         = A.Width();
00619 
00620         const double factor = get_factor();
00621 
00622         if((sketch_height * factor <= width) &&
00623             (sketch_width * factor <= width))
00624             inner_panel_gemm(A, sketch_of_A, tag);
00625         else if((sketch_height >= width * factor) &&
00626             (sketch_width >= width * factor))
00627             outer_panel_gemm(A, sketch_of_A, tag);
00628         else
00629             matrix_panel_gemm(A, sketch_of_A, tag);
00630     }
00631 
00632 
00633     void sketch_gemm(const matrix_type& A,
00634         output_matrix_type& sketch_of_A,
00635         skylark::sketch::columnwise_tag tag) const {
00636 
00637         const int sketch_height = sketch_of_A.Height();
00638         const int sketch_width  = sketch_of_A.Width();
00639         const int height         = A.Height();
00640 
00641         const double factor = get_factor();
00642 
00643         if((sketch_height * factor <= height) &&
00644             (sketch_width * factor <= height))
00645             inner_panel_gemm(A, sketch_of_A, tag);
00646         else if((sketch_height >= height * factor) &&
00647             (sketch_width >= height * factor))
00648             outer_panel_gemm(A, sketch_of_A, tag);
00649         else
00650             panel_matrix_gemm(A, sketch_of_A, tag);
00651     }
00652 
00653     void apply_impl_dist (const matrix_type& A,
00654                           output_matrix_type& sketch_of_A,
00655                           skylark::sketch::rowwise_tag tag) const {
00656 
00657         sketch_gemm(A, sketch_of_A, tag);
00658     }
00659 
00660 
00661     void apply_impl_dist (const matrix_type& A,
00662                           output_matrix_type& sketch_of_A,
00663                           skylark::sketch::columnwise_tag tag) const {
00664 
00665         sketch_gemm(A, sketch_of_A, tag);
00666     }
00667 
00668 };
00669 
00670 } } 
00672 #endif // SKYLARK_DENSE_TRANSFORM_ELEMENTAL_MC_MR_HPP