Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/ml/BlockADMM.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_BLOCKADMM_HPP
00002 #define SKYLARK_BLOCKADMM_HPP
00003 
00004 #include <elemental.hpp>
00005 #include <skylark.hpp>
00006 #include <cmath>
00007 #include <boost/mpi.hpp>
00008 
00009 #ifdef SKYLARK_HAVE_OPENMP
00010 #include <omp.h>
00011 #endif
00012 
00013 #include "../utility/timer.hpp"
00014 #include "hilbert.hpp"
00015 
00016 // Columns are examples, rows are features
00017 typedef elem::DistMatrix<double, elem::STAR, elem::VC> DistInputMatrixType;
00018 
00019 // Rows are examples, columns are target values
00020 typedef elem::DistMatrix<double, elem::VC, elem::STAR> DistTargetMatrixType;
00021 
00022 typedef elem::Matrix<double> LocalMatrixType;
00023 typedef skylark::base::sparse_matrix_t<double> sparse_matrix_t;
00024 
00025 template <class T>
00026 class BlockADMMSolver
00027 {
00028 public:
00029 
00030     typedef skylark::sketch::sketch_transform_t<T, LocalMatrixType>
00031     feature_transform_t;
00032     typedef std::vector<const feature_transform_t *> feature_transform_array_t;
00033 
00034 
00035     // No feature transdeforms (aka just linear regression).
00036     BlockADMMSolver(const lossfunction* loss,
00037         const regularization* regularizer,
00038         double lambda, // regularization parameter
00039         int NumFeatures,
00040         int NumFeaturePartitions = 1);
00041 
00042     // Easy interface, aka kernel based.
00043     template<typename Kernel, typename MapTypeTag>
00044     BlockADMMSolver<T>(skylark::base::context_t& context,
00045         const lossfunction* loss,
00046         const regularization* regularizer,
00047         double lambda, // regularization parameter
00048         int NumFeatures,
00049         Kernel kernel,
00050         MapTypeTag tag,
00051         int NumFeaturePartitions = 1);
00052 
00053     // Guru interface.
00054     BlockADMMSolver<T>(const lossfunction* loss,
00055         const regularization* regularizer,
00056         const feature_transform_array_t& featureMaps,
00057         double lambda, // regularization parameter
00058         bool ScaleFeatureMaps = true);
00059 
00060     void set_nthreads(int NumThreads) { this->NumThreads = NumThreads; }
00061     void set_rho(double RHO) { this->RHO = RHO; }
00062     void set_maxiter(double MAXITER) { this->MAXITER = MAXITER; }
00063     void set_tol(double TOL) { this->TOL = TOL; }
00064     void set_cache_transform(bool CacheTransforms) {this->CacheTransforms = CacheTransforms;}
00065 
00066     ~BlockADMMSolver();
00067 
00068     void InitializeFactorizationCache();
00069     void InitializeTransformCache(int n);
00070 
00071     skylark::ml::model_t<T, LocalMatrixType>* train(T& X,
00072         LocalMatrixType& Y, T& Xv, LocalMatrixType& Yv,
00073         const boost::mpi::communicator& comm);
00074 
00075     int get_numfeatures() {return NumFeatures;}
00076 
00077     feature_transform_array_t& get_feature_maps() {return featureMaps;}
00078 
00079 private:
00080 
00081     feature_transform_array_t featureMaps;
00082     int NumFeatures;
00083     int NumFeaturePartitions;
00084     lossfunction* loss;
00085     regularization* regularizer;
00086     std::vector<int> starts, finishes;
00087     bool ScaleFeatureMaps;
00088     bool OwnFeatureMaps;
00089     LocalMatrixType **Cache;
00090     LocalMatrixType **TransformCache;
00091     int NumThreads;
00092 
00093     double lambda;
00094     double RHO;
00095     int MAXITER;
00096     double TOL;
00097 
00098     bool CacheTransforms;
00099 };
00100 
00101 template <class T>
00102 void BlockADMMSolver<T>::InitializeFactorizationCache() {
00103     Cache = new LocalMatrixType* [NumFeaturePartitions];
00104     for(int j=0; j<NumFeaturePartitions; j++) {
00105         int start = starts[j];
00106         int finish = finishes[j];
00107         int sj = finish - start  + 1;
00108         Cache[j]  = new elem::Matrix<double>(sj, sj);
00109     }
00110 }
00111 
00112 template <class T>
00113 void BlockADMMSolver<T>::InitializeTransformCache(int n) {
00114     TransformCache = new LocalMatrixType* [NumFeaturePartitions];
00115     for(int j=0; j<NumFeaturePartitions; j++) {
00116         int start = starts[j];
00117         int finish = finishes[j];
00118         int sj = finish - start  + 1;
00119         TransformCache[j]  = new elem::Matrix<double>(sj, n);
00120     }
00121 }
00122 
00123 
00124 // No feature transforms (aka just linear regression).
00125 template <class T>
00126 BlockADMMSolver<T>::BlockADMMSolver(
00127         const lossfunction* loss,
00128         const regularization* regularizer,
00129         double lambda, // regularization parameter
00130         int NumFeatures,
00131         int NumFeaturePartitions) :
00132         NumFeatures(NumFeatures),
00133             NumFeaturePartitions(NumFeaturePartitions),
00134             starts(NumFeaturePartitions), finishes(NumFeaturePartitions),
00135             NumThreads(1), RHO(1.0), MAXITER(1000), TOL(0.1) {
00136 
00137     this->loss = const_cast<lossfunction *> (loss);
00138     this->regularizer = const_cast<regularization *> (regularizer);
00139     this->lambda = lambda;
00140     this->NumFeaturePartitions = NumFeaturePartitions;
00141     int blksize = int(ceil(double(NumFeatures) / NumFeaturePartitions));
00142     for(int i = 0; i < NumFeaturePartitions; i++) {
00143         starts[i] = i * blksize;
00144         finishes[i] = std::min((i + 1) * blksize, NumFeatures) - 1;
00145     }
00146     this->ScaleFeatureMaps = false;
00147     OwnFeatureMaps = false;
00148     InitializeFactorizationCache();
00149     CacheTransforms = false;
00150 }
00151 
00152 // Easy interface, aka kernel based.
00153 template<class T>
00154 template<typename Kernel, typename MapTypeTag>
00155 BlockADMMSolver<T>::BlockADMMSolver(skylark::base::context_t& context,
00156     const lossfunction* loss,
00157     const regularization* regularizer,
00158     double lambda, // regularization parameter
00159     int NumFeatures,
00160     Kernel kernel,
00161     MapTypeTag tag,
00162     int NumFeaturePartitions) :
00163     featureMaps(NumFeaturePartitions),
00164     NumFeatures(NumFeatures), NumFeaturePartitions(NumFeaturePartitions),
00165     starts(NumFeaturePartitions), finishes(NumFeaturePartitions),
00166     NumThreads(1), RHO(1.0), MAXITER(1000), TOL(0.1) {
00167 
00168     this->loss = const_cast<lossfunction *> (loss);
00169     this->regularizer = const_cast<regularization *> (regularizer);
00170     this->lambda = lambda;
00171     int blksize = int(ceil(double(NumFeatures) / NumFeaturePartitions));
00172     for(int i = 0; i < NumFeaturePartitions; i++) {
00173         starts[i] = i * blksize;
00174         finishes[i] = std::min((i + 1) * blksize, NumFeatures) - 1;
00175         int sj = finishes[i] - starts[i] + 1;
00176         featureMaps[i] =
00177             kernel.template create_rft< T, LocalMatrixType >(sj, tag, context);
00178     }
00179     this->ScaleFeatureMaps = true;
00180     OwnFeatureMaps = true;
00181     InitializeFactorizationCache();
00182     CacheTransforms = false;
00183 }
00184 
00185 // Guru interface
00186 template <class T>
00187 BlockADMMSolver<T>::BlockADMMSolver(const lossfunction* loss,
00188     const regularization* regularizer,
00189     const feature_transform_array_t &featureMaps,
00190     double lambda,
00191     bool ScaleFeatureMaps) :
00192     featureMaps(featureMaps),
00193     NumFeaturePartitions(featureMaps.size()),
00194     starts(NumFeaturePartitions), finishes(NumFeaturePartitions),
00195     NumThreads(1), RHO(1.0), MAXITER(1000), TOL(0.1)  {
00196 
00197     this->loss = const_cast<lossfunction *> (loss);
00198     this->regularizer = const_cast<regularization *> (regularizer);
00199     this->lambda = lambda;
00200     NumFeaturePartitions = featureMaps.size();
00201     NumFeatures = 0;
00202     for(int i = 0; i < NumFeaturePartitions; i++) {
00203         starts[i] = NumFeatures;
00204         finishes[i] = NumFeatures + featureMaps[i]->get_S() - 1;
00205         NumFeatures += featureMaps[i]->get_S();
00206     }
00207     this->ScaleFeatureMaps = ScaleFeatureMaps;
00208     OwnFeatureMaps = false;
00209     InitializeFactorizationCache();
00210     CacheTransforms = false;
00211 }
00212 
00213 template <class T>
00214 BlockADMMSolver<T>::~BlockADMMSolver() {
00215     for(int i=0; i  < NumFeaturePartitions; i++) {
00216         delete Cache[i];
00217         if (OwnFeatureMaps)
00218             delete featureMaps[i];
00219     }
00220     delete[] Cache;
00221 }
00222 
00223 
00224 template <class T>
00225 skylark::ml::model_t<T, LocalMatrixType>* BlockADMMSolver<T>::train(T& X, LocalMatrixType& Y, T& Xv, LocalMatrixType& Yv,
00226     const boost::mpi::communicator& comm) {
00227 
00228        int rank = comm.rank();
00229        int size = comm.size();
00230 
00231        int P = size;
00232 
00233        int ni = skylark::base::Width(X);
00234        int d = skylark::base::Height(X);
00235        int targets = GetNumTargets(comm, Y);
00236 
00237        skylark::ml::model_t<T, LocalMatrixType>* model =
00238            new skylark::ml::model_t<T, LocalMatrixType>(featureMaps,
00239                ScaleFeatureMaps, NumFeatures, targets);
00240 
00241        elem::Matrix<double> Wbar;
00242        elem::View(Wbar, model->get_coef());
00243 
00244 
00245        int k = Wbar.Width();
00246 
00247        // number of classes, targets - to generalize
00248 
00249        int D = NumFeatures;
00250 
00251        // exception: check if D = Wbar.Height();
00252 
00253        LocalMatrixType O(k, ni); //uses default Grid
00254        elem::MakeZeros(O);
00255 
00256        LocalMatrixType Obar(k, ni); //uses default Grid
00257        elem::MakeZeros(Obar);
00258 
00259        LocalMatrixType nu(k, ni); //uses default Grid
00260        elem::MakeZeros(nu);
00261 
00262        LocalMatrixType W, mu, Wi, mu_ij, ZtObar_ij;
00263 
00264        if(rank==0) {
00265            elem::Zeros(W,  D, k);
00266            elem::Zeros(mu, D, k);
00267        }
00268        elem::Zeros(Wi, D, k);
00269        elem::Zeros(mu_ij, D, k);
00270        elem::Zeros(ZtObar_ij, D, k);
00271 
00272        int iter = 0;
00273 
00274        // int ni = O.LocalWidth();
00275 
00276        //elem::Matrix<double> x = X.Matrix();
00277        //elem::Matrix<double> y = Y.Matrix();
00278 
00279 
00280        double localloss = loss->evaluate(O, Y);
00281        double totalloss, accuracy, obj;
00282 
00283        int Dk = D*k;
00284        int nik  = ni*k;
00285        int start, finish, sj;
00286 
00287        boost::mpi::timer timer;
00288 
00289        LocalMatrixType sum_o, del_o, wbar_output;
00290        elem::Zeros(del_o, k, ni);
00291        LocalMatrixType Yp(Yv.Height(), k);
00292        LocalMatrixType Yp_labels(Yv.Height(), 1);
00293 
00294        /*LocalMatrixType wbar_tmp;
00295        //if (NumThreads > 1)
00296 
00297        elem::Zeros(wbar_tmp, k, ni);*/
00298 
00299        if (CacheTransforms)
00300                    InitializeTransformCache(ni);
00301 
00302        SKYLARK_TIMER_INITIALIZE(ITERATIONS_PROFILE);
00303        SKYLARK_TIMER_INITIALIZE(COMMUNICATION_PROFILE);
00304        SKYLARK_TIMER_INITIALIZE(TRANSFORM_PROFILE);
00305        SKYLARK_TIMER_INITIALIZE(ZTRANSFORM_PROFILE);
00306        SKYLARK_TIMER_INITIALIZE(ZMULT_PROFILE);
00307        SKYLARK_TIMER_INITIALIZE(PROXLOSS_PROFILE);
00308        SKYLARK_TIMER_INITIALIZE(BARRIER_PROFILE);
00309        SKYLARK_TIMER_INITIALIZE(PREDICTION_PROFILE);
00310 
00311        while(iter<MAXITER) {
00312 
00313            SKYLARK_TIMER_RESTART(ITERATIONS_PROFILE);
00314 
00315            iter++;
00316 
00317            SKYLARK_TIMER_RESTART(COMMUNICATION_PROFILE);
00318            broadcast(comm, Wbar.Buffer(), Dk, 0);
00319 
00320            SKYLARK_TIMER_ACCUMULATE(COMMUNICATION_PROFILE)
00321 
00322            // mu_ij = mu_ij - Wbar
00323            elem::Axpy(-1.0, Wbar, mu_ij);
00324 
00325            // Obar = Obar - nu
00326            elem::Axpy(-1.0, nu, Obar);
00327 
00328            SKYLARK_TIMER_RESTART(PROXLOSS_PROFILE);
00329            loss->proxoperator(Obar, 1.0/RHO, Y, O);
00330            SKYLARK_TIMER_ACCUMULATE(PROXLOSS_PROFILE);
00331 
00332            if(rank==0) {
00333                regularizer->proxoperator(Wbar, lambda/RHO, mu, W);
00334            }
00335 
00336            elem::Zeros(sum_o, k, ni);
00337            elem::Zeros(wbar_output, k, ni);
00338 
00339            int j;
00340            const feature_transform_t* featureMap;
00341 
00342            SKYLARK_TIMER_RESTART(TRANSFORM_PROFILE);
00343 
00344    #       ifdef SKYLARK_HAVE_OPENMP
00345    #       pragma omp parallel for if(NumThreads > 1) private(j, start, finish, sj, featureMap) num_threads(NumThreads)
00346    #       endif
00347            for(j = 0; j < NumFeaturePartitions; j++) {
00348                start = starts[j];
00349                finish = finishes[j];
00350                sj = finish - start  + 1;
00351 
00352                elem::Matrix<double> z(sj, ni);
00353 
00354                if (CacheTransforms && (iter > 1))
00355                {
00356                     elem::View(z,  *TransformCache[j], 0, 0, sj, ni);
00357                }
00358                else {
00359                    if (featureMaps.size() > 0) {
00360                        featureMap = featureMaps[j];
00361 
00362                        SKYLARK_TIMER_RESTART(ZTRANSFORM_PROFILE);
00363                        featureMap->apply(X, z, skylark::sketch::columnwise_tag());
00364                        SKYLARK_TIMER_ACCUMULATE(ZTRANSFORM_PROFILE)
00365 
00366                        if (ScaleFeatureMaps)
00367                            elem::Scal(sqrt(double(sj) / d), z);
00368                        } else {
00369                           // for linear case just use Z = X no slicing business.
00370                           // skylark::base::ColumnView<double>(z, x, );
00371                           // ;// VIEWS on SPARSE MATRICES: elem::View(z, x, start, 0, sj, ni);
00372                        }
00373                }
00374 
00375                elem::Matrix<double> tmp(sj, k);
00376                elem::Matrix<double> rhs(sj, k);
00377                elem::Matrix<double> o(k, ni);
00378 
00379                if(iter==1) {
00380 
00381                    elem::Matrix<double> Ones;
00382                    elem::Ones(Ones, sj, 1);
00383                    elem::Gemm(elem::NORMAL, elem::TRANSPOSE, 1.0, z, z, 0.0, *Cache[j]);
00384                    Cache[j]->UpdateDiagonal(Ones);
00385                    elem::Inverse(*Cache[j]);
00386 
00387 
00388                    if (CacheTransforms) {
00389                        *TransformCache[j] = z;
00390                        //DEBUG
00391                         std::cout << "CACHING TRANSFORMS..." << std::endl;
00392                         elem::Write(*TransformCache[0], "FeatureMatrix.asc", elem::ASCII, "");
00393                    }
00394                }
00395 
00396                elem::View(tmp, Wbar, start, 0, sj, k); //tmp = Wbar[J,:]
00397 
00398                LocalMatrixType wbar_tmp;
00399                elem::Zeros(wbar_tmp, k, ni);
00400 
00401                if (NumThreads > 1) {
00402                    elem::Gemm(elem::TRANSPOSE, elem::NORMAL, 1.0, tmp, z, 0.0, wbar_tmp);
00403 
00404    #               ifdef SKYLARK_HAVE_OPENMP
00405    #               pragma omp critical
00406    #               endif
00407                    elem::Axpy(1.0, wbar_tmp, wbar_output);
00408                } else
00409                    elem::Gemm(elem::TRANSPOSE, elem::NORMAL, 1.0, tmp, z, 1.0, wbar_output);
00410 
00411                rhs = tmp; //rhs = Wbar[J,:]
00412                elem::View(tmp, mu_ij, start, 0, sj, k); //tmp = mu_ij[J,:]
00413                elem::Axpy(-1.0, tmp, rhs); // rhs = rhs - mu_ij[J,:] = Wbar[J,:] - mu_ij[J,:]
00414                elem::View(tmp, ZtObar_ij, start, 0, sj, k);
00415                elem::Axpy(+1.0, tmp, rhs); // rhs = rhs + ZtObar_ij[J,:]
00416 
00417                SKYLARK_TIMER_RESTART(ZMULT_PROFILE);
00418                elem::Matrix<double> dsum = del_o;
00419                elem::Axpy(NumFeaturePartitions + 1.0, nu, dsum);
00420                elem::Gemm(elem::NORMAL, elem::TRANSPOSE, 1.0/(NumFeaturePartitions + 1.0), z, dsum, 1.0, rhs); // rhs = rhs + z'*(1/(n+1) * del_o + nu)
00421                SKYLARK_TIMER_ACCUMULATE(ZMULT_PROFILE);
00422 
00423                elem::View(tmp, Wi, start, 0, sj, k);
00424                elem::Gemm(elem::NORMAL, elem::NORMAL, 1.0, *Cache[j], rhs, 0.0, tmp); // ]tmp = Wi[J,:] = Cache[j]*rhs
00425 
00426                SKYLARK_TIMER_RESTART(ZMULT_PROFILE);
00427                elem::Gemm(elem::TRANSPOSE, elem::NORMAL, 1.0, tmp, z, 0.0, o); // o = (z*tmp)' = (z*Wi[J,:])'
00428                SKYLARK_TIMER_ACCUMULATE(ZMULT_PROFILE);
00429 
00430                // mu_ij[JJ,:] = mu_ij[JJ,:] + Wi[JJ,:];
00431                elem::View(tmp, mu_ij, start, 0, sj, k); //tmp = mu_ij[J,:]
00432                elem::View(rhs, Wi, start, 0, sj, k);
00433                elem::Axpy(+1.0, rhs, tmp);
00434 
00435                //ZtObar_ij[JJ,:] = numpy.dot(Z.T, o);
00436                elem::View(tmp, ZtObar_ij, start, 0, sj, k);
00437                elem::Gemm(elem::NORMAL, elem::TRANSPOSE, 1.0, z, o, 0.0, tmp);
00438 
00439                //  sum_o += o
00440                if (NumThreads > 1) {
00441    #               ifdef SKYLARK_HAVE_OPENMP
00442    #               pragma omp critical
00443    #               endif
00444                    elem::Axpy(1.0, o, sum_o);
00445                } else
00446                    elem::Axpy(1.0, o, sum_o);
00447 
00448                z.Empty();
00449            }
00450 
00451            SKYLARK_TIMER_ACCUMULATE(TRANSFORM_PROFILE);
00452 
00453            localloss = 0.0 ;
00454            //  elem::Zeros(o, ni, k);
00455            elem::Matrix<double> o(k, ni);
00456            elem::MakeZeros(o);
00457            elem::Scal(-1.0, sum_o);
00458            elem::Axpy(+1.0, O, sum_o); // sum_o = O.Matrix - sum_o
00459            del_o = sum_o;
00460 
00461            SKYLARK_TIMER_RESTART(PREDICTION_PROFILE);
00462            if (skylark::base::Width(Xv) > 0) {
00463                elem::MakeZeros(Yp);
00464                elem::MakeZeros(Yp_labels);
00465                model->predict(Xv, Yp_labels, Yp, NumThreads);
00466                accuracy = model->evaluate(Yv, Yp, comm);
00467            }
00468            SKYLARK_TIMER_ACCUMULATE(PREDICTION_PROFILE);
00469 
00470            localloss += loss->evaluate(wbar_output, Y);
00471 
00472            SKYLARK_TIMER_RESTART(COMMUNICATION_PROFILE);
00473            reduce(comm, localloss, totalloss, std::plus<double>(), 0);
00474            SKYLARK_TIMER_ACCUMULATE(COMMUNICATION_PROFILE);
00475 
00476            if(rank==0) {
00477                obj = totalloss + lambda*regularizer->evaluate(Wbar);
00478                if (skylark::base::Width(Xv) <=0) {
00479                    std::cout << "iteration " << iter << " objective " << obj << " time " << timer.elapsed() << " seconds" << std::endl;
00480                }
00481                else {
00482                    std::cout << "iteration " << iter << " objective " << obj << " accuracy " << accuracy << " time " << timer.elapsed() << " seconds" << std::endl;
00483                }
00484            }
00485 
00486            elem::Copy(O, Obar);
00487            elem::Scal(1.0/(NumFeaturePartitions+1.0), sum_o);
00488            elem::Axpy(-1.0, sum_o, Obar);
00489 
00490            elem::Axpy(+1.0, O, nu);
00491            elem::Axpy(-1.0, Obar, nu);
00492 
00493 
00494 
00495            //Wbar = comm.reduce(Wi)
00496            SKYLARK_TIMER_RESTART(COMMUNICATION_PROFILE);
00497            boost::mpi::reduce (comm,
00498                                    Wi.LockedBuffer(),
00499                                    Wi.MemorySize(),
00500                                    Wbar.Buffer(),
00501                                    std::plus<double>(),
00502                                    0);
00503            SKYLARK_TIMER_ACCUMULATE(COMMUNICATION_PROFILE);
00504 
00505            if(rank==0) {
00506                //Wbar = (Wisum + W)/(P+1)
00507                elem::Axpy(1.0, W, Wbar);
00508                elem::Scal(1.0/(P+1), Wbar);
00509 
00510                // mu = mu + W - Wbar;
00511                elem::Axpy(+1.0, W, mu);
00512                elem::Axpy(-1.0, Wbar, mu);
00513            }
00514 
00515            SKYLARK_TIMER_RESTART(BARRIER_PROFILE);
00516            comm.barrier();
00517            SKYLARK_TIMER_ACCUMULATE(BARRIER_PROFILE);
00518 
00519            SKYLARK_TIMER_ACCUMULATE(ITERATIONS_PROFILE);
00520        }
00521 
00522        SKYLARK_TIMER_PRINT(ITERATIONS_PROFILE, comm);
00523        SKYLARK_TIMER_PRINT(COMMUNICATION_PROFILE, comm);
00524        SKYLARK_TIMER_PRINT(TRANSFORM_PROFILE, comm);
00525        SKYLARK_TIMER_PRINT(ZTRANSFORM_PROFILE, comm);
00526        SKYLARK_TIMER_PRINT(ZMULT_PROFILE, comm);
00527        SKYLARK_TIMER_PRINT(PROXLOSS_PROFILE, comm);
00528        SKYLARK_TIMER_PRINT(BARRIER_PROFILE, comm);
00529        SKYLARK_TIMER_PRINT(PREDICTION_PROFILE, comm);
00530 
00531        return model;
00532 }
00533 
00534 
00535 #endif /* SKYLARK_BLOCKADDM_HPP */