Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_GEMV_HPP 00002 #define SKYLARK_GEMV_HPP 00003 00004 #include <boost/mpi.hpp> 00005 #include "exception.hpp" 00006 00007 // Defines a generic Gemv function that recieves a wider set of matrices 00008 00009 #if SKYLARK_HAVE_ELEMENTAL 00010 00011 namespace skylark { namespace base { 00012 00013 template<typename T> 00014 inline void Gemv(elem::Orientation oA, 00015 T alpha, const elem::Matrix<T>& A, const elem::Matrix<T>& x, 00016 T beta, elem::Matrix<T>& y) { 00017 elem::Gemv(oA, alpha, A, x, beta, y); 00018 } 00019 00020 template<typename T> 00021 inline void Gemv(elem::Orientation oA, 00022 T alpha, const elem::Matrix<T>& A, const elem::Matrix<T>& x, 00023 elem::Matrix<T>& y) { 00024 elem::Gemv(oA, alpha, A, x, y); 00025 } 00026 00027 template<typename T> 00028 inline void Gemv(elem::Orientation oA, 00029 T alpha, const elem::DistMatrix<T>& A, const elem::DistMatrix<T>& x, 00030 T beta, elem::DistMatrix<T>& y) { 00031 elem::Gemv(oA, alpha, A, x, beta, y); 00032 } 00033 00034 template<typename T> 00035 inline void Gemv(elem::Orientation oA, 00036 T alpha, const elem::DistMatrix<T>& A, const elem::DistMatrix<T>& x, 00037 elem::DistMatrix<T>& y) { 00038 elem::Gemv(oA, alpha, A, x, y); 00039 } 00040 00046 template<typename T> 00047 inline void Gemv(elem::Orientation oA, 00048 T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A, 00049 const elem::DistMatrix<T, elem::VC, elem::STAR>& x, 00050 T beta, elem::DistMatrix<T, elem::STAR, elem::STAR>& y) { 00051 // TODO verify sizes etc. 00052 // TODO verify matching grids. 00053 00054 if (oA == elem::TRANSPOSE) { 00055 boost::mpi::communicator comm(y.Grid().Comm(), boost::mpi::comm_attach); 00056 elem::Matrix<T> ylocal(y.Matrix()); 00057 elem::Gemv(elem::TRANSPOSE, 00058 alpha, A.LockedMatrix(), x.LockedMatrix(), 00059 beta / T(comm.size()), ylocal); 00060 boost::mpi::all_reduce(comm, 00061 ylocal.Buffer(), ylocal.MemorySize(), y.Buffer(), 00062 std::plus<T>()); 00063 } else { 00064 SKYLARK_THROW_EXCEPTION(base::unsupported_base_operation()); 00065 } 00066 } 00067 00068 template<typename T> 00069 inline void Gemv(elem::Orientation oA, 00070 T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A, 00071 const elem::DistMatrix<T, elem::VC, elem::STAR>& x, 00072 elem::DistMatrix<T, elem::STAR, elem::STAR>& y) { 00073 00074 int y_height = (oA == elem::NORMAL ? A.Height() : A.Width()); 00075 elem::Zeros(y, y_height, 1); 00076 base::Gemv(oA, alpha, A, x, T(0), y); 00077 } 00078 00079 template<typename T> 00080 inline void Gemv(elem::Orientation oA, 00081 T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A, 00082 const elem::DistMatrix<T, elem::STAR, elem::STAR>& x, 00083 T beta, elem::DistMatrix<T, elem::VC, elem::STAR>& y) { 00084 // TODO verify sizes etc. 00085 00086 if (oA == elem::NORMAL) { 00087 elem::Gemv(elem::NORMAL, 00088 alpha, A.LockedMatrix(), x.LockedMatrix(), 00089 beta, y.Matrix()); 00090 } else { 00091 SKYLARK_THROW_EXCEPTION(base::unsupported_base_operation()); 00092 } 00093 } 00094 00095 template<typename T> 00096 inline void Gemv(elem::Orientation oA, 00097 T alpha, const elem::DistMatrix<T, elem::VC, elem::STAR>& A, 00098 const elem::DistMatrix<T, elem::STAR, elem::STAR>& x, 00099 elem::DistMatrix<T, elem::VC, elem::STAR>& y) { 00100 00101 int y_height = (oA == elem::NORMAL ? A.Height() : A.Width()); 00102 elem::Zeros(y, y_height, 1); 00103 base::Gemv(oA, alpha, A, x, T(0), y); 00104 } 00105 00106 } } // namespace skylark::base 00107 00108 #endif // SKYLARK_HAVE_ELEMENTAL 00109 00110 #endif // SKYLARK_GEMV_HPP