Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_NORM_HPP 00002 #define SKYLARK_NORM_HPP 00003 00004 #include <boost/mpi.hpp> 00005 00006 // Defines a generic Gemm function that recieves a wider set of matrices 00007 00008 #if SKYLARK_HAVE_ELEMENTAL 00009 00010 namespace skylark { namespace base { 00011 00012 template<typename T> 00013 inline elem::Base<T> Nrm2(const elem::Matrix<T>& x) { 00014 return elem::Nrm2(x); 00015 } 00016 00017 template<typename T> 00018 inline elem::Base<T> Nrm2(const elem::DistMatrix<T>& x) { 00019 return elem::Nrm2(x); 00020 } 00021 00022 template<typename T> 00023 inline elem::Base<T> Nrm2(const elem::DistMatrix<T, elem::VC, elem::STAR>& x) { 00024 boost::mpi::communicator comm(x.DistComm(), boost::mpi::comm_attach); 00025 T local = elem::Nrm2(x.LockedMatrix()); 00026 T snrm = boost::mpi::all_reduce(comm, local * local, std::plus<T>()); 00027 return sqrt(snrm); 00028 } 00029 00030 template<typename T> 00031 inline elem::Base<T> Nrm2(const elem::DistMatrix<T, elem::VR, elem::STAR>& x) { 00032 boost::mpi::communicator comm(x.DistComm(), boost::mpi::comm_attach); 00033 T local = elem::Nrm2(x.LockedMatrix()); 00034 T snrm = boost::mpi::all_reduce(comm, local * local, std::plus<T>()); 00035 return sqrt(snrm); 00036 } 00037 00038 template<typename T> 00039 inline elem::Base<T> Nrm2(const elem::DistMatrix<T, elem::STAR, elem::STAR>& x) { 00040 return elem::Nrm2(x.LockedMatrix()); 00041 } 00042 00043 template<typename T> 00044 inline void ColumnNrm2(const elem::DistMatrix<T, elem::STAR, elem::STAR>& A, 00045 elem::DistMatrix<elem::Base<T>, elem::STAR, elem::STAR>& N) { 00046 00047 double *n = N.Buffer(); 00048 const double *a = A.LockedBuffer(); 00049 for(int j = 0; j < A.Width(); j++) { 00050 n[j] = 0.0; 00051 for(int i = 0; i < A.LocalHeight(); i++) 00052 n[j] += a[j * A.LDim() + i] * a[j * A.LDim() + i]; 00053 n[j] = sqrt(n[j]); 00054 } 00055 } 00056 00057 template<typename T> 00058 inline void ColumnNrm2(const elem::DistMatrix<T, elem::VC, elem::STAR>& A, 00059 elem::DistMatrix<elem::Base<T>, elem::STAR, elem::STAR>& N) { 00060 00061 std::vector<T> n(A.Width(), 1); 00062 const elem::Matrix<T> &Al = A.LockedMatrix(); 00063 const double *a = Al.LockedBuffer(); 00064 for(int j = 0; j < Al.Width(); j++) { 00065 n[j] = 0.0; 00066 for(int i = 0; i < Al.Height(); i++) 00067 n[j] += a[j * Al.LDim() + i] * a[j * Al.LDim() + i]; 00068 } 00069 N.Resize(A.Width(), 1); 00070 elem::Zero(N); 00071 boost::mpi::communicator comm(N.Grid().Comm(), boost::mpi::comm_attach); 00072 boost::mpi::all_reduce(comm, n.data(), A.Width(), N.Buffer(), std::plus<T>()); 00073 for(int j = 0; j < A.Width(); j++) 00074 N.Set(j, 0, sqrt(N.Get(j, 0))); 00075 } 00076 00077 } } // namespace skylark::base 00078 00079 #endif // SKYLARK_HAVE_ELEMENTAL 00080 00081 #endif // SKYLARK_NORM_HPP