Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_LSQR_HPP 00002 #define SKYLARK_LSQR_HPP 00003 00004 #include "../../base/base.hpp" 00005 #include "../../utility/elem_extender.hpp" 00006 #include "../../utility/typer.hpp" 00007 #include "../../utility/external/print.hpp" 00008 #include "precond.hpp" 00009 00010 namespace skylark { namespace nla { 00011 00012 // We can have a version that is indpendent of Elemental. But that will 00013 // be tedious (convert between [STAR,STAR] and vector<T>, and really 00014 // elemental is a very fudmanetal to Skylark. 00015 #if SKYLARK_HAVE_ELEMENTAL 00016 00022 template<typename MatrixType, typename RhsType, typename SolType> 00023 int LSQR(const MatrixType& A, const RhsType& B, SolType& X, 00024 iter_params_t params = iter_params_t(), 00025 const precond_t<SolType>& R = id_precond_t<SolType>()) { 00026 00027 typedef typename utility::typer_t<MatrixType>::value_type value_t; 00028 typedef typename utility::typer_t<MatrixType>::index_type index_t; 00029 00030 typedef MatrixType matrix_type; 00031 typedef RhsType rhs_type; // Also serves as "long" vector type. 00032 typedef SolType sol_type; // Also serves as "short" vector type. 00033 00034 typedef utility::print_t<rhs_type> rhs_print_t; 00035 typedef utility::print_t<sol_type> sol_print_t; 00036 00037 typedef utility::elem_extender_t< 00038 elem::DistMatrix<value_t, elem::STAR, elem::STAR> > 00039 scalar_cont_type; 00040 00041 bool log_lev1 = params.am_i_printing && params.log_level >= 1; 00042 bool log_lev2 = params.am_i_printing && params.log_level >= 2; 00043 00045 index_t m = base::Height(A); 00046 index_t n = base::Width(A); 00047 index_t k = base::Width(B); 00048 00050 const value_t eps = 32*std::numeric_limits<value_t>::epsilon(); 00051 if (params.tolerance<eps) params.tolerance=eps; 00052 else if (params.tolerance>=1.0) params.tolerance=(1-eps); 00053 else {} /* nothing */ 00054 00056 // We set the grid and rank for beta, and all other scalar containers 00057 // just copy from him to get that to be set right (not for the values). 00058 rhs_type U(B); 00059 scalar_cont_type beta(k, 1, A.Grid(), A.Root()), i_beta(beta); 00060 base::ColumnNrm2(U, beta); 00061 for (index_t i=0; i<k; ++i) 00062 i_beta[i] = 1 / beta[i]; 00063 base::DiagonalScale(elem::RIGHT, elem::NORMAL, i_beta, U); 00064 rhs_print_t::apply(U, "U Init", params.am_i_printing, params.debug_level); 00065 00066 sol_type V(X); // No need to really copy, just want sizes&comm correct. 00067 base::Gemm(elem::ADJOINT, elem::NORMAL, 1.0, A, U, V); 00068 R.apply_adjoint(V); 00069 scalar_cont_type alpha(beta), i_alpha(beta); 00070 base::ColumnNrm2(V, alpha); 00071 for (index_t i=0; i<k; ++i) 00072 i_alpha[i] = 1 / alpha[i]; 00073 base::DiagonalScale(elem::RIGHT, elem::NORMAL, i_alpha, V); 00074 sol_type Z(V); 00075 R.apply(Z); 00076 sol_print_t::apply(V, "V Init", params.am_i_printing, params.debug_level); 00077 00078 /* Create W=Z and X=0 */ 00079 base::Zero(X); 00080 sol_type W(Z); 00081 scalar_cont_type phibar(beta), rhobar(alpha), nrm_r(beta); 00082 // /!\ Actually copied for init 00083 scalar_cont_type nrm_a(beta), cnd_a(beta), sq_d(beta), nrm_ar_0(beta); 00084 base::Zero(nrm_a); base::Zero(cnd_a); base::Zero(sq_d); 00085 elem::Hadamard(alpha, beta, nrm_ar_0); 00086 00088 for (index_t i=0; i<k; ++i) 00089 if (nrm_ar_0[i]==0) 00090 return 0; 00091 00092 scalar_cont_type nrm_x(beta), sq_x(beta), z(beta), cs2(beta), sn2(beta); 00093 elem::Zero(nrm_x); elem::Zero(sq_x); elem::Zero(z); elem::Zero(sn2); 00094 for (index_t i=0; i<k; ++i) 00095 cs2[i] = -1.0; 00096 00097 int max_n_stag = 3; 00098 std::vector<int> stag(k, 0); 00099 00100 /* Reset the iteration limit if none was specified */ 00101 if (0>params.iter_lim) params.iter_lim = std::max(20, 2*std::min(m,n)); 00102 00103 /* More varaibles */ 00104 sol_type AU(X); 00105 scalar_cont_type minus_beta(beta), rho(beta); 00106 scalar_cont_type cs(beta), sn(beta), theta(beta), phi(beta); 00107 scalar_cont_type phi_by_rho(beta), minus_theta_by_rho(beta), nrm_ar(beta); 00108 scalar_cont_type nrm_w(beta), sq_w(beta), gamma(beta); 00109 scalar_cont_type delta(beta), gambar(beta), rhs(beta), zbar(beta); 00110 00112 for (index_t itn=0; itn<params.iter_lim; ++itn) { 00113 00115 elem::Scal(-1.0, alpha); // Can safely overwrite based on subseq ops. 00116 base::DiagonalScale(elem::RIGHT, elem::NORMAL, alpha, U); 00117 base::Gemm(elem::NORMAL, elem::NORMAL, 1.0, A, Z, 1.0, U); 00118 base::ColumnNrm2(U, beta); 00119 for (index_t i=0; i<k; ++i) 00120 i_beta[i] = 1 / beta[i]; 00121 base::DiagonalScale(elem::RIGHT, elem::NORMAL, i_beta, U); 00122 00124 for (index_t i=0; i<k; ++i) { 00125 double a = nrm_a[i], b = alpha[i], c = beta[i]; 00126 nrm_a[i] = sqrt(a*a + b*b + c*c); 00127 } 00128 00130 for (index_t i=0; i<k; ++i) 00131 minus_beta[i] = -beta[i]; 00132 base::DiagonalScale(elem::RIGHT, elem::NORMAL, minus_beta, V); 00133 base::Gemm(elem::ADJOINT, elem::NORMAL, 1.0, A, U, AU); 00134 R.apply_adjoint(AU); 00135 base::Axpy(1.0, AU, V); 00136 base::ColumnNrm2(V, alpha); 00137 for (index_t i=0; i<k; ++i) 00138 i_alpha[i] = 1 / alpha[i]; 00139 base::DiagonalScale(elem::RIGHT, elem::NORMAL, i_alpha, V); 00140 Z = V; R.apply(Z); 00141 00143 for (index_t i=0; i<k; ++i) { 00144 rho[i] = sqrt((rhobar[i]*rhobar[i]) + (beta[i]*beta[i])); 00145 cs[i] = rhobar[i]/rho[i]; 00146 sn[i] = beta[i]/rho[i]; 00147 theta[i] = sn[i]*alpha[i]; 00148 rhobar[i] = -cs[i]*alpha[i]; 00149 phi[i] = cs[i]*phibar[i]; 00150 phibar[i] = sn[i]*phibar[i]; 00151 } 00152 00154 for (index_t i=0; i<k; ++i) 00155 phi_by_rho[i] = phi[i]/rho[i]; 00156 base::Axpy(phi_by_rho, W, X); 00157 sol_print_t::apply(X, "X", params.am_i_printing, params.debug_level); 00158 00159 for (index_t i=0; i<k; ++i) 00160 minus_theta_by_rho[i] = -theta[i]/rho[i]; 00161 base::DiagonalScale(elem::RIGHT, elem::NORMAL, minus_theta_by_rho, W); 00162 base::Axpy(1.0, Z, W); 00163 sol_print_t::apply(W, "W", params.am_i_printing, params.debug_level); 00164 00166 nrm_r = phibar; 00167 00169 for (index_t i=0; i<k; ++i) { 00170 nrm_ar[i] = std::abs(phibar[i]*alpha[i]*cs[i]); 00171 00172 if (log_lev2) 00173 params.log_stream << "LSQR: Iteration " << i << "/" << itn 00174 << ": " << nrm_ar[i] 00175 << std::endl; 00176 00178 if (nrm_ar[i]<(params.tolerance*nrm_ar_0[i])) { 00179 if (log_lev1) 00180 params.log_stream << "LSQR: Convergence (S1)!" << std::endl; 00181 return -2; 00182 } 00183 00184 if (nrm_ar[i]<(eps*nrm_a[i]*nrm_r[i])) { 00185 if (log_lev1) 00186 params.log_stream << "LSQR: Convergence (S2)!" << std::endl; 00187 return -3; 00188 } 00189 } 00190 00192 base::ColumnNrm2(W, nrm_w); 00193 for (index_t i=0; i<k; ++i) { 00194 sq_w[i] = nrm_w[i]*nrm_w[i]; 00195 sq_d[i] += sq_w[i]/(rho[i]*rho[i]); 00196 cnd_a[i] = nrm_a[i]*sqrt(sq_d[i]); 00197 00199 if (cnd_a[i]>(1.0/eps)) { 00200 if (log_lev1) 00201 params.log_stream << "LSQR: Stopping (S3)!" << std::endl; 00202 return -4; 00203 } 00204 } 00205 00207 for (index_t i=0; i<k; ++i) { 00208 if (std::abs(phi[i]/rho[i])*nrm_w[i] < (eps*nrm_x[i])) 00209 stag[i]++; 00210 else 00211 stag[i] = 0; 00212 00213 if (stag[i] >= max_n_stag) { 00214 if (log_lev1) 00215 params.log_stream << "LSQR: Stagnation." << std::endl; 00216 return -5; 00217 } 00218 } 00219 00221 for (index_t i=0; i<k; ++i) { 00222 delta[i] = sn2[i]*rho[i]; 00223 gambar[i] = -cs2[i]*rho[i]; 00224 rhs[i] = phi[i] - delta[i]*z[i]; 00225 zbar[i] = rhs[i]/gambar[i]; 00226 nrm_x[i] = sqrt(sq_x[i] + (zbar[i]*zbar[i])); 00227 gamma[i] = sqrt((gambar[i]*gambar[i]) + (theta[i]*theta[i])); 00228 cs2[i] = gambar[i]/gamma[i]; 00229 sn2[i] = theta[i]/gamma[i]; 00230 z[i] = rhs[i]/gamma[i]; 00231 sq_x[i] += z[i]*z[i]; 00232 } 00233 } 00234 if (log_lev1) 00235 params.log_stream << "LSQR: No convergence within iteration limit." 00236 << std::endl; 00237 00238 return -6; 00239 } 00240 00241 #endif 00242 00243 } } 00245 #endif // SKYLARK_LSQR_HPP