[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest/rf_preprocessing.hxx | ![]() |
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 #ifndef VIGRA_RF_PREPROCESSING_HXX 00037 #define VIGRA_RF_PREPROCESSING_HXX 00038 00039 #include <limits> 00040 #include "rf_common.hxx" 00041 00042 namespace vigra 00043 { 00044 00045 /** Class used while preprocessing (currently used only during learn) 00046 * 00047 * This class is internally used by the Random Forest learn function. 00048 * Different split functors may need to process the data in different manners 00049 * (i.e., regression labels that should not be touched and classification 00050 * labels that must be converted into a integral format) 00051 * 00052 * This Class only exists in specialized versions, where the Tag class is 00053 * fixed. 00054 * 00055 * The Tag class is determined by Splitfunctor::Preprocessor_t . Currently 00056 * it can either be ClassificationTag or RegressionTag. look At the 00057 * RegressionTag specialisation for the basic interface if you ever happen 00058 * to care.... - or need some sort of vague new preprocessor. 00059 * new preprocessor ( Soft labels or whatever) 00060 */ 00061 template<class Tag, class LabelType, class T1, class C1, class T2, class C2> 00062 class Processor; 00063 00064 namespace detail 00065 { 00066 00067 /* Common helper function used in all Processors. 00068 * This function analyses the options struct and calculates the real 00069 * values needed for the current problem (data) 00070 */ 00071 template<class T> 00072 void fill_external_parameters(RandomForestOptions const & options, 00073 ProblemSpec<T> & ext_param) 00074 { 00075 // set correct value for mtry 00076 switch(options.mtry_switch_) 00077 { 00078 case RF_SQRT: 00079 ext_param.actual_mtry_ = 00080 int(std::floor( 00081 std::sqrt(double(ext_param.column_count_)) 00082 + 0.5)); 00083 break; 00084 case RF_LOG: 00085 // this is in Breimans original paper 00086 ext_param.actual_mtry_ = 00087 int(1+(std::log(double(ext_param.column_count_)) 00088 /std::log(2.0))); 00089 break; 00090 case RF_FUNCTION: 00091 ext_param.actual_mtry_ = 00092 options.mtry_func_(ext_param.column_count_); 00093 break; 00094 case RF_ALL: 00095 ext_param.actual_mtry_ = ext_param.column_count_; 00096 break; 00097 default: 00098 ext_param.actual_mtry_ = 00099 options.mtry_; 00100 } 00101 // set correct value for msample 00102 switch(options.training_set_calc_switch_) 00103 { 00104 case RF_CONST: 00105 ext_param.actual_msample_ = 00106 options.training_set_size_; 00107 break; 00108 case RF_PROPORTIONAL: 00109 ext_param.actual_msample_ = 00110 (int)std::ceil( options.training_set_proportion_ * 00111 ext_param.row_count_); 00112 break; 00113 case RF_FUNCTION: 00114 ext_param.actual_msample_ = 00115 options.training_set_func_(ext_param.row_count_); 00116 break; 00117 default: 00118 vigra_precondition(1!= 1, "unexpected error"); 00119 00120 } 00121 00122 } 00123 00124 /* Returns true if MultiArray contains NaNs 00125 */ 00126 template<unsigned int N, class T, class C> 00127 bool contains_nan(MultiArrayView<N, T, C> const & in) 00128 { 00129 for(int ii = 0; ii < in.size(); ++ii) 00130 if(in[ii] != in[ii]) 00131 return true; 00132 return false; 00133 } 00134 00135 /* Returns true if MultiArray contains Infs 00136 */ 00137 template<unsigned int N, class T, class C> 00138 bool contains_inf(MultiArrayView<N, T, C> const & in) 00139 { 00140 if(!std::numeric_limits<T>::has_infinity) 00141 return false; 00142 for(int ii = 0; ii < in.size(); ++ii) 00143 if(in[ii] == std::numeric_limits<T>::infinity()) 00144 return true; 00145 return false; 00146 } 00147 } // namespace detail 00148 00149 00150 00151 /** Preprocessor used during Classification 00152 * 00153 * This class converts the labels int Integral labels which are used by the 00154 * standard split functor to address memory in the node objects. 00155 */ 00156 template<class LabelType, class T1, class C1, class T2, class C2> 00157 class Processor<ClassificationTag, LabelType, T1, C1, T2, C2> 00158 { 00159 public: 00160 typedef Int32 LabelInt; 00161 typedef MultiArrayView<2, T1, C1> Feature_t; 00162 typedef MultiArray<2, T1> FeatureWithMemory_t; 00163 typedef MultiArrayView<2,LabelInt> Label_t; 00164 MultiArrayView<2, T1, C1>const & features_; 00165 MultiArray<2, LabelInt> intLabels_; 00166 MultiArrayView<2, LabelInt> strata_; 00167 00168 template<class T> 00169 Processor(MultiArrayView<2, T1, C1>const & features, 00170 MultiArrayView<2, T2, C2>const & response, 00171 RandomForestOptions &options, 00172 ProblemSpec<T> &ext_param) 00173 : 00174 features_( features) // do not touch the features. 00175 { 00176 vigra_precondition(!detail::contains_nan(features), "Processor(): Feature Matrix " 00177 "Contains NaNs"); 00178 vigra_precondition(!detail::contains_nan(response), "Processor(): Response " 00179 "Contains NaNs"); 00180 vigra_precondition(!detail::contains_inf(features), "Processor(): Feature Matrix " 00181 "Contains inf"); 00182 vigra_precondition(!detail::contains_inf(response), "Processor(): Response " 00183 "Contains inf"); 00184 // set some of the problem specific parameters 00185 ext_param.column_count_ = features.shape(1); 00186 ext_param.row_count_ = features.shape(0); 00187 ext_param.problem_type_ = CLASSIFICATION; 00188 ext_param.used_ = true; 00189 intLabels_.reshape(response.shape()); 00190 00191 //get the class labels 00192 if(ext_param.class_count_ == 0) 00193 { 00194 // fill up a map with the current labels and then create the 00195 // integral labels. 00196 std::set<T2> labelToInt; 00197 for(MultiArrayIndex k = 0; k < features.shape(0); ++k) 00198 labelToInt.insert(response(k,0)); 00199 std::vector<T2> tmp_(labelToInt.begin(), labelToInt.end()); 00200 ext_param.classes_(tmp_.begin(), tmp_.end()); 00201 } 00202 for(MultiArrayIndex k = 0; k < features.shape(0); ++k) 00203 { 00204 if(std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0)) == ext_param.classes.end()) 00205 { 00206 throw std::runtime_error("unknown label type"); 00207 } 00208 else 00209 intLabels_(k, 0) = std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0)) 00210 - ext_param.classes.begin(); 00211 } 00212 // set class weights 00213 if(ext_param.class_weights_.size() == 0) 00214 { 00215 ArrayVector<T2> 00216 tmp((std::size_t)ext_param.class_count_, 00217 NumericTraits<T2>::one()); 00218 ext_param.class_weights(tmp.begin(), tmp.end()); 00219 } 00220 00221 // set mtry and msample 00222 detail::fill_external_parameters(options, ext_param); 00223 00224 // set strata 00225 strata_ = intLabels_; 00226 00227 } 00228 00229 /** Access the processed features 00230 */ 00231 MultiArrayView<2, T1, C1>const & features() 00232 { 00233 return features_; 00234 } 00235 00236 /** Access processed labels 00237 */ 00238 MultiArrayView<2, LabelInt>& response() 00239 { 00240 return intLabels_; 00241 } 00242 00243 /** Access processed strata 00244 */ 00245 ArrayVectorView < LabelInt> strata() 00246 { 00247 return ArrayVectorView<LabelInt>(intLabels_.size(), intLabels_.data()); 00248 } 00249 00250 /** Access strata fraction sized - not used currently 00251 */ 00252 ArrayVectorView< double> strata_prob() 00253 { 00254 return ArrayVectorView< double>(); 00255 } 00256 }; 00257 00258 00259 00260 /** Regression Preprocessor - This basically does not do anything with the 00261 * data. 00262 */ 00263 template<class LabelType, class T1, class C1, class T2, class C2> 00264 class Processor<RegressionTag,LabelType, T1, C1, T2, C2> 00265 { 00266 public: 00267 // only views are created - no data copied. 00268 MultiArrayView<2, T1, C1> features_; 00269 MultiArrayView<2, T2, C2> response_; 00270 RandomForestOptions const & options_; 00271 ProblemSpec<LabelType> const & 00272 ext_param_; 00273 // will only be filled if needed 00274 MultiArray<2, int> strata_; 00275 bool strata_filled; 00276 00277 // copy the views. 00278 template<class T> 00279 Processor( MultiArrayView<2, T1, C1> features, 00280 MultiArrayView<2, T2, C2> response, 00281 RandomForestOptions const & options, 00282 ProblemSpec<T>& ext_param) 00283 : 00284 features_(features), 00285 response_(response), 00286 options_(options), 00287 ext_param_(ext_param) 00288 { 00289 // set some of the problem specific parameters 00290 ext_param.column_count_ = features.shape(1); 00291 ext_param.row_count_ = features.shape(0); 00292 ext_param.problem_type_ = REGRESSION; 00293 ext_param.used_ = true; 00294 detail::fill_external_parameters(options, ext_param); 00295 vigra_precondition(!detail::contains_nan(features), "Processor(): Feature Matrix " 00296 "Contains NaNs"); 00297 vigra_precondition(!detail::contains_nan(response), "Processor(): Response " 00298 "Contains NaNs"); 00299 vigra_precondition(!detail::contains_inf(features), "Processor(): Feature Matrix " 00300 "Contains inf"); 00301 vigra_precondition(!detail::contains_inf(response), "Processor(): Response " 00302 "Contains inf"); 00303 strata_ = MultiArray<2, int> (MultiArrayShape<2>::type(response_.shape(0), 1)); 00304 ext_param.response_size_ = response.shape(1); 00305 ext_param.class_count_ = response_.shape(1); 00306 std::vector<T2> tmp_(ext_param.class_count_, 0); 00307 ext_param.classes_(tmp_.begin(), tmp_.end()); 00308 } 00309 00310 /** access preprocessed features 00311 */ 00312 MultiArrayView<2, T1, C1> & features() 00313 { 00314 return features_; 00315 } 00316 00317 /** access preprocessed response 00318 */ 00319 MultiArrayView<2, T2, C2> & response() 00320 { 00321 return response_; 00322 } 00323 00324 /** access strata - this is not used currently 00325 */ 00326 MultiArray<2, int> & strata() 00327 { 00328 return strata_; 00329 } 00330 }; 00331 } 00332 #endif //VIGRA_RF_PREPROCESSING_HXX 00333 00334 00335
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|