Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/base/svd.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_SVD_HPP
00002 #define SKYLARK_SVD_HPP
00003 
00004 #if SKYLARK_HAVE_ELEMENTAL
00005 #include <elemental.hpp>
00006 #endif
00007 
00008 #include "Gemm.hpp"
00009 
00010 namespace skylark { namespace base {
00011 
00012 #if SKYLARK_HAVE_ELEMENTAL
00013 
00014 
00015 template<typename T>
00016 void SVD(elem::Matrix<T>& A, elem::Matrix< elem::Base<T> >& S,
00017     elem::Matrix<T>& V) {
00018     elem::SVD(A, S, V);
00019 }
00020 
00021 template<typename T>
00022 void SVD(const elem::Matrix<T>& A, elem::Matrix<T>& U,
00023     elem::Matrix<elem::Base<T> >& S,
00024     elem::Matrix<T>& V) {
00025 
00026     U = A;
00027     elem::SVD(U, S, V);
00028 }
00029 
00030 template<typename T>
00031 void SVD(elem::DistMatrix<T, elem::STAR, elem::STAR>& A,
00032     elem::DistMatrix<elem::Base<T>, elem::STAR, elem::STAR>& S,
00033     elem::DistMatrix<T, elem::STAR, elem::STAR>& V) {
00034     elem::SVD(A.Matrix(), S.Matrix(), V.Matrix());
00035     A.Resize(A.Matrix().Height(), A.Matrix().Width());
00036     S.Resize(S.Matrix().Height(), S.Matrix().Width());
00037     V.Resize(V.Matrix().Height(), V.Matrix().Width());
00038 }
00039 
00040 template<typename T>
00041 void SVD(const elem::DistMatrix<T, elem::STAR, elem::STAR>& A,
00042     elem::DistMatrix<T, elem::STAR, elem::STAR>& U,
00043     elem::DistMatrix<elem::Base<T>, elem::STAR, elem::STAR>& S,
00044     elem::DistMatrix<T, elem::STAR, elem::STAR>& V) {
00045 
00046     U = A;
00047     elem::SVD(U.Matrix(), S.Matrix(), V.Matrix());
00048     U.Resize(U.Matrix().Height(), U.Matrix().Width());
00049     S.Resize(S.Matrix().Height(), S.Matrix().Width());
00050     V.Resize(V.Matrix().Height(), V.Matrix().Width());
00051 }
00052 
00053 template<typename T>
00054 void SVD(elem::DistMatrix<T>& A,
00055     elem::DistMatrix<elem::Base<T>, elem::VR, elem::STAR>& S,
00056     elem::DistMatrix<T>& V) {
00057     elem::SVD(A, S, V);
00058 }
00059 
00060 template<typename T>
00061 void SVD(const elem::DistMatrix<T>& A, elem::DistMatrix<T>& U,
00062     elem::DistMatrix<elem::Base<T>, elem::VR, elem::STAR>& S,
00063     elem::DistMatrix<T>& V) {
00064 
00065     U = A;
00066     elem::SVD(U, S, V);
00067 }
00068 
00069 template<typename T>
00070 void SVD(elem::DistMatrix<T, elem::VC,    elem::STAR>& A,
00071     elem::DistMatrix<elem::Base<T>, elem::STAR, elem::STAR>& S,
00072     elem::DistMatrix<T, elem::STAR, elem::STAR>& V) {
00073 
00074     elem::DistMatrix<T, elem::VC,   elem::STAR> Q;
00075     elem::DistMatrix<T, elem::STAR, elem::STAR> R;
00076     elem::Matrix<T> U_tilda;
00077 
00078     // tall and skinny QR (TSQR)
00079     Q = A;
00080     elem::qr::ExplicitTS(Q, R);
00081 
00082     // local SVD of R
00083     base::SVD(R, S, V);
00084     S.Resize(S.Height(), 1);
00085     V.Resize(V.Height(), V.Width());
00086 
00087     // Compute U
00088     A.Resize(Q.Height(), S.Height());
00089     base::Gemm(elem::NORMAL, elem::NORMAL, T(1), Q, R, T(0), A);
00090 }
00091 
00092 template<typename T>
00093 void SVD(const elem::DistMatrix<T, elem::VC,    elem::STAR>& A,
00094     elem::DistMatrix<T, elem::VC,    elem::STAR>& U,
00095     elem::DistMatrix<elem::Base<T>, elem::STAR, elem::STAR>& S,
00096     elem::DistMatrix<T, elem::STAR, elem::STAR>& V) {
00097 
00098     elem::DistMatrix<T, elem::VC,   elem::STAR> Q;
00099     elem::DistMatrix<T, elem::STAR, elem::STAR> R;
00100 
00101     // tall and skinny QR (TSQR)
00102     Q = A;
00103     elem::qr::ExplicitTS(Q, R);
00104 
00105     // local SVD of R
00106      base::SVD(R, S, V);
00107     S.Resize(S.Height(), 1);
00108     V.Resize(V.Height(), V.Width());
00109 
00110     // Compute U
00111     U.Resize(Q.Height(), S.Height());
00112     base::Gemm(elem::NORMAL, elem::NORMAL, T(1), Q, R, T(0), U);
00113 }
00114 
00115 template<typename T>
00116 void SVD(elem::DistMatrix<T, elem::VR,    elem::STAR>& A,
00117     elem::DistMatrix<elem::Base<T>, elem::STAR, elem::STAR>& S,
00118     elem::DistMatrix<T, elem::STAR, elem::STAR>& V) {
00119 
00120     elem::DistMatrix<T, elem::VR,   elem::STAR> Q;
00121     elem::DistMatrix<T, elem::STAR, elem::STAR> R;
00122 
00123     // tall and skinny QR (TSQR)
00124     Q = A;
00125     elem::qr::ExplicitTS(A, R);
00126 
00127     // local SVD of R
00128     base::SVD(R, S, V);
00129     S.Resize(S.Height(), 1);
00130     V.Resize(V.Height(), V.Width());
00131 
00132     // Compute U
00133     A.Resize(A.Height(), S.Height());
00134     base::Gemm(elem::NORMAL, elem::NORMAL, T(1), Q, R, T(0), A);
00135 }
00136 
00137 template<typename T>
00138 void SVD(const elem::DistMatrix<T, elem::VR,    elem::STAR>& A,
00139     elem::DistMatrix<T, elem::VR,    elem::STAR>& U,
00140     elem::DistMatrix<elem::Base<T>, elem::STAR, elem::STAR>& S,
00141     elem::DistMatrix<T, elem::STAR, elem::STAR>& V) {
00142 
00143     elem::DistMatrix<T, elem::VR,   elem::STAR> Q;
00144     elem::DistMatrix<T, elem::STAR, elem::STAR> R;
00145 
00146     // tall and skinny QR (TSQR)
00147     Q = A;
00148     elem::qr::ExplicitTS(Q, R);
00149 
00150     // local SVD of R
00151     base::SVD(R, S, V);
00152     S.Resize(S.Height(), 1);
00153     V.Resize(V.Height(), V.Width());
00154 
00155     // Compute U
00156     U.Resize(Q.Height(), S.Height());
00157     base::Gemm(elem::NORMAL, elem::NORMAL, T(1), Q, R, T(0), U);
00158 }
00159 
00160 template<typename T>
00161 void SVD(const elem::DistMatrix<T, elem::STAR, elem::VC>& A,
00162                elem::DistMatrix<T, elem::STAR, elem::STAR>& U,
00163                elem::DistMatrix<T, elem::STAR, elem::STAR>& S,
00164                elem::DistMatrix<T, elem::VC,   elem::STAR>& V) {
00165 
00166     elem::DistMatrix<T, elem::VC, elem::STAR> A_U_STAR;
00167     elem::Adjoint(A, A_U_STAR);
00168     SVD(A_U_STAR, V, S, U);
00169 }
00170 
00171 
00172 template<typename T>
00173 void SVD(const elem::DistMatrix<T, elem::STAR, elem::VR>& A,
00174                elem::DistMatrix<T, elem::STAR, elem::STAR>& U,
00175                elem::DistMatrix<T, elem::STAR, elem::STAR>& S,
00176                elem::DistMatrix<T, elem::VR,   elem::STAR>& V) {
00177 
00178     elem::DistMatrix<T, elem::VR, elem::STAR> A_U_STAR;
00179     elem::Adjoint(A, A_U_STAR);
00180     SVD(A_U_STAR, V, S, U);
00181 }
00182 
00183 
00184 #endif // SKYLARK_HAVE_ELEMENTAL
00185 
00186 } } // namespace skylark::base
00187 
00188 #endif // SKYLARK_SVD_HPP