Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_GEMM_HPP 00002 #define SKYLARK_GEMM_HPP 00003 00004 #include <boost/mpi.hpp> 00005 #include "exception.hpp" 00006 #include "sparse_matrix.hpp" 00007 #include "computed_matrix.hpp" 00008 #include "../utility/typer.hpp" 00009 00010 #include "Gemm_detail.hpp" 00011 00012 // Defines a generic Gemm function that receives both dense and sparse matrices. 00013 00014 namespace skylark { namespace base { 00015 00016 #if SKYLARK_HAVE_ELEMENTAL 00017 00022 template<typename T> 00023 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00024 T alpha, const elem::Matrix<T>& A, const elem::Matrix<T>& B, 00025 T beta, elem::Matrix<T>& C) { 00026 elem::Gemm(oA, oB, alpha, A, B, beta, C); 00027 } 00028 00029 template<typename T> 00030 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00031 T alpha, const elem::Matrix<T>& A, const elem::Matrix<T>& B, 00032 elem::Matrix<T>& C) { 00033 elem::Gemm(oA, oB, alpha, A, B, C); 00034 } 00035 00036 template<typename T> 00037 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00038 T alpha, const elem::DistMatrix<T, elem::STAR, elem::STAR>& A, 00039 const elem::DistMatrix<T, elem::STAR, elem::STAR>& B, 00040 T beta, elem::DistMatrix<T, elem::STAR, elem::STAR>& C) { 00041 elem::Gemm(oA, oB, alpha, A.LockedMatrix(), B.LockedMatrix(), beta, C.Matrix()); 00042 } 00043 00044 template<typename T> 00045 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00046 T alpha, const elem::DistMatrix<T, elem::STAR, elem::STAR>& A, 00047 const elem::DistMatrix<T, elem::STAR, elem::STAR>& B, 00048 elem::DistMatrix<T, elem::STAR, elem::STAR>& C) { 00049 elem::Gemm(oA, oB, alpha, A.LockedMatrix(), B.LockedMatrix(), C.Matrix()); 00050 } 00051 00052 template<typename T> 00053 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00054 T alpha, const elem::DistMatrix<T>& A, const elem::DistMatrix<T>& B, 00055 T beta, elem::DistMatrix<T>& C) { 00056 elem::Gemm(oA, oB, alpha, A, B, beta, C); 00057 } 00058 00059 template<typename T> 00060 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00061 T alpha, const elem::DistMatrix<T>& A, const elem::DistMatrix<T>& B, 00062 elem::DistMatrix<T>& C) { 00063 elem::Gemm(oA, oB, alpha, A, B, C); 00064 } 00065 00071 template<typename T> 00072 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00073 T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A, 00074 const elem::DistMatrix<T, elem::VC, elem::STAR>& B, 00075 T beta, elem::DistMatrix<T, elem::STAR, elem::STAR>& C) { 00076 // TODO verify sizes etc. 00077 00078 if ((oA == elem::TRANSPOSE || oA == elem::ADJOINT) && oB == elem::NORMAL) { 00079 boost::mpi::communicator comm(C.Grid().Comm(), boost::mpi::comm_attach); 00080 elem::Matrix<T> Clocal(C.Matrix()); 00081 elem::Gemm(oA, elem::NORMAL, 00082 alpha, A.LockedMatrix(), B.LockedMatrix(), 00083 beta / T(comm.size()), Clocal); 00084 boost::mpi::all_reduce(comm, 00085 Clocal.Buffer(), Clocal.MemorySize(), C.Matrix().Buffer(), 00086 std::plus<T>()); 00087 } else { 00088 SKYLARK_THROW_EXCEPTION(base::unsupported_base_operation()); 00089 } 00090 } 00091 00092 template<typename T> 00093 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00094 T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A, 00095 const elem::DistMatrix<T, elem::VC, elem::STAR>& B, 00096 elem::DistMatrix<T, elem::STAR, elem::STAR>& C) { 00097 00098 int C_height = (oA == elem::NORMAL ? A.Height() : A.Width()); 00099 int C_width = (oB == elem::NORMAL ? B.Width() : B.Height()); 00100 elem::Zeros(C, C_height, C_width); 00101 base::Gemm(oA, oB, alpha, A, B, T(0), C); 00102 } 00103 00104 template<typename T> 00105 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00106 T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A, 00107 const elem::DistMatrix<T, elem::STAR, elem::STAR>& B, 00108 T beta, elem::DistMatrix<T, elem::VC, elem::STAR>& C) { 00109 // TODO verify sizes etc. 00110 00111 if (oA == elem::NORMAL && oB == elem::NORMAL) { 00112 elem::Gemm(elem::NORMAL, elem::NORMAL, 00113 alpha, A.LockedMatrix(), B.LockedMatrix(), 00114 beta, C.Matrix()); 00115 } else { 00116 SKYLARK_THROW_EXCEPTION(base::unsupported_base_operation()); 00117 } 00118 } 00119 00120 template<typename T> 00121 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00122 T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A, 00123 const elem::DistMatrix<T, elem::STAR, elem::STAR>& B, 00124 elem::DistMatrix<T, elem::VC, elem::STAR>& C) { 00125 00126 int C_height = (oA == elem::NORMAL ? A.Height() : A.Width()); 00127 int C_width = (oB == elem::NORMAL ? B.Width() : B.Height()); 00128 elem::Zeros(C, C_height, C_width); 00129 base::Gemm(oA, oB, alpha, A, B, T(0), C); 00130 } 00131 00132 00133 template<typename T> 00134 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00135 T alpha, const elem::DistMatrix<T, elem::VR, elem::STAR>& A, 00136 const elem::DistMatrix<T, elem::VR, elem::STAR>& B, 00137 T beta, elem::DistMatrix<T, elem::STAR, elem::STAR>& C) { 00138 // TODO verify sizes etc. 00139 00140 if ((oA == elem::TRANSPOSE || oA == elem::ADJOINT) && oB == elem::NORMAL) { 00141 boost::mpi::communicator comm(C.Grid().Comm(), boost::mpi::comm_attach); 00142 elem::Matrix<T> Clocal(C.Matrix()); 00143 elem::Gemm(oA, elem::NORMAL, 00144 alpha, A.LockedMatrix(), B.LockedMatrix(), 00145 beta / T(comm.size()), Clocal); 00146 boost::mpi::all_reduce(comm, 00147 Clocal.Buffer(), Clocal.MemorySize(), C.Matrix().Buffer(), 00148 std::plus<T>()); 00149 } else { 00150 SKYLARK_THROW_EXCEPTION(base::unsupported_base_operation()); 00151 } 00152 } 00153 00154 template<typename T> 00155 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00156 T alpha, const elem::DistMatrix<T, elem::VR, elem::STAR>& A, 00157 const elem::DistMatrix<T, elem::VR, elem::STAR>& B, 00158 elem::DistMatrix<T, elem::STAR, elem::STAR>& C) { 00159 00160 int C_height = (oA == elem::NORMAL ? A.Height() : A.Width()); 00161 int C_width = (oB == elem::NORMAL ? B.Width() : B.Height()); 00162 elem::Zeros(C, C_height, C_width); 00163 base::Gemm(oA, oB, alpha, A, B, T(0), C); 00164 } 00165 00166 template<typename T> 00167 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00168 T alpha, const elem::DistMatrix<T, elem::VR, elem::STAR>& A, 00169 const elem::DistMatrix<T, elem::STAR, elem::STAR>& B, 00170 T beta, elem::DistMatrix<T, elem::VR, elem::STAR>& C) { 00171 // TODO verify sizes etc. 00172 00173 if (oA == elem::NORMAL && oB == elem::NORMAL) { 00174 elem::Gemm(elem::NORMAL, elem::NORMAL, 00175 alpha, A.LockedMatrix(), B.LockedMatrix(), 00176 beta, C.Matrix()); 00177 } else { 00178 SKYLARK_THROW_EXCEPTION(base::unsupported_base_operation()); 00179 } 00180 } 00181 00182 template<typename T> 00183 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00184 T alpha, const elem::DistMatrix<T, elem::VR, elem::STAR>& A, 00185 const elem::DistMatrix<T, elem::STAR, elem::STAR>& B, 00186 elem::DistMatrix<T, elem::VR, elem::STAR>& C) { 00187 00188 int C_height = (oA == elem::NORMAL ? A.Height() : A.Width()); 00189 int C_width = (oB == elem::NORMAL ? B.Width() : B.Height()); 00190 elem::Zeros(C, C_height, C_width); 00191 base::Gemm(oA, oB, alpha, A, B, T(0), C); 00192 } 00193 00198 template<typename T> 00199 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00200 T alpha, const elem::Matrix<T>& A, const sparse_matrix_t<T>& B, 00201 T beta, elem::Matrix<T>& C) { 00202 // TODO verify sizes etc. 00203 00204 const int* indptr = B.indptr(); 00205 const int* indices = B.indices(); 00206 const T *values = B.locked_values(); 00207 00208 int k = A.Width(); 00209 int n = B.width(); 00210 int m = A.Height(); 00211 00212 // NN 00213 if (oA == elem::NORMAL && oB == elem::NORMAL) { 00214 00215 elem::Scal(beta, C); 00216 00217 elem::Matrix<T> Ac; 00218 elem::Matrix<T> Cc; 00219 00220 # if SKYLARK_HAVE_OPENMP 00221 # pragma omp parallel for private(Cc, Ac) 00222 # endif 00223 for(int col = 0; col < n; col++) { 00224 elem::View(Cc, C, 0, col, m, 1); 00225 for (int j = indptr[col]; j < indptr[col + 1]; j++) { 00226 int row = indices[j]; 00227 T val = values[j]; 00228 elem::LockedView(Ac, A, 0, row, m, 1); 00229 elem::Axpy(alpha * val, Ac, Cc); 00230 } 00231 } 00232 } 00233 00234 // NT 00235 if (oA == elem::NORMAL && oB == elem::TRANSPOSE) { 00236 00237 elem::Scal(beta, C); 00238 00239 elem::Matrix<T> Ac; 00240 elem::Matrix<T> Cc; 00241 00242 // Now, we simply think of B has being in CSR mode... 00243 int row = 0; 00244 for(int row = 0; row < n; row++) { 00245 elem::LockedView(Ac, A, 0, row, m, 1); 00246 # if SKYLARK_HAVE_OPENMP 00247 # pragma omp parallel for private(Cc) 00248 # endif 00249 for (int j = indptr[row]; j < indptr[row + 1]; j++) { 00250 int col = indices[j]; 00251 T val = values[j]; 00252 elem::View(Cc, C, 0, col, m, 1); 00253 elem::Axpy(alpha * val, Ac, Cc); 00254 } 00255 } 00256 } 00257 00258 00259 // TN - TODO: Not tested! 00260 if (oA == elem::TRANSPOSE && oB == elem::NORMAL) { 00261 double *c = C.Buffer(); 00262 int ldc = C.LDim(); 00263 00264 const double *a = A.LockedBuffer(); 00265 int lda = A.LDim(); 00266 00267 # if SKYLARK_HAVE_OPENMP 00268 # pragma omp parallel for collapse(2) 00269 # endif 00270 for (int j = 0; j < n; j++) 00271 for(int row = 0; row < k; row++) { 00272 c[j * ldc + row] *= beta; 00273 for (int l = indptr[j]; l < indptr[j + 1]; l++) { 00274 int rr = indices[l]; 00275 T val = values[l]; 00276 c[j * ldc + row] += val * a[j * lda + rr]; 00277 } 00278 } 00279 } 00280 00281 // TT - TODO: Not tested! 00282 if (oA == elem::TRANSPOSE && oB == elem::TRANSPOSE) { 00283 elem::Scal(beta, C); 00284 00285 double *c = C.Buffer(); 00286 int ldc = C.LDim(); 00287 00288 const double *a = A.LockedBuffer(); 00289 int lda = A.LDim(); 00290 00291 # if SKYLARK_HAVE_OPENMP 00292 # pragma omp parallel for 00293 # endif 00294 for(int row = 0; row < k; row++) 00295 for(int rb = 0; rb < n; rb++) 00296 for (int l = indptr[rb]; l < indptr[rb + 1]; l++) { 00297 int col = indices[l]; 00298 c[col * ldc + row] += values[l] * a[row * lda + rb]; 00299 } 00300 } 00301 } 00302 00303 template<typename T> 00304 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00305 T alpha, const sparse_matrix_t<T>& A, const elem::Matrix<T>& B, 00306 T beta, elem::Matrix<T>& C) { 00307 // TODO verify sizes etc. 00308 00309 const int* indptr = A.indptr(); 00310 const int* indices = A.indices(); 00311 const double *values = A.locked_values(); 00312 00313 int k = A.width(); 00314 int n = B.Width(); 00315 int m = B.Height(); 00316 00317 // NN 00318 if (oA == elem::NORMAL && oB == elem::NORMAL) { 00319 00320 elem::Scal(beta, C); 00321 00322 double *c = C.Buffer(); 00323 int ldc = C.LDim(); 00324 00325 const double *b = B.LockedBuffer(); 00326 int ldb = B.LDim(); 00327 00328 # if SKYLARK_HAVE_OPENMP 00329 # pragma omp parallel for 00330 # endif 00331 for(int i = 0; i < n; i++) 00332 for(int col = 0; col < k; col++) 00333 for (int j = indptr[col]; j < indptr[col + 1]; j++) { 00334 int row = indices[j]; 00335 T val = values[j]; 00336 c[i * ldc + row] += alpha * val * b[i * ldb + col]; 00337 } 00338 } 00339 00340 // NT 00341 if (oA == elem::NORMAL && oB == elem::TRANSPOSE) { 00342 00343 elem::Scal(beta, C); 00344 00345 elem::Matrix<T> Bc; 00346 elem::Matrix<T> BTr; 00347 elem::Matrix<T> Cr; 00348 00349 for(int col = 0; col < k; col++) { 00350 elem::LockedView(Bc, B, 0, col, m, 1); 00351 elem::Transpose(Bc, BTr); 00352 # if SKYLARK_HAVE_OPENMP 00353 # pragma omp parallel for private(Cr) 00354 # endif 00355 for (int j = indptr[col]; j < indptr[col + 1]; j++) { 00356 int row = indices[j]; 00357 T val = values[j]; 00358 elem::View(Cr, C, row, 0, 1, m); 00359 elem::Axpy(alpha * val, BTr, Cr); 00360 } 00361 } 00362 } 00363 00364 // TN - TODO: Not tested! 00365 if (oA == elem::TRANSPOSE && oB == elem::NORMAL) { 00366 double *c = C.Buffer(); 00367 int ldc = C.LDim(); 00368 00369 const double *b = B.LockedBuffer(); 00370 int ldb = B.LDim(); 00371 00372 # if SKYLARK_HAVE_OPENMP 00373 # pragma omp parallel for collapse(2) 00374 # endif 00375 for (int j = 0; j < n; j++) 00376 for(int row = 0; row < k; row++) { 00377 c[j * ldc + row] *= beta; 00378 for (int l = indptr[row]; l < indptr[row + 1]; l++) { 00379 int col = indices[l]; 00380 T val = values[l]; 00381 c[j * ldc + row] += val * b[j * ldb + col]; 00382 } 00383 } 00384 } 00385 00386 // TT - TODO: Not tested! 00387 if (oA == elem::TRANSPOSE && oB == elem::TRANSPOSE) { 00388 00389 elem::Scal(beta, C); 00390 00391 elem::Matrix<T> Bc; 00392 elem::Matrix<T> BTr; 00393 elem::Matrix<T> Cr; 00394 00395 # if SKYLARK_HAVE_OPENMP 00396 # pragma omp parallel for private(Cr, Bc, BTr) 00397 # endif 00398 for(int row = 0; row < k; row++) { 00399 elem::View(Cr, C, row, 0, 1, m); 00400 for (int l = indptr[row]; l < indptr[row + 1]; l++) { 00401 int col = indices[l]; 00402 T val = values[l]; 00403 elem::LockedView(Bc, B, 0, col, m, 1); 00404 elem::Transpose(Bc, BTr); 00405 elem::Axpy(alpha * val, BTr, Cr); 00406 } 00407 } 00408 } 00409 } 00410 00411 00412 #if SKYLARK_HAVE_COMBBLAS 00413 00418 00419 template<typename index_type, typename value_type, elem::Distribution col_d> 00420 void Gemm(elem::Orientation oA, elem::Orientation oB, double alpha, 00421 const SpParMat<index_type, value_type, SpDCCols<index_type, value_type> > &A, 00422 const elem::DistMatrix<value_type, elem::STAR, elem::STAR> &B, 00423 double beta, 00424 elem::DistMatrix<value_type, col_d, elem::STAR> &C) { 00425 00426 if(oA == elem::NORMAL && oB == elem::NORMAL) { 00427 00428 if(A.getnol() != B.Height()) 00429 SKYLARK_THROW_EXCEPTION ( 00430 base::combblas_exception() 00431 << base::error_msg("Gemm: Dimensions do not agree")); 00432 00433 if(A.getnrow() != C.Height()) 00434 SKYLARK_THROW_EXCEPTION ( 00435 base::combblas_exception() 00436 << base::error_msg("Gemm: Dimensions do not agree")); 00437 00438 if(B.Width() != C.Width()) 00439 SKYLARK_THROW_EXCEPTION ( 00440 base::combblas_exception() 00441 << base::error_msg("Gemm: Dimensions do not agree")); 00442 00443 //XXX: simple heuristic to decide what to communicate (improve!) 00444 // or just if A.getncol() < B.Width.. 00445 if(A.getnnz() < B.Height() * B.Width()) 00446 detail::outer_panel_mixed_gemm_impl_nn(alpha, A, B, beta, C); 00447 else 00448 detail::inner_panel_mixed_gemm_impl_nn(alpha, A, B, beta, C); 00449 } 00450 } 00451 00453 template<typename index_type, typename value_type, elem::Distribution col_d> 00454 void Gemm(elem::Orientation oA, elem::Orientation oB, double alpha, 00455 const SpParMat<index_type, value_type, SpDCCols<index_type, value_type> > &A, 00456 const elem::DistMatrix<value_type, col_d, elem::STAR> &B, 00457 double beta, 00458 elem::DistMatrix<value_type, elem::STAR, elem::STAR> &C) { 00459 00460 if(oA == elem::TRANSPOSE && oB == elem::NORMAL) { 00461 00462 if(A.getrow() != B.Height()) 00463 SKYLARK_THROW_EXCEPTION ( 00464 base::combblas_exception() 00465 << base::error_msg("Gemm: Dimensions do not agree")); 00466 00467 if(A.getncol() != C.Height()) 00468 SKYLARK_THROW_EXCEPTION ( 00469 base::combblas_exception() 00470 << base::error_msg("Gemm: Dimensions do not agree")); 00471 00472 if(B.Width() != C.Width()) 00473 SKYLARK_THROW_EXCEPTION ( 00474 base::combblas_exception() 00475 << base::error_msg("Gemm: Dimensions do not agree")); 00476 00477 detail::outer_panel_mixed_gemm_impl_tn(alpha, A, B, beta, C); 00478 } 00479 00480 } 00481 00482 #endif // SKYLARK_HAVE_COMBBLAS 00483 #endif // SKYLARK_HAVE_ELEMENTAL 00484 00485 /* All combinations with computed matrix */ 00486 00487 template<typename CT, typename RT, typename OT> 00488 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00489 typename utility::typer_t<OT>::value_type alpha, const computed_matrix_t<CT>& A, 00490 const RT& B, typename utility::typer_t<OT>::value_type beta, OT& C) { 00491 base::Gemm(oA, oB, alpha, A.materialize(), B, beta, C); 00492 } 00493 00494 template<typename CT, typename RT, typename OT> 00495 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00496 typename utility::typer_t<OT>::value_type alpha, const computed_matrix_t<CT>& A, 00497 const RT& B, OT& C) { 00498 base::Gemm(oA, oB, alpha, A.materialize(), B, C); 00499 } 00500 00501 template<typename CT, typename RT, typename OT> 00502 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00503 typename utility::typer_t<OT>::value_type alpha, const RT& A, 00504 const computed_matrix_t<CT>& B, 00505 typename utility::typer_t<OT>::value_type beta, OT& C) { 00506 base::Gemm(oA, oB, alpha, A, B.materialize(), beta, C); 00507 } 00508 00509 template<typename CT, typename RT, typename OT> 00510 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00511 typename utility::typer_t<OT>::value_type alpha, const RT& A, 00512 const computed_matrix_t<CT>& B, OT& C) { 00513 base::Gemm(oA, oB, alpha, A, B.materialize(), C); 00514 } 00515 00516 template<typename CT1, typename CT2, typename OT> 00517 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00518 typename utility::typer_t<OT>::value_type alpha, const computed_matrix_t<CT1>& A, 00519 const computed_matrix_t<CT2>& B, typename utility::typer_t<OT>::value_type beta, 00520 OT& C) { 00521 base::Gemm(oA, oB, alpha, A.materialize(), B.materialize(), beta, C); 00522 } 00523 00524 template<typename CT1, typename CT2, typename OT> 00525 inline void Gemm(elem::Orientation oA, elem::Orientation oB, 00526 typename utility::typer_t<OT>::value_type alpha, const computed_matrix_t<CT1>& A, 00527 const computed_matrix_t<CT2>& B, OT& C) { 00528 base::Gemm(oA, oB, alpha, A.materialize(), B.materialize(), C); 00529 } 00530 00531 00532 } } // namespace skylark::base 00533 #endif // SKYLARK_GEMM_HPP