Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_GEMM_DETAIL_HPP 00002 #define SKYLARK_GEMM_DETAIL_HPP 00003 00004 #if SKYLARK_HAVE_COMBBLAS 00005 #include <CombBLAS.h> 00006 #include <CommGrid.h> 00007 #endif 00008 00009 #if SKYLARK_HAVE_ELEMENTAL 00010 #include <elemental.hpp> 00011 #endif 00012 00013 #include "../utility/external/view.hpp" 00014 #include "../utility/external/combblas_comm_grid.hpp" 00015 00016 namespace skylark { namespace base { namespace detail { 00017 00018 #if SKYLARK_HAVE_ELEMENTAL && SKYLARK_HAVE_COMBBLAS 00019 00022 template<typename index_type, typename value_type> 00023 inline void mixed_gemm_local_part_nn ( 00024 const double alpha, 00025 const SpParMat<index_type, value_type, SpDCCols<index_type, value_type> > &A, 00026 const elem::DistMatrix<value_type, elem::STAR, elem::STAR> &S, 00027 const double beta, 00028 std::vector<value_type> &local_matrix) { 00029 00030 typedef SpDCCols< index_type, value_type > col_t; 00031 typedef SpParMat< index_type, value_type, col_t > matrix_type; 00032 matrix_type &_A = const_cast<matrix_type&>(A); 00033 col_t &data = _A.seq(); 00034 00035 //FIXME 00036 local_matrix.resize(S.Width() * data.getnrow(), 0); 00037 size_t cb_col_offset = utility::cb_my_col_offset(A); 00038 00039 for(typename col_t::SpColIter col = data.begcol(); 00040 col != data.endcol(); col++) { 00041 for(typename col_t::SpColIter::NzIter nz = data.begnz(col); 00042 nz != data.endnz(col); nz++) { 00043 00044 // we want local index here to fill local dense matrix 00045 index_type rowid = nz.rowid(); 00046 // column needs to be global 00047 index_type colid = col.colid() + cb_col_offset; 00048 00049 // compute application of S to yield a partial row in the result. 00050 for(size_t bcol = 0; bcol < S.Width(); ++bcol) { 00051 local_matrix[rowid * S.Width() + bcol] += 00052 alpha * S.Get(colid, bcol) * nz.value(); 00053 } 00054 } 00055 } 00056 } 00057 00058 00060 //FIXME: benchmark against one-sided 00061 template<typename index_type, typename value_type, elem::Distribution col_d> 00062 inline void inner_panel_mixed_gemm_impl_nn( 00063 const double alpha, 00064 const SpParMat<index_type, value_type, SpDCCols<index_type, value_type> > &A, 00065 const elem::DistMatrix<value_type, elem::STAR, elem::STAR> &S, 00066 const double beta, 00067 elem::DistMatrix<value_type, col_d, elem::STAR> &C) { 00068 00069 int n_proc_side = A.getcommgrid()->GetGridRows(); 00070 int output_width = S.Width(); 00071 int output_height = A.getnrow(); 00072 00073 size_t rank = A.getcommgrid()->GetRank(); 00074 size_t cb_row_offset = utility::cb_my_row_offset(A); 00075 00076 typedef SpDCCols< index_type, value_type > col_t; 00077 typedef SpParMat< index_type, value_type, col_t > matrix_type; 00078 matrix_type &_A = const_cast<matrix_type&>(A); 00079 col_t &data = _A.seq(); 00080 00081 // 1) compute the local values still using the CombBLAS distribution (2D 00082 // processor grid). We assume the result is dense. 00083 std::vector<double> local_matrix; 00084 mixed_gemm_local_part_nn(alpha, A, S, 0.0, local_matrix); 00085 00086 // 2) reduce first along rows so that each processor owns the values in 00087 // the output row of the SOMETHING/* matrix and values for processors in 00088 // the same processor column. 00089 boost::mpi::communicator my_row_comm( 00090 A.getcommgrid()->GetRowWorld(), boost::mpi::comm_duplicate); 00091 00092 // storage for other procs in same row communicator: rank -> (row, values) 00093 typedef std::vector<std::pair<int, std::vector<double> > > for_rank_t; 00094 std::vector<for_rank_t> for_rank(n_proc_side); 00095 00096 for(size_t local_row = 0; local_row < data.getnrow(); ++local_row) { 00097 00098 size_t row = local_row + cb_row_offset; 00099 00100 // the owner for VR/* and VC/* matrices is independent of the column 00101 size_t target_proc = C.Owner(row, 0); 00102 00103 // if the target processor is not in the current row communicator, get 00104 // the value in the processor grid sharing the same row. 00105 if(!A.getcommgrid()->OnSameProcRow(target_proc)) 00106 target_proc = static_cast<int>(rank / n_proc_side) * 00107 n_proc_side + target_proc % n_proc_side; 00108 00109 size_t target_row_rank = A.getcommgrid()->GetRankInProcRow(target_proc); 00110 00111 // reduce partial row (FIXME: if the resulting matrix is still 00112 // expected to be sparse, change this to communicate only nnz). 00113 // Working on local_width columns concurrently per column processing 00114 // group. 00115 size_t local_width = S.Width(); 00116 const value_type* buffer = &local_matrix[local_row * local_width]; 00117 std::vector<value_type> new_values(local_width); 00118 boost::mpi::reduce(my_row_comm, buffer, local_width, 00119 &new_values[0], std::plus<value_type>(), target_row_rank); 00120 00121 // processor stores result directly if it is the owning rank of that 00122 // row, save for subsequent communication along rows otherwise 00123 if(rank == C.Owner(row, 0)) { 00124 int elem_lrow = C.LocalRow(row); 00125 for(size_t idx = 0; idx < local_width; ++idx) { 00126 int elem_lcol = C.LocalCol(idx); 00127 C.SetLocal(elem_lrow, elem_lcol, 00128 new_values[idx] + beta * C.GetLocal(elem_lrow, elem_lcol)); 00129 } 00130 } else if (rank == target_proc) { 00131 // store for later comm across rows 00132 for_rank[C.Owner(row, 0) / n_proc_side].push_back( 00133 std::make_pair(row, new_values)); 00134 } 00135 } 00136 00137 // 3) gather remaining values along rows: we exchange all the values with 00138 // other processors in the same communicator row and then add them to 00139 // our local part. 00140 boost::mpi::communicator my_col_comm( 00141 A.getcommgrid()->GetColWorld(), boost::mpi::comm_duplicate); 00142 00143 std::vector<for_rank_t> new_values; 00144 for(int i = 0; i < n_proc_side; ++i) 00145 boost::mpi::gather(my_col_comm, for_rank[i], new_values, i); 00146 00147 // insert new values 00148 for(size_t proc = 0; proc < new_values.size(); ++proc) { 00149 const for_rank_t &cur = new_values[proc]; 00150 00151 for(size_t i = 0; i < cur.size(); ++i) { 00152 int elem_lrow = C.LocalRow(cur[i].first); 00153 for(size_t j = 0; j < cur[i].second.size(); ++j) { 00154 size_t elem_lcol = C.LocalCol(j); 00155 C.SetLocal(elem_lrow, elem_lcol, 00156 cur[i].second[j] + beta * 00157 C.GetLocal(elem_lrow, elem_lcol)); 00158 } 00159 } 00160 } 00161 } 00162 00163 00164 //FIXME: benchmark against one-sided 00165 template<typename index_type, typename value_type, elem::Distribution col_d> 00166 inline void outer_panel_mixed_gemm_impl_nn( 00167 const double alpha, 00168 const SpParMat<index_type, value_type, SpDCCols<index_type, value_type> > &A, 00169 const elem::DistMatrix<value_type, elem::STAR, elem::STAR> &S, 00170 const double beta, 00171 elem::DistMatrix<value_type, col_d, elem::STAR> &C) { 00172 00173 utility::combblas_slab_view_t<index_type, value_type> cbview(A, false); 00174 00175 //FIXME: factor 00176 size_t slab_size = 2 * C.Grid().Height(); 00177 for(size_t cur_row_idx = 0; cur_row_idx < cbview.nrows(); 00178 cur_row_idx += slab_size) { 00179 00180 size_t cur_slab_size = 00181 std::min(slab_size, cbview.nrows() - cur_row_idx); 00182 00183 // get the next slab_size columns of B 00184 elem::DistMatrix<value_type, col_d, elem::STAR> 00185 A_row(cur_slab_size, S.Height()); 00186 00187 cbview.extract_elemental_row_slab_view(A_row, cur_slab_size); 00188 00189 // assemble the distributed column vector 00190 for(size_t l_col_idx = 0; l_col_idx < A_row.LocalWidth(); 00191 l_col_idx++) { 00192 00193 size_t g_col_idx = l_col_idx * A_row.RowStride() 00194 + A_row.RowShift(); 00195 00196 for(size_t l_row_idx = 0; l_row_idx < A_row.LocalHeight(); 00197 ++l_row_idx) { 00198 00199 size_t g_row_idx = l_row_idx * A_row.ColStride() 00200 + A_row.ColShift() + cur_row_idx; 00201 00202 A_row.SetLocal(l_row_idx, l_col_idx, 00203 cbview(g_row_idx, g_col_idx)); 00204 } 00205 } 00206 00207 elem::DistMatrix<value_type, col_d, elem::STAR> 00208 C_slice(cur_slab_size, C.Width()); 00209 elem::View(C_slice, C, cur_row_idx, 0, cur_slab_size, C.Width()); 00210 elem::LocalGemm(elem::NORMAL, elem::NORMAL, alpha, A_row, S, 00211 beta, C_slice); 00212 } 00213 } 00214 00215 00216 //FIXME: benchmark against one-sided 00217 template<typename index_type, typename value_type, elem::Distribution col_d> 00218 inline void outer_panel_mixed_gemm_impl_tn( 00219 const double alpha, 00220 const SpParMat<index_type, value_type, SpDCCols<index_type, value_type> > &A, 00221 const elem::DistMatrix<value_type, col_d, elem::STAR> &S, 00222 const double beta, 00223 elem::DistMatrix<value_type, elem::STAR, elem::STAR> &C) { 00224 00225 elem::DistMatrix<value_type, elem::STAR, elem::STAR> 00226 tmp_C(C.Height(), C.Width()); 00227 elem::Zero(tmp_C); 00228 00229 utility::combblas_slab_view_t<index_type, value_type> cbview(A, false); 00230 00231 //FIXME: factor 00232 size_t slab_size = 2 * S.Grid().Height(); 00233 for(size_t cur_row_idx = 0; cur_row_idx < cbview.ncols(); 00234 cur_row_idx += slab_size) { 00235 00236 size_t cur_slab_size = 00237 std::min(slab_size, cbview.ncols() - cur_row_idx); 00238 00239 // get the next slab_size columns of B 00240 elem::DistMatrix<value_type, elem::STAR, elem::STAR> 00241 A_row(cur_slab_size, S.Height()); 00242 00243 // transpose is column 00244 //cbview.extract_elemental_column_slab_view(A_row, cur_slab_size); 00245 cbview.extract_full_slab_view(cur_slab_size); 00246 00247 // matrix mult (FIXME only iter nz) 00248 for(size_t l_row_idx = 0; l_row_idx < A_row.LocalHeight(); 00249 ++l_row_idx) { 00250 00251 size_t g_row_idx = l_row_idx * A_row.ColStride() 00252 + A_row.ColShift() + cur_row_idx; 00253 00254 for(size_t l_col_idx = 0; l_col_idx < A_row.LocalWidth(); 00255 l_col_idx++) { 00256 00257 //XXX: should be the same as l_col_idx 00258 size_t g_col_idx = l_col_idx * A_row.RowStride() 00259 + A_row.RowShift(); 00260 00261 // continue if we don't own values in S in this row 00262 if(!S.IsLocalRow(g_col_idx)) 00263 continue; 00264 00265 //get transposed value 00266 value_type val = alpha * cbview(g_col_idx, g_row_idx); 00267 00268 for(size_t s_col_idx = 0; s_col_idx < S.LocalWidth(); 00269 s_col_idx++) { 00270 00271 tmp_C.UpdateLocal(g_row_idx, s_col_idx, 00272 val * S.GetLocal(S.LocalRow(g_col_idx), s_col_idx)); 00273 } 00274 } 00275 } 00276 } 00277 00278 //FIXME: scaling 00279 if(A.getcommgrid()->GetRank() == 0) { 00280 for(size_t col_idx = 0; col_idx < C.Width(); col_idx++) 00281 for(size_t row_idx = 0; row_idx < C.Height(); row_idx++) 00282 tmp_C.UpdateLocal(row_idx, col_idx, 00283 beta * C.GetLocal(row_idx, col_idx)); 00284 } 00285 00286 //FIXME: Use utility getter 00287 boost::mpi::communicator world( 00288 A.getcommgrid()->GetWorld(), boost::mpi::comm_duplicate); 00289 boost::mpi::all_reduce (world, 00290 tmp_C.LockedBuffer(), 00291 C.Height() * C.Width(), 00292 C.Buffer(), 00293 std::plus<value_type>()); 00294 } 00295 00296 #endif // SKYLARK_HAVE_ELEMENTAL && SKYLARK_HAVE_COMBBLAS 00297 00298 } } } // namespace skylark::base::detail 00299 00300 #endif //SKYLARK_GEMM_DETAIL_HPP