Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/base/Gemv.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_GEMV_HPP
00002 #define SKYLARK_GEMV_HPP
00003 
00004 #include <boost/mpi.hpp>
00005 #include "exception.hpp"
00006 
00007 // Defines a generic Gemv function that recieves a wider set of matrices
00008 
00009 #if SKYLARK_HAVE_ELEMENTAL
00010 
00011 namespace skylark { namespace base {
00012 
00013 template<typename T>
00014 inline void Gemv(elem::Orientation oA,
00015     T alpha, const elem::Matrix<T>& A, const elem::Matrix<T>& x,
00016     T beta, elem::Matrix<T>& y) {
00017     elem::Gemv(oA, alpha, A, x, beta, y);
00018 }
00019 
00020 template<typename T>
00021 inline void Gemv(elem::Orientation oA,
00022     T alpha, const elem::Matrix<T>& A, const elem::Matrix<T>& x,
00023     elem::Matrix<T>& y) {
00024     elem::Gemv(oA, alpha, A, x, y);
00025 }
00026 
00027 template<typename T>
00028 inline void Gemv(elem::Orientation oA,
00029     T alpha, const elem::DistMatrix<T>& A, const elem::DistMatrix<T>& x,
00030     T beta, elem::DistMatrix<T>& y) {
00031     elem::Gemv(oA, alpha, A, x, beta, y);
00032 }
00033 
00034 template<typename T>
00035 inline void Gemv(elem::Orientation oA,
00036     T alpha, const elem::DistMatrix<T>& A, const elem::DistMatrix<T>& x,
00037     elem::DistMatrix<T>& y) {
00038     elem::Gemv(oA, alpha, A, x, y);
00039 }
00040 
00046 template<typename T>
00047 inline void Gemv(elem::Orientation oA,
00048     T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A,
00049     const elem::DistMatrix<T, elem::VC, elem::STAR>& x,
00050     T beta, elem::DistMatrix<T, elem::STAR, elem::STAR>& y) {
00051     // TODO verify sizes etc.
00052     // TODO verify matching grids.
00053 
00054     if (oA == elem::TRANSPOSE) {
00055         boost::mpi::communicator comm(y.Grid().Comm(), boost::mpi::comm_attach);
00056         elem::Matrix<T> ylocal(y.Matrix());
00057         elem::Gemv(elem::TRANSPOSE,
00058             alpha, A.LockedMatrix(), x.LockedMatrix(),
00059             beta / T(comm.size()), ylocal);
00060         boost::mpi::all_reduce(comm,
00061             ylocal.Buffer(), ylocal.MemorySize(), y.Buffer(),
00062             std::plus<T>());
00063     } else {
00064         SKYLARK_THROW_EXCEPTION(base::unsupported_base_operation());
00065     }
00066 }
00067 
00068 template<typename T>
00069 inline void Gemv(elem::Orientation oA,
00070     T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A,
00071     const elem::DistMatrix<T, elem::VC, elem::STAR>& x,
00072     elem::DistMatrix<T, elem::STAR, elem::STAR>& y) {
00073 
00074     int y_height = (oA == elem::NORMAL ? A.Height() : A.Width());
00075     elem::Zeros(y, y_height, 1);
00076     base::Gemv(oA, alpha, A, x, T(0), y);
00077 }
00078 
00079 template<typename T>
00080 inline void Gemv(elem::Orientation oA,
00081     T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A,
00082     const elem::DistMatrix<T, elem::STAR, elem::STAR>& x,
00083     T beta, elem::DistMatrix<T, elem::VC, elem::STAR>& y) {
00084     // TODO verify sizes etc.
00085 
00086     if (oA == elem::NORMAL) {
00087         elem::Gemv(elem::NORMAL,
00088             alpha, A.LockedMatrix(), x.LockedMatrix(),
00089             beta, y.Matrix());
00090     } else {
00091         SKYLARK_THROW_EXCEPTION(base::unsupported_base_operation());
00092     }
00093 }
00094 
00095 template<typename T>
00096 inline void Gemv(elem::Orientation oA,
00097     T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A,
00098     const elem::DistMatrix<T, elem::STAR, elem::STAR>& x,
00099     elem::DistMatrix<T, elem::VC, elem::STAR>& y) {
00100 
00101     int y_height = (oA == elem::NORMAL ? A.Height() : A.Width());
00102     elem::Zeros(y, y_height, 1);
00103     base::Gemv(oA, alpha, A, x, T(0), y);
00104 }
00105 
00106 } } // namespace skylark::base
00107 
00108 #endif // SKYLARK_HAVE_ELEMENTAL
00109 
00110 #endif // SKYLARK_GEMV_HPP