Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/base/Gemm.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_GEMM_HPP
00002 #define SKYLARK_GEMM_HPP
00003 
00004 #include <boost/mpi.hpp>
00005 #include "exception.hpp"
00006 #include "sparse_matrix.hpp"
00007 #include "computed_matrix.hpp"
00008 #include "../utility/typer.hpp"
00009 
00010 #include "Gemm_detail.hpp"
00011 
00012 // Defines a generic Gemm function that receives both dense and sparse matrices.
00013 
00014 namespace skylark { namespace base {
00015 
00016 #if SKYLARK_HAVE_ELEMENTAL
00017 
00022 template<typename T>
00023 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00024     T alpha, const elem::Matrix<T>& A, const elem::Matrix<T>& B,
00025     T beta, elem::Matrix<T>& C) {
00026     elem::Gemm(oA, oB, alpha, A, B, beta, C);
00027 }
00028 
00029 template<typename T>
00030 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00031     T alpha, const elem::Matrix<T>& A, const elem::Matrix<T>& B,
00032     elem::Matrix<T>& C) {
00033     elem::Gemm(oA, oB, alpha, A, B, C);
00034 }
00035 
00036 template<typename T>
00037 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00038     T alpha, const elem::DistMatrix<T, elem::STAR, elem::STAR>& A, 
00039     const elem::DistMatrix<T, elem::STAR, elem::STAR>& B,
00040     T beta, elem::DistMatrix<T, elem::STAR, elem::STAR>& C) {
00041     elem::Gemm(oA, oB, alpha, A.LockedMatrix(), B.LockedMatrix(), beta, C.Matrix());
00042 }
00043 
00044 template<typename T>
00045 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00046     T alpha, const elem::DistMatrix<T, elem::STAR, elem::STAR>& A,
00047     const elem::DistMatrix<T, elem::STAR, elem::STAR>& B,
00048     elem::DistMatrix<T, elem::STAR, elem::STAR>& C) {
00049     elem::Gemm(oA, oB, alpha, A.LockedMatrix(), B.LockedMatrix(), C.Matrix());
00050 }
00051 
00052 template<typename T>
00053 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00054     T alpha, const elem::DistMatrix<T>& A, const elem::DistMatrix<T>& B,
00055     T beta, elem::DistMatrix<T>& C) {
00056     elem::Gemm(oA, oB, alpha, A, B, beta, C);
00057 }
00058 
00059 template<typename T>
00060 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00061     T alpha, const elem::DistMatrix<T>& A, const elem::DistMatrix<T>& B,
00062     elem::DistMatrix<T>& C) {
00063     elem::Gemm(oA, oB, alpha, A, B, C);
00064 }
00065 
00071 template<typename T>
00072 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00073     T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A,
00074     const elem::DistMatrix<T, elem::VC, elem::STAR>& B,
00075     T beta, elem::DistMatrix<T, elem::STAR, elem::STAR>& C) {
00076     // TODO verify sizes etc.
00077 
00078     if ((oA == elem::TRANSPOSE || oA == elem::ADJOINT) && oB == elem::NORMAL) {
00079         boost::mpi::communicator comm(C.Grid().Comm(), boost::mpi::comm_attach);
00080         elem::Matrix<T> Clocal(C.Matrix());
00081         elem::Gemm(oA, elem::NORMAL,
00082             alpha, A.LockedMatrix(), B.LockedMatrix(),
00083             beta / T(comm.size()), Clocal);
00084         boost::mpi::all_reduce(comm,
00085             Clocal.Buffer(), Clocal.MemorySize(), C.Matrix().Buffer(),
00086             std::plus<T>());
00087     } else {
00088         SKYLARK_THROW_EXCEPTION(base::unsupported_base_operation());
00089     }
00090 }
00091 
00092 template<typename T>
00093 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00094     T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A,
00095     const elem::DistMatrix<T, elem::VC, elem::STAR>& B,
00096     elem::DistMatrix<T, elem::STAR, elem::STAR>& C) {
00097 
00098     int C_height = (oA == elem::NORMAL ? A.Height() : A.Width());
00099     int C_width = (oB == elem::NORMAL ? B.Width() : B.Height());
00100     elem::Zeros(C, C_height, C_width);
00101     base::Gemm(oA, oB, alpha, A, B, T(0), C);
00102 }
00103 
00104 template<typename T>
00105 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00106     T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A,
00107     const elem::DistMatrix<T, elem::STAR, elem::STAR>& B,
00108     T beta, elem::DistMatrix<T, elem::VC, elem::STAR>& C) {
00109     // TODO verify sizes etc.
00110 
00111     if (oA == elem::NORMAL && oB == elem::NORMAL) {
00112         elem::Gemm(elem::NORMAL, elem::NORMAL,
00113             alpha, A.LockedMatrix(), B.LockedMatrix(),
00114             beta, C.Matrix());
00115     } else {
00116         SKYLARK_THROW_EXCEPTION(base::unsupported_base_operation());
00117     }
00118 }
00119 
00120 template<typename T>
00121 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00122     T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A,
00123     const elem::DistMatrix<T, elem::STAR, elem::STAR>& B,
00124     elem::DistMatrix<T, elem::VC, elem::STAR>& C) {
00125 
00126     int C_height = (oA == elem::NORMAL ? A.Height() : A.Width());
00127     int C_width = (oB == elem::NORMAL ? B.Width() : B.Height());
00128     elem::Zeros(C, C_height, C_width);
00129     base::Gemm(oA, oB, alpha, A, B, T(0), C);
00130 }
00131 
00132 
00133 template<typename T>
00134 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00135     T alpha, const elem::DistMatrix<T, elem::VR, elem::STAR>& A,
00136     const elem::DistMatrix<T, elem::VR, elem::STAR>& B,
00137     T beta, elem::DistMatrix<T, elem::STAR, elem::STAR>& C) {
00138     // TODO verify sizes etc.
00139 
00140     if ((oA == elem::TRANSPOSE || oA == elem::ADJOINT) && oB == elem::NORMAL) {
00141         boost::mpi::communicator comm(C.Grid().Comm(), boost::mpi::comm_attach);
00142         elem::Matrix<T> Clocal(C.Matrix());
00143         elem::Gemm(oA, elem::NORMAL,
00144             alpha, A.LockedMatrix(), B.LockedMatrix(),
00145             beta / T(comm.size()), Clocal);
00146         boost::mpi::all_reduce(comm,
00147             Clocal.Buffer(), Clocal.MemorySize(), C.Matrix().Buffer(),
00148             std::plus<T>());
00149     } else {
00150         SKYLARK_THROW_EXCEPTION(base::unsupported_base_operation());
00151     }
00152 }
00153 
00154 template<typename T>
00155 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00156     T alpha, const elem::DistMatrix<T, elem::VR, elem::STAR>& A,
00157     const elem::DistMatrix<T, elem::VR, elem::STAR>& B,
00158     elem::DistMatrix<T, elem::STAR, elem::STAR>& C) {
00159 
00160     int C_height = (oA == elem::NORMAL ? A.Height() : A.Width());
00161     int C_width = (oB == elem::NORMAL ? B.Width() : B.Height());
00162     elem::Zeros(C, C_height, C_width);
00163     base::Gemm(oA, oB, alpha, A, B, T(0), C);
00164 }
00165 
00166 template<typename T>
00167 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00168     T alpha, const elem::DistMatrix<T, elem::VR, elem::STAR>& A,
00169     const elem::DistMatrix<T, elem::STAR, elem::STAR>& B,
00170     T beta, elem::DistMatrix<T, elem::VR, elem::STAR>& C) {
00171     // TODO verify sizes etc.
00172 
00173     if (oA == elem::NORMAL && oB == elem::NORMAL) {
00174         elem::Gemm(elem::NORMAL, elem::NORMAL,
00175             alpha, A.LockedMatrix(), B.LockedMatrix(),
00176             beta, C.Matrix());
00177     } else {
00178         SKYLARK_THROW_EXCEPTION(base::unsupported_base_operation());
00179     }
00180 }
00181 
00182 template<typename T>
00183 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00184     T alpha, const elem::DistMatrix<T, elem::VR, elem::STAR>& A,
00185     const elem::DistMatrix<T, elem::STAR, elem::STAR>& B,
00186     elem::DistMatrix<T, elem::VR, elem::STAR>& C) {
00187 
00188     int C_height = (oA == elem::NORMAL ? A.Height() : A.Width());
00189     int C_width = (oB == elem::NORMAL ? B.Width() : B.Height());
00190     elem::Zeros(C, C_height, C_width);
00191     base::Gemm(oA, oB, alpha, A, B, T(0), C);
00192 }
00193 
00198 template<typename T>
00199 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00200     T alpha, const elem::Matrix<T>& A, const sparse_matrix_t<T>& B,
00201     T beta, elem::Matrix<T>& C) {
00202     // TODO verify sizes etc.
00203 
00204     const int* indptr = B.indptr();
00205     const int* indices = B.indices();
00206     const T *values = B.locked_values();
00207 
00208     int k = A.Width();
00209     int n = B.width();
00210     int m = A.Height();
00211 
00212     // NN
00213     if (oA == elem::NORMAL && oB == elem::NORMAL) {
00214 
00215         elem::Scal(beta, C);
00216 
00217         elem::Matrix<T> Ac;
00218         elem::Matrix<T> Cc;
00219 
00220 #       if SKYLARK_HAVE_OPENMP
00221 #       pragma omp parallel for private(Cc, Ac)
00222 #       endif
00223         for(int col = 0; col < n; col++) {
00224             elem::View(Cc, C, 0, col, m, 1);
00225             for (int j = indptr[col]; j < indptr[col + 1]; j++) {
00226                 int row = indices[j];
00227                 T val = values[j];
00228                 elem::LockedView(Ac, A, 0, row, m, 1);
00229                 elem::Axpy(alpha * val, Ac, Cc);
00230             }
00231         }
00232     }
00233 
00234     // NT
00235     if (oA == elem::NORMAL && oB == elem::TRANSPOSE) {
00236 
00237         elem::Scal(beta, C);
00238 
00239         elem::Matrix<T> Ac;
00240         elem::Matrix<T> Cc;
00241 
00242         // Now, we simply think of B has being in CSR mode...
00243         int row = 0;
00244         for(int row = 0; row < n; row++) {
00245             elem::LockedView(Ac, A, 0, row, m, 1);
00246 #           if SKYLARK_HAVE_OPENMP
00247 #           pragma omp parallel for private(Cc)
00248 #           endif
00249             for (int j = indptr[row]; j < indptr[row + 1]; j++) {
00250                 int col = indices[j];
00251                 T val = values[j];
00252                 elem::View(Cc, C, 0, col, m, 1);
00253                 elem::Axpy(alpha * val, Ac, Cc);
00254             }
00255         }
00256     }
00257 
00258 
00259     // TN - TODO: Not tested!
00260     if (oA == elem::TRANSPOSE && oB == elem::NORMAL) {
00261         double *c = C.Buffer();
00262         int ldc = C.LDim();
00263 
00264         const double *a = A.LockedBuffer();
00265         int lda = A.LDim();
00266 
00267 #       if SKYLARK_HAVE_OPENMP
00268 #       pragma omp parallel for collapse(2)
00269 #       endif
00270         for (int j = 0; j < n; j++)
00271             for(int row = 0; row < k; row++) {
00272                 c[j * ldc + row] *= beta;
00273                  for (int l = indptr[j]; l < indptr[j + 1]; l++) {
00274                      int rr = indices[l];
00275                      T val = values[l];
00276                      c[j * ldc + row] += val * a[j * lda + rr];
00277                  }
00278             }
00279     }
00280 
00281     // TT - TODO: Not tested!
00282     if (oA == elem::TRANSPOSE && oB == elem::TRANSPOSE) {
00283         elem::Scal(beta, C);
00284 
00285         double *c = C.Buffer();
00286         int ldc = C.LDim();
00287 
00288         const double *a = A.LockedBuffer();
00289         int lda = A.LDim();
00290 
00291 #       if SKYLARK_HAVE_OPENMP
00292 #       pragma omp parallel for
00293 #       endif
00294         for(int row = 0; row < k; row++)
00295             for(int rb = 0; rb < n; rb++)
00296                 for (int l = indptr[rb]; l < indptr[rb + 1]; l++) {
00297                     int col = indices[l];
00298                     c[col * ldc + row] += values[l] * a[row * lda + rb];
00299                 }
00300     }
00301 }
00302 
00303 template<typename T>
00304 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00305     T alpha, const sparse_matrix_t<T>& A, const elem::Matrix<T>& B,
00306     T beta, elem::Matrix<T>& C) {
00307     // TODO verify sizes etc.
00308 
00309     const int* indptr = A.indptr();
00310     const int* indices = A.indices();
00311     const double *values = A.locked_values();
00312 
00313     int k = A.width();
00314     int n = B.Width();
00315     int m = B.Height();
00316 
00317     // NN
00318     if (oA == elem::NORMAL && oB == elem::NORMAL) {
00319 
00320         elem::Scal(beta, C);
00321 
00322         double *c = C.Buffer();
00323         int ldc = C.LDim();
00324 
00325         const double *b = B.LockedBuffer();
00326         int ldb = B.LDim();
00327 
00328 #       if SKYLARK_HAVE_OPENMP
00329 #       pragma omp parallel for
00330 #       endif
00331         for(int i = 0; i < n; i++)
00332             for(int col = 0; col < k; col++)
00333                  for (int j = indptr[col]; j < indptr[col + 1]; j++) {
00334                      int row = indices[j];
00335                      T val = values[j];
00336                      c[i * ldc + row] += alpha * val * b[i * ldb + col];
00337                  }
00338     }
00339 
00340     // NT
00341     if (oA == elem::NORMAL && oB == elem::TRANSPOSE) {
00342 
00343         elem::Scal(beta, C);
00344 
00345         elem::Matrix<T> Bc;
00346         elem::Matrix<T> BTr;
00347         elem::Matrix<T> Cr;
00348 
00349         for(int col = 0; col < k; col++) {
00350             elem::LockedView(Bc, B, 0, col, m, 1);
00351             elem::Transpose(Bc, BTr);
00352 #           if SKYLARK_HAVE_OPENMP
00353 #           pragma omp parallel for private(Cr)
00354 #           endif
00355             for (int j = indptr[col]; j < indptr[col + 1]; j++) {
00356                 int row = indices[j];
00357                 T val = values[j];
00358                 elem::View(Cr, C, row, 0, 1, m);
00359                 elem::Axpy(alpha * val, BTr, Cr);
00360             }
00361         }
00362     }
00363 
00364     // TN - TODO: Not tested!
00365     if (oA == elem::TRANSPOSE && oB == elem::NORMAL) {
00366         double *c = C.Buffer();
00367         int ldc = C.LDim();
00368 
00369         const double *b = B.LockedBuffer();
00370         int ldb = B.LDim();
00371 
00372 #       if SKYLARK_HAVE_OPENMP
00373 #       pragma omp parallel for collapse(2)
00374 #       endif
00375         for (int j = 0; j < n; j++)
00376             for(int row = 0; row < k; row++) {
00377                 c[j * ldc + row] *= beta;
00378                  for (int l = indptr[row]; l < indptr[row + 1]; l++) {
00379                      int col = indices[l];
00380                      T val = values[l];
00381                      c[j * ldc + row] += val * b[j * ldb + col];
00382                  }
00383             }
00384     }
00385 
00386     // TT - TODO: Not tested!
00387     if (oA == elem::TRANSPOSE && oB == elem::TRANSPOSE) {
00388 
00389         elem::Scal(beta, C);
00390 
00391         elem::Matrix<T> Bc;
00392         elem::Matrix<T> BTr;
00393         elem::Matrix<T> Cr;
00394 
00395 #       if SKYLARK_HAVE_OPENMP
00396 #       pragma omp parallel for private(Cr, Bc, BTr)
00397 #       endif
00398         for(int row = 0; row < k; row++) {
00399             elem::View(Cr, C, row, 0, 1, m);
00400             for (int l = indptr[row]; l < indptr[row + 1]; l++) {
00401                 int col = indices[l];
00402                 T val = values[l];
00403                 elem::LockedView(Bc, B, 0, col, m, 1);
00404                 elem::Transpose(Bc, BTr);
00405                 elem::Axpy(alpha * val, BTr, Cr);
00406             }
00407         }
00408     }
00409 }
00410 
00411 
00412 #if SKYLARK_HAVE_COMBBLAS
00413 
00418 
00419 template<typename index_type, typename value_type, elem::Distribution col_d>
00420 void Gemm(elem::Orientation oA, elem::Orientation oB, double alpha,
00421           const SpParMat<index_type, value_type, SpDCCols<index_type, value_type> > &A,
00422           const elem::DistMatrix<value_type, elem::STAR, elem::STAR> &B,
00423           double beta,
00424           elem::DistMatrix<value_type, col_d, elem::STAR> &C) {
00425 
00426     if(oA == elem::NORMAL && oB == elem::NORMAL) {
00427 
00428         if(A.getnol() != B.Height())
00429             SKYLARK_THROW_EXCEPTION (
00430                 base::combblas_exception()
00431                     << base::error_msg("Gemm: Dimensions do not agree"));
00432 
00433         if(A.getnrow() != C.Height())
00434             SKYLARK_THROW_EXCEPTION (
00435                 base::combblas_exception()
00436                     << base::error_msg("Gemm: Dimensions do not agree"));
00437 
00438         if(B.Width() != C.Width())
00439             SKYLARK_THROW_EXCEPTION (
00440                 base::combblas_exception()
00441                     << base::error_msg("Gemm: Dimensions do not agree"));
00442 
00443         //XXX: simple heuristic to decide what to communicate (improve!)
00444         //     or just if A.getncol() < B.Width..
00445         if(A.getnnz() < B.Height() * B.Width())
00446             detail::outer_panel_mixed_gemm_impl_nn(alpha, A, B, beta, C);
00447         else
00448             detail::inner_panel_mixed_gemm_impl_nn(alpha, A, B, beta, C);
00449     }
00450 }
00451 
00453 template<typename index_type, typename value_type, elem::Distribution col_d>
00454 void Gemm(elem::Orientation oA, elem::Orientation oB, double alpha,
00455           const SpParMat<index_type, value_type, SpDCCols<index_type, value_type> > &A,
00456           const elem::DistMatrix<value_type, col_d, elem::STAR> &B,
00457           double beta,
00458           elem::DistMatrix<value_type, elem::STAR, elem::STAR> &C) {
00459 
00460     if(oA == elem::TRANSPOSE && oB == elem::NORMAL) {
00461 
00462         if(A.getrow() != B.Height())
00463             SKYLARK_THROW_EXCEPTION (
00464                 base::combblas_exception()
00465                     << base::error_msg("Gemm: Dimensions do not agree"));
00466 
00467         if(A.getncol() != C.Height())
00468             SKYLARK_THROW_EXCEPTION (
00469                     base::combblas_exception()
00470                     << base::error_msg("Gemm: Dimensions do not agree"));
00471 
00472         if(B.Width() != C.Width())
00473             SKYLARK_THROW_EXCEPTION (
00474                 base::combblas_exception()
00475                     << base::error_msg("Gemm: Dimensions do not agree"));
00476 
00477         detail::outer_panel_mixed_gemm_impl_tn(alpha, A, B, beta, C);
00478     }
00479 
00480 }
00481 
00482 #endif // SKYLARK_HAVE_COMBBLAS
00483 #endif // SKYLARK_HAVE_ELEMENTAL
00484 
00485 /* All combinations with computed matrix */
00486 
00487 template<typename CT, typename RT, typename OT>
00488 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00489     typename utility::typer_t<OT>::value_type alpha, const computed_matrix_t<CT>& A,
00490     const RT& B, typename utility::typer_t<OT>::value_type beta, OT& C) {
00491     base::Gemm(oA, oB, alpha, A.materialize(), B, beta, C);
00492 }
00493 
00494 template<typename CT, typename RT, typename OT>
00495 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00496     typename utility::typer_t<OT>::value_type alpha, const computed_matrix_t<CT>& A,
00497     const RT& B, OT& C) {
00498     base::Gemm(oA, oB, alpha, A.materialize(), B, C);
00499 }
00500 
00501 template<typename CT, typename RT, typename OT>
00502 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00503     typename utility::typer_t<OT>::value_type alpha, const RT& A,
00504     const computed_matrix_t<CT>& B,
00505     typename utility::typer_t<OT>::value_type beta, OT& C) {
00506     base::Gemm(oA, oB, alpha, A, B.materialize(), beta, C);
00507 }
00508 
00509 template<typename CT, typename RT, typename OT>
00510 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00511     typename utility::typer_t<OT>::value_type alpha, const RT& A,
00512     const computed_matrix_t<CT>& B, OT& C) {
00513     base::Gemm(oA, oB, alpha, A, B.materialize(), C);
00514 }
00515 
00516 template<typename CT1, typename CT2, typename OT>
00517 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00518     typename utility::typer_t<OT>::value_type alpha, const computed_matrix_t<CT1>& A,
00519     const computed_matrix_t<CT2>& B, typename utility::typer_t<OT>::value_type beta,
00520     OT& C) {
00521     base::Gemm(oA, oB, alpha, A.materialize(), B.materialize(), beta, C);
00522 }
00523 
00524 template<typename CT1, typename CT2, typename OT>
00525 inline void Gemm(elem::Orientation oA, elem::Orientation oB,
00526     typename utility::typer_t<OT>::value_type alpha, const computed_matrix_t<CT1>& A,
00527     const computed_matrix_t<CT2>& B, OT& C) {
00528     base::Gemm(oA, oB, alpha, A.materialize(), B.materialize(), C);
00529 }
00530 
00531 
00532 } } // namespace skylark::base
00533 #endif // SKYLARK_GEMM_HPP