Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_DENSE_TRANSFORM_ELEMENTAL_COLDIST_STAR_HPP 00002 #define SKYLARK_DENSE_TRANSFORM_ELEMENTAL_COLDIST_STAR_HPP 00003 00004 #include "../base/base.hpp" 00005 00006 #include "transforms.hpp" 00007 #include "dense_transform_data.hpp" 00008 #include "../utility/comm.hpp" 00009 #include "../utility/get_communicator.hpp" 00010 00011 #include "sketch_params.hpp" 00012 00013 namespace skylark { namespace sketch { 00017 template <typename ValueType, 00018 elem::Distribution ColDist, 00019 template <typename> class ValueDistribution> 00020 struct dense_transform_t < 00021 elem::DistMatrix<ValueType, ColDist, elem::STAR>, 00022 elem::DistMatrix<ValueType, ColDist, elem::STAR>, 00023 ValueDistribution> : 00024 public dense_transform_data_t<ValueDistribution> { 00025 // Typedef matrix and distribution types so that we can use them regularly 00026 typedef ValueType value_type; 00027 typedef elem::DistMatrix<value_type, ColDist, elem::STAR> matrix_type; 00028 typedef elem::DistMatrix<value_type, ColDist, elem::STAR> 00029 output_matrix_type; 00030 typedef ValueDistribution<value_type> value_distribution_type; 00031 typedef dense_transform_data_t<ValueDistribution> data_type; 00032 00036 dense_transform_t (int N, int S, double scale, base::context_t& context) 00037 : data_type (N, S, scale, context) { 00038 00039 } 00040 00041 00045 dense_transform_t (dense_transform_t<matrix_type, 00046 output_matrix_type, 00047 ValueDistribution>& other) 00048 : data_type(other) {} 00049 00053 dense_transform_t(const data_type& other_data) 00054 : data_type(other_data) {} 00055 00056 00060 template <typename Dimension> 00061 void apply (const matrix_type& A, 00062 output_matrix_type& sketch_of_A, 00063 Dimension dimension) const { 00064 00065 switch(ColDist) { 00066 case elem::VR: 00067 case elem::VC: 00068 try { 00069 apply_impl_vdist (A, sketch_of_A, dimension); 00070 } catch (std::logic_error e) { 00071 SKYLARK_THROW_EXCEPTION ( 00072 base::elemental_exception() 00073 << base::error_msg(e.what()) ); 00074 } catch(boost::mpi::exception e) { 00075 SKYLARK_THROW_EXCEPTION ( 00076 base::mpi_exception() 00077 << base::error_msg(e.what()) ); 00078 } 00079 00080 break; 00081 00082 default: 00083 SKYLARK_THROW_EXCEPTION ( 00084 base::unsupported_matrix_distribution() ); 00085 } 00086 } 00087 00088 int get_N() const { return this->_N; } 00089 int get_S() const { return this->_S; } 00091 const sketch_transform_data_t* get_data() const { return this; } 00092 00093 private: 00094 00098 void inner_panel_gemm(const matrix_type& A, 00099 output_matrix_type& sketch_of_A, 00100 skylark::sketch::rowwise_tag) const { 00101 00102 const elem::Grid& grid = A.Grid(); 00103 00104 elem::DistMatrix<value_type, elem::STAR, elem::STAR> R1(grid); 00105 elem::DistMatrix<value_type, ColDist, elem::STAR> 00106 A_Top(grid), 00107 A_Bottom(grid), 00108 A0(grid), 00109 A1(grid), 00110 A2(grid); 00111 elem::DistMatrix<value_type, ColDist, elem::STAR> 00112 sketch_of_A_Left(grid), 00113 sketch_of_A_Right(grid), 00114 sketch_of_A0(grid), 00115 sketch_of_A1(grid), 00116 sketch_of_A2(grid), 00117 sketch_of_A1_Top(grid), 00118 sketch_of_A1_Bottom(grid), 00119 sketch_of_A10(grid), 00120 sketch_of_A11(grid), 00121 sketch_of_A12(grid); 00122 00123 // TODO: are alignments necessary? 00124 00125 elem::PartitionRight 00126 ( sketch_of_A, 00127 sketch_of_A_Left, sketch_of_A_Right, 0 ); 00128 00129 // TODO: Allow for different blocksizes in "down" and "right" directions 00130 int blocksize = get_blocksize(); 00131 if (blocksize == 0) { 00132 blocksize = std::min(sketch_of_A.Height(), sketch_of_A.Width()); 00133 } 00134 int base = 0; 00135 while (sketch_of_A_Right.Width() > 0) { 00136 00137 int b = std::min(sketch_of_A_Right.Width(), blocksize); 00138 data_type::realize_matrix_view(R1, base, 0, 00139 b, A.Width()); 00140 00141 elem::RepartitionRight 00142 ( sketch_of_A_Left, sketch_of_A_Right, 00143 sketch_of_A0, sketch_of_A1, sketch_of_A2, b ); 00144 00145 elem::LockedPartitionDown 00146 ( A, A_Top, 00147 A_Bottom, 0 ); 00148 00149 elem::PartitionDown 00150 ( sketch_of_A1, sketch_of_A1_Top, 00151 sketch_of_A1_Bottom, 0 ); 00152 00153 while(A_Bottom.Height() > 0) { 00154 elem::LockedRepartitionDown 00155 ( A_Top, A0, 00156 00157 A1, 00158 A_Bottom, A2, b ); 00159 00160 elem::RepartitionDown 00161 ( sketch_of_A1_Top, sketch_of_A10, 00162 00163 sketch_of_A11, 00164 sketch_of_A1_Bottom, sketch_of_A12, b ); 00165 00166 // Local Gemm 00167 base::Gemm(elem::NORMAL, 00168 elem::TRANSPOSE, 00169 value_type(1), 00170 A1.LockedMatrix(), 00171 R1.LockedMatrix(), 00172 sketch_of_A11.Matrix()); 00173 00174 elem::SlideLockedPartitionDown 00175 ( A_Top, A0, 00176 A1, 00177 00178 A_Bottom, A2 ); 00179 00180 elem::SlidePartitionDown 00181 ( sketch_of_A1_Top, sketch_of_A10, 00182 sketch_of_A11, 00183 00184 sketch_of_A1_Bottom, sketch_of_A12 ); 00185 } 00186 00187 base = base + b; 00188 00189 elem::SlidePartitionRight 00190 ( sketch_of_A_Left, sketch_of_A_Right, 00191 sketch_of_A0, sketch_of_A1, sketch_of_A2 ); 00192 } 00193 } 00194 00195 00200 // Communication demanding scenario: Memory-oblivious mode 00201 // TODO: Block-by-block mode 00202 void inner_panel_gemm(const matrix_type& A, 00203 output_matrix_type& sketch_of_A, 00204 skylark::sketch::columnwise_tag) const { 00205 00206 const elem::Grid& grid = A.Grid(); 00207 00208 elem::DistMatrix<value_type, elem::STAR, ColDist> R(grid); 00209 elem::DistMatrix<value_type, elem::STAR, elem::STAR> 00210 sketch_of_A_STAR_STAR(grid); 00211 00212 data_type::realize_matrix_view(R); 00213 00214 // TODO: is alignment necessary? 00215 00216 // Global size of the result of the Local Gemm that follows 00217 sketch_of_A_STAR_STAR.Resize(R.Height(), 00218 A.Width()); 00219 00220 // Local Gemm 00221 base::Gemm(elem::NORMAL, 00222 elem::NORMAL, 00223 value_type(1), 00224 R.LockedMatrix(), 00225 A.LockedMatrix(), 00226 sketch_of_A_STAR_STAR.Matrix()); 00227 00228 // Reduce-scatter within process grid 00229 sketch_of_A.SumScatterFrom(sketch_of_A_STAR_STAR); 00230 00231 } 00232 00233 00237 void outer_panel_gemm(const matrix_type& A, 00238 output_matrix_type& sketch_of_A, 00239 skylark::sketch::rowwise_tag) const { 00240 00241 const elem::Grid& grid = A.Grid(); 00242 00243 elem::DistMatrix<value_type, elem::STAR, elem::STAR> R1(grid); 00244 elem::DistMatrix<value_type, ColDist, elem::STAR> 00245 A_Left(grid), 00246 A_Right(grid), 00247 A0(grid), 00248 A1(grid), 00249 A2(grid); 00250 00251 // Zero sketch_of_A 00252 elem::Zero(sketch_of_A); 00253 00254 // TODO: is alignment necessary? 00255 R1.AlignWith(sketch_of_A); 00256 00257 elem::LockedPartitionRight 00258 ( A, 00259 A_Left, A_Right, 0 ); 00260 00261 int blocksize = get_blocksize(); 00262 if (blocksize == 0) { 00263 blocksize = A_Right.Width(); 00264 } 00265 int base = 0; 00266 while (A_Right.Width() > 0) { 00267 00268 int b = std::min(A_Right.Width(), blocksize); 00269 data_type::realize_matrix_view(R1, 0, base, 00270 sketch_of_A.Width(), b); 00271 00272 elem::RepartitionRight 00273 ( A_Left, A_Right, 00274 A0, A1, A2, b ); 00275 00276 // Local Gemm 00277 base::Gemm(elem::NORMAL, 00278 elem::TRANSPOSE, 00279 value_type(1), 00280 A1.LockedMatrix(), 00281 R1.LockedMatrix(), 00282 value_type(1), 00283 sketch_of_A.Matrix()); 00284 00285 base = base + b; 00286 00287 elem::SlidePartitionRight 00288 ( A_Left, A_Right, 00289 A0, A1, A2 ); 00290 00291 } 00292 } 00293 00294 00299 // Communication demanding scenario: Memory-oblivious mode 00300 // TODO: Block-by-block mode 00301 void outer_panel_gemm(const matrix_type& A, 00302 output_matrix_type& sketch_of_A, 00303 skylark::sketch::columnwise_tag) const { 00304 00305 const elem::Grid& grid = A.Grid(); 00306 00307 elem::DistMatrix<value_type, ColDist, elem::STAR> R(grid); 00308 elem::DistMatrix<value_type, elem::STAR, elem::STAR> 00309 A_STAR_STAR(grid); 00310 00311 // TODO: are alignments necessary? 00312 R.AlignWith(sketch_of_A); 00313 A_STAR_STAR.AlignWith(sketch_of_A); 00314 00315 data_type::realize_matrix_view(R); 00316 00317 // Zero sketch_of_A 00318 elem::Zero(sketch_of_A); 00319 00320 // Allgather within process grid 00321 A_STAR_STAR = A; 00322 00323 // Local Gemm 00324 base::Gemm(elem::NORMAL, 00325 elem::NORMAL, 00326 value_type(1), 00327 R.LockedMatrix(), 00328 A_STAR_STAR.LockedMatrix(), 00329 value_type(1), 00330 sketch_of_A.Matrix()); 00331 } 00332 00333 00337 void matrix_panel_gemm(const matrix_type& A, 00338 output_matrix_type& sketch_of_A, 00339 skylark::sketch::rowwise_tag) const { 00340 00341 const elem::Grid& grid = A.Grid(); 00342 00343 elem::DistMatrix<value_type, elem::STAR, elem::STAR> R1(grid); 00344 elem::DistMatrix<value_type, ColDist, elem::STAR> 00345 sketch_of_A_Left(grid), 00346 sketch_of_A_Right(grid), 00347 sketch_of_A0(grid), 00348 sketch_of_A1(grid), 00349 sketch_of_A2(grid); 00350 00351 // TODO: is alignment necessary? 00352 R1.AlignWith(sketch_of_A); 00353 00354 elem::PartitionRight 00355 ( sketch_of_A, 00356 sketch_of_A_Left, sketch_of_A_Right, 0 ); 00357 00358 int blocksize = get_blocksize(); 00359 if (blocksize == 0) { 00360 blocksize = sketch_of_A_Right.Width(); 00361 } 00362 int base = 0; 00363 while (sketch_of_A_Right.Width() > 0) { 00364 00365 int b = std::min(sketch_of_A_Right.Width(), blocksize); 00366 data_type::realize_matrix_view(R1, base, 0, 00367 b, A.Width()); 00368 00369 elem::RepartitionRight 00370 ( sketch_of_A_Left, sketch_of_A_Right, 00371 sketch_of_A0, sketch_of_A1, sketch_of_A2, b ); 00372 00373 // Local Gemm 00374 base::Gemm(elem::NORMAL, 00375 elem::TRANSPOSE, 00376 value_type(1), 00377 A.LockedMatrix(), 00378 R1.LockedMatrix(), 00379 sketch_of_A1.Matrix()); 00380 00381 base = base + b; 00382 00383 elem::SlidePartitionRight 00384 ( sketch_of_A_Left, sketch_of_A_Right, 00385 sketch_of_A0, sketch_of_A1, sketch_of_A2 ); 00386 00387 } 00388 } 00389 00390 00395 // Communication demanding scenario: Memory-oblivious mode 00396 // TODO: Block-by-block mode 00397 void panel_matrix_gemm(const matrix_type& A, 00398 output_matrix_type& sketch_of_A, 00399 skylark::sketch::columnwise_tag) const { 00400 00401 const elem::Grid& grid = A.Grid(); 00402 00403 elem::DistMatrix<value_type, elem::STAR, ColDist> R(grid); 00404 elem::DistMatrix<value_type, elem::STAR, elem::STAR> 00405 sketch_of_A_STAR_STAR(grid); 00406 00407 // TODO: is alignment necessary? 00408 sketch_of_A_STAR_STAR.AlignWith(A); 00409 00410 data_type::realize_matrix_view(R); 00411 00412 // Global size of the result of the Local Gemm that follows 00413 sketch_of_A_STAR_STAR.Resize(R.Height(), 00414 A.Width()); 00415 00416 // Local Gemm 00417 base::Gemm(elem::NORMAL, 00418 elem::NORMAL, 00419 value_type(1), 00420 R.LockedMatrix(), 00421 A.LockedMatrix(), 00422 sketch_of_A_STAR_STAR.Matrix()); 00423 00424 // Reduce-scatter within process grid 00425 sketch_of_A.SumScatterFrom(sketch_of_A_STAR_STAR); 00426 } 00427 00428 00429 void sketch_gemm(const matrix_type& A, 00430 output_matrix_type& sketch_of_A, 00431 skylark::sketch::rowwise_tag tag) const { 00432 00433 const int sketch_height = sketch_of_A.Height(); 00434 const int sketch_width = sketch_of_A.Width(); 00435 const int width = A.Width(); 00436 00437 const double factor = get_factor(); 00438 00439 if((sketch_height * factor <= width) && 00440 (sketch_width * factor <= width)) { 00441 inner_panel_gemm(A, sketch_of_A, tag); 00442 } else if((sketch_height >= width * factor) && 00443 (sketch_width >= width * factor)) 00444 outer_panel_gemm(A, sketch_of_A, tag); 00445 else 00446 matrix_panel_gemm(A, sketch_of_A, tag); 00447 } 00448 00449 00450 void sketch_gemm(const matrix_type& A, 00451 output_matrix_type& sketch_of_A, 00452 skylark::sketch::columnwise_tag tag) const { 00453 00454 const int sketch_height = sketch_of_A.Height(); 00455 const int sketch_width = sketch_of_A.Width(); 00456 const int height = A.Height(); 00457 00458 const double factor = get_factor(); 00459 00460 if((sketch_height * factor <= height) && 00461 (sketch_width * factor <= height)) 00462 inner_panel_gemm(A, sketch_of_A, tag); 00463 else if((sketch_height >= height * factor) && 00464 (sketch_width >= height * factor)) 00465 outer_panel_gemm(A, sketch_of_A, tag); 00466 else 00467 panel_matrix_gemm(A, sketch_of_A, tag); 00468 } 00469 00470 00471 void apply_impl_vdist (const matrix_type& A, 00472 output_matrix_type& sketch_of_A, 00473 skylark::sketch::rowwise_tag tag) const { 00474 00475 sketch_gemm(A, sketch_of_A, tag); 00476 } 00477 00478 00479 void apply_impl_vdist (const matrix_type& A, 00480 output_matrix_type& sketch_of_A, 00481 skylark::sketch::columnwise_tag tag) const { 00482 00483 sketch_gemm(A, sketch_of_A, tag); 00484 } 00485 }; 00486 00487 } } 00489 #endif // SKYLARK_DENSE_TRANSFORM_ELEMENTAL_COLDIST_STAR_HPP