Skylark (Sketching Library)
0.1
|
00001 #ifndef TEST_UTILS_HPP 00002 #define TEST_UTILS_HPP 00003 00004 #include "../../base/svd.hpp" 00005 00006 #if SKYLARK_HAVE_BOOST 00007 #include <boost/test/minimal.hpp> 00008 #endif 00009 00010 #if SKYLARK_HAVE_ELEMENTAL 00011 #include <elemental.hpp> 00012 #endif 00013 00014 00015 #if SKYLARK_HAVE_ELEMENTAL 00016 00017 template<typename MatrixType> 00018 MatrixType operator-(MatrixType& A, MatrixType& B) { 00019 MatrixType C; 00020 elem::Copy(A, C); 00021 elem::Axpy(-1.0, B, C); 00022 return C; 00023 } 00024 00025 00026 template<typename MatrixType> 00027 bool equal(MatrixType& A, MatrixType& B, double threshold=1.e-4) { 00028 MatrixType C = A - B; 00029 double diff_norm = elem::Norm(C); 00030 if (diff_norm < threshold) { 00031 return true; 00032 } 00033 return false; 00034 } 00035 00036 00037 template<typename InputMatrixType, 00038 typename LeftSingularVectorsMatrixType, 00039 typename SingularValuesMatrixType, 00040 typename RightSingularVectorsMatrixType> 00041 bool equal_svd_product(InputMatrixType& A, 00042 LeftSingularVectorsMatrixType& U, 00043 SingularValuesMatrixType& S, 00044 RightSingularVectorsMatrixType& V, 00045 double threshold=1e-4) { 00046 00047 elem::DistMatrix<double> S_CIRC_CIRC = S; 00048 std::vector<double> values(S_CIRC_CIRC.Buffer(), 00049 S_CIRC_CIRC.Buffer() + S_CIRC_CIRC.Height()); 00050 elem::Diagonal(S_CIRC_CIRC, values); 00051 elem::DistMatrix<double> S_MC_MR = S_CIRC_CIRC; 00052 00053 elem::DistMatrix<double> A_MC_MR = A; 00054 elem::DistMatrix<double> U_MC_MR = U; 00055 elem::DistMatrix<double> V_MC_MR = V; 00056 elem::DistMatrix<double> US_MC_MR; 00057 elem::DistMatrix<double> USVt_MC_MR; 00058 00059 US_MC_MR.Resize(U.Height(), S_CIRC_CIRC.Width()); 00060 elem::Zero(US_MC_MR); 00061 USVt_MC_MR.Resize(U.Height(), V.Height()); 00062 00063 elem::Zero(USVt_MC_MR); 00064 elem::Gemm(elem::NORMAL, elem::NORMAL, 1.0, U_MC_MR, 00065 S_MC_MR, 0.0, US_MC_MR); 00066 elem::Gemm(elem::NORMAL, elem::TRANSPOSE, 1.0, US_MC_MR, 00067 V_MC_MR, 0.0, USVt_MC_MR); 00068 00069 return equal(A_MC_MR, USVt_MC_MR, threshold); 00070 } 00071 00072 00073 #if SKYLARK_HAVE_BOOST 00074 00075 void check(elem::DistMatrix<double>& A, 00076 double threshold=1e-4) { 00077 elem::DistMatrix<double> U, V; 00078 elem::DistMatrix<double, elem::VR, elem::STAR> S_VR_STAR; 00079 skylark::base::SVD(A, U, S_VR_STAR, V); 00080 bool passed = equal_svd_product(A, U, S_VR_STAR, V, threshold); 00081 if (!passed) { 00082 BOOST_FAIL("Failure in [MC, MR] case"); 00083 } 00084 00085 } 00086 00087 00088 template<elem::Distribution ColDist> 00089 void check(elem::DistMatrix<double, ColDist, elem::STAR>& A, 00090 double threshold=1e-4) { 00091 elem::DistMatrix<double, ColDist, elem::STAR> A_CD_STAR, U_CD_STAR; 00092 elem::DistMatrix<double, elem::STAR, elem::STAR> S_STAR_STAR, V_STAR_STAR; 00093 A_CD_STAR = A; 00094 skylark::base::SVD(A_CD_STAR, U_CD_STAR, S_STAR_STAR, V_STAR_STAR); 00095 bool passed = equal_svd_product(A_CD_STAR, 00096 U_CD_STAR, S_STAR_STAR, V_STAR_STAR, threshold); 00097 if (!passed) { 00098 BOOST_FAIL("Failure in [VC/VR, *] case"); 00099 } 00100 } 00101 00102 template<elem::Distribution RowDist> 00103 void check(elem::DistMatrix<double, elem::STAR, RowDist>& A, 00104 double threshold=1e-4) { 00105 elem::DistMatrix<double, RowDist, elem::STAR> V_RD_STAR; 00106 elem::DistMatrix<double, elem::STAR, elem::STAR> S_STAR_STAR, U_STAR_STAR; 00107 skylark::base::SVD(A, U_STAR_STAR, S_STAR_STAR, V_RD_STAR); 00108 bool passed = equal_svd_product(A, 00109 U_STAR_STAR, S_STAR_STAR, V_RD_STAR, threshold); 00110 if (!passed) { 00111 BOOST_FAIL("Failure in [*, VC/VR] case"); 00112 } 00113 00114 } 00115 00116 00117 #endif // SKYLARK_HAVE_BOOST 00118 00119 #endif // SKYLARK_HAVE_ELEMENTAL 00120 00121 #endif // TEST_UTILS_HPP