Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/tests/unit/test_utils.hpp
Go to the documentation of this file.
00001 #ifndef TEST_UTILS_HPP
00002 #define TEST_UTILS_HPP
00003 
00004 #include "../../base/svd.hpp"
00005 
00006 #if SKYLARK_HAVE_BOOST
00007 #include <boost/test/minimal.hpp>
00008 #endif
00009 
00010 #if SKYLARK_HAVE_ELEMENTAL
00011 #include <elemental.hpp>
00012 #endif
00013 
00014 
00015 #if SKYLARK_HAVE_ELEMENTAL
00016 
00017 template<typename MatrixType>
00018 MatrixType operator-(MatrixType& A, MatrixType& B) {
00019     MatrixType C;
00020     elem::Copy(A, C);
00021     elem::Axpy(-1.0, B, C);
00022     return C;
00023 }
00024 
00025 
00026 template<typename MatrixType>
00027 bool equal(MatrixType& A, MatrixType& B,  double threshold=1.e-4) {
00028     MatrixType C = A - B;
00029     double diff_norm = elem::Norm(C);
00030     if (diff_norm < threshold) {
00031         return true;
00032     }
00033     return false;
00034 }
00035 
00036 
00037 template<typename InputMatrixType,
00038          typename LeftSingularVectorsMatrixType,
00039          typename SingularValuesMatrixType,
00040          typename RightSingularVectorsMatrixType>
00041 bool equal_svd_product(InputMatrixType& A,
00042     LeftSingularVectorsMatrixType& U,
00043     SingularValuesMatrixType& S,
00044     RightSingularVectorsMatrixType& V,
00045     double threshold=1e-4) {
00046 
00047     elem::DistMatrix<double> S_CIRC_CIRC = S;
00048     std::vector<double> values(S_CIRC_CIRC.Buffer(),
00049         S_CIRC_CIRC.Buffer() + S_CIRC_CIRC.Height());
00050     elem::Diagonal(S_CIRC_CIRC, values);
00051     elem::DistMatrix<double> S_MC_MR = S_CIRC_CIRC;
00052 
00053     elem::DistMatrix<double> A_MC_MR = A;
00054     elem::DistMatrix<double> U_MC_MR = U;
00055     elem::DistMatrix<double> V_MC_MR = V;
00056     elem::DistMatrix<double> US_MC_MR;
00057     elem::DistMatrix<double> USVt_MC_MR;
00058 
00059     US_MC_MR.Resize(U.Height(), S_CIRC_CIRC.Width());
00060     elem::Zero(US_MC_MR);
00061     USVt_MC_MR.Resize(U.Height(), V.Height());
00062 
00063     elem::Zero(USVt_MC_MR);
00064     elem::Gemm(elem::NORMAL, elem::NORMAL,    1.0, U_MC_MR,
00065         S_MC_MR, 0.0, US_MC_MR);
00066     elem::Gemm(elem::NORMAL, elem::TRANSPOSE, 1.0, US_MC_MR,
00067         V_MC_MR, 0.0, USVt_MC_MR);
00068 
00069     return equal(A_MC_MR, USVt_MC_MR, threshold);
00070 }
00071 
00072 
00073 #if SKYLARK_HAVE_BOOST
00074 
00075 void check(elem::DistMatrix<double>& A,
00076     double threshold=1e-4) {
00077     elem::DistMatrix<double> U, V;
00078     elem::DistMatrix<double, elem::VR, elem::STAR> S_VR_STAR;
00079     skylark::base::SVD(A, U, S_VR_STAR, V);
00080     bool passed = equal_svd_product(A, U, S_VR_STAR, V, threshold);
00081     if (!passed) {
00082         BOOST_FAIL("Failure in [MC, MR] case");
00083     }
00084 
00085 }
00086 
00087 
00088 template<elem::Distribution ColDist>
00089 void check(elem::DistMatrix<double, ColDist, elem::STAR>& A,
00090     double threshold=1e-4) {
00091     elem::DistMatrix<double, ColDist, elem::STAR> A_CD_STAR, U_CD_STAR;
00092     elem::DistMatrix<double, elem::STAR, elem::STAR> S_STAR_STAR, V_STAR_STAR;
00093     A_CD_STAR = A;
00094     skylark::base::SVD(A_CD_STAR, U_CD_STAR, S_STAR_STAR, V_STAR_STAR);
00095     bool passed = equal_svd_product(A_CD_STAR,
00096         U_CD_STAR, S_STAR_STAR, V_STAR_STAR, threshold);
00097     if (!passed) {
00098         BOOST_FAIL("Failure in [VC/VR, *] case");
00099     }
00100 }
00101 
00102 template<elem::Distribution RowDist>
00103 void check(elem::DistMatrix<double, elem::STAR, RowDist>& A,
00104     double threshold=1e-4) {
00105     elem::DistMatrix<double, RowDist, elem::STAR> V_RD_STAR;
00106     elem::DistMatrix<double, elem::STAR, elem::STAR> S_STAR_STAR, U_STAR_STAR;
00107     skylark::base::SVD(A, U_STAR_STAR, S_STAR_STAR, V_RD_STAR);
00108     bool passed = equal_svd_product(A,
00109         U_STAR_STAR, S_STAR_STAR, V_RD_STAR, threshold);
00110     if (!passed) {
00111         BOOST_FAIL("Failure in [*, VC/VR] case");
00112     }
00113 
00114 }
00115 
00116 
00117 #endif // SKYLARK_HAVE_BOOST
00118 
00119 #endif // SKYLARK_HAVE_ELEMENTAL
00120 
00121 #endif // TEST_UTILS_HPP