Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/utility/comm.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_COMM_HPP
00002 #define SKYLARK_COMM_HPP
00003 
00004 #include <elemental.hpp>
00005 
00006 #include "../base/exception.hpp"
00007 
00008 namespace skylark {
00009 namespace utility {
00010 
00011 template<typename T>
00012 void collect_dist_matrix(boost::mpi::communicator& comm, bool here,
00013     const elem::DistMatrix<T> &DA, elem::Matrix<T> &A) {
00014 
00015     // Technically the following should work for any row and col distributions.
00016     // But it seems to not work well for VR/VC type distributions.
00017     // And it is probably not the best way to do it for these distributions
00018     // anyway.
00019 
00020     try {
00021         elem::AxpyInterface<T> interface;
00022         interface.Attach(elem::GLOBAL_TO_LOCAL, DA);
00023         if (here) {
00024             elem::Zero(A);
00025             interface.Axpy(1.0, A, 0, 0);
00026         }
00027         interface.Detach();
00028     } catch (std::logic_error e) {
00029         SKYLARK_THROW_EXCEPTION (base::elemental_exception()
00030             << base::error_msg(e.what()) );
00031     }
00032 }
00033 
00034 template<typename T, elem::Distribution RowDist>
00035 void collect_dist_matrix(boost::mpi::communicator& comm, bool here,
00036     const elem::DistMatrix<T, RowDist, elem::STAR> &DA,
00037     elem::Matrix<T> &A) {
00038 
00039     if (RowDist == elem::VR || RowDist == elem::VC) {
00040         // TODO this is probably the most laziest way to do it.
00041         //      Must be possible to do it much better (less communication).
00042 
00043         try {
00044             elem::Matrix<T> A0(DA.Height(), DA.Width(), DA.Height());
00045             const elem::Matrix<T> &A_local = DA.LockedMatrix();
00046             elem::Zero(A0);
00047             for(int j = 0; j < A_local.Width(); j++)
00048                 for(int i = 0; i < A_local.Height(); i++)
00049                     A0.Set(DA.ColShift() + i * DA.ColStride(), j,
00050                         A_local.Get(i, j));
00051 
00052             boost::mpi::reduce (comm,
00053                 A0.LockedBuffer(),
00054                 A0.MemorySize(),
00055                 A.Buffer(),
00056                 std::plus<T>(),
00057                 0);
00058 
00059         } catch (std::logic_error e) {
00060             SKYLARK_THROW_EXCEPTION( base::elemental_exception()
00061                 << base::error_msg(e.what()) );
00062         } catch(boost::mpi::exception e) {
00063             SKYLARK_THROW_EXCEPTION(base::mpi_exception()
00064                 << base::error_msg(e.what()) );
00065         }
00066 
00067     } else {
00068         SKYLARK_THROW_EXCEPTION ( base::unsupported_matrix_distribution() );
00069     }
00070 }
00071 
00072 template<typename T, elem::Distribution ColDist>
00073 void collect_dist_matrix(boost::mpi::communicator& comm, bool here,
00074     const elem::DistMatrix<T, elem::STAR, ColDist> &DA,
00075     elem::Matrix<T> &A) {
00076 
00077     if (ColDist == elem::VR || ColDist == elem::VC) {
00078         // TODO this is probably the most laziest way to do it.
00079         //      Must be possible to do it much better (less communication).
00080 
00081         try {
00082             elem::Matrix<T> A0(DA.Height(), DA.Width(), DA.Height());
00083             const elem::Matrix<T> &A_local = DA.LockedMatrix();
00084             elem::Zero(A0);
00085             for(int j = 0; j < A_local.Width(); j++)
00086                 for(int i = 0; i < A_local.Height(); i++)
00087                     A0.Set(i, DA.RowShift() + j * DA.RowStride(),
00088                         A_local.Get(i, j));
00089 
00090             boost::mpi::reduce (comm,
00091                 A0.LockedBuffer(),
00092                 A0.MemorySize(),
00093                 A.Buffer(),
00094                 std::plus<T>(),
00095                 0);
00096 
00097         } catch (std::logic_error e) {
00098             SKYLARK_THROW_EXCEPTION (base::elemental_exception()
00099                 << base::error_msg(e.what()) );
00100         } catch(boost::mpi::exception e) {
00101             SKYLARK_THROW_EXCEPTION (base::mpi_exception()
00102                 << base::error_msg(e.what()) );
00103         }
00104 
00105     } else {
00106         SKYLARK_THROW_EXCEPTION ( base::unsupported_matrix_distribution() );
00107     }
00108 }
00109 
00110 } // namespace sketch
00111 } // namespace skylark
00112 
00113 #endif