Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_COMM_HPP 00002 #define SKYLARK_COMM_HPP 00003 00004 #include <elemental.hpp> 00005 00006 #include "../base/exception.hpp" 00007 00008 namespace skylark { 00009 namespace utility { 00010 00011 template<typename T> 00012 void collect_dist_matrix(boost::mpi::communicator& comm, bool here, 00013 const elem::DistMatrix<T> &DA, elem::Matrix<T> &A) { 00014 00015 // Technically the following should work for any row and col distributions. 00016 // But it seems to not work well for VR/VC type distributions. 00017 // And it is probably not the best way to do it for these distributions 00018 // anyway. 00019 00020 try { 00021 elem::AxpyInterface<T> interface; 00022 interface.Attach(elem::GLOBAL_TO_LOCAL, DA); 00023 if (here) { 00024 elem::Zero(A); 00025 interface.Axpy(1.0, A, 0, 0); 00026 } 00027 interface.Detach(); 00028 } catch (std::logic_error e) { 00029 SKYLARK_THROW_EXCEPTION (base::elemental_exception() 00030 << base::error_msg(e.what()) ); 00031 } 00032 } 00033 00034 template<typename T, elem::Distribution RowDist> 00035 void collect_dist_matrix(boost::mpi::communicator& comm, bool here, 00036 const elem::DistMatrix<T, RowDist, elem::STAR> &DA, 00037 elem::Matrix<T> &A) { 00038 00039 if (RowDist == elem::VR || RowDist == elem::VC) { 00040 // TODO this is probably the most laziest way to do it. 00041 // Must be possible to do it much better (less communication). 00042 00043 try { 00044 elem::Matrix<T> A0(DA.Height(), DA.Width(), DA.Height()); 00045 const elem::Matrix<T> &A_local = DA.LockedMatrix(); 00046 elem::Zero(A0); 00047 for(int j = 0; j < A_local.Width(); j++) 00048 for(int i = 0; i < A_local.Height(); i++) 00049 A0.Set(DA.ColShift() + i * DA.ColStride(), j, 00050 A_local.Get(i, j)); 00051 00052 boost::mpi::reduce (comm, 00053 A0.LockedBuffer(), 00054 A0.MemorySize(), 00055 A.Buffer(), 00056 std::plus<T>(), 00057 0); 00058 00059 } catch (std::logic_error e) { 00060 SKYLARK_THROW_EXCEPTION( base::elemental_exception() 00061 << base::error_msg(e.what()) ); 00062 } catch(boost::mpi::exception e) { 00063 SKYLARK_THROW_EXCEPTION(base::mpi_exception() 00064 << base::error_msg(e.what()) ); 00065 } 00066 00067 } else { 00068 SKYLARK_THROW_EXCEPTION ( base::unsupported_matrix_distribution() ); 00069 } 00070 } 00071 00072 template<typename T, elem::Distribution ColDist> 00073 void collect_dist_matrix(boost::mpi::communicator& comm, bool here, 00074 const elem::DistMatrix<T, elem::STAR, ColDist> &DA, 00075 elem::Matrix<T> &A) { 00076 00077 if (ColDist == elem::VR || ColDist == elem::VC) { 00078 // TODO this is probably the most laziest way to do it. 00079 // Must be possible to do it much better (less communication). 00080 00081 try { 00082 elem::Matrix<T> A0(DA.Height(), DA.Width(), DA.Height()); 00083 const elem::Matrix<T> &A_local = DA.LockedMatrix(); 00084 elem::Zero(A0); 00085 for(int j = 0; j < A_local.Width(); j++) 00086 for(int i = 0; i < A_local.Height(); i++) 00087 A0.Set(i, DA.RowShift() + j * DA.RowStride(), 00088 A_local.Get(i, j)); 00089 00090 boost::mpi::reduce (comm, 00091 A0.LockedBuffer(), 00092 A0.MemorySize(), 00093 A.Buffer(), 00094 std::plus<T>(), 00095 0); 00096 00097 } catch (std::logic_error e) { 00098 SKYLARK_THROW_EXCEPTION (base::elemental_exception() 00099 << base::error_msg(e.what()) ); 00100 } catch(boost::mpi::exception e) { 00101 SKYLARK_THROW_EXCEPTION (base::mpi_exception() 00102 << base::error_msg(e.what()) ); 00103 } 00104 00105 } else { 00106 SKYLARK_THROW_EXCEPTION ( base::unsupported_matrix_distribution() ); 00107 } 00108 } 00109 00110 } // namespace sketch 00111 } // namespace skylark 00112 00113 #endif