Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/nla/sketched_svd_Elemental.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_SKETCHED_SVD_ELEMENTAL_HPP
00002 #define SKYLARK_SKETCHED_SVD_ELEMENTAL_HPP
00003 
00004 #include <elemental.hpp>
00005 #include "../base/context.hpp"
00006 #include "../base/exception.hpp"
00007 #include "../utility/get_communicator.hpp"
00008 
00009 namespace skylark { namespace nla {
00010 #if 0
00011 
00012 template <typename ValueType>
00013 struct sketch_svd_t <elem::DistMatrix<ValueType, elem::MC, elem::MR>,
00014                      elem::DistMatrix<ValueType, elem::MC, elem::MR>,
00015                      skylark::sketch::JLT_t <
00016                          elem::DistMatrix<ValueType, elem::MC, elem::MR>,
00017                          elem::DistMatrix<ValueType, elem::MC, elem::MR> > > {
00018     typedef ValueType value_type;
00019     typedef elem::DistMatrix<value_type, elem::MC, elem::MR> matrix_type;
00020     typedef elem::DistMatrix<value_type, elem::MC, elem::MR> output_matrix_type;
00021     typedef SketchTransformType sketch_transform_type;
00022     typedef sketch_transform_type::output_matrix_type sketched_matrix_type;
00023 
00035     static void apply (int k,
00036                        int sketch_size,
00037                        int q,
00038                        const matrix_type& A,
00039                        output_matrix_type& U,
00040                        output_matrix_type& S,
00041                        output_matrix_type& V,
00042                        skylark::base::context_t& context) {
00043 
00044         int height = A.Height();
00045         int width = A.Width();
00046 
00054          if (k > std::min(height, width)) ||
00055             (sketch_size > width) ||
00056                 (sketch_size < k) {
00057                 SKYLARK_THROW_EXCEPTION(
00058                     base::elemental_exception()
00059                         << base::error_msg(e.what()) );
00060           }
00061 
00065          sketch_transform_type sketch_transform(width, sketch_size, context);
00066          sketched_matrix_type Y(height, sketch_size);
00067          sketch_transform.apply(A, Y, skylark:sketch::rowwise_tag());
00068 
00070          sketched_matrix_type Q(Y);
00071          elem::qr::Explicit(Q);
00072 
00074          for(int step = 0; step < q; step++) {
00076              elem::Gemm(elem::ADJOINT, elem::NORMAL, 0.0, A, Q, Y);
00077              sketched_matrix_type Q(Y);
00078              elem::qr::Explicit(Q);
00079 
00081              elem::Gemm(elem::NORMAL, elem::NORMAL, 0.0, A, Q, Y);
00082              sketched_matrix_type Q(Y);
00083              elem::qr::Explicit(Q);
00084          }
00085 
00087          sketched_matrix_type B;
00088          elem::Gemm(elem::ADJOINT, elem::NORMAL, value_type(0), Q, A, B);
00089 
00091          elem::DistMatrix<elem::VR, elem::STAR> Sigma;
00092          elem::SVD(B, Sigma, V);
00093          S = Sigma;
00094 
00096          elem::Gemm(elem::NORMAL, elem::NORMAL, Q, B, value_type(0), U);
00097   }
00098 };
00099 #endif
00100 
00101 #if 0
00102 /*************************************************************************/
00103 /* Everything below here is Vikas' code; I am keeping it for posterity   */
00104 /*************************************************************************/
00105 
00106 // TODO this should be templated ASAP. They confuse codes that want to define
00107 // these later.
00108 typedef elem::Matrix<double> MatrixType;
00109 typedef elem::DistMatrix<double, elem::VR, elem::STAR> DistMatrixType;
00110 
00111 // Takes an m x nA matrix A, m x nB matrix B and computes C = A'*B which is
00112 // small nA x nB.  A and B are row-partitioned together. So computation of C
00113 // boils down to C = sum_i A_i^T*B_i i.e. an mpi reduce operation.  Note: we
00114 // need to pass the context so we can call mpi::reduce with the associated
00115 // communicator.
00116 void Gemm(DistMatrixType& A,
00117           DistMatrixType& B,
00118           MatrixType& C,
00119           skylark::base::context_t& context) {
00120 
00121     int mA = A.Height();
00122     int nA = A.Width();
00123     int mB = B.Height();
00124     int nB = B.Width();
00125 
00126     MatrixType C_local(nA, nB);
00127     Gemm(elem::ADJOINT,
00128          elem::NORMAL,
00129          1.0,
00130          A.LockedMatrix(),
00131          B.LockedMatrix(),
00132          0.0,
00133          C_local);
00134 
00135     // get communicator from matrix
00136     boost::mpi::communicator comm = skylark::utility::get_communicator(A);
00137 
00138     boost::mpi::reduce (comm,
00139                         C_local.LockedBuffer(),
00140                         C_local.MemorySize(),
00141                         C.Buffer(),
00142                         std::plus<double>(),
00143                         0);
00144 }
00145 
00146 // Takes a m x n distributed matrix A and a n x l local matrix B and computes the distributed matrix A*B by local matrix multiplication.
00147 void Gemm(DistMatrixType& A, MatrixType& B, DistMatrixType& C) {
00148     Gemm(elem::NORMAL,
00149          elem::NORMAL,
00150          1.0,
00151          A.LockedMatrix(),
00152          B,
00153          0.0,
00154          C.Matrix());
00155 }
00156 
00157 //templatize later
00158 void SVD(DistMatrixType& A,
00159          DistMatrixType& U,
00160          MatrixType& s,
00161          MatrixType& V,
00162          int l,
00163          int q,
00164          skylark::base::context_t& context) {
00165 
00166     int m = A.Height();
00167     int n = A.Width();
00168 
00169     // Create an n x l JLT Sketch
00170     skylark::sketch::JLT_t<DistMatrixType, DistMatrixType> JLT (n, l, context);
00171 
00172     // Create space to hold the sketched result
00173     DistMatrixType Y(m,l);
00174     //Y.Resize(m,l);
00175 
00176     JLT.apply (A, Y, skylark::sketch::rowwise_tag());
00177 
00178     // TO DO : need to do power iterations here
00179 
00180     // call Explicit QR on Y. Y is overwritten with Q where Y = QR.
00181     // NOTE: Type conversions below.
00182     elem::DistMatrix<double> Q(Y);
00183     elem::QR( Q );
00184     DistMatrixType Q2(Q);
00185     Q2 = Q;
00186 
00187     //Compute B = Q'A of size l x n
00188 
00189     MatrixType B(l, n);
00190     Gemm(Q2, A, B, context);
00191 
00192     // Get SVD of B - Note B is overwritten by U where B = U diag(s) V' is the
00193     // SVD of B.
00194     elem::SVD(B, s, V);
00195 
00196     // Write U = Q B
00197     Gemm(Q2, B, U);
00198 }
00199 #endif
00200 
00201 } } 
00203 #endif // SKYLARK_SKETCHED_SVD_ELEMENTAL_HPP