[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest/rf_ridge_split.hxx | ![]() |
00001 // 00002 // C++ Interface: rf_ridge_split 00003 // 00004 // Description: 00005 // 00006 // 00007 // Author: Nico Splitthoff <splitthoff@zg00103>, (C) 2009 00008 // 00009 // Copyright: See COPYING file that comes with this distribution 00010 // 00011 // 00012 #ifndef VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H 00013 #define VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H 00014 //#include "rf_sampling.hxx" 00015 #include "../sampling.hxx" 00016 #include "rf_split.hxx" 00017 #include "rf_nodeproxy.hxx" 00018 #include "../regression.hxx" 00019 00020 #define outm(v) std::cout << (#v) << ": " << (v) << std::endl; 00021 #define outm2(v) std::cout << (#v) << ": " << (v) << ", "; 00022 00023 namespace vigra 00024 { 00025 00026 /*template<> 00027 class Node<i_RegrNode> 00028 : public NodeBase 00029 { 00030 public: 00031 typedef NodeBase BT; 00032 00033 00034 Node( BT::T_Container_type & topology, 00035 BT::P_Container_type & param, 00036 int nNumCols) 00037 : BT(5+nNumCols,2+nNumCols,topology, param) 00038 { 00039 BT::typeID() = i_RegrNode; 00040 } 00041 00042 Node( BT::T_Container_type & topology, 00043 BT::P_Container_type & param, 00044 INT n ) 00045 : BT(5,2,topology, param, n) 00046 {} 00047 00048 Node( BT & node_) 00049 : BT(5, 2, node_) 00050 {} 00051 00052 double& threshold() 00053 { 00054 return BT::parameters_begin()[1]; 00055 } 00056 00057 BT::INT& column() 00058 { 00059 return BT::column_data()[0]; 00060 } 00061 00062 template<class U, class C> 00063 BT::INT& next(MultiArrayView<2,U,C> const & feature) 00064 { 00065 return (feature(0, column()) < threshold())? child(0):child(1); 00066 } 00067 };*/ 00068 00069 00070 template<class ColumnDecisionFunctor, class Tag = ClassificationTag> 00071 class RidgeSplit: public SplitBase<Tag> 00072 { 00073 public: 00074 00075 00076 typedef SplitBase<Tag> SB; 00077 00078 ArrayVector<Int32> splitColumns; 00079 ColumnDecisionFunctor bgfunc; 00080 00081 double region_gini_; 00082 ArrayVector<double> min_gini_; 00083 ArrayVector<std::ptrdiff_t> min_indices_; 00084 ArrayVector<double> min_thresholds_; 00085 00086 int bestSplitIndex; 00087 00088 //dns 00089 bool m_bDoScalingInTraining; 00090 bool m_bDoBestLambdaBasedOnGini; 00091 00092 RidgeSplit() 00093 :m_bDoScalingInTraining(true), 00094 m_bDoBestLambdaBasedOnGini(true) 00095 { 00096 } 00097 00098 double minGini() const 00099 { 00100 return min_gini_[bestSplitIndex]; 00101 } 00102 00103 int bestSplitColumn() const 00104 { 00105 return splitColumns[bestSplitIndex]; 00106 } 00107 00108 bool& doScalingInTraining() 00109 { return m_bDoScalingInTraining; } 00110 00111 bool& doBestLambdaBasedOnGini() 00112 { return m_bDoBestLambdaBasedOnGini; } 00113 00114 template<class T> 00115 void set_external_parameters(ProblemSpec<T> const & in) 00116 { 00117 SB::set_external_parameters(in); 00118 bgfunc.set_external_parameters(in); 00119 int featureCount_ = in.column_count_; 00120 splitColumns.resize(featureCount_); 00121 for(int k=0; k<featureCount_; ++k) 00122 splitColumns[k] = k; 00123 min_gini_.resize(featureCount_); 00124 min_indices_.resize(featureCount_); 00125 min_thresholds_.resize(featureCount_); 00126 } 00127 00128 00129 template<class T, class C, class T2, class C2, class Region, class Random> 00130 int findBestSplit(MultiArrayView<2, T, C> features, 00131 MultiArrayView<2, T2, C2> multiClassLabels, 00132 Region & region, 00133 ArrayVector<Region>& childRegions, 00134 Random & randint) 00135 { 00136 00137 //std::cerr << "Split called" << std::endl; 00138 typedef typename Region::IndexIterator IndexIterator; 00139 typedef typename MultiArrayView <2, T, C>::difference_type fShape; 00140 typedef typename MultiArrayView <2, T2, C2>::difference_type lShape; 00141 typedef typename MultiArrayView <2, double>::difference_type dShape; 00142 00143 // calculate things that haven't been calculated yet. 00144 // std::cout << "start" << std::endl; 00145 if(std::accumulate(region.classCounts().begin(), 00146 region.classCounts().end(), 0) != region.size()) 00147 { 00148 RandomForestClassCounter< MultiArrayView<2,T2, C2>, 00149 ArrayVector<double> > 00150 counter(multiClassLabels, region.classCounts()); 00151 std::for_each( region.begin(), region.end(), counter); 00152 region.classCountsIsValid = true; 00153 } 00154 00155 00156 // Is the region pure already? 00157 region_gini_ = GiniCriterion::impurity(region.classCounts(), 00158 region.size()); 00159 if(region_gini_ == 0 || region.size() < SB::ext_param_.actual_mtry_ || region.oob_size() < 2) 00160 return SB::makeTerminalNode(features, multiClassLabels, region, randint); 00161 00162 // select columns to be tried. 00163 for(int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii) 00164 std::swap(splitColumns[ii], 00165 splitColumns[ii+ randint(features.shape(1) - ii)]); 00166 00167 //do implicit binary case 00168 MultiArray<2, T2> labels(lShape(multiClassLabels.shape(0),1)); 00169 //number of classes should be >1, otherwise makeTerminalNode would have been called 00170 int nNumClasses=0; 00171 for(int n=0; n<(int)region.classCounts().size(); n++) 00172 nNumClasses+=((region.classCounts()[n]>0) ? 1:0); 00173 00174 //convert to binary case 00175 if(nNumClasses>2) 00176 { 00177 int nMaxClass=0; 00178 int nMaxClassCounts=0; 00179 for(int n=0; n<(int)region.classCounts().size(); n++) 00180 { 00181 //this should occur in any case: 00182 //we had more than two non-zero classes in order to get here 00183 if(region.classCounts()[n]>nMaxClassCounts) 00184 { 00185 nMaxClassCounts=region.classCounts()[n]; 00186 nMaxClass=n; 00187 } 00188 } 00189 00190 //create binary labels 00191 for(int n=0; n<multiClassLabels.shape(0); n++) 00192 labels(n,0)=((multiClassLabels(n,0)==nMaxClass) ? 1:0); 00193 } 00194 else 00195 labels=multiClassLabels; 00196 00197 //_do implicit binary case 00198 00199 //uncomment this for some debugging 00200 /* int nNumCases=features.shape(0); 00201 00202 typedef typename MultiArrayView <2, int>::difference_type nShape; 00203 MultiArray<2, int> elementCounterArray(nShape(nNumCases,1),(int)0); 00204 int nUniqueElements=0; 00205 for(int n=0; n<region.size(); n++) 00206 elementCounterArray[region[n]]++; 00207 00208 for(int n=0; n<nNumCases; n++) 00209 nUniqueElements+=((elementCounterArray[n]>0) ? 1:0); 00210 00211 outm(nUniqueElements); 00212 nUniqueElements=0; 00213 MultiArray<2, int> elementCounterArray_oob(nShape(nNumCases,1),(int)0); 00214 for(int n=0; n<region.oob_size(); n++) 00215 elementCounterArray_oob[region.oob_begin()[n]]++; 00216 for(int n=0; n<nNumCases; n++) 00217 nUniqueElements+=((elementCounterArray_oob[n]>0) ? 1:0); 00218 outm(nUniqueElements); 00219 00220 int notUniqueElements=0; 00221 for(int n=0; n<nNumCases; n++) 00222 notUniqueElements+=(((elementCounterArray_oob[n]>0) && (elementCounterArray[n]>0)) ? 1:0); 00223 outm(notUniqueElements);*/ 00224 00225 //outm(SB::ext_param_.actual_mtry_); 00226 00227 00228 //select submatrix of features for regression calculation 00229 MultiArrayView<2, T, C> cVector; 00230 MultiArray<2, T> xtrain(fShape(region.size(),SB::ext_param_.actual_mtry_)); 00231 //we only want -1 and 1 for this 00232 MultiArray<2, double> regrLabels(dShape(region.size(),1)); 00233 00234 //copy data into a vigra data structure and centre and scale while doing so 00235 MultiArray<2, double> meanMatrix(dShape(SB::ext_param_.actual_mtry_,1)); 00236 MultiArray<2, double> stdMatrix(dShape(SB::ext_param_.actual_mtry_,1)); 00237 for(int m=0; m<SB::ext_param_.actual_mtry_; m++) 00238 { 00239 cVector=columnVector(features, splitColumns[m]); 00240 00241 //centre and scale the data 00242 double dCurrFeatureColumnMean=0.0; 00243 double dCurrFeatureColumnStd=1.0; //default value 00244 00245 //calc mean on bootstrap data 00246 for(int n=0; n<region.size(); n++) 00247 dCurrFeatureColumnMean+=cVector[region[n]]; 00248 dCurrFeatureColumnMean/=region.size(); 00249 //calc scaling 00250 if(m_bDoScalingInTraining) 00251 { 00252 for(int n=0; n<region.size(); n++) 00253 { 00254 dCurrFeatureColumnStd+= 00255 (cVector[region[n]]-dCurrFeatureColumnMean)*(cVector[region[n]]-dCurrFeatureColumnMean); 00256 } 00257 //unbiased std estimator: 00258 dCurrFeatureColumnStd=sqrt(dCurrFeatureColumnStd/(region.size()-1)); 00259 } 00260 //dCurrFeatureColumnStd is still 1.0 if we didn't want scaling 00261 stdMatrix(m,0)=dCurrFeatureColumnStd; 00262 00263 meanMatrix(m,0)=dCurrFeatureColumnMean; 00264 00265 //get feature matrix, i.e. A (note that weighting is done automatically 00266 //since rows can occur multiple times -> bagging) 00267 for(int n=0; n<region.size(); n++) 00268 xtrain(n,m)=(cVector[region[n]]-dCurrFeatureColumnMean)/dCurrFeatureColumnStd; 00269 } 00270 00271 // std::cout << "middle" << std::endl; 00272 //get label vector (i.e. b) 00273 for(int n=0; n<region.size(); n++) 00274 { 00275 //we checked for/built binary case further up. 00276 //class labels should thus be either 0 or 1 00277 //-> convert to -1 and 1 for regression 00278 regrLabels(n,0)=((labels[region[n]]==0) ? -1:1); 00279 } 00280 00281 MultiArray<2, double> dLambdas(dShape(11,1)); 00282 int nCounter=0; 00283 for(int nLambda=-5; nLambda<=5; nLambda++) 00284 dLambdas[nCounter++]=pow(10.0,nLambda); 00285 //destination vector for regression coefficients; use same type as for xtrain 00286 MultiArray<2, double> regrCoef(dShape(SB::ext_param_.actual_mtry_,11)); 00287 ridgeRegressionSeries(xtrain,regrLabels,regrCoef,dLambdas); 00288 00289 double dMaxRidgeSum=NumericTraits<double>::min(); 00290 double dCurrRidgeSum; 00291 int nMaxRidgeSumAtLambdaInd=0; 00292 00293 for(int nLambdaInd=0; nLambdaInd<11; nLambdaInd++) 00294 { 00295 //just sum up the correct answers 00296 //(correct means >=intercept for class 1, <intercept for class 0) 00297 //(intercept=0 or intercept=threshold based on gini) 00298 dCurrRidgeSum=0.0; 00299 00300 //assemble projection vector 00301 MultiArray<2, double> dDistanceFromHyperplane(dShape(features.shape(0),1)); 00302 00303 for(int n=0; n<region.oob_size(); n++) 00304 { 00305 dDistanceFromHyperplane(region.oob_begin()[n],0)=0.0; 00306 for (int m=0; m<SB::ext_param_.actual_mtry_; m++) 00307 { 00308 dDistanceFromHyperplane(region.oob_begin()[n],0)+= 00309 features(region.oob_begin()[n],splitColumns[m])*regrCoef(m,nLambdaInd); 00310 } 00311 } 00312 00313 double dCurrIntercept=0.0; 00314 if(m_bDoBestLambdaBasedOnGini) 00315 { 00316 //calculate gini index 00317 bgfunc(dDistanceFromHyperplane, 00318 labels, 00319 region.oob_begin(), region.oob_end(), 00320 region.classCounts()); 00321 dCurrIntercept=bgfunc.min_threshold_; 00322 } 00323 else 00324 { 00325 for (int m=0; m<SB::ext_param_.actual_mtry_; m++) 00326 dCurrIntercept+=meanMatrix(m,0)*regrCoef(m,nLambdaInd); 00327 } 00328 00329 for(int n=0; n<region.oob_size(); n++) 00330 { 00331 //check what lambda performs best on oob data 00332 int nClassPrediction=((dDistanceFromHyperplane(region.oob_begin()[n],0) >=dCurrIntercept) ? 1:0); 00333 dCurrRidgeSum+=((nClassPrediction == labels(region.oob_begin()[n],0)) ? 1:0); 00334 } 00335 if(dCurrRidgeSum>dMaxRidgeSum) 00336 { 00337 dMaxRidgeSum=dCurrRidgeSum; 00338 nMaxRidgeSumAtLambdaInd=nLambdaInd; 00339 } 00340 } 00341 00342 // std::cout << "middle2" << std::endl; 00343 //create a Node for output 00344 Node<i_HyperplaneNode> node(SB::ext_param_.actual_mtry_, SB::t_data, SB::p_data); 00345 00346 //normalise coeffs 00347 //data was scaled (by 1.0 or by std) -> take into account 00348 MultiArray<2, double> dCoeffVector(dShape(SB::ext_param_.actual_mtry_,1)); 00349 for(int n=0; n<SB::ext_param_.actual_mtry_; n++) 00350 dCoeffVector(n,0)=regrCoef(n,nMaxRidgeSumAtLambdaInd)*stdMatrix(n,0); 00351 00352 //calc norm 00353 double dVnorm=columnVector(regrCoef,nMaxRidgeSumAtLambdaInd).norm(); 00354 00355 for(int n=0; n<SB::ext_param_.actual_mtry_; n++) 00356 node.weights()[n]=dCoeffVector(n,0)/dVnorm; 00357 //_normalise coeffs 00358 00359 //save the columns 00360 node.column_data()[0]=SB::ext_param_.actual_mtry_; 00361 for(int n=0; n<SB::ext_param_.actual_mtry_; n++) 00362 node.column_data()[n+1]=splitColumns[n]; 00363 00364 //assemble projection vector 00365 //careful here: "region" is a pointer to indices... 00366 //all the indices in "region" need to have valid data 00367 //convert from "region" space to original "feature" space 00368 MultiArray<2, double> dDistanceFromHyperplane(dShape(features.shape(0),1)); 00369 00370 for(int n=0; n<region.size(); n++) 00371 { 00372 dDistanceFromHyperplane(region[n],0)=0.0; 00373 for (int m=0; m<SB::ext_param_.actual_mtry_; m++) 00374 { 00375 dDistanceFromHyperplane(region[n],0)+= 00376 features(region[n],m)*node.weights()[m]; 00377 } 00378 } 00379 for(int n=0; n<region.oob_size(); n++) 00380 { 00381 dDistanceFromHyperplane(region.oob_begin()[n],0)=0.0; 00382 for (int m=0; m<SB::ext_param_.actual_mtry_; m++) 00383 { 00384 dDistanceFromHyperplane(region.oob_begin()[n],0)+= 00385 features(region.oob_begin()[n],m)*node.weights()[m]; 00386 } 00387 } 00388 00389 //calculate gini index 00390 bgfunc(dDistanceFromHyperplane, 00391 labels, 00392 region.begin(), region.end(), 00393 region.classCounts()); 00394 00395 // did not find any suitable split 00396 if(closeAtTolerance(bgfunc.min_gini_, NumericTraits<double>::max())) 00397 return SB::makeTerminalNode(features, multiClassLabels, region, randint); 00398 00399 //take gini threshold here due to scaling, normalisation, etc. of the coefficients 00400 node.intercept() = bgfunc.min_threshold_; 00401 SB::node_ = node; 00402 00403 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0]; 00404 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1]; 00405 childRegions[0].classCountsIsValid = true; 00406 childRegions[1].classCountsIsValid = true; 00407 00408 // Save the ranges of the child stack entries. 00409 childRegions[0].setRange( region.begin() , region.begin() + bgfunc.min_index_ ); 00410 childRegions[0].rule = region.rule; 00411 childRegions[0].rule.push_back(std::make_pair(1, 1.0)); 00412 childRegions[1].setRange( region.begin() + bgfunc.min_index_ , region.end() ); 00413 childRegions[1].rule = region.rule; 00414 childRegions[1].rule.push_back(std::make_pair(1, 1.0)); 00415 00416 //adjust oob ranges 00417 // std::cout << "adjust oob" << std::endl; 00418 //sort the oobs 00419 std::sort(region.oob_begin(), region.oob_end(), 00420 SortSamplesByDimensions< MultiArray<2, double> > (dDistanceFromHyperplane, 0)); 00421 00422 //find split index 00423 int nOOBindx; 00424 for(nOOBindx=0; nOOBindx<region.oob_size(); nOOBindx++) 00425 { 00426 if(dDistanceFromHyperplane(region.oob_begin()[nOOBindx],0)>=node.intercept()) 00427 break; 00428 } 00429 00430 childRegions[0].set_oob_range( region.oob_begin() , region.oob_begin() + nOOBindx ); 00431 childRegions[1].set_oob_range( region.oob_begin() + nOOBindx , region.oob_end() ); 00432 00433 // std::cout << "end" << std::endl; 00434 // outm2(region.oob_begin());outm2(nOOBindx);outm(region.oob_begin() + nOOBindx); 00435 //_adjust oob ranges 00436 00437 return i_HyperplaneNode; 00438 } 00439 }; 00440 00441 /** Standard ridge regression split 00442 */ 00443 typedef RidgeSplit<BestGiniOfColumn<GiniCriterion> > GiniRidgeSplit; 00444 00445 00446 } //namespace vigra 00447 #endif // VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|