Skylark (Sketching Library)  0.1
/var/lib/jenkins/jobs/Skylark/workspace/sketch/FRFT_Elemental.hpp
Go to the documentation of this file.
00001 #ifndef SKYLARK_FRFT_ELEMENTAL_HPP
00002 #define SKYLARK_FRFT_ELEMENTAL_HPP
00003 
00004 namespace skylark {
00005 namespace sketch {
00006 
00007 #if SKYLARK_HAVE_ELEMENTAL && (SKYLARK_HAVE_FFTW || SKYLARK_HAVE_SPIRALWHT)
00008 
00013 template <typename ValueType,
00014           template <typename> class InputType>
00015 struct FastRFT_t <
00016     InputType<ValueType>,
00017     elem::Matrix<ValueType> > :
00018         public FastRFT_data_t {
00019     // Typedef value, matrix, transform, distribution and transform data types
00020     // so that we can use them regularly and consistently.
00021     typedef ValueType value_type;
00022     typedef InputType<value_type> matrix_type;
00023     typedef elem::Matrix<value_type> output_matrix_type;
00024     typedef FastRFT_data_t data_type;
00025 
00026 public:
00027 
00028     // No regular contructor, since need to be subclassed.
00029 
00033     FastRFT_t(const FastRFT_t<matrix_type,
00034                       output_matrix_type>& other)
00035         : data_type(other), _fut(data_type::_N) {
00036 
00037     }
00038 
00042     FastRFT_t(const data_type& other_data)
00043         : data_type(other_data), _fut(data_type::_N) {
00044 
00045     }
00046 
00050     template <typename Dimension>
00051     void apply (const matrix_type& A,
00052                 output_matrix_type& sketch_of_A,
00053                 Dimension dimension) const {
00054         try {
00055             apply_impl(A, sketch_of_A, dimension);
00056         } catch (std::logic_error e) {
00057             SKYLARK_THROW_EXCEPTION (
00058                 base::elemental_exception()
00059                     << base::error_msg(e.what()) );
00060         } catch(boost::mpi::exception e) {
00061             SKYLARK_THROW_EXCEPTION (
00062                 base::mpi_exception()
00063                     << base::error_msg(e.what()) );
00064         }
00065     }
00066 
00067 private:
00072     void apply_impl(const matrix_type& A,
00073         output_matrix_type& sketch_of_A,
00074         skylark::sketch::columnwise_tag tag) const {
00075 
00076 #       ifdef SKYLARK_HAVE_OPENMP
00077 #       pragma omp parallel
00078 #       endif
00079         {
00080         output_matrix_type W(data_type::_NB, 1);
00081         double *w = W.Buffer();
00082 
00083         output_matrix_type Ac(data_type::_NB, 1);
00084         double *ac = Ac.Buffer();
00085 
00086         output_matrix_type Acv;
00087         elem::View(Acv, Ac, 0, 0, data_type::_N, 1);
00088 
00089         double *sa = sketch_of_A.Buffer();
00090         int ldsa = sketch_of_A.LDim();
00091 
00092         value_type scal =
00093             std::sqrt(data_type::_NB) * _fut.scale();
00094 
00095         output_matrix_type B(data_type::_NB, 1), G(data_type::_NB, 1);
00096         output_matrix_type Sm(data_type::_NB, 1);
00097 
00098 #       ifdef SKYLARK_HAVE_OPENMP
00099 #       pragma omp for
00100 #       endif
00101         for(int c = 0; c < base::Width(A); c++) {
00102             const matrix_type Acs = base::ColumnView(A, c, 1);
00103             base::DenseCopy(Acs, Acv);
00104             std::fill(ac + data_type::_N, ac + data_type::_NB, 0);
00105 
00106             for(int i = 0; i < data_type::numblks; i++) {
00107 
00108                 int s = i * data_type::_NB;
00109                 int e = std::min(s + data_type::_NB,  data_type::_S);
00110 
00111                 // Set the local values of B, G and S
00112                 for(int j = 0; j < data_type::_NB; j++) {
00113                     B.Set(j, 0, data_type::B[i * data_type::_NB + j]);
00114                     G.Set(j, 0, scal * data_type::G[i * data_type::_NB + j]);
00115                     Sm.Set(j, 0, scal * data_type::Sm[i * data_type::_NB + j]);
00116                 }
00117 
00118                 W = Ac;
00119 
00120                 elem::DiagonalScale(elem::LEFT, elem::NORMAL, B, W);
00121                 _fut.apply(W, tag);
00122                 for(int l = 0; l < data_type::_NB - 1; l++) {
00123                     int idx1 = data_type::_NB - 1 - l;
00124                     int idx2 = data_type::P[i * (data_type::_NB - 1) + l];
00125                     std::swap(w[idx1], w[idx2]);
00126                 }
00127                 elem::DiagonalScale(elem::LEFT, elem::NORMAL, G, W);
00128                 _fut.apply(W, tag);
00129                 elem::DiagonalScale(elem::LEFT, elem::NORMAL, Sm, W);
00130 
00131                 double *sac = sa + ldsa * c;
00132                 for(int l = s; l < e; l++) {
00133                     value_type x = w[l - s];
00134                     x += data_type::shifts[l];
00135 
00136 #                   ifdef SKYLARK_EXACT_COSINE
00137                     x = std::cos(x);
00138 #                   else
00139                     // x = std::cos(x) is slow
00140                     // Instead use low-accuracy approximation
00141                     if (x < -3.14159265) x += 6.28318531;
00142                     else if (x >  3.14159265) x -= 6.28318531;
00143                     x += 1.57079632;
00144                     if (x >  3.14159265)
00145                         x -= 6.28318531;
00146                     x = (x < 0) ?
00147                         1.27323954 * x + 0.405284735 * x * x :
00148                         1.27323954 * x - 0.405284735 * x * x;
00149 #                   endif
00150 
00151                     x = data_type::scale * x;
00152                     sac[l - s] = x;
00153                 }
00154             }
00155         }
00156 
00157         }
00158     }
00159 
00164     void apply_impl(const matrix_type& A,
00165         output_matrix_type& sketch_of_A,
00166         skylark::sketch::rowwise_tag tag) const {
00167 
00168         // TODO this version is really bad: it completely densifies the matrix
00169         //      on the begining.
00170         // TODO this version does not work with _NB and N
00171         // TODO this version is not as optimized as the columnwise version.
00172 
00173         // Create a work array W
00174         output_matrix_type W(base::Height(A), base::Width(A));
00175 
00176         output_matrix_type B(data_type::_N, 1), G(data_type::_N, 1);
00177         output_matrix_type Sm(data_type::_N, 1);
00178         for(int i = 0; i < data_type::numblks; i++) {
00179             int s = i * data_type::_N;
00180             int e = std::min(s + data_type::_N, data_type::_S);
00181 
00182             base::DenseCopy(A, W);
00183 
00184             // Set the local values of B, G and S
00185             value_type scal =
00186                 std::sqrt(data_type::_N) * _fut.scale();
00187             for(int j = 0; j < data_type::_N; j++) {
00188                 B.Set(j, 0, data_type::B[i * data_type::_N + j]);
00189                 G.Set(j, 0, scal * data_type::G[i * data_type::_N + j]);
00190                 Sm.Set(j, 0, scal * data_type::Sm[i * data_type::_N + j]);
00191             }
00192 
00193             elem::DiagonalScale(elem::RIGHT, elem::NORMAL, B, W);
00194 
00195             _fut.apply(W, tag);
00196 
00197             double *w = W.Buffer();
00198             for(int c = 0; c < base::Height(W); c++)
00199                 for(int l = 0; l < data_type::_N - 1; l++) {
00200                     int idx1 = c + (data_type::_N - 1 - l) * W.LDim();
00201                     int idx2 = c  +
00202                         (data_type::P[i * (data_type::_N - 1) + l]) * W.LDim();
00203                     std::swap(w[idx1], w[idx2]);
00204                 }
00205 
00206             elem::DiagonalScale(elem::RIGHT, elem::NORMAL, G, W);
00207 
00208             _fut.apply(W, tag);
00209 
00210             elem::DiagonalScale(elem::RIGHT, elem::NORMAL, Sm, W);
00211 
00212             // Copy that part to the output
00213             output_matrix_type view_sketch_of_A;
00214             elem::View(view_sketch_of_A, sketch_of_A, 0, s,
00215                 base::Height(A), e - s);
00216             output_matrix_type view_W;
00217             elem::View(view_W, W, 0, 0, base::Height(A), e - s);
00218             view_sketch_of_A = view_W;
00219         }
00220 
00221         for(int j = 0; j < data_type::_S; j++)
00222             for(int i = 0; i < base::Height(A); i++) {
00223                 value_type x = sketch_of_A.Get(i, j);
00224                 x += data_type::shifts[j];
00225 
00226 #               ifdef SKYLARK_EXACT_COSINE
00227                 x = std::cos(x);
00228 #               else
00229                 // x = std::cos(x) is slow
00230                 // Instead use low-accuracy approximation
00231                 if (x < -3.14159265) x += 6.28318531;
00232                 else if (x >  3.14159265) x -= 6.28318531;
00233                 x += 1.57079632;
00234                 if (x >  3.14159265)
00235                     x -= 6.28318531;
00236                 x = (x < 0) ?
00237                     1.27323954 * x + 0.405284735 * x * x :
00238                     1.27323954 * x - 0.405284735 * x * x;
00239 #               endif
00240 
00241                 x = data_type::scale * x;
00242                 sketch_of_A.Set(i, j, x);
00243             }
00244     }
00245 
00246 private:
00247 
00248 #ifdef SKYLARK_HAVE_FFTW
00249     typename fft_futs<ValueType>::DCT_t _fut;
00250 #elif SKYLARK_HAVE_SPIRALWHT
00251     WHT_t<double> _fut;
00252 #endif
00253 
00254 };
00255 
00256 #endif // SKYLARK_HAVE_ELEMENTAL && (SKYLARK_HAVE_FFTW || SKYLARK_HAVE_SPIRALWHT)
00257 
00258 } } 
00260 #endif // SKYLARK_FRFT_ELEMENTAL_HPP