Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/tests/unit/MixedGemmTest.cpp
Go to the documentation of this file.
00001 #include <vector>
00002 
00003 #include <boost/mpi.hpp>
00004 #include <boost/test/minimal.hpp>
00005 
00006 #include <elemental.hpp>
00007 #include <CombBLAS.h>
00008 #include <SpParMat.h>
00009 
00010 #include <skylark.hpp>
00011 
00012 #include "../../base/Gemm.hpp"
00013 #include "../../base/Gemm_detail.hpp"
00014 
00016 typedef elem::Matrix<double> MatrixType;
00017 typedef elem::DistMatrix<double, elem::VC, elem::STAR> DistMatrixVCSType;
00018 typedef elem::DistMatrix<double> DistMatrixType;
00019 
00020 typedef SpDCCols< size_t, double> col_t;
00021 typedef SpParMat< size_t, double, col_t > cbDistMatrixType;
00022 
00023 static const size_t matrix_size = 50;
00024 
00025 static MatrixType nn_expected;
00026 static MatrixType tn_expected;
00027 static MatrixType nt_expected;
00028 static MatrixType tt_expected;
00029 
00030 template <typename dist_matrix_t>
00031 void check_matrix(const dist_matrix_t &result, const MatrixType &expected,
00032                   const std::string error) {
00033 
00034     elem::DistMatrix<double, elem::STAR, elem::STAR> full_result = result;
00035     for(size_t j = 0; j < full_result.Height(); j++ )
00036         for(size_t i = 0; i < full_result.Width(); i++ ) {
00037             if(full_result.GetLocal(j, i) != expected.Get(j, i)) {
00038                 std::cout << result.GetLocal(j, i) << " != "
00039                           << expected.Get(j, i)
00040                           << " at index (" << j << ", " << i << ")"
00041                           << std::endl;
00042                 BOOST_FAIL(error.c_str());
00043             }
00044         }
00045 }
00046 
00047 
00048 int test_main(int argc, char *argv[]) {
00049 
00050     namespace mpi = boost::mpi;
00051 
00052 #ifdef SKYLARK_HAVE_OPENMP
00053     int provided;
00054     MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided);
00055 #endif
00056 
00057     mpi::environment env (argc, argv);
00058     mpi::communicator world;
00059 
00060     elem::Initialize (argc, argv);
00061     MPI_Comm mpi_world(world);
00062     elem::Grid grid (mpi_world);
00063 
00064     // compute local expected value
00065     MatrixType localA(matrix_size, matrix_size);
00066     for( size_t j = 0; j < localA.Height(); j++ ) {
00067         for( size_t i = 0; i < localA.Width(); i++ ) {
00068             double value = j * matrix_size + i + 1;
00069             localA.Set(j, i, value);
00070         }
00071     }
00072 
00073     elem::Ones(nn_expected, matrix_size, matrix_size);
00074     elem::Gemm(elem::NORMAL, elem::NORMAL, -1.0, localA, localA,
00075                 1.5, nn_expected);
00076     elem::Ones(nt_expected, matrix_size, matrix_size);
00077     elem::Gemm(elem::NORMAL, elem::TRANSPOSE, -1.0, localA, localA,
00078                 1.5, nt_expected);
00079     elem::Ones(tn_expected, matrix_size, matrix_size);
00080     elem::Gemm(elem::TRANSPOSE, elem::NORMAL, -1.0, localA, localA,
00081                 1.5, tn_expected);
00082     elem::Ones(tt_expected, matrix_size, matrix_size);
00083     elem::Gemm(elem::TRANSPOSE, elem::TRANSPOSE, -1.0, localA, localA,
00084                 1.5, tt_expected);
00085 
00086 
00087     // prepare an Elemental matrix with the test data
00088     double val = 0.0;
00089     elem::DistMatrix<double, elem::STAR, elem::STAR>
00090         A_stst(matrix_size, matrix_size, grid);
00091     for( size_t j = 0; j < A_stst.LocalHeight(); j++ ) {
00092         for( size_t i = 0; i < A_stst.LocalWidth(); i++ ) {
00093             val = (j * A_stst.ColStride() + A_stst.ColShift()) * matrix_size +
00094                    i * A_stst.RowStride() + A_stst.RowShift() + 1;
00095             A_stst.SetLocal(j, i, val);
00096         }
00097     }
00098 
00099     // and fill a CombBLAS sparse matrix (with the same data)
00100     FullyDistVec<size_t, double> cols(matrix_size * matrix_size, 0.0);
00101     FullyDistVec<size_t, double> rows(matrix_size * matrix_size, 0.0);
00102     FullyDistVec<size_t, double> vals(matrix_size * matrix_size, 0.0);
00103 
00104     for(size_t i = 0; i < matrix_size * matrix_size; ++i) {
00105         rows.SetElement(i, floor(i / matrix_size));
00106         cols.SetElement(i, i % matrix_size);
00107         vals.SetElement(i, static_cast<double>(i+1));
00108     }
00109 
00110     cbDistMatrixType B(matrix_size, matrix_size, rows, cols, vals);
00111 
00112 
00113 
00114     //std::vector<double> local_matrix;
00115     //skylark::base::detail::mixed_gemm_local_part_tt(-1.0, B, A_stst, 0.0,
00116             //local_matrix);
00117     //for(size_t idx = 0; idx < local_matrix.size(); idx++)
00118         //std::cout << local_matrix[idx] << std::endl;
00119 
00120 
00121     if(world.rank() == 0)
00122         std::cout << "Testing CombBLAS^T x Elemental (VX/*) = Elemental (*/*) :";
00123 
00124     elem::DistMatrix<double, elem::STAR, elem::STAR>
00125         result_stst(matrix_size, matrix_size, grid);
00126     for( size_t j = 0; j < result_stst.LocalHeight(); j++ )
00127         for( size_t i = 0; i < result_stst.LocalWidth(); i++ )
00128             result_stst.SetLocal(j, i, 1.0);
00129 
00130     DistMatrixVCSType A_vcs = A_stst;
00131     skylark::base::detail::outer_panel_mixed_gemm_impl_tn(
00132             -1.0, B, A_vcs, 1.5, result_stst);
00133     check_matrix(result_stst, tn_expected,
00134                  "Result of outer panel TN gemm not as expected");
00135 
00136     if(world.rank() == 0)
00137         std::cout << "outer panel: OK" << std::endl;
00138 
00139     if(world.rank() == 0)
00140         std::cout << "Testing CombBLAS x Elemental (*/*) = Elemental (VX/*) :";
00141 
00142     DistMatrixVCSType result_vcs(matrix_size, matrix_size, grid);
00143     for( size_t j = 0; j < result_vcs.LocalHeight(); j++ )
00144         for( size_t i = 0; i < result_vcs.LocalWidth(); i++ )
00145             result_vcs.SetLocal(j, i, 1.0);
00146 
00147     skylark::base::detail::outer_panel_mixed_gemm_impl_nn(
00148             -1.0, B, A_stst, 1.5, result_vcs);
00149     check_matrix(result_vcs, nn_expected,
00150                  "Result of outer panel NN gemm not as expected");
00151 
00152     if(world.rank() == 0)
00153         std::cout << "outer panel: OK" << std::endl;
00154 
00155     for( size_t j = 0; j < result_vcs.LocalHeight(); j++ )
00156         for( size_t i = 0; i < result_vcs.LocalWidth(); i++ )
00157             result_vcs.SetLocal(j, i, 1.0);
00158 
00159     skylark::base::detail::inner_panel_mixed_gemm_impl_nn(
00160             -1.0, B, A_stst, 1.5, result_vcs);
00161     check_matrix(result_vcs, nn_expected,
00162                  "Result of inner panel NN gemm not as expected");
00163 
00164     if(world.rank() == 0)
00165         std::cout << "inner panel: OK" << std::endl;
00166 
00167     elem::Finalize();
00168 
00169     return 0;
00170 }
00171