Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/algorithms/Krylov/LSQR.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_LSQR_HPP
00002 #define SKYLARK_LSQR_HPP
00003 
00004 #include "../../base/base.hpp"
00005 #include "../../utility/elem_extender.hpp"
00006 #include "../../utility/typer.hpp"
00007 #include "../../utility/external/print.hpp"
00008 #include "precond.hpp"
00009 
00010 namespace skylark { namespace nla {
00011 
00012 // We can have a version that is indpendent of Elemental. But that will
00013 // be tedious (convert between [STAR,STAR] and vector<T>, and really
00014 // elemental is a very fudmanetal to Skylark.
00015 #if SKYLARK_HAVE_ELEMENTAL
00016 
00022 template<typename MatrixType, typename RhsType, typename SolType>
00023 int LSQR(const MatrixType& A, const RhsType& B, SolType& X,
00024     iter_params_t params = iter_params_t(),
00025     const precond_t<SolType>& R = id_precond_t<SolType>()) {
00026 
00027     typedef typename utility::typer_t<MatrixType>::value_type value_t;
00028     typedef typename utility::typer_t<MatrixType>::index_type index_t;
00029 
00030     typedef MatrixType matrix_type;
00031     typedef RhsType rhs_type;        // Also serves as "long" vector type.
00032     typedef SolType sol_type;        // Also serves as "short" vector type.
00033 
00034     typedef utility::print_t<rhs_type> rhs_print_t;
00035     typedef utility::print_t<sol_type> sol_print_t;
00036 
00037     typedef utility::elem_extender_t<
00038         elem::DistMatrix<value_t, elem::STAR, elem::STAR> >
00039         scalar_cont_type;
00040 
00041     bool log_lev1 = params.am_i_printing && params.log_level >= 1;
00042     bool log_lev2 = params.am_i_printing && params.log_level >= 2;
00043 
00045     index_t m = base::Height(A);
00046     index_t n = base::Width(A);
00047     index_t k = base::Width(B);
00048 
00050     const value_t eps = 32*std::numeric_limits<value_t>::epsilon();
00051     if (params.tolerance<eps) params.tolerance=eps;
00052     else if (params.tolerance>=1.0) params.tolerance=(1-eps);
00053     else {} /* nothing */
00054 
00056     // We set the grid and rank for beta, and all other scalar containers
00057     // just copy from him to get that to be set right (not for the values).
00058     rhs_type U(B);
00059     scalar_cont_type beta(k, 1, A.Grid(), A.Root()), i_beta(beta);
00060     base::ColumnNrm2(U, beta);
00061     for (index_t i=0; i<k; ++i)
00062         i_beta[i] = 1 / beta[i];
00063     base::DiagonalScale(elem::RIGHT, elem::NORMAL, i_beta, U);
00064     rhs_print_t::apply(U, "U Init", params.am_i_printing, params.debug_level);
00065 
00066     sol_type V(X);     // No need to really copy, just want sizes&comm correct.
00067     base::Gemm(elem::ADJOINT, elem::NORMAL, 1.0, A, U, V);
00068     R.apply_adjoint(V);
00069     scalar_cont_type alpha(beta), i_alpha(beta);
00070     base::ColumnNrm2(V, alpha);
00071     for (index_t i=0; i<k; ++i)
00072         i_alpha[i] = 1 / alpha[i];
00073     base::DiagonalScale(elem::RIGHT, elem::NORMAL, i_alpha, V);
00074     sol_type Z(V);
00075     R.apply(Z);
00076     sol_print_t::apply(V, "V Init", params.am_i_printing, params.debug_level);
00077 
00078     /* Create W=Z and X=0 */
00079     base::Zero(X);
00080     sol_type W(Z);
00081     scalar_cont_type phibar(beta), rhobar(alpha), nrm_r(beta);
00082         // /!\ Actually copied for init
00083     scalar_cont_type nrm_a(beta), cnd_a(beta), sq_d(beta), nrm_ar_0(beta);
00084     base::Zero(nrm_a); base::Zero(cnd_a); base::Zero(sq_d);
00085     elem::Hadamard(alpha, beta, nrm_ar_0);
00086 
00088     for (index_t i=0; i<k; ++i)
00089         if (nrm_ar_0[i]==0)
00090             return 0;
00091 
00092     scalar_cont_type nrm_x(beta), sq_x(beta), z(beta), cs2(beta), sn2(beta);
00093     elem::Zero(nrm_x); elem::Zero(sq_x); elem::Zero(z); elem::Zero(sn2);
00094     for (index_t i=0; i<k; ++i)
00095         cs2[i] = -1.0;
00096 
00097     int max_n_stag = 3;
00098     std::vector<int> stag(k, 0);
00099 
00100     /* Reset the iteration limit if none was specified */
00101     if (0>params.iter_lim) params.iter_lim = std::max(20, 2*std::min(m,n));
00102 
00103     /* More varaibles */
00104     sol_type AU(X);
00105     scalar_cont_type minus_beta(beta), rho(beta);
00106     scalar_cont_type cs(beta), sn(beta), theta(beta), phi(beta);
00107     scalar_cont_type phi_by_rho(beta), minus_theta_by_rho(beta), nrm_ar(beta);
00108     scalar_cont_type nrm_w(beta), sq_w(beta), gamma(beta);
00109     scalar_cont_type delta(beta), gambar(beta), rhs(beta), zbar(beta);
00110 
00112     for (index_t itn=0; itn<params.iter_lim; ++itn) {
00113 
00115         elem::Scal(-1.0, alpha);   // Can safely overwrite based on subseq ops.
00116         base::DiagonalScale(elem::RIGHT, elem::NORMAL, alpha, U);
00117         base::Gemm(elem::NORMAL, elem::NORMAL, 1.0, A, Z, 1.0, U);
00118         base::ColumnNrm2(U, beta);
00119         for (index_t i=0; i<k; ++i)
00120             i_beta[i] = 1 / beta[i];
00121         base::DiagonalScale(elem::RIGHT, elem::NORMAL, i_beta, U);
00122 
00124         for (index_t i=0; i<k; ++i) {
00125             double a = nrm_a[i], b = alpha[i], c = beta[i];
00126             nrm_a[i] = sqrt(a*a + b*b + c*c);
00127         }
00128 
00130         for (index_t i=0; i<k; ++i)
00131             minus_beta[i] = -beta[i];
00132         base::DiagonalScale(elem::RIGHT, elem::NORMAL, minus_beta, V);
00133         base::Gemm(elem::ADJOINT, elem::NORMAL, 1.0, A, U, AU);
00134         R.apply_adjoint(AU);
00135         base::Axpy(1.0, AU, V);
00136         base::ColumnNrm2(V, alpha);
00137         for (index_t i=0; i<k; ++i)
00138             i_alpha[i] = 1 / alpha[i];
00139         base::DiagonalScale(elem::RIGHT, elem::NORMAL, i_alpha, V);
00140         Z = V; R.apply(Z);
00141 
00143         for (index_t i=0; i<k; ++i) {
00144             rho[i] = sqrt((rhobar[i]*rhobar[i]) + (beta[i]*beta[i]));
00145             cs[i] = rhobar[i]/rho[i];
00146             sn[i] =  beta[i]/rho[i];  
00147             theta[i] = sn[i]*alpha[i];
00148             rhobar[i] = -cs[i]*alpha[i];
00149             phi[i] = cs[i]*phibar[i];
00150             phibar[i] =  sn[i]*phibar[i];
00151         }
00152 
00154         for (index_t i=0; i<k; ++i)
00155             phi_by_rho[i] = phi[i]/rho[i];
00156         base::Axpy(phi_by_rho, W, X);
00157         sol_print_t::apply(X, "X", params.am_i_printing, params.debug_level);
00158 
00159         for (index_t i=0; i<k; ++i)
00160             minus_theta_by_rho[i] = -theta[i]/rho[i];
00161         base::DiagonalScale(elem::RIGHT, elem::NORMAL, minus_theta_by_rho, W);
00162         base::Axpy(1.0, Z, W);
00163         sol_print_t::apply(W, "W", params.am_i_printing, params.debug_level);
00164 
00166         nrm_r = phibar;
00167 
00169         for (index_t i=0; i<k; ++i) {
00170             nrm_ar[i] = std::abs(phibar[i]*alpha[i]*cs[i]);
00171 
00172             if (log_lev2)
00173                 params.log_stream << "LSQR: Iteration " << i << "/" << itn 
00174                                   << ": " << nrm_ar[i]
00175                                   << std::endl;
00176 
00178             if (nrm_ar[i]<(params.tolerance*nrm_ar_0[i])) {
00179                 if (log_lev1)
00180                     params.log_stream << "LSQR: Convergence (S1)!" << std::endl;
00181                 return -2;
00182             }
00183 
00184             if (nrm_ar[i]<(eps*nrm_a[i]*nrm_r[i])) {
00185                 if (log_lev1)
00186                     params.log_stream << "LSQR: Convergence (S2)!" << std::endl;
00187                 return -3;
00188             }
00189         }
00190 
00192         base::ColumnNrm2(W, nrm_w);
00193         for (index_t i=0; i<k; ++i) {
00194             sq_w[i] = nrm_w[i]*nrm_w[i];
00195             sq_d[i] += sq_w[i]/(rho[i]*rho[i]);
00196             cnd_a[i] = nrm_a[i]*sqrt(sq_d[i]);
00197 
00199             if (cnd_a[i]>(1.0/eps)) {
00200                 if (log_lev1)
00201                     params.log_stream << "LSQR: Stopping (S3)!" << std::endl;
00202                 return -4;
00203             }
00204         }
00205 
00207         for (index_t i=0; i<k; ++i) {
00208             if (std::abs(phi[i]/rho[i])*nrm_w[i] < (eps*nrm_x[i]))
00209                 stag[i]++;
00210             else
00211                 stag[i] = 0;
00212 
00213             if (stag[i] >= max_n_stag) {
00214                 if (log_lev1)
00215                     params.log_stream << "LSQR: Stagnation." << std::endl;
00216                 return -5;
00217             }
00218         }
00219 
00221         for (index_t i=0; i<k; ++i) {
00222             delta[i] =  sn2[i]*rho[i];
00223             gambar[i] = -cs2[i]*rho[i];
00224             rhs[i] = phi[i] - delta[i]*z[i];
00225             zbar[i] = rhs[i]/gambar[i];
00226             nrm_x[i] = sqrt(sq_x[i] + (zbar[i]*zbar[i]));
00227             gamma[i] = sqrt((gambar[i]*gambar[i]) + (theta[i]*theta[i]));
00228             cs2[i] = gambar[i]/gamma[i];
00229             sn2[i] = theta[i]/gamma[i];
00230             z[i] = rhs[i]/gamma[i];
00231             sq_x[i] += z[i]*z[i];
00232         }
00233     }
00234     if (log_lev1)
00235         params.log_stream << "LSQR: No convergence within iteration limit."
00236                           << std::endl;
00237 
00238     return -6;
00239 }
00240 
00241 #endif
00242 
00243 } } 
00245 #endif // SKYLARK_LSQR_HPP