Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_SKETCHED_SVD_ELEMENTAL_HPP 00002 #define SKYLARK_SKETCHED_SVD_ELEMENTAL_HPP 00003 00004 #include <elemental.hpp> 00005 #include "../base/context.hpp" 00006 #include "../base/exception.hpp" 00007 #include "../utility/get_communicator.hpp" 00008 00009 namespace skylark { namespace nla { 00010 #if 0 00011 00012 template <typename ValueType> 00013 struct sketch_svd_t <elem::DistMatrix<ValueType, elem::MC, elem::MR>, 00014 elem::DistMatrix<ValueType, elem::MC, elem::MR>, 00015 skylark::sketch::JLT_t < 00016 elem::DistMatrix<ValueType, elem::MC, elem::MR>, 00017 elem::DistMatrix<ValueType, elem::MC, elem::MR> > > { 00018 typedef ValueType value_type; 00019 typedef elem::DistMatrix<value_type, elem::MC, elem::MR> matrix_type; 00020 typedef elem::DistMatrix<value_type, elem::MC, elem::MR> output_matrix_type; 00021 typedef SketchTransformType sketch_transform_type; 00022 typedef sketch_transform_type::output_matrix_type sketched_matrix_type; 00023 00035 static void apply (int k, 00036 int sketch_size, 00037 int q, 00038 const matrix_type& A, 00039 output_matrix_type& U, 00040 output_matrix_type& S, 00041 output_matrix_type& V, 00042 skylark::base::context_t& context) { 00043 00044 int height = A.Height(); 00045 int width = A.Width(); 00046 00054 if (k > std::min(height, width)) || 00055 (sketch_size > width) || 00056 (sketch_size < k) { 00057 SKYLARK_THROW_EXCEPTION( 00058 base::elemental_exception() 00059 << base::error_msg(e.what()) ); 00060 } 00061 00065 sketch_transform_type sketch_transform(width, sketch_size, context); 00066 sketched_matrix_type Y(height, sketch_size); 00067 sketch_transform.apply(A, Y, skylark:sketch::rowwise_tag()); 00068 00070 sketched_matrix_type Q(Y); 00071 elem::qr::Explicit(Q); 00072 00074 for(int step = 0; step < q; step++) { 00076 elem::Gemm(elem::ADJOINT, elem::NORMAL, 0.0, A, Q, Y); 00077 sketched_matrix_type Q(Y); 00078 elem::qr::Explicit(Q); 00079 00081 elem::Gemm(elem::NORMAL, elem::NORMAL, 0.0, A, Q, Y); 00082 sketched_matrix_type Q(Y); 00083 elem::qr::Explicit(Q); 00084 } 00085 00087 sketched_matrix_type B; 00088 elem::Gemm(elem::ADJOINT, elem::NORMAL, value_type(0), Q, A, B); 00089 00091 elem::DistMatrix<elem::VR, elem::STAR> Sigma; 00092 elem::SVD(B, Sigma, V); 00093 S = Sigma; 00094 00096 elem::Gemm(elem::NORMAL, elem::NORMAL, Q, B, value_type(0), U); 00097 } 00098 }; 00099 #endif 00100 00101 #if 0 00102 /*************************************************************************/ 00103 /* Everything below here is Vikas' code; I am keeping it for posterity */ 00104 /*************************************************************************/ 00105 00106 // TODO this should be templated ASAP. They confuse codes that want to define 00107 // these later. 00108 typedef elem::Matrix<double> MatrixType; 00109 typedef elem::DistMatrix<double, elem::VR, elem::STAR> DistMatrixType; 00110 00111 // Takes an m x nA matrix A, m x nB matrix B and computes C = A'*B which is 00112 // small nA x nB. A and B are row-partitioned together. So computation of C 00113 // boils down to C = sum_i A_i^T*B_i i.e. an mpi reduce operation. Note: we 00114 // need to pass the context so we can call mpi::reduce with the associated 00115 // communicator. 00116 void Gemm(DistMatrixType& A, 00117 DistMatrixType& B, 00118 MatrixType& C, 00119 skylark::base::context_t& context) { 00120 00121 int mA = A.Height(); 00122 int nA = A.Width(); 00123 int mB = B.Height(); 00124 int nB = B.Width(); 00125 00126 MatrixType C_local(nA, nB); 00127 Gemm(elem::ADJOINT, 00128 elem::NORMAL, 00129 1.0, 00130 A.LockedMatrix(), 00131 B.LockedMatrix(), 00132 0.0, 00133 C_local); 00134 00135 // get communicator from matrix 00136 boost::mpi::communicator comm = skylark::utility::get_communicator(A); 00137 00138 boost::mpi::reduce (comm, 00139 C_local.LockedBuffer(), 00140 C_local.MemorySize(), 00141 C.Buffer(), 00142 std::plus<double>(), 00143 0); 00144 } 00145 00146 // Takes a m x n distributed matrix A and a n x l local matrix B and computes the distributed matrix A*B by local matrix multiplication. 00147 void Gemm(DistMatrixType& A, MatrixType& B, DistMatrixType& C) { 00148 Gemm(elem::NORMAL, 00149 elem::NORMAL, 00150 1.0, 00151 A.LockedMatrix(), 00152 B, 00153 0.0, 00154 C.Matrix()); 00155 } 00156 00157 //templatize later 00158 void SVD(DistMatrixType& A, 00159 DistMatrixType& U, 00160 MatrixType& s, 00161 MatrixType& V, 00162 int l, 00163 int q, 00164 skylark::base::context_t& context) { 00165 00166 int m = A.Height(); 00167 int n = A.Width(); 00168 00169 // Create an n x l JLT Sketch 00170 skylark::sketch::JLT_t<DistMatrixType, DistMatrixType> JLT (n, l, context); 00171 00172 // Create space to hold the sketched result 00173 DistMatrixType Y(m,l); 00174 //Y.Resize(m,l); 00175 00176 JLT.apply (A, Y, skylark::sketch::rowwise_tag()); 00177 00178 // TO DO : need to do power iterations here 00179 00180 // call Explicit QR on Y. Y is overwritten with Q where Y = QR. 00181 // NOTE: Type conversions below. 00182 elem::DistMatrix<double> Q(Y); 00183 elem::QR( Q ); 00184 DistMatrixType Q2(Q); 00185 Q2 = Q; 00186 00187 //Compute B = Q'A of size l x n 00188 00189 MatrixType B(l, n); 00190 Gemm(Q2, A, B, context); 00191 00192 // Get SVD of B - Note B is overwritten by U where B = U diag(s) V' is the 00193 // SVD of B. 00194 elem::SVD(B, s, V); 00195 00196 // Write U = Q B 00197 Gemm(Q2, B, U); 00198 } 00199 #endif 00200 00201 } } 00203 #endif // SKYLARK_SKETCHED_SVD_ELEMENTAL_HPP