Skylark (Sketching Library)
0.1
|
00001 #include <boost/mpi.hpp> 00002 #include <elemental.hpp> 00003 #include <skylark.hpp> 00004 #include <iostream> 00005 #include <boost/test/minimal.hpp> 00006 00007 00008 #include "test_utils.hpp" 00009 00010 00013 typedef elem::Matrix<double> dense_matrix_t; 00014 typedef elem::DistMatrix<double> dist_dense_matrix_t; 00015 typedef elem::DistMatrix<double, elem::CIRC, elem::CIRC> 00016 dist_CIRC_CIRC_dense_matrix_t; 00017 00018 typedef dist_dense_matrix_t input_matrix_t; 00019 typedef dist_dense_matrix_t output_matrix_t; 00020 typedef skylark::sketch::JLT_t<input_matrix_t, output_matrix_t> 00021 sketch_transform_t; 00022 00023 typedef skylark::sketch::JLT_t<dense_matrix_t, dense_matrix_t> 00024 sketch_transform_local_t; 00025 00026 00027 int test_main(int argc, char* argv[]) { 00028 00030 elem::Initialize (argc, argv); 00031 00033 boost::mpi::environment env(argc, argv); 00034 boost::mpi::communicator world; 00035 00036 MPI_Comm mpi_world(world); 00037 elem::Grid grid(mpi_world); 00038 00040 int height = 20; 00041 int width = 10; 00042 int sketch_size = 5; 00043 00044 dist_CIRC_CIRC_dense_matrix_t A_CIRC_CIRC(grid); 00045 input_matrix_t A(grid); 00046 elem::Uniform(A_CIRC_CIRC, height, width); 00047 A = A_CIRC_CIRC; 00048 int size; 00049 00050 00052 size = width; 00053 skylark::base::context_t context_rw(0); 00054 output_matrix_t sketched_A_rw(height, sketch_size); 00055 sketch_transform_t sketch_transform_rw(size, sketch_size, context_rw); 00056 sketch_transform_rw.apply(A, sketched_A_rw, 00057 skylark::sketch::rowwise_tag()); 00058 dist_CIRC_CIRC_dense_matrix_t sketched_A_rw_CIRC_CIRC = sketched_A_rw; 00059 00060 if(world.rank() == 0) { 00061 dense_matrix_t sketched_A_rw_gathered = 00062 sketched_A_rw_CIRC_CIRC.Matrix(); 00063 00064 skylark::base::context_t context_rw_local(0); 00065 dense_matrix_t A_rw_local = A_CIRC_CIRC.Matrix(); 00066 dense_matrix_t sketched_A_rw_local(height, sketch_size); 00067 sketch_transform_local_t sketch_transform_rw_local(size, sketch_size, 00068 context_rw_local); 00069 sketch_transform_rw_local.apply(A_rw_local, sketched_A_rw_local, 00070 skylark::sketch::rowwise_tag()); 00071 00072 if (!equal(sketched_A_rw_gathered, sketched_A_rw_local)) 00073 BOOST_FAIL("Rowwise sketching resuts are not equal"); 00074 } 00075 00076 00078 size = height; 00079 skylark::base::context_t context_cw(0); 00080 output_matrix_t sketched_A_cw(sketch_size, width); 00081 sketch_transform_t sketch_transform_cw(size, sketch_size, context_cw); 00082 sketch_transform_cw.apply(A, sketched_A_cw, 00083 skylark::sketch::columnwise_tag()); 00084 dist_CIRC_CIRC_dense_matrix_t sketched_A_cw_CIRC_CIRC = sketched_A_cw; 00085 00086 if(world.rank() == 0) { 00087 dense_matrix_t sketched_A_cw_gathered = 00088 sketched_A_cw_CIRC_CIRC.Matrix(); 00089 00090 skylark::base::context_t context_cw_local(0); 00091 dense_matrix_t A_cw_local = A_CIRC_CIRC.Matrix(); 00092 dense_matrix_t sketched_A_cw_local(sketch_size, width); 00093 sketch_transform_local_t sketch_transform_cw_local(size, sketch_size, 00094 context_cw_local); 00095 sketch_transform_cw_local.apply(A_cw_local, sketched_A_cw_local, 00096 skylark::sketch::columnwise_tag()); 00097 00098 if (!equal(sketched_A_cw_gathered, sketched_A_cw_local)) 00099 BOOST_FAIL("Columnwise sketching resuts are not equal"); 00100 } 00101 00102 00103 elem::Finalize(); 00104 return 0; 00105 }