Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_DENSE_TRANSFORM_ELEMENTAL_STAR_ROWDIST_HPP 00002 #define SKYLARK_DENSE_TRANSFORM_ELEMENTAL_STAR_ROWDIST_HPP 00003 00004 #include "../base/base.hpp" 00005 00006 #include "transforms.hpp" 00007 #include "dense_transform_data.hpp" 00008 #include "../utility/comm.hpp" 00009 #include "../utility/get_communicator.hpp" 00010 00011 #include "sketch_params.hpp" 00012 00013 namespace skylark { namespace sketch { 00017 template <typename ValueType, 00018 elem::Distribution RowDist, 00019 template <typename> class ValueDistribution> 00020 struct dense_transform_t < 00021 elem::DistMatrix<ValueType, elem::STAR, RowDist>, 00022 elem::DistMatrix<ValueType, elem::STAR, RowDist>, 00023 ValueDistribution> : 00024 public dense_transform_data_t<ValueDistribution> { 00025 // Typedef matrix and distribution types so that we can use them regularly 00026 typedef ValueType value_type; 00027 typedef elem::DistMatrix<value_type, elem::STAR, RowDist> matrix_type; 00028 typedef elem::DistMatrix<value_type, elem::STAR, RowDist> 00029 output_matrix_type; 00030 typedef ValueDistribution<value_type> value_distribution_type; 00031 typedef dense_transform_data_t<ValueDistribution> data_type; 00032 00036 dense_transform_t (int N, int S, double scale, base::context_t& context) 00037 : data_type (N, S, scale, context) { 00038 00039 } 00040 00041 00045 dense_transform_t (dense_transform_t<matrix_type, 00046 output_matrix_type, 00047 ValueDistribution>& other) 00048 : data_type(other) {} 00049 00053 dense_transform_t(const data_type& other_data) 00054 : data_type(other_data) {} 00055 00056 00060 template <typename Dimension> 00061 void apply (const matrix_type& A, 00062 output_matrix_type& sketch_of_A, 00063 Dimension dimension) const { 00064 00065 switch(RowDist) { 00066 case elem::VR: 00067 case elem::VC: 00068 try { 00069 apply_impl_vdist (A, sketch_of_A, dimension); 00070 } catch (std::logic_error e) { 00071 SKYLARK_THROW_EXCEPTION ( 00072 base::elemental_exception() 00073 << base::error_msg(e.what()) ); 00074 } catch(boost::mpi::exception e) { 00075 SKYLARK_THROW_EXCEPTION ( 00076 base::mpi_exception() 00077 << base::error_msg(e.what()) ); 00078 } 00079 00080 break; 00081 00082 default: 00083 SKYLARK_THROW_EXCEPTION ( 00084 base::unsupported_matrix_distribution() ); 00085 } 00086 } 00087 00088 int get_N() const { return this->_N; } 00089 int get_S() const { return this->_S; } 00091 const sketch_transform_data_t* get_data() const { return this; } 00092 00093 private: 00094 00095 // Communication demanding scenario: Memory-oblivious mode 00096 // TODO: Block-by-block mode 00097 void inner_panel_gemm(const matrix_type& A, 00098 output_matrix_type& sketch_of_A, 00099 skylark::sketch::rowwise_tag) const { 00100 00101 const elem::Grid& grid = A.Grid(); 00102 00103 elem::DistMatrix<value_type, elem::STAR, RowDist> R(grid); 00104 elem::DistMatrix<value_type, elem::STAR, elem::STAR> 00105 sketch_of_A_STAR_STAR(grid); 00106 00107 data_type::realize_matrix_view(R); 00108 00109 // TODO: is alignment necessary? 00110 00111 // Global size of the result of the Local Gemm that follows 00112 sketch_of_A_STAR_STAR.Resize(A.Height(), 00113 R.Height()); 00114 00115 // Local Gemm 00116 base::Gemm(elem::NORMAL, 00117 elem::TRANSPOSE, 00118 value_type(1), 00119 A.LockedMatrix(), 00120 R.LockedMatrix(), 00121 sketch_of_A_STAR_STAR.Matrix()); 00122 00123 // Reduce-scatter within process grid 00124 sketch_of_A.SumScatterFrom(sketch_of_A_STAR_STAR); 00125 } 00126 00127 00131 void inner_panel_gemm(const matrix_type& A, 00132 output_matrix_type& sketch_of_A, 00133 skylark::sketch::columnwise_tag) const { 00134 00135 const elem::Grid& grid = A.Grid(); 00136 00137 elem::DistMatrix<value_type, elem::STAR, elem::STAR> R1(grid); 00138 elem::DistMatrix<value_type, elem::STAR, RowDist> 00139 A_Left(grid), 00140 A_Right(grid), 00141 A0(grid), 00142 A1(grid), 00143 A2(grid); 00144 elem::DistMatrix<value_type, elem::STAR, RowDist> 00145 sketch_of_A_Top(grid), 00146 sketch_of_A_Bottom(grid), 00147 sketch_of_A0(grid), 00148 sketch_of_A1(grid), 00149 sketch_of_A2(grid), 00150 sketch_of_A1_Left(grid), 00151 sketch_of_A1_Right(grid), 00152 sketch_of_A10(grid), 00153 sketch_of_A11(grid), 00154 sketch_of_A12(grid); 00155 00156 // TODO: are alignments necessary? 00157 00158 elem::PartitionDown 00159 ( sketch_of_A, 00160 sketch_of_A_Top, sketch_of_A_Bottom, 0 ); 00161 00162 // TODO: Allow for different blocksizes in "down" and "right" directions 00163 int blocksize = get_blocksize(); 00164 if (blocksize == 0) { 00165 blocksize = std::min(sketch_of_A.Height(), sketch_of_A.Width()); 00166 } 00167 int base = 0; 00168 while (sketch_of_A_Bottom.Height() > 0) { 00169 00170 int b = std::min(sketch_of_A_Bottom.Height(), blocksize); 00171 data_type::realize_matrix_view(R1, base, 0, 00172 b, A.Height()); 00173 00174 elem::RepartitionDown 00175 ( sketch_of_A_Top, sketch_of_A0, 00176 00177 sketch_of_A1, 00178 sketch_of_A_Bottom, sketch_of_A2, b ); 00179 00180 00181 elem::LockedPartitionRight 00182 ( A, 00183 A_Left, A_Right, 0 ); 00184 00185 elem::PartitionRight 00186 ( sketch_of_A1, 00187 sketch_of_A1_Left, sketch_of_A1_Right, 0 ); 00188 00189 while(A_Right.Width() > 0) { 00190 00191 elem::LockedRepartitionRight 00192 ( A_Left, A_Right, 00193 A0, A1, A2, b ); 00194 00195 elem::RepartitionRight 00196 ( sketch_of_A1_Left, sketch_of_A1_Right, 00197 sketch_of_A10, sketch_of_A11, sketch_of_A12, b ); 00198 00199 // Local Gemm 00200 base::Gemm(elem::NORMAL, 00201 elem::NORMAL, 00202 value_type(1), 00203 R1.LockedMatrix(), 00204 A1.LockedMatrix(), 00205 sketch_of_A11.Matrix()); 00206 00207 elem::SlideLockedPartitionRight 00208 ( A_Left, A_Right, 00209 A0, A1, A2 ); 00210 00211 elem::SlidePartitionRight 00212 ( sketch_of_A1_Left, sketch_of_A1_Right, 00213 sketch_of_A10, sketch_of_A11, sketch_of_A12 ); 00214 } 00215 00216 base = base + b; 00217 00218 elem::SlidePartitionDown 00219 ( sketch_of_A_Top, sketch_of_A0, 00220 sketch_of_A1, 00221 00222 sketch_of_A_Bottom, sketch_of_A2 ); 00223 } 00224 } 00225 00226 00231 // Communication demanding scenario: Memory-oblivious mode 00232 // TODO: Block-by-block mode 00233 void outer_panel_gemm(const matrix_type& A, 00234 output_matrix_type& sketch_of_A, 00235 skylark::sketch::rowwise_tag) const { 00236 00237 const elem::Grid& grid = A.Grid(); 00238 00239 elem::DistMatrix<value_type, RowDist, elem::STAR> R(grid); 00240 elem::DistMatrix<value_type, elem::STAR, elem::STAR> 00241 A_STAR_STAR(grid); 00242 00243 // TODO: are alignments necessary? 00244 R.AlignWith(sketch_of_A); 00245 A_STAR_STAR.AlignWith(sketch_of_A); 00246 00247 data_type::realize_matrix_view(R); 00248 00249 // Allgather within process grid 00250 A_STAR_STAR = A; 00251 00252 // Zero sketch_of_A 00253 elem::Zero(sketch_of_A); 00254 00255 // Local Gemm 00256 base::Gemm(elem::NORMAL, 00257 elem::TRANSPOSE, 00258 value_type(1), 00259 A_STAR_STAR.LockedMatrix(), 00260 R.LockedMatrix(), 00261 value_type(1), 00262 sketch_of_A.Matrix()); 00263 00264 } 00265 00266 00270 void outer_panel_gemm(const matrix_type& A, 00271 output_matrix_type& sketch_of_A, 00272 skylark::sketch::columnwise_tag) const { 00273 00274 const elem::Grid& grid = A.Grid(); 00275 00276 elem::DistMatrix<value_type, elem::STAR, elem::STAR> R1(grid); 00277 elem::DistMatrix<value_type, elem::STAR, RowDist> 00278 A_Top(grid), 00279 A_Bottom(grid), 00280 A0(grid), 00281 A1(grid), 00282 A2(grid); 00283 00284 // Zero sketch_of_A 00285 elem::Zero(sketch_of_A); 00286 00287 // TODO: is alignment necessary? 00288 R1.AlignWith(sketch_of_A); 00289 00290 elem::LockedPartitionDown 00291 ( A, 00292 A_Top, A_Bottom, 0 ); 00293 00294 int blocksize = get_blocksize(); 00295 if (blocksize == 0) { 00296 blocksize = A_Bottom.Height(); 00297 } 00298 int base = 0; 00299 while (A_Bottom.Height() > 0) { 00300 00301 int b = std::min(A_Bottom.Height(), blocksize); 00302 data_type::realize_matrix_view(R1, 0, base, 00303 sketch_of_A.Height(), b); 00304 00305 elem::RepartitionDown 00306 ( A_Top, A0, 00307 00308 A1, 00309 A_Bottom, A2, b ); 00310 00311 // Local Gemm 00312 base::Gemm(elem::NORMAL, 00313 elem::NORMAL, 00314 value_type(1), 00315 R1.LockedMatrix(), 00316 A1.LockedMatrix(), 00317 value_type(1), 00318 sketch_of_A.Matrix()); 00319 00320 base = base + b; 00321 00322 elem::SlidePartitionDown 00323 ( A_Top, A0, 00324 A1, 00325 00326 A_Bottom, A2 ); 00327 } 00328 } 00329 00330 00335 // Communication demanding scenario: Memory-oblivious mode 00336 // TODO: Block-by-block mode 00337 void matrix_panel_gemm(const matrix_type& A, 00338 output_matrix_type& sketch_of_A, 00339 skylark::sketch::rowwise_tag) const { 00340 00341 const elem::Grid& grid = A.Grid(); 00342 00343 elem::DistMatrix<value_type, elem::STAR, RowDist> R(grid); 00344 elem::DistMatrix<value_type, elem::STAR, elem::STAR> 00345 sketch_of_A_STAR_STAR(grid); 00346 00347 // TODO: are alignments necessary? 00348 R.AlignWith(sketch_of_A); 00349 sketch_of_A_STAR_STAR.AlignWith(sketch_of_A); 00350 00351 data_type::realize_matrix_view(R); 00352 00353 // Global size of the result of the Local Gemm that follows 00354 sketch_of_A_STAR_STAR.Resize(sketch_of_A.Height(), 00355 sketch_of_A.Width()); 00356 00357 // Local Gemm 00358 base::Gemm(elem::NORMAL, 00359 elem::TRANSPOSE, 00360 value_type(1), 00361 A.LockedMatrix(), 00362 R.LockedMatrix(), 00363 sketch_of_A_STAR_STAR.Matrix()); 00364 00365 // Reduce-scatter within process grid 00366 sketch_of_A.SumScatterFrom(sketch_of_A_STAR_STAR); 00367 00368 } 00369 00370 00374 void panel_matrix_gemm(const matrix_type& A, 00375 output_matrix_type& sketch_of_A, 00376 skylark::sketch::columnwise_tag) const { 00377 00378 const elem::Grid& grid = A.Grid(); 00379 00380 elem::DistMatrix<value_type, elem::STAR, elem::STAR> R1(grid); 00381 elem::DistMatrix<value_type, elem::STAR, RowDist> 00382 sketch_of_A_Top(grid), 00383 sketch_of_A_Bottom(grid), 00384 sketch_of_A0(grid), 00385 sketch_of_A1(grid), 00386 sketch_of_A2(grid); 00387 00388 // TODO: is alignment necessary? 00389 R1.AlignWith(A); 00390 00391 elem::PartitionDown 00392 ( sketch_of_A, 00393 sketch_of_A_Top, sketch_of_A_Bottom, 0 ); 00394 00395 int blocksize = get_blocksize(); 00396 if (blocksize == 0) { 00397 blocksize = sketch_of_A_Bottom.Height(); 00398 } 00399 int base = 0; 00400 while (sketch_of_A_Bottom.Height() > 0) { 00401 00402 int b = std::min(sketch_of_A_Bottom.Height(), blocksize); 00403 data_type::realize_matrix_view(R1, base, 0, 00404 b, A.Height()); 00405 00406 elem::RepartitionDown 00407 ( sketch_of_A_Top, sketch_of_A0, 00408 00409 sketch_of_A1, 00410 sketch_of_A_Bottom, sketch_of_A2, b ); 00411 00412 // Local Gemm 00413 base::Gemm(elem::NORMAL, 00414 elem::NORMAL, 00415 value_type(1), 00416 R1.LockedMatrix(), 00417 A.LockedMatrix(), 00418 sketch_of_A1.Matrix()); 00419 00420 base = base + b; 00421 00422 elem::SlidePartitionDown 00423 ( sketch_of_A_Top, sketch_of_A0, 00424 sketch_of_A1, 00425 00426 sketch_of_A_Bottom, sketch_of_A2 ); 00427 } 00428 } 00429 00430 00431 void sketch_gemm(const matrix_type& A, 00432 output_matrix_type& sketch_of_A, 00433 skylark::sketch::rowwise_tag tag) const { 00434 00435 const int sketch_height = sketch_of_A.Height(); 00436 const int sketch_width = sketch_of_A.Width(); 00437 const int width = A.Width(); 00438 00439 const double factor = get_factor(); 00440 00441 if((sketch_height * factor <= width) && 00442 (sketch_width * factor <= width)) 00443 inner_panel_gemm(A, sketch_of_A, tag); 00444 else if((sketch_height >= width * factor) && 00445 (sketch_width >= width * factor)) 00446 outer_panel_gemm(A, sketch_of_A, tag); 00447 else 00448 matrix_panel_gemm(A, sketch_of_A, tag); 00449 } 00450 00451 00452 void sketch_gemm(const matrix_type& A, 00453 output_matrix_type& sketch_of_A, 00454 skylark::sketch::columnwise_tag tag) const { 00455 00456 const int sketch_height = sketch_of_A.Height(); 00457 const int sketch_width = sketch_of_A.Width(); 00458 const int height = A.Height(); 00459 00460 const double factor = get_factor(); 00461 00462 if((sketch_height * factor <= height) && 00463 (sketch_width * factor <= height)) 00464 inner_panel_gemm(A, sketch_of_A, tag); 00465 else if((sketch_height >= height * factor) && 00466 (sketch_width >= height * factor)) 00467 outer_panel_gemm(A, sketch_of_A, tag); 00468 else 00469 panel_matrix_gemm(A, sketch_of_A, tag); 00470 } 00471 00472 00473 void apply_impl_vdist (const matrix_type& A, 00474 output_matrix_type& sketch_of_A, 00475 skylark::sketch::rowwise_tag tag) const { 00476 00477 sketch_gemm(A, sketch_of_A, tag); 00478 } 00479 00480 00481 void apply_impl_vdist (const matrix_type& A, 00482 output_matrix_type& sketch_of_A, 00483 skylark::sketch::columnwise_tag tag) const { 00484 00485 sketch_gemm(A, sketch_of_A, tag); 00486 } 00487 00488 }; 00489 00490 } } 00492 #endif // SKYLARK_DENSE_TRANSFORM_ELEMENTAL_STAR_ROWDIST_HPP