Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_SVD_HPP 00002 #define SKYLARK_SVD_HPP 00003 00004 #if SKYLARK_HAVE_ELEMENTAL 00005 #include <elemental.hpp> 00006 #endif 00007 00008 #include "Gemm.hpp" 00009 00010 namespace skylark { namespace base { 00011 00012 #if SKYLARK_HAVE_ELEMENTAL 00013 00014 00015 template<typename T> 00016 void SVD(elem::Matrix<T>& A, elem::Matrix< elem::Base<T> >& S, 00017 elem::Matrix<T>& V) { 00018 elem::SVD(A, S, V); 00019 } 00020 00021 template<typename T> 00022 void SVD(const elem::Matrix<T>& A, elem::Matrix<T>& U, 00023 elem::Matrix<elem::Base<T> >& S, 00024 elem::Matrix<T>& V) { 00025 00026 U = A; 00027 elem::SVD(U, S, V); 00028 } 00029 00030 template<typename T> 00031 void SVD(elem::DistMatrix<T, elem::STAR, elem::STAR>& A, 00032 elem::DistMatrix<elem::Base<T>, elem::STAR, elem::STAR>& S, 00033 elem::DistMatrix<T, elem::STAR, elem::STAR>& V) { 00034 elem::SVD(A.Matrix(), S.Matrix(), V.Matrix()); 00035 A.Resize(A.Matrix().Height(), A.Matrix().Width()); 00036 S.Resize(S.Matrix().Height(), S.Matrix().Width()); 00037 V.Resize(V.Matrix().Height(), V.Matrix().Width()); 00038 } 00039 00040 template<typename T> 00041 void SVD(const elem::DistMatrix<T, elem::STAR, elem::STAR>& A, 00042 elem::DistMatrix<T, elem::STAR, elem::STAR>& U, 00043 elem::DistMatrix<elem::Base<T>, elem::STAR, elem::STAR>& S, 00044 elem::DistMatrix<T, elem::STAR, elem::STAR>& V) { 00045 00046 U = A; 00047 elem::SVD(U.Matrix(), S.Matrix(), V.Matrix()); 00048 U.Resize(U.Matrix().Height(), U.Matrix().Width()); 00049 S.Resize(S.Matrix().Height(), S.Matrix().Width()); 00050 V.Resize(V.Matrix().Height(), V.Matrix().Width()); 00051 } 00052 00053 template<typename T> 00054 void SVD(elem::DistMatrix<T>& A, 00055 elem::DistMatrix<elem::Base<T>, elem::VR, elem::STAR>& S, 00056 elem::DistMatrix<T>& V) { 00057 elem::SVD(A, S, V); 00058 } 00059 00060 template<typename T> 00061 void SVD(const elem::DistMatrix<T>& A, elem::DistMatrix<T>& U, 00062 elem::DistMatrix<elem::Base<T>, elem::VR, elem::STAR>& S, 00063 elem::DistMatrix<T>& V) { 00064 00065 U = A; 00066 elem::SVD(U, S, V); 00067 } 00068 00069 template<typename T> 00070 void SVD(elem::DistMatrix<T, elem::VC, elem::STAR>& A, 00071 elem::DistMatrix<elem::Base<T>, elem::STAR, elem::STAR>& S, 00072 elem::DistMatrix<T, elem::STAR, elem::STAR>& V) { 00073 00074 elem::DistMatrix<T, elem::VC, elem::STAR> Q; 00075 elem::DistMatrix<T, elem::STAR, elem::STAR> R; 00076 elem::Matrix<T> U_tilda; 00077 00078 // tall and skinny QR (TSQR) 00079 Q = A; 00080 elem::qr::ExplicitTS(Q, R); 00081 00082 // local SVD of R 00083 base::SVD(R, S, V); 00084 S.Resize(S.Height(), 1); 00085 V.Resize(V.Height(), V.Width()); 00086 00087 // Compute U 00088 A.Resize(Q.Height(), S.Height()); 00089 base::Gemm(elem::NORMAL, elem::NORMAL, T(1), Q, R, T(0), A); 00090 } 00091 00092 template<typename T> 00093 void SVD(const elem::DistMatrix<T, elem::VC, elem::STAR>& A, 00094 elem::DistMatrix<T, elem::VC, elem::STAR>& U, 00095 elem::DistMatrix<elem::Base<T>, elem::STAR, elem::STAR>& S, 00096 elem::DistMatrix<T, elem::STAR, elem::STAR>& V) { 00097 00098 elem::DistMatrix<T, elem::VC, elem::STAR> Q; 00099 elem::DistMatrix<T, elem::STAR, elem::STAR> R; 00100 00101 // tall and skinny QR (TSQR) 00102 Q = A; 00103 elem::qr::ExplicitTS(Q, R); 00104 00105 // local SVD of R 00106 base::SVD(R, S, V); 00107 S.Resize(S.Height(), 1); 00108 V.Resize(V.Height(), V.Width()); 00109 00110 // Compute U 00111 U.Resize(Q.Height(), S.Height()); 00112 base::Gemm(elem::NORMAL, elem::NORMAL, T(1), Q, R, T(0), U); 00113 } 00114 00115 template<typename T> 00116 void SVD(elem::DistMatrix<T, elem::VR, elem::STAR>& A, 00117 elem::DistMatrix<elem::Base<T>, elem::STAR, elem::STAR>& S, 00118 elem::DistMatrix<T, elem::STAR, elem::STAR>& V) { 00119 00120 elem::DistMatrix<T, elem::VR, elem::STAR> Q; 00121 elem::DistMatrix<T, elem::STAR, elem::STAR> R; 00122 00123 // tall and skinny QR (TSQR) 00124 Q = A; 00125 elem::qr::ExplicitTS(A, R); 00126 00127 // local SVD of R 00128 base::SVD(R, S, V); 00129 S.Resize(S.Height(), 1); 00130 V.Resize(V.Height(), V.Width()); 00131 00132 // Compute U 00133 A.Resize(A.Height(), S.Height()); 00134 base::Gemm(elem::NORMAL, elem::NORMAL, T(1), Q, R, T(0), A); 00135 } 00136 00137 template<typename T> 00138 void SVD(const elem::DistMatrix<T, elem::VR, elem::STAR>& A, 00139 elem::DistMatrix<T, elem::VR, elem::STAR>& U, 00140 elem::DistMatrix<elem::Base<T>, elem::STAR, elem::STAR>& S, 00141 elem::DistMatrix<T, elem::STAR, elem::STAR>& V) { 00142 00143 elem::DistMatrix<T, elem::VR, elem::STAR> Q; 00144 elem::DistMatrix<T, elem::STAR, elem::STAR> R; 00145 00146 // tall and skinny QR (TSQR) 00147 Q = A; 00148 elem::qr::ExplicitTS(Q, R); 00149 00150 // local SVD of R 00151 base::SVD(R, S, V); 00152 S.Resize(S.Height(), 1); 00153 V.Resize(V.Height(), V.Width()); 00154 00155 // Compute U 00156 U.Resize(Q.Height(), S.Height()); 00157 base::Gemm(elem::NORMAL, elem::NORMAL, T(1), Q, R, T(0), U); 00158 } 00159 00160 template<typename T> 00161 void SVD(const elem::DistMatrix<T, elem::STAR, elem::VC>& A, 00162 elem::DistMatrix<T, elem::STAR, elem::STAR>& U, 00163 elem::DistMatrix<T, elem::STAR, elem::STAR>& S, 00164 elem::DistMatrix<T, elem::VC, elem::STAR>& V) { 00165 00166 elem::DistMatrix<T, elem::VC, elem::STAR> A_U_STAR; 00167 elem::Adjoint(A, A_U_STAR); 00168 SVD(A_U_STAR, V, S, U); 00169 } 00170 00171 00172 template<typename T> 00173 void SVD(const elem::DistMatrix<T, elem::STAR, elem::VR>& A, 00174 elem::DistMatrix<T, elem::STAR, elem::STAR>& U, 00175 elem::DistMatrix<T, elem::STAR, elem::STAR>& S, 00176 elem::DistMatrix<T, elem::VR, elem::STAR>& V) { 00177 00178 elem::DistMatrix<T, elem::VR, elem::STAR> A_U_STAR; 00179 elem::Adjoint(A, A_U_STAR); 00180 SVD(A_U_STAR, V, S, U); 00181 } 00182 00183 00184 #endif // SKYLARK_HAVE_ELEMENTAL 00185 00186 } } // namespace skylark::base 00187 00188 #endif // SKYLARK_SVD_HPP