Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_DENSE_TRANSFORM_ELEMENTAL_MC_MR_HPP 00002 #define SKYLARK_DENSE_TRANSFORM_ELEMENTAL_MC_MR_HPP 00003 00004 #include "../base/base.hpp" 00005 00006 #include "transforms.hpp" 00007 #include "dense_transform_data.hpp" 00008 #include "../utility/comm.hpp" 00009 #include "../utility/get_communicator.hpp" 00010 00011 #include "sketch_params.hpp" 00012 00013 namespace skylark { namespace sketch { 00017 template <typename ValueType, 00018 template <typename> class ValueDistribution> 00019 struct dense_transform_t < 00020 elem::DistMatrix<ValueType>, 00021 elem::DistMatrix<ValueType>, 00022 ValueDistribution> : 00023 public dense_transform_data_t<ValueDistribution> { 00024 00025 // Typedef matrix and distribution types so that we can use them regularly 00026 typedef ValueType value_type; 00027 typedef elem::DistMatrix<value_type> matrix_type; 00028 typedef elem::DistMatrix<value_type> output_matrix_type; 00029 typedef ValueDistribution<value_type> value_distribution_type; 00030 typedef dense_transform_data_t<ValueDistribution> data_type; 00031 00032 00036 dense_transform_t (int N, int S, double scale, base::context_t& context) 00037 : data_type (N, S, scale, context) { 00038 00039 } 00040 00044 dense_transform_t (dense_transform_t<matrix_type, 00045 output_matrix_type, 00046 ValueDistribution>& other) 00047 : data_type(other) { 00048 00049 } 00050 00054 dense_transform_t(const data_type& other_data) 00055 : data_type(other_data) { 00056 00057 } 00058 00062 template <typename Dimension> 00063 void apply (const matrix_type& A, 00064 output_matrix_type& sketch_of_A, 00065 Dimension dimension) const { 00066 try { 00067 apply_impl_dist(A, sketch_of_A, dimension); 00068 } catch (std::logic_error e) { 00069 SKYLARK_THROW_EXCEPTION ( 00070 base::elemental_exception() 00071 << base::error_msg(e.what()) ); 00072 } catch(boost::mpi::exception e) { 00073 SKYLARK_THROW_EXCEPTION ( 00074 base::mpi_exception() 00075 << base::error_msg(e.what()) ); 00076 } 00077 } 00078 00079 int get_N() const { return this->_N; } 00080 int get_S() const { return this->_S; } 00082 const sketch_transform_data_t* get_data() const { return this; } 00083 00084 private: 00085 00089 void inner_panel_gemm(const matrix_type& A, 00090 output_matrix_type& sketch_of_A, 00091 skylark::sketch::rowwise_tag) const { 00092 00093 const elem::Grid& grid = A.Grid(); 00094 00095 elem::DistMatrix<value_type, elem::STAR, elem::VR> R1(grid); 00096 elem::DistMatrix<value_type> 00097 A_Top(grid), 00098 A_Bottom(grid), 00099 A0(grid), 00100 A1(grid), 00101 A2(grid); 00102 elem::DistMatrix<value_type> 00103 sketch_of_A_Left(grid), 00104 sketch_of_A_Right(grid), 00105 sketch_of_A0(grid), 00106 sketch_of_A1(grid), 00107 sketch_of_A2(grid), 00108 sketch_of_A1_Top(grid), 00109 sketch_of_A1_Bottom(grid), 00110 sketch_of_A10(grid), 00111 sketch_of_A11(grid), 00112 sketch_of_A12(grid); 00113 elem::DistMatrix<value_type, elem::STAR, elem::VR> 00114 A1_STAR_VR(grid); 00115 elem::DistMatrix<value_type, elem::STAR, elem::STAR> 00116 sketch_of_A11_STAR_STAR(grid); 00117 00118 // TODO: are alignments necessary? 00119 00120 elem::PartitionRight 00121 ( sketch_of_A, 00122 sketch_of_A_Left, sketch_of_A_Right, 0 ); 00123 00124 // TODO: Allow for different blocksizes in "down" and "right" directions 00125 int blocksize = get_blocksize(); 00126 if (blocksize == 0) { 00127 blocksize = std::min(sketch_of_A.Height(), sketch_of_A.Width()); 00128 } 00129 int base = 0; 00130 while (sketch_of_A_Right.Width() > 0) { 00131 00132 int b = std::min(sketch_of_A_Right.Width(), blocksize); 00133 data_type::realize_matrix_view(R1, base, 0, 00134 b, A.Width()); 00135 00136 elem::RepartitionRight 00137 ( sketch_of_A_Left, sketch_of_A_Right, 00138 sketch_of_A0, sketch_of_A1, sketch_of_A2, b ); 00139 00140 // TODO: is alignment necessary? 00141 A1_STAR_VR.AlignWith(R1); 00142 00143 elem::LockedPartitionDown 00144 ( A, 00145 A_Top, A_Bottom, 0 ); 00146 00147 elem::PartitionDown 00148 ( sketch_of_A1, 00149 sketch_of_A1_Top, sketch_of_A1_Bottom, 0 ); 00150 00151 while(A_Bottom.Height() > 0) { 00152 00153 elem::LockedRepartitionDown 00154 ( A_Top, A0, 00155 00156 A1, 00157 A_Bottom, A2, b ); 00158 00159 elem::RepartitionDown 00160 ( sketch_of_A1_Top, sketch_of_A10, 00161 00162 sketch_of_A11, 00163 sketch_of_A1_Bottom, sketch_of_A12, b ); 00164 00165 // Alltoall within process columns 00166 A1_STAR_VR = A1; 00167 00168 // Global size of the result of the Local Gemm that follows 00169 sketch_of_A11_STAR_STAR.Resize(A1_STAR_VR.Height(), 00170 R1.Height()); 00171 00172 // Local Gemm 00173 base::Gemm(elem::NORMAL, 00174 elem::TRANSPOSE, 00175 value_type(1), 00176 A1_STAR_VR.LockedMatrix(), 00177 R1.LockedMatrix(), 00178 sketch_of_A11_STAR_STAR.Matrix()); 00179 00180 // Reduce-scatter within process grid 00181 sketch_of_A11.SumScatterFrom(sketch_of_A11_STAR_STAR); 00182 00183 elem::SlideLockedPartitionDown 00184 ( A_Top, A0, 00185 A1, 00186 00187 A_Bottom, A2 ); 00188 00189 elem::SlidePartitionDown 00190 ( sketch_of_A1_Top, sketch_of_A10, 00191 sketch_of_A11, 00192 00193 sketch_of_A1_Bottom, sketch_of_A12 ); 00194 00195 } 00196 00197 base = base + b; 00198 00199 elem::SlidePartitionRight 00200 ( sketch_of_A_Left, sketch_of_A_Right, 00201 sketch_of_A0, sketch_of_A1, sketch_of_A2 ); 00202 00203 } 00204 } 00205 00206 00210 void inner_panel_gemm(const matrix_type& A, 00211 output_matrix_type& sketch_of_A, 00212 skylark::sketch::columnwise_tag) const { 00213 00214 const elem::Grid& grid = A.Grid(); 00215 00216 elem::DistMatrix<value_type, elem::STAR, elem::VC> R1(grid); 00217 elem::DistMatrix<value_type> 00218 A_Left(grid), 00219 A_Right(grid), 00220 A0(grid), 00221 A1(grid), 00222 A2(grid); 00223 elem::DistMatrix<value_type> 00224 sketch_of_A_Top(grid), 00225 sketch_of_A_Bottom(grid), 00226 sketch_of_A0(grid), 00227 sketch_of_A1(grid), 00228 sketch_of_A2(grid), 00229 sketch_of_A1_Left(grid), 00230 sketch_of_A1_Right(grid), 00231 sketch_of_A10(grid), 00232 sketch_of_A11(grid), 00233 sketch_of_A12(grid); 00234 elem::DistMatrix<value_type, elem::VC, elem::STAR> 00235 A1_VC_STAR(grid); 00236 elem::DistMatrix<value_type, elem::STAR, elem::STAR> 00237 sketch_of_A11_STAR_STAR(grid); 00238 00239 // TODO: are alignments necessary? 00240 00241 elem::PartitionDown 00242 ( sketch_of_A, 00243 sketch_of_A_Top, sketch_of_A_Bottom, 0 ); 00244 00245 // TODO: Allow for different blocksizes in "down" and "right" directions 00246 int blocksize = get_blocksize(); 00247 if (blocksize == 0) { 00248 blocksize = std::min(sketch_of_A.Height(), sketch_of_A.Width()); 00249 } 00250 int base = 0; 00251 while (sketch_of_A_Bottom.Height() > 0) { 00252 00253 int b = std::min(sketch_of_A_Bottom.Height(), blocksize); 00254 data_type::realize_matrix_view(R1, base, 0, 00255 b, A.Height()); 00256 00257 elem::RepartitionDown 00258 ( sketch_of_A_Top, sketch_of_A0, 00259 00260 sketch_of_A1, 00261 sketch_of_A_Bottom, sketch_of_A2, b ); 00262 00263 // TODO: is alignment necessary? 00264 A1_VC_STAR.AlignWith(R1); 00265 00266 elem::LockedPartitionRight 00267 ( A, 00268 A_Left, A_Right, 0 ); 00269 00270 elem::PartitionRight 00271 ( sketch_of_A1, 00272 sketch_of_A1_Left, sketch_of_A1_Right, 0 ); 00273 00274 while(A_Right.Width() > 0) { 00275 00276 elem::LockedRepartitionRight 00277 ( A_Left, A_Right, 00278 A0, A1, A2, b ); 00279 00280 elem::RepartitionRight 00281 ( sketch_of_A1_Left, sketch_of_A1_Right, 00282 sketch_of_A10, sketch_of_A11, sketch_of_A12, b); 00283 00284 // Alltoall within process rows 00285 A1_VC_STAR = A1; 00286 00287 // Global size of the result of the Local Gemm that follows 00288 sketch_of_A11_STAR_STAR.Resize(R1.Height(), 00289 A1_VC_STAR.Width()); 00290 00291 // Local Gemm 00292 base::Gemm(elem::NORMAL, 00293 elem::NORMAL, 00294 value_type(1), 00295 R1.LockedMatrix(), 00296 A1_VC_STAR.LockedMatrix(), 00297 value_type(0), 00298 sketch_of_A11_STAR_STAR.Matrix()); 00299 00300 // Reduce-scatter within process grid 00301 sketch_of_A11.SumScatterFrom(sketch_of_A11_STAR_STAR); 00302 00303 elem::SlideLockedPartitionRight 00304 ( A_Left, A_Right, 00305 A0, A1, A2 ); 00306 00307 elem::SlidePartitionRight 00308 ( sketch_of_A1_Left, sketch_of_A1_Right, 00309 sketch_of_A10, sketch_of_A11, sketch_of_A12 ); 00310 00311 } 00312 00313 base = base + b; 00314 00315 elem::SlidePartitionDown 00316 ( sketch_of_A_Top, sketch_of_A0, 00317 sketch_of_A1, 00318 00319 sketch_of_A_Bottom, sketch_of_A2 ); 00320 00321 } 00322 } 00323 00324 00328 void outer_panel_gemm(const matrix_type& A, 00329 output_matrix_type& sketch_of_A, 00330 skylark::sketch::rowwise_tag) const { 00331 00332 00333 const elem::Grid& grid = A.Grid(); 00334 00335 elem::DistMatrix<value_type, elem::MR, elem::STAR> R1(grid); 00336 elem::DistMatrix<value_type> 00337 A_Left(grid), 00338 A_Right(grid), 00339 A0(grid), 00340 A1(grid), 00341 A2(grid); 00342 elem::DistMatrix<value_type, elem::MC, elem::STAR> 00343 A1_MC_STAR(grid); 00344 00345 // Zero sketch_of_A 00346 elem::Zero(sketch_of_A); 00347 00348 // TODO: are alignments necessary? 00349 R1.AlignWith(sketch_of_A); 00350 A1_MC_STAR.AlignWith(sketch_of_A); 00351 00352 elem::LockedPartitionRight 00353 ( A, 00354 A_Left, A_Right, 0 ); 00355 00356 int blocksize = get_blocksize(); 00357 if (blocksize == 0) { 00358 blocksize = A_Right.Width(); 00359 } 00360 int base = 0; 00361 while (A_Right.Width() > 0) { 00362 00363 int b = std::min(A_Right.Width(), blocksize); 00364 data_type::realize_matrix_view(R1, 0, base, 00365 sketch_of_A.Width(), b); 00366 00367 elem::RepartitionRight 00368 ( A_Left, A_Right, 00369 A0, A1, A2, b ); 00370 00371 // Allgather within process rows 00372 A1_MC_STAR = A1; 00373 00374 // Local Gemm 00375 base::Gemm(elem::NORMAL, 00376 elem::TRANSPOSE, 00377 value_type(1), 00378 A1_MC_STAR.LockedMatrix(), 00379 R1.LockedMatrix(), 00380 value_type(1), 00381 sketch_of_A.Matrix()); 00382 00383 base = base + b; 00384 00385 elem::SlidePartitionRight 00386 ( A_Left, A_Right, 00387 A0, A1, A2 ); 00388 00389 } 00390 } 00391 00392 00396 void outer_panel_gemm(const matrix_type& A, 00397 output_matrix_type& sketch_of_A, 00398 skylark::sketch::columnwise_tag) const { 00399 00400 const elem::Grid& grid = A.Grid(); 00401 00402 elem::DistMatrix<value_type, elem::MC, elem::STAR> R1(grid); 00403 elem::DistMatrix<value_type> 00404 A_Top(grid), 00405 A_Bottom(grid), 00406 A0(grid), 00407 A1(grid), 00408 A2(grid); 00409 elem::DistMatrix<value_type, elem::MR, elem::STAR> 00410 A1Trans_MR_STAR(grid); 00411 00412 // Zero sketch_of_A 00413 elem::Zero(sketch_of_A); 00414 00415 // TODO: are alignments necessary? 00416 R1.AlignWith(sketch_of_A); 00417 A1Trans_MR_STAR.AlignWith(sketch_of_A); 00418 00419 elem::LockedPartitionDown 00420 ( A, 00421 A_Top, A_Bottom, 0 ); 00422 00423 int blocksize = get_blocksize(); 00424 if (blocksize == 0) { 00425 blocksize = A_Bottom.Height(); 00426 } 00427 int base = 0; 00428 while (A_Bottom.Height() > 0) { 00429 00430 int b = std::min(A_Bottom.Height(), blocksize); 00431 data_type::realize_matrix_view(R1, 0, base, 00432 sketch_of_A.Height(), b); 00433 00434 elem::RepartitionDown 00435 ( A_Top, A0, 00436 00437 A1, 00438 A_Bottom, A2, b ); 00439 00440 00441 // Global size of the target of Allgather that follows 00442 A1Trans_MR_STAR.Resize(A1.Width(), 00443 A1.Height()); 00444 00445 // Allgather within process columns 00446 // TODO: Describe cache benefits from transposition: 00447 // why not simply use A1[STAR, MR]? 00448 A1.TransposeColAllGather(A1Trans_MR_STAR); 00449 00450 // Local Gemm 00451 base::Gemm(elem::NORMAL, 00452 elem::TRANSPOSE, 00453 value_type(1), 00454 R1.LockedMatrix(), 00455 A1Trans_MR_STAR.LockedMatrix(), 00456 value_type(1), 00457 sketch_of_A.Matrix()); 00458 00459 base = base + b; 00460 00461 elem::SlidePartitionDown 00462 ( A_Top, A0, 00463 A1, 00464 00465 A_Bottom, A2 ); 00466 } 00467 } 00468 00469 00473 void matrix_panel_gemm(const matrix_type& A, 00474 output_matrix_type& sketch_of_A, 00475 skylark::sketch::rowwise_tag) const { 00476 00477 const elem::Grid& grid = A.Grid(); 00478 00479 elem::DistMatrix<value_type, elem::STAR, elem::MR> R1(grid); 00480 elem::DistMatrix<value_type> 00481 sketch_of_A_Left(grid), 00482 sketch_of_A_Right(grid), 00483 sketch_of_A0(grid), 00484 sketch_of_A1(grid), 00485 sketch_of_A2(grid); 00486 elem::DistMatrix<value_type, elem::MC, elem::STAR> 00487 sketch_of_A_temp(grid); 00488 00489 // TODO: are alignments necessary? 00490 R1.AlignWith(sketch_of_A); 00491 sketch_of_A_temp.AlignWith(sketch_of_A); 00492 00493 elem::PartitionRight 00494 ( sketch_of_A, 00495 sketch_of_A_Left, sketch_of_A_Right, 0 ); 00496 00497 int blocksize = get_blocksize(); 00498 if (blocksize == 0) { 00499 blocksize = sketch_of_A_Right.Width(); 00500 } 00501 int base = 0; 00502 while (sketch_of_A_Right.Width() > 0) { 00503 00504 int b = std::min(sketch_of_A_Right.Width(), blocksize); 00505 data_type::realize_matrix_view(R1, base, 0, 00506 b, A.Width()); 00507 00508 elem::RepartitionRight 00509 ( sketch_of_A_Left, sketch_of_A_Right, 00510 sketch_of_A0, sketch_of_A1, sketch_of_A2, b ); 00511 00512 // Global size of the result of the Local Gemm that follows 00513 sketch_of_A_temp.Resize(sketch_of_A.Height(), 00514 R1.Height()); 00515 00516 // Local Gemm 00517 base::Gemm(elem::NORMAL, 00518 elem::TRANSPOSE, 00519 value_type(1), 00520 A.LockedMatrix(), 00521 R1.LockedMatrix(), 00522 sketch_of_A_temp.Matrix()); 00523 00524 // Reduce-scatter within row communicators 00525 sketch_of_A1.RowSumScatterFrom(sketch_of_A_temp); 00526 00527 base = base + b; 00528 00529 elem::SlidePartitionRight 00530 ( sketch_of_A_Left, sketch_of_A_Right, 00531 sketch_of_A0, sketch_of_A1, sketch_of_A2 ); 00532 00533 } 00534 } 00535 00536 00540 void panel_matrix_gemm(const matrix_type& A, 00541 output_matrix_type& sketch_of_A, 00542 skylark::sketch::columnwise_tag) const { 00543 00544 const elem::Grid& grid = A.Grid(); 00545 00546 elem::DistMatrix<value_type, elem::STAR, elem::MC> R1(grid); 00547 elem::DistMatrix<value_type> 00548 sketch_of_A_Top(grid), 00549 sketch_of_A_Bottom(grid), 00550 sketch_of_A0(grid), 00551 sketch_of_A1(grid), 00552 sketch_of_A2(grid); 00553 elem::DistMatrix<value_type, elem::STAR, elem::MR> 00554 sketch_of_A_temp(grid); 00555 00556 // TODO: are alignments necessary? 00557 R1.AlignWith(A); 00558 sketch_of_A_temp.AlignWith(A); 00559 00560 elem::PartitionDown 00561 ( sketch_of_A, 00562 sketch_of_A_Top, sketch_of_A_Bottom, 0 ); 00563 00564 int blocksize = get_blocksize(); 00565 if (blocksize == 0) { 00566 blocksize = sketch_of_A_Bottom.Height(); 00567 } 00568 int base = 0; 00569 while (sketch_of_A_Bottom.Height() > 0) { 00570 00571 int b = std::min(sketch_of_A_Bottom.Height(), blocksize); 00572 00573 data_type::realize_matrix_view(R1, base, 0, 00574 b, A.Height()); 00575 00576 elem::RepartitionDown 00577 ( sketch_of_A_Top, sketch_of_A0, 00578 00579 sketch_of_A1, 00580 sketch_of_A_Bottom, sketch_of_A2, b ); 00581 00582 // Global size of the result of the Local Gemm that follows 00583 sketch_of_A_temp.Resize(R1.Height(), 00584 A.Width()); 00585 00586 // Local Gemm 00587 base::Gemm(elem::NORMAL, 00588 elem::NORMAL, 00589 value_type(1), 00590 R1.LockedMatrix(), 00591 A.LockedMatrix(), 00592 sketch_of_A_temp.Matrix()); 00593 00594 // Reduce-scatter within column communicators 00595 sketch_of_A1.ColSumScatterFrom(sketch_of_A_temp); 00596 00597 // Reduce-scatter within column communicators 00598 // sketch_of_A1.ColSumScatterUpdate(value_type(1), 00599 // sketch_of_A_temp); 00600 00601 base = base + b; 00602 00603 elem::SlidePartitionDown 00604 ( sketch_of_A_Top, sketch_of_A0, 00605 sketch_of_A1, 00606 00607 sketch_of_A_Bottom, sketch_of_A2 ); 00608 } 00609 } 00610 00611 00612 void sketch_gemm(const matrix_type& A, 00613 output_matrix_type& sketch_of_A, 00614 skylark::sketch::rowwise_tag tag) const { 00615 00616 const int sketch_height = sketch_of_A.Height(); 00617 const int sketch_width = sketch_of_A.Width(); 00618 const int width = A.Width(); 00619 00620 const double factor = get_factor(); 00621 00622 if((sketch_height * factor <= width) && 00623 (sketch_width * factor <= width)) 00624 inner_panel_gemm(A, sketch_of_A, tag); 00625 else if((sketch_height >= width * factor) && 00626 (sketch_width >= width * factor)) 00627 outer_panel_gemm(A, sketch_of_A, tag); 00628 else 00629 matrix_panel_gemm(A, sketch_of_A, tag); 00630 } 00631 00632 00633 void sketch_gemm(const matrix_type& A, 00634 output_matrix_type& sketch_of_A, 00635 skylark::sketch::columnwise_tag tag) const { 00636 00637 const int sketch_height = sketch_of_A.Height(); 00638 const int sketch_width = sketch_of_A.Width(); 00639 const int height = A.Height(); 00640 00641 const double factor = get_factor(); 00642 00643 if((sketch_height * factor <= height) && 00644 (sketch_width * factor <= height)) 00645 inner_panel_gemm(A, sketch_of_A, tag); 00646 else if((sketch_height >= height * factor) && 00647 (sketch_width >= height * factor)) 00648 outer_panel_gemm(A, sketch_of_A, tag); 00649 else 00650 panel_matrix_gemm(A, sketch_of_A, tag); 00651 } 00652 00653 void apply_impl_dist (const matrix_type& A, 00654 output_matrix_type& sketch_of_A, 00655 skylark::sketch::rowwise_tag tag) const { 00656 00657 sketch_gemm(A, sketch_of_A, tag); 00658 } 00659 00660 00661 void apply_impl_dist (const matrix_type& A, 00662 output_matrix_type& sketch_of_A, 00663 skylark::sketch::columnwise_tag tag) const { 00664 00665 sketch_gemm(A, sketch_of_A, tag); 00666 } 00667 00668 }; 00669 00670 } } 00672 #endif // SKYLARK_DENSE_TRANSFORM_ELEMENTAL_MC_MR_HPP