Skylark (Sketching Library)
0.1
|
00001 #ifndef SKYLARK_HILBERT_RUN_HPP 00002 #define SKYLARK_HILBERT_RUN_HPP 00003 00004 //#include "hilbert.hpp" 00005 #include "BlockADMM.hpp" 00006 #include "options.hpp" 00007 #include "io.hpp" 00008 #include "../base/context.hpp" 00009 #include "model.hpp" 00010 00011 00012 template <class InputType> 00013 BlockADMMSolver<InputType>* GetSolver(skylark::base::context_t& context, 00014 const hilbert_options_t& options, int dimensions) { 00015 00016 lossfunction *loss = NULL; 00017 switch(options.lossfunction) { 00018 case SQUARED: 00019 loss = new squaredloss(); 00020 break; 00021 case HINGE: 00022 loss = new hingeloss(); 00023 break; 00024 case LOGISTIC: 00025 loss = new logisticloss(); 00026 break; 00027 case LAD: 00028 loss = new ladloss(); 00029 break; 00030 default: 00031 // TODO 00032 break; 00033 } 00034 00035 regularization *regularizer = NULL; 00036 switch(options.regularizer) { 00037 case L2: 00038 regularizer = new l2(); 00039 break; 00040 case L1: 00041 regularizer = new l1(); 00042 break; 00043 default: 00044 // TODO 00045 break; 00046 } 00047 00048 BlockADMMSolver<InputType> *Solver = NULL; 00049 int features = 0; 00050 switch(options.kernel) { 00051 case LINEAR: 00052 features = dimensions; 00053 Solver = 00054 new BlockADMMSolver<InputType>(loss, 00055 regularizer, 00056 options.lambda, 00057 dimensions, 00058 options.numfeaturepartitions); 00059 break; 00060 00061 case GAUSSIAN: 00062 features = options.randomfeatures; 00063 if (options.regularmap) 00064 Solver = 00065 new BlockADMMSolver<InputType>(context, 00066 loss, 00067 regularizer, 00068 options.lambda, 00069 features, 00070 skylark::ml::kernels::gaussian_t(dimensions, 00071 options.kernelparam), 00072 skylark::ml::regular_feature_transform_tag(), 00073 options.numfeaturepartitions); 00074 else 00075 Solver = 00076 new BlockADMMSolver<InputType>(context, 00077 loss, 00078 regularizer, 00079 options.lambda, 00080 features, 00081 skylark::ml::kernels::gaussian_t(dimensions, 00082 options.kernelparam), 00083 skylark::ml::fast_feature_transform_tag(), 00084 options.numfeaturepartitions); 00085 break; 00086 00087 case POLYNOMIAL: 00088 features = options.randomfeatures; 00089 Solver = 00090 new BlockADMMSolver<InputType>(context, 00091 loss, 00092 regularizer, 00093 options.lambda, 00094 features, 00095 skylark::ml::kernels::polynomial_t(dimensions, 00096 options.kernelparam, options.kernelparam2, options.kernelparam3), 00097 skylark::ml::regular_feature_transform_tag(), 00098 options.numfeaturepartitions); 00099 break; 00100 00101 case LAPLACIAN: 00102 features = options.randomfeatures; 00103 Solver = 00104 new BlockADMMSolver<InputType>(context, 00105 loss, 00106 regularizer, 00107 options.lambda, 00108 features, 00109 skylark::ml::kernels::laplacian_t(dimensions, options.kernelparam), 00110 skylark::ml::regular_feature_transform_tag(), 00111 options.numfeaturepartitions); 00112 break; 00113 00114 case EXPSEMIGROUP: 00115 features = options.randomfeatures; 00116 Solver = 00117 new BlockADMMSolver<InputType>(context, 00118 loss, 00119 regularizer, 00120 options.lambda, 00121 features, 00122 skylark::ml::kernels::expsemigroup_t(dimensions, options.kernelparam), 00123 skylark::ml::regular_feature_transform_tag(), 00124 options.numfeaturepartitions); 00125 break; 00126 00127 default: 00128 // TODO! 00129 break; 00130 00131 } 00132 00133 // Set parameters 00134 Solver->set_rho(options.rho); 00135 Solver->set_maxiter(options.MAXITER); 00136 Solver->set_tol(options.tolerance); 00137 Solver->set_nthreads(options.numthreads); 00138 Solver->set_cache_transform(options.cachetransforms); 00139 00140 return Solver; 00141 } 00142 00143 00144 void ShiftForLogistic(LocalMatrixType& Y) { 00145 double y; 00146 for(int i=0;i<Y.Height(); i++) { 00147 y = Y.Get(i, 0); 00148 Y.Set(i, 0, 0.5*(y+1.0)); 00149 } 00150 } 00151 00152 template <class InputType, class LabelType> 00153 int run(const boost::mpi::communicator& comm, skylark::base::context_t& context, 00154 hilbert_options_t& options) { 00155 00156 int rank = comm.rank(); 00157 00158 InputType X, Xv, Xt; 00159 LabelType Y, Yv, Yt; 00160 00161 if(!options.trainfile.empty()) { //training mode 00162 00163 read(comm, options.fileformat, options.trainfile, X, Y); 00164 int dimensions = skylark::base::Height(X); 00165 int targets = GetNumTargets<LabelType>(comm, Y); 00166 bool shift = false; 00167 00168 if ((options.lossfunction == LOGISTIC) && (targets == 1)) { 00169 ShiftForLogistic(Y); 00170 targets = 2; 00171 shift = true; 00172 } 00173 00174 BlockADMMSolver<InputType>* Solver = 00175 GetSolver<InputType>(context, options, dimensions); 00176 00177 if(!options.valfile.empty()) { 00178 comm.barrier(); 00179 if(rank == 0) std::cout << "Loading validation data." << std::endl; 00180 00181 read(comm, options.fileformat, options.valfile, Xv, Yv, 00182 skylark::base::Height(X)); 00183 00184 if ((options.lossfunction == LOGISTIC) && shift) { 00185 ShiftForLogistic(Yv); 00186 } 00187 } 00188 00189 skylark::ml::model_t<InputType, LabelType>* model = 00190 Solver->train(X, Y, Xv, Yv, comm); 00191 00192 if (comm.rank() == 0) 00193 model->save(options.modelfile, options.print()); 00194 } 00195 00196 else { 00197 00198 std::cout << "Testing Mode" << std::endl; 00199 skylark::ml::model_t<InputType, LabelType> model(options.modelfile); 00200 read(comm, options.fileformat, options.testfile, Xt, Yt, 00201 model.get_input_size()); 00202 LabelType DecisionValues(Yt.Height(), model.get_num_outputs()); 00203 LabelType PredictedLabels(Yt.Height(), 1); 00204 elem::MakeZeros(DecisionValues); 00205 elem::MakeZeros(PredictedLabels); 00206 00207 std::cout << "Starting predictions" << std::endl; 00208 model.predict(Xt, PredictedLabels, DecisionValues, options.numthreads); 00209 double accuracy = model.evaluate(Yt, DecisionValues, comm); 00210 if(rank == 0) 00211 std::cout << "Test Accuracy = " << accuracy << " %" << std::endl; 00212 00213 // fix logistic case -- provide mechanism to dump predictions -- clean up evaluate 00214 00215 } 00216 00217 return 0; 00218 } 00219 00220 00221 00222 #endif /* SKYLARK_HILBERT_RUN_HPP */