Skylark (Sketching Library)
0.1
|
00001 #include <skylark.hpp> 00002 #include <boost/mpi.hpp> 00003 #include <elemental.hpp> 00004 #include <iostream> 00005 #include "../base/QR.hpp" 00006 #include <cfloat> 00007 #include <vector> 00008 00009 00011 typedef elem::DistMatrix<double> dist_matrix_t; 00012 typedef elem::Matrix<double> matrix_t; 00013 typedef elem::DistMatrix<double, elem::VR, elem::STAR> vr_star_dist_matrix_t; 00014 typedef elem::DistMatrix<double, elem::STAR, elem::STAR> star_star_matrix_t; 00015 typedef skylark::sketch::JLT_t<dist_matrix_t, dist_matrix_t> sketch_transform_t; 00016 00017 using namespace std; 00018 00019 int main(int argc, char* argv[]) { 00020 00022 #ifdef SKYLARK_HAVE_OPENMP 00023 int provided; 00024 MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); 00025 #endif 00026 boost::mpi::environment env(argc, argv); 00027 boost::mpi::communicator world; 00028 MPI_Comm mpi_world(world); 00029 elem::Grid grid(mpi_world); 00030 00032 elem::Initialize (argc, argv); 00033 00035 skylark::base::context_t context(0); 00036 00038 dist_matrix_t A(grid), B(grid), C(grid); 00039 elem::Uniform(B, 5000, 100); 00040 skylark::base::qr::Explicit(B); 00041 00042 elem::Uniform(C, 100, 100); 00043 skylark::base::qr::Explicit(C); 00044 00045 //star_star_matrix_t S(100,100); 00046 dist_matrix_t S(100,100); 00047 elem::Zero(S); 00048 00049 vector<double> diag(100); 00050 00051 for( int j=0; j<100; ++j ) 00052 { 00053 diag[j] = exp(-j)*100; 00054 std::cout << exp(-j) *100 << "\n"; 00055 } 00056 00057 elem::Diagonal(S, diag); 00058 dist_matrix_t tmp(grid); 00059 00060 elem::Gemm(elem::NORMAL, elem::NORMAL, double(1), B, S, tmp); 00061 elem::Gemm(elem::NORMAL, elem::ADJOINT, double(1), tmp, C, A); 00062 00063 dist_matrix_t U(grid), V(grid); 00064 vr_star_dist_matrix_t S1; 00065 00066 dist_matrix_t A1(A); 00067 elem::SVD(A1,S1,V); 00068 00069 elem::Print(S1, "S1"); 00070 00072 dist_matrix_t A2(A); 00073 dist_matrix_t U1(grid), V1(grid); 00074 vr_star_dist_matrix_t S2; 00075 00076 int sketch_size = 50; 00077 int target_rank = 10; 00078 00079 skylark::nla::rand_svd_params_t params(sketch_size-target_rank); 00080 00081 skylark::nla::randsvd_t<skylark::sketch::JLT_t> rand_svd; 00082 rand_svd(A2, target_rank, U1, S2, V1, params, context); 00083 00084 elem::Print(S2, "S2"); 00085 00086 elem::Finalize(); 00087 return 0; 00088 }