Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/sketch/FJLT_Elemental.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_FJLT_ELEMENTAL_HPP
00002 #define SKYLARK_FJLT_ELEMENTAL_HPP
00003 
00004 #include "../utility/get_communicator.hpp"
00005 
00006 namespace skylark { namespace sketch {
00010 template <typename ValueType, elem::Distribution ColDist>
00011 struct FJLT_t <
00012     elem::DistMatrix<ValueType, ColDist, elem::STAR>,
00013     elem::Matrix<ValueType> > :
00014         public FJLT_data_t,
00015         virtual public sketch_transform_t<elem::DistMatrix<ValueType,
00016                                                            ColDist,
00017                                                            elem::STAR>,
00018                                           elem::Matrix<ValueType> > {
00019     // Typedef value, matrix, transform, distribution and transform data types
00020     // so that we can use them regularly and consistently.
00021     typedef ValueType value_type;
00022     typedef elem::DistMatrix<value_type, ColDist, elem::STAR> matrix_type;
00023     typedef elem::Matrix<value_type> output_matrix_type;
00024     typedef elem::DistMatrix<ValueType,
00025                              elem::STAR, ColDist> intermediate_type;
00026     typedef fft_futs<double>::DCT_t transform_type;
00027     typedef utility::rademacher_distribution_t<value_type>
00028     underlying_value_distribution_type;
00029 
00030     typedef FJLT_data_t data_type;
00031     typedef data_type::params_t params_t;
00032 
00033 protected:
00034     typedef RFUT_t<intermediate_type,
00035                    transform_type,
00036                    underlying_value_distribution_type> underlying_type;
00037 
00038 public:
00039     FJLT_t(int N, int S, base::context_t& context)
00040         : data_type (N, S, context) {
00041 
00042     }
00043 
00044     FJLT_t(int N, int S, const params_t& params, base::context_t& context)
00045         : data_type (N, S, params, context) {
00046 
00047     }
00048 
00049     FJLT_t(const boost::property_tree::ptree &pt)
00050         : data_type(pt) {
00051 
00052     }
00053 
00054     template <typename OtherInputMatrixType,
00055               typename OtherOutputMatrixType>
00056     FJLT_t(const FJLT_t<OtherInputMatrixType, OtherOutputMatrixType>& other)
00057         : data_type(other) {
00058 
00059     }
00060 
00061     FJLT_t(const data_type& other_data)
00062         : data_type(other_data) {
00063 
00064     }
00065 
00070     void apply (const matrix_type& A,
00071                 output_matrix_type& sketch_of_A,
00072                 columnwise_tag dimension) const {
00073         switch (ColDist) {
00074         case elem::VR:
00075         case elem::VC:
00076             try {
00077                 apply_impl_vdist (A, sketch_of_A, dimension);
00078             } catch (std::logic_error e) {
00079                 SKYLARK_THROW_EXCEPTION (
00080                     base::elemental_exception()
00081                         << base::error_msg(e.what()) );
00082             } catch(boost::mpi::exception e) {
00083                 SKYLARK_THROW_EXCEPTION (
00084                     base::mpi_exception()
00085                         << base::error_msg(e.what()) );
00086             }
00087 
00088             break;
00089 
00090         default:
00091             SKYLARK_THROW_EXCEPTION (
00092                 base::unsupported_matrix_distribution() );
00093 
00094         }
00095     }
00096 
00097 
00102     void apply (const matrix_type& A,
00103                 output_matrix_type& sketch_of_A,
00104                 rowwise_tag dimension) const {
00105         switch (ColDist) {
00106         case elem::VR:
00107         case elem::VC:
00108             try {
00109                 apply_impl_vdist (A, sketch_of_A, dimension);
00110             } catch (std::logic_error e) {
00111                 SKYLARK_THROW_EXCEPTION (
00112                     base::elemental_exception()
00113                         << base::error_msg(e.what()) );
00114             } catch(boost::mpi::exception e) {
00115                 SKYLARK_THROW_EXCEPTION (
00116                     base::mpi_exception()
00117                         << base::error_msg(e.what()) );
00118             }
00119 
00120             break;
00121 
00122         default:
00123             SKYLARK_THROW_EXCEPTION (
00124                 base::unsupported_matrix_distribution() );
00125 
00126         }
00127     }
00128 
00129     int get_N() const { return this->_N; } 
00130     int get_S() const { return this->_S; } 
00132     const sketch_transform_data_t* get_data() const { return this; }
00133 
00134 private:
00139     void apply_impl_vdist(const matrix_type& A,
00140                     output_matrix_type& sketch_A,
00141                     skylark::sketch::columnwise_tag) const {
00142 
00143         // Rearrange the matrix to fit the underlying transform
00144         intermediate_type inter_A(A.Grid());
00145         inter_A = A;
00146 
00147         // Apply the underlying transform
00148         underlying_type underlying(*data_type::underlying_data);
00149         underlying.apply(inter_A, inter_A,
00150             skylark::sketch::columnwise_tag());
00151 
00152         // Create the sampled and scaled matrix -- still in distributed mode
00153         intermediate_type dist_sketch_A(data_type::_S,
00154             inter_A.Width(), inter_A.Grid());
00155         double scale = sqrt((double)data_type::_N / (double)data_type::_S);
00156         for (int j = 0; j < inter_A.LocalWidth(); j++)
00157             for (int i = 0; i < data_type::_S; i++) {
00158                 int row = data_type::samples[i];
00159                 dist_sketch_A.Matrix().Set(i, j,
00160                     scale * inter_A.Matrix().Get(row, j));
00161             }
00162 
00163         // get communicator from matrix
00164         boost::mpi::communicator comm = skylark::utility::get_communicator(A);
00165         skylark::utility::collect_dist_matrix(comm, comm.rank() == 0,
00166             dist_sketch_A, sketch_A);
00167     }
00168 
00173     void apply_impl_vdist(const matrix_type& A,
00174                     output_matrix_type& sketch_of_A,
00175                     skylark::sketch::rowwise_tag) const {
00176 
00177         // TODO This is a quick&dirty hack - uses the columnwise implementation.
00178         matrix_type A_t(A.Grid());
00179         elem::Transpose(A, A_t);
00180         output_matrix_type sketch_of_A_t(sketch_of_A.Width(),
00181             sketch_of_A.Height());
00182         apply_impl_vdist(A_t, sketch_of_A_t,
00183             skylark::sketch::columnwise_tag());
00184         elem::Transpose(sketch_of_A_t, sketch_of_A);
00185      }
00186 };
00187 
00191 template <typename ValueType, elem::Distribution ColDist>
00192 struct FJLT_t <
00193     elem::DistMatrix<ValueType, ColDist, elem::STAR>,
00194     elem::DistMatrix<ValueType, elem::STAR, elem::STAR> > :
00195         public FJLT_data_t,
00196         virtual public sketch_transform_t<elem::DistMatrix<ValueType,
00197                                                            ColDist,
00198                                                            elem::STAR>,
00199                                           elem::DistMatrix<ValueType,
00200                                                            elem::STAR,
00201                                                            elem::STAR> > {
00202     // Typedef value, matrix, transform, distribution and transform data types
00203     // so that we can use them regularly and consistently.
00204     typedef ValueType value_type;
00205     typedef elem::DistMatrix<value_type, ColDist, elem::STAR> matrix_type;
00206     typedef elem::DistMatrix<value_type, elem::STAR, elem::STAR>
00207     output_matrix_type;
00208     typedef elem::DistMatrix<ValueType,
00209                              elem::STAR, ColDist> intermediate_type;
00210     typedef fft_futs<double>::DCT_t transform_type;
00211     typedef utility::rademacher_distribution_t<value_type>
00212     underlying_value_distribution_type;
00213 
00214     typedef FJLT_data_t data_type;
00215     typedef data_type::params_t params_t;
00216 
00217 protected:
00218     typedef RFUT_t<intermediate_type,
00219                    transform_type,
00220                    underlying_value_distribution_type> underlying_type;
00221 
00222 public:
00223 
00224     FJLT_t(int N, int S, base::context_t& context)
00225         : data_type (N, S, context) {
00226 
00227     }
00228 
00229     FJLT_t(int N, int S, const params_t& params, base::context_t& context)
00230         : data_type (N, S, params, context) {
00231 
00232     }
00233 
00234     template <typename OtherInputMatrixType,
00235               typename OtherOutputMatrixType>
00236     FJLT_t(const FJLT_t<OtherInputMatrixType, OtherOutputMatrixType>& other)
00237         : data_type(other) {
00238 
00239     }
00240 
00241     FJLT_t(const data_type& other_data)
00242         : data_type(other_data) {
00243 
00244     }
00245 
00246     FJLT_t(const boost::property_tree::ptree &pt)
00247         : data_type(pt) {
00248 
00249     }
00250 
00255     void apply (const matrix_type& A,
00256                 output_matrix_type& sketch_of_A,
00257                 columnwise_tag dimension) const {
00258         switch (ColDist) {
00259         case elem::VR:
00260         case elem::VC:
00261             try {
00262                 apply_impl_vdist (A, sketch_of_A, dimension);
00263             } catch (std::logic_error e) {
00264                 SKYLARK_THROW_EXCEPTION (
00265                     base::elemental_exception()
00266                         << base::error_msg(e.what()) );
00267             } catch(boost::mpi::exception e) {
00268                 SKYLARK_THROW_EXCEPTION (
00269                     base::mpi_exception()
00270                         << base::error_msg(e.what()) );
00271             }
00272 
00273             break;
00274 
00275         default:
00276             SKYLARK_THROW_EXCEPTION (
00277                 base::unsupported_matrix_distribution() );
00278 
00279         }
00280     }
00281 
00282 
00287     void apply (const matrix_type& A,
00288                 output_matrix_type& sketch_of_A,
00289                 rowwise_tag dimension) const {
00290         switch (ColDist) {
00291         case elem::VR:
00292         case elem::VC:
00293             try {
00294                 apply_impl_vdist (A, sketch_of_A, dimension);
00295             } catch (std::logic_error e) {
00296                 SKYLARK_THROW_EXCEPTION (
00297                     base::elemental_exception()
00298                         << base::error_msg(e.what()) );
00299             } catch(boost::mpi::exception e) {
00300                 SKYLARK_THROW_EXCEPTION (
00301                     base::mpi_exception()
00302                         << base::error_msg(e.what()) );
00303             }
00304 
00305             break;
00306 
00307         default:
00308             SKYLARK_THROW_EXCEPTION (
00309                 base::unsupported_matrix_distribution() );
00310 
00311         }
00312     }
00313 
00314     int get_N() const { return this->_N; } 
00315     int get_S() const { return this->_S; } 
00317     const sketch_transform_data_t* get_data() const { return this; }
00318 
00319 private:
00324     void apply_impl_vdist(const matrix_type& A,
00325                     output_matrix_type& sketch_A,
00326                     skylark::sketch::columnwise_tag) const {
00327 
00328         // Rearrange the matrix to fit the underlying transform
00329         intermediate_type inter_A(A.Grid());
00330         inter_A = A;
00331 
00332         // Apply the underlying transform
00333         underlying_type underlying(*data_type::underlying_data);
00334         underlying.apply(inter_A, inter_A,
00335             skylark::sketch::columnwise_tag());
00336 
00337         // Create the sampled and scaled matrix -- still in distributed mode
00338         intermediate_type dist_sketch_A(data_type::_S,
00339             inter_A.Width(), inter_A.Grid());
00340         double scale = sqrt((double)data_type::_N / (double)data_type::_S);
00341         for (int j = 0; j < inter_A.LocalWidth(); j++)
00342             for (int i = 0; i < data_type::_S; i++) {
00343                 int row = data_type::samples[i];
00344                 dist_sketch_A.Matrix().Set(i, j,
00345                     scale * inter_A.Matrix().Get(row, j));
00346             }
00347 
00348         sketch_A = dist_sketch_A;
00349     }
00350 
00355     void apply_impl_vdist(const matrix_type& A,
00356                     output_matrix_type& sketch_of_A,
00357                     skylark::sketch::rowwise_tag) const {
00358 
00359         // TODO This is a quick&dirty hack - uses the columnwise implementation.
00360         matrix_type A_t(A.Grid());
00361         elem::Transpose(A, A_t);
00362         output_matrix_type sketch_of_A_t(sketch_of_A.Width(),
00363             sketch_of_A.Height());
00364         apply_impl_vdist(A_t, sketch_of_A_t,
00365             skylark::sketch::columnwise_tag());
00366         elem::Transpose(sketch_of_A_t, sketch_of_A);
00367      }
00368 };
00369 
00370 } } 
00372 #endif // FJLT_ELEMENTAL_HPP