[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest_deprec.hxx | ![]() |
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008 by Ullrich Koethe */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 #ifndef VIGRA_RANDOM_FOREST_DEPREC_HXX 00037 #define VIGRA_RANDOM_FOREST_DEPREC_HXX 00038 00039 #include <algorithm> 00040 #include <map> 00041 #include <numeric> 00042 #include <iostream> 00043 #include <ctime> 00044 #include <cstdlib> 00045 #include "vigra/mathutil.hxx" 00046 #include "vigra/array_vector.hxx" 00047 #include "vigra/sized_int.hxx" 00048 #include "vigra/matrix.hxx" 00049 #include "vigra/random.hxx" 00050 #include "vigra/functorexpression.hxx" 00051 00052 00053 namespace vigra 00054 { 00055 00056 /** \addtogroup MachineLearning 00057 **/ 00058 //@{ 00059 00060 namespace detail 00061 { 00062 00063 template<class DataMatrix> 00064 class RandomForestDeprecFeatureSorter 00065 { 00066 DataMatrix const & data_; 00067 MultiArrayIndex sortColumn_; 00068 00069 public: 00070 00071 RandomForestDeprecFeatureSorter(DataMatrix const & data, MultiArrayIndex sortColumn) 00072 : data_(data), 00073 sortColumn_(sortColumn) 00074 {} 00075 00076 void setColumn(MultiArrayIndex sortColumn) 00077 { 00078 sortColumn_ = sortColumn; 00079 } 00080 00081 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00082 { 00083 return data_(l, sortColumn_) < data_(r, sortColumn_); 00084 } 00085 }; 00086 00087 template<class LabelArray> 00088 class RandomForestDeprecLabelSorter 00089 { 00090 LabelArray const & labels_; 00091 00092 public: 00093 00094 RandomForestDeprecLabelSorter(LabelArray const & labels) 00095 : labels_(labels) 00096 {} 00097 00098 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00099 { 00100 return labels_[l] < labels_[r]; 00101 } 00102 }; 00103 00104 template <class CountArray> 00105 class RandomForestDeprecClassCounter 00106 { 00107 ArrayVector<int> const & labels_; 00108 CountArray & counts_; 00109 00110 public: 00111 00112 RandomForestDeprecClassCounter(ArrayVector<int> const & labels, CountArray & counts) 00113 : labels_(labels), 00114 counts_(counts) 00115 { 00116 reset(); 00117 } 00118 00119 void reset() 00120 { 00121 counts_.init(0); 00122 } 00123 00124 void operator()(MultiArrayIndex l) const 00125 { 00126 ++counts_[labels_[l]]; 00127 } 00128 }; 00129 00130 struct DecisionTreeDeprecCountNonzeroFunctor 00131 { 00132 double operator()(double old, double other) const 00133 { 00134 if(other != 0.0) 00135 ++old; 00136 return old; 00137 } 00138 }; 00139 00140 struct DecisionTreeDeprecNode 00141 { 00142 DecisionTreeDeprecNode(int t, MultiArrayIndex bestColumn) 00143 : thresholdIndex(t), splitColumn(bestColumn) 00144 {} 00145 00146 int children[2]; 00147 int thresholdIndex; 00148 Int32 splitColumn; 00149 }; 00150 00151 template <class INT> 00152 struct DecisionTreeDeprecNodeProxy 00153 { 00154 DecisionTreeDeprecNodeProxy(ArrayVector<INT> const & tree, INT n) 00155 : node(const_cast<ArrayVector<INT> &>(tree).begin()+n) 00156 {} 00157 00158 INT & child(INT l) const 00159 { 00160 return node[l]; 00161 } 00162 00163 INT & decisionWeightsIndex() const 00164 { 00165 return node[2]; 00166 } 00167 00168 typename ArrayVector<INT>::iterator decisionColumns() const 00169 { 00170 return node+3; 00171 } 00172 00173 mutable typename ArrayVector<INT>::iterator node; 00174 }; 00175 00176 struct DecisionTreeDeprecAxisSplitFunctor 00177 { 00178 ArrayVector<Int32> splitColumns; 00179 ArrayVector<double> classCounts, currentCounts[2], bestCounts[2], classWeights; 00180 double threshold; 00181 double totalCounts[2], bestTotalCounts[2]; 00182 int mtry, classCount, bestSplitColumn; 00183 bool pure[2], isWeighted; 00184 00185 void init(int mtry, int cols, int classCount, ArrayVector<double> const & weights) 00186 { 00187 this->mtry = mtry; 00188 splitColumns.resize(cols); 00189 for(int k=0; k<cols; ++k) 00190 splitColumns[k] = k; 00191 00192 this->classCount = classCount; 00193 classCounts.resize(classCount); 00194 currentCounts[0].resize(classCount); 00195 currentCounts[1].resize(classCount); 00196 bestCounts[0].resize(classCount); 00197 bestCounts[1].resize(classCount); 00198 00199 isWeighted = weights.size() > 0; 00200 if(isWeighted) 00201 classWeights = weights; 00202 else 00203 classWeights.resize(classCount, 1.0); 00204 } 00205 00206 bool isPure(int k) const 00207 { 00208 return pure[k]; 00209 } 00210 00211 unsigned int totalCount(int k) const 00212 { 00213 return (unsigned int)bestTotalCounts[k]; 00214 } 00215 00216 int sizeofNode() const { return 4; } 00217 00218 int writeSplitParameters(ArrayVector<Int32> & tree, 00219 ArrayVector<double> &terminalWeights) 00220 { 00221 int currentWeightIndex = terminalWeights.size(); 00222 terminalWeights.push_back(threshold); 00223 00224 int currentNodeIndex = tree.size(); 00225 tree.push_back(-1); // left child 00226 tree.push_back(-1); // right child 00227 tree.push_back(currentWeightIndex); 00228 tree.push_back(bestSplitColumn); 00229 00230 return currentNodeIndex; 00231 } 00232 00233 void writeWeights(int l, ArrayVector<double> &terminalWeights) 00234 { 00235 for(int k=0; k<classCount; ++k) 00236 terminalWeights.push_back(isWeighted 00237 ? bestCounts[l][k] 00238 : bestCounts[l][k] / totalCount(l)); 00239 } 00240 00241 template <class U, class C, class AxesIterator, class WeightIterator> 00242 bool decideAtNode(MultiArrayView<2, U, C> const & features, 00243 AxesIterator a, WeightIterator w) const 00244 { 00245 return (features(0, *a) < *w); 00246 } 00247 00248 template <class U, class C, class IndexIterator, class Random> 00249 IndexIterator findBestSplit(MultiArrayView<2, U, C> const & features, 00250 ArrayVector<int> const & labels, 00251 IndexIterator indices, int exampleCount, 00252 Random & randint); 00253 00254 }; 00255 00256 00257 template <class U, class C, class IndexIterator, class Random> 00258 IndexIterator 00259 DecisionTreeDeprecAxisSplitFunctor::findBestSplit(MultiArrayView<2, U, C> const & features, 00260 ArrayVector<int> const & labels, 00261 IndexIterator indices, int exampleCount, 00262 Random & randint) 00263 { 00264 // select columns to be tried for split 00265 for(int k=0; k<mtry; ++k) 00266 std::swap(splitColumns[k], splitColumns[k+randint(columnCount(features)-k)]); 00267 00268 RandomForestDeprecFeatureSorter<MultiArrayView<2, U, C> > sorter(features, 0); 00269 RandomForestDeprecClassCounter<ArrayVector<double> > counter(labels, classCounts); 00270 std::for_each(indices, indices+exampleCount, counter); 00271 00272 // find the best gini index 00273 double minGini = NumericTraits<double>::max(); 00274 IndexIterator bestSplit = indices; 00275 for(int k=0; k<mtry; ++k) 00276 { 00277 sorter.setColumn(splitColumns[k]); 00278 std::sort(indices, indices+exampleCount, sorter); 00279 00280 currentCounts[0].init(0); 00281 std::transform(classCounts.begin(), classCounts.end(), classWeights.begin(), 00282 currentCounts[1].begin(), std::multiplies<double>()); 00283 totalCounts[0] = 0; 00284 totalCounts[1] = std::accumulate(currentCounts[1].begin(), currentCounts[1].end(), 0.0); 00285 for(int m = 0; m < exampleCount-1; ++m) 00286 { 00287 int label = labels[indices[m]]; 00288 double w = classWeights[label]; 00289 currentCounts[0][label] += w; 00290 totalCounts[0] += w; 00291 currentCounts[1][label] -= w; 00292 totalCounts[1] -= w; 00293 00294 if (m < exampleCount-2 && 00295 features(indices[m], splitColumns[k]) == features(indices[m+1], splitColumns[k])) 00296 continue ; 00297 00298 double gini = 0.0; 00299 if(classCount == 2) 00300 { 00301 gini = currentCounts[0][0]*currentCounts[0][1] / totalCounts[0] + 00302 currentCounts[1][0]*currentCounts[1][1] / totalCounts[1]; 00303 } 00304 else 00305 { 00306 for(int l=0; l<classCount; ++l) 00307 gini += currentCounts[0][l]*(1.0 - currentCounts[0][l] / totalCounts[0]) + 00308 currentCounts[1][l]*(1.0 - currentCounts[1][l] / totalCounts[1]); 00309 } 00310 if(gini < minGini) 00311 { 00312 minGini = gini; 00313 bestSplit = indices+m; 00314 bestSplitColumn = splitColumns[k]; 00315 bestCounts[0] = currentCounts[0]; 00316 bestCounts[1] = currentCounts[1]; 00317 } 00318 } 00319 00320 00321 00322 } 00323 //std::cerr << minGini << " " << bestSplitColumn << std::endl; 00324 // split using the best feature 00325 sorter.setColumn(bestSplitColumn); 00326 std::sort(indices, indices+exampleCount, sorter); 00327 00328 for(int k=0; k<2; ++k) 00329 { 00330 bestTotalCounts[k] = std::accumulate(bestCounts[k].begin(), bestCounts[k].end(), 0.0); 00331 } 00332 00333 threshold = (features(bestSplit[0], bestSplitColumn) + features(bestSplit[1], bestSplitColumn)) / 2.0; 00334 ++bestSplit; 00335 00336 counter.reset(); 00337 std::for_each(indices, bestSplit, counter); 00338 pure[0] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeDeprecCountNonzeroFunctor()); 00339 counter.reset(); 00340 std::for_each(bestSplit, indices+exampleCount, counter); 00341 pure[1] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeDeprecCountNonzeroFunctor()); 00342 00343 return bestSplit; 00344 } 00345 00346 enum { DecisionTreeDeprecNoParent = -1 }; 00347 00348 template <class Iterator> 00349 struct DecisionTreeDeprecStackEntry 00350 { 00351 DecisionTreeDeprecStackEntry(Iterator i, int c, 00352 int lp = DecisionTreeDeprecNoParent, int rp = DecisionTreeDeprecNoParent) 00353 : indices(i), exampleCount(c), 00354 leftParent(lp), rightParent(rp) 00355 {} 00356 00357 Iterator indices; 00358 int exampleCount, leftParent, rightParent; 00359 }; 00360 00361 class DecisionTreeDeprec 00362 { 00363 public: 00364 typedef Int32 TreeInt; 00365 ArrayVector<TreeInt> tree_; 00366 ArrayVector<double> terminalWeights_; 00367 unsigned int classCount_; 00368 DecisionTreeDeprecAxisSplitFunctor split; 00369 00370 public: 00371 00372 00373 DecisionTreeDeprec(unsigned int classCount) 00374 : classCount_(classCount) 00375 {} 00376 00377 void reset(unsigned int classCount = 0) 00378 { 00379 if(classCount) 00380 classCount_ = classCount; 00381 tree_.clear(); 00382 terminalWeights_.clear(); 00383 } 00384 00385 template <class U, class C, class Iterator, class Options, class Random> 00386 void learn(MultiArrayView<2, U, C> const & features, 00387 ArrayVector<int> const & labels, 00388 Iterator indices, int exampleCount, 00389 Options const & options, 00390 Random & randint); 00391 00392 template <class U, class C> 00393 ArrayVector<double>::const_iterator 00394 predict(MultiArrayView<2, U, C> const & features) const 00395 { 00396 int nodeindex = 0; 00397 for(;;) 00398 { 00399 DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, nodeindex); 00400 nodeindex = split.decideAtNode(features, node.decisionColumns(), 00401 terminalWeights_.begin() + node.decisionWeightsIndex()) 00402 ? node.child(0) 00403 : node.child(1); 00404 if(nodeindex <= 0) 00405 return terminalWeights_.begin() + (-nodeindex); 00406 } 00407 } 00408 00409 template <class U, class C> 00410 int 00411 predictLabel(MultiArrayView<2, U, C> const & features) const 00412 { 00413 ArrayVector<double>::const_iterator weights = predict(features); 00414 return argMax(weights, weights+classCount_) - weights; 00415 } 00416 00417 template <class U, class C> 00418 int 00419 leafID(MultiArrayView<2, U, C> const & features) const 00420 { 00421 int nodeindex = 0; 00422 for(;;) 00423 { 00424 DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, nodeindex); 00425 nodeindex = split.decideAtNode(features, node.decisionColumns(), 00426 terminalWeights_.begin() + node.decisionWeightsIndex()) 00427 ? node.child(0) 00428 : node.child(1); 00429 if(nodeindex <= 0) 00430 return -nodeindex; 00431 } 00432 } 00433 00434 void depth(int & maxDep, int & interiorCount, int & leafCount, int k = 0, int d = 1) const 00435 { 00436 DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, k); 00437 ++interiorCount; 00438 ++d; 00439 for(int l=0; l<2; ++l) 00440 { 00441 int child = node.child(l); 00442 if(child > 0) 00443 depth(maxDep, interiorCount, leafCount, child, d); 00444 else 00445 { 00446 ++leafCount; 00447 if(maxDep < d) 00448 maxDep = d; 00449 } 00450 } 00451 } 00452 00453 void printStatistics(std::ostream & o) const 00454 { 00455 int maxDep = 0, interiorCount = 0, leafCount = 0; 00456 depth(maxDep, interiorCount, leafCount); 00457 00458 o << "interior nodes: " << interiorCount << 00459 ", terminal nodes: " << leafCount << 00460 ", depth: " << maxDep << "\n"; 00461 } 00462 00463 void print(std::ostream & o, int k = 0, std::string s = "") const 00464 { 00465 DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, k); 00466 o << s << (*node.decisionColumns()) << " " << terminalWeights_[node.decisionWeightsIndex()] << "\n"; 00467 00468 for(int l=0; l<2; ++l) 00469 { 00470 int child = node.child(l); 00471 if(child <= 0) 00472 o << s << " weights " << terminalWeights_[-child] << " " 00473 << terminalWeights_[-child+1] << "\n"; 00474 else 00475 print(o, child, s+" "); 00476 } 00477 } 00478 }; 00479 00480 00481 template <class U, class C, class Iterator, class Options, class Random> 00482 void DecisionTreeDeprec::learn(MultiArrayView<2, U, C> const & features, 00483 ArrayVector<int> const & labels, 00484 Iterator indices, int exampleCount, 00485 Options const & options, 00486 Random & randint) 00487 { 00488 ArrayVector<double> const & classLoss = options.class_weights; 00489 00490 vigra_precondition(classLoss.size() == 0 || classLoss.size() == classCount_, 00491 "DecisionTreeDeprec2::learn(): class weights array has wrong size."); 00492 00493 reset(); 00494 00495 unsigned int mtry = options.mtry; 00496 MultiArrayIndex cols = columnCount(features); 00497 00498 split.init(mtry, cols, classCount_, classLoss); 00499 00500 typedef DecisionTreeDeprecStackEntry<Iterator> Entry; 00501 ArrayVector<Entry> stack; 00502 stack.push_back(Entry(indices, exampleCount)); 00503 00504 while(!stack.empty()) 00505 { 00506 // std::cerr << "*"; 00507 indices = stack.back().indices; 00508 exampleCount = stack.back().exampleCount; 00509 int leftParent = stack.back().leftParent, 00510 rightParent = stack.back().rightParent; 00511 00512 stack.pop_back(); 00513 00514 Iterator bestSplit = split.findBestSplit(features, labels, indices, exampleCount, randint); 00515 00516 00517 int currentNode = split.writeSplitParameters(tree_, terminalWeights_); 00518 00519 if(leftParent != DecisionTreeDeprecNoParent) 00520 DecisionTreeDeprecNodeProxy<TreeInt>(tree_, leftParent).child(0) = currentNode; 00521 if(rightParent != DecisionTreeDeprecNoParent) 00522 DecisionTreeDeprecNodeProxy<TreeInt>(tree_, rightParent).child(1) = currentNode; 00523 leftParent = currentNode; 00524 rightParent = DecisionTreeDeprecNoParent; 00525 00526 for(int l=0; l<2; ++l) 00527 { 00528 00529 if(!split.isPure(l) && split.totalCount(l) >= options.min_split_node_size) 00530 { 00531 // sample is still large enough and not yet perfectly separated => split 00532 stack.push_back(Entry(indices, split.totalCount(l), leftParent, rightParent)); 00533 } 00534 else 00535 { 00536 DecisionTreeDeprecNodeProxy<TreeInt>(tree_, currentNode).child(l) = -(TreeInt)terminalWeights_.size(); 00537 00538 split.writeWeights(l, terminalWeights_); 00539 } 00540 std::swap(leftParent, rightParent); 00541 indices = bestSplit; 00542 } 00543 } 00544 // std::cerr << "\n"; 00545 } 00546 00547 } // namespace detail 00548 00549 class RandomForestOptionsDeprec 00550 { 00551 public: 00552 /** Initialize all options with default values. 00553 */ 00554 RandomForestOptionsDeprec() 00555 : training_set_proportion(1.0), 00556 mtry(0), 00557 min_split_node_size(1), 00558 training_set_size(0), 00559 sample_with_replacement(true), 00560 sample_classes_individually(false), 00561 treeCount(255) 00562 {} 00563 00564 /** Number of features considered in each node. 00565 00566 If \a n is 0 (the default), the number of features tried in every node 00567 is determined by the square root of the total number of features. 00568 According to Breiman, this quantity should always be optimized by means 00569 of the out-of-bag error.<br> 00570 Default: 0 (use <tt>sqrt(columnCount(featureMatrix))</tt>) 00571 */ 00572 RandomForestOptionsDeprec & featuresPerNode(unsigned int n) 00573 { 00574 mtry = n; 00575 return *this; 00576 } 00577 00578 /** How to sample the subset of the training data for each tree. 00579 00580 Each tree is only trained with a subset of the entire training data. 00581 If \a r is <tt>true</tt>, this subset is sampled from the entire training set with 00582 replacement.<br> 00583 Default: <tt>true</tt> (use sampling with replacement)) 00584 */ 00585 RandomForestOptionsDeprec & sampleWithReplacement(bool r) 00586 { 00587 sample_with_replacement = r; 00588 return *this; 00589 } 00590 00591 RandomForestOptionsDeprec & setTreeCount(unsigned int cnt) 00592 { 00593 treeCount = cnt; 00594 return *this; 00595 } 00596 /** Proportion of training examples used for each tree. 00597 00598 If \a p is 1.0 (the default), and samples are drawn with replacement, 00599 the training set of each tree will contain as many examples as the entire 00600 training set, but some are drawn multiply and others not at all. On average, 00601 each tree is actually trained on about 65% of the examples in the full 00602 training set. Changing the proportion makes mainly sense when 00603 sampleWithReplacement() is set to <tt>false</tt>. trainingSetSizeProportional() gets 00604 overridden by trainingSetSizeAbsolute().<br> 00605 Default: 1.0 00606 */ 00607 RandomForestOptionsDeprec & trainingSetSizeProportional(double p) 00608 { 00609 vigra_precondition(p >= 0.0 && p <= 1.0, 00610 "RandomForestOptionsDeprec::trainingSetSizeProportional(): proportion must be in [0, 1]."); 00611 if(training_set_size == 0) // otherwise, absolute size gets priority 00612 training_set_proportion = p; 00613 return *this; 00614 } 00615 00616 /** Size of the training set for each tree. 00617 00618 If this option is set, it overrides the proportion set by 00619 trainingSetSizeProportional(). When classes are sampled individually, 00620 the number of examples is divided by the number of classes (rounded upwards) 00621 to determine the number of examples drawn from every class.<br> 00622 Default: <tt>0</tt> (determine size by proportion) 00623 */ 00624 RandomForestOptionsDeprec & trainingSetSizeAbsolute(unsigned int s) 00625 { 00626 training_set_size = s; 00627 if(s > 0) 00628 training_set_proportion = 0.0; 00629 return *this; 00630 } 00631 00632 /** Are the classes sampled individually? 00633 00634 If \a s is <tt>false</tt> (the default), the training set for each tree is sampled 00635 without considering class labels. Otherwise, samples are drawn from each 00636 class independently. The latter is especially useful in connection 00637 with the specification of an absolute training set size: then, the same number of 00638 examples is drawn from every class. This can be used as a counter-measure when the 00639 classes are very unbalanced in size.<br> 00640 Default: <tt>false</tt> 00641 */ 00642 RandomForestOptionsDeprec & sampleClassesIndividually(bool s) 00643 { 00644 sample_classes_individually = s; 00645 return *this; 00646 } 00647 00648 /** Number of examples required for a node to be split. 00649 00650 When the number of examples in a node is below this number, the node is not 00651 split even if class separation is not yet perfect. Instead, the node returns 00652 the proportion of each class (among the remaining examples) during the 00653 prediction phase.<br> 00654 Default: 1 (complete growing) 00655 */ 00656 RandomForestOptionsDeprec & minSplitNodeSize(unsigned int n) 00657 { 00658 if(n == 0) 00659 n = 1; 00660 min_split_node_size = n; 00661 return *this; 00662 } 00663 00664 /** Use a weighted random forest. 00665 00666 This is usually used to penalize the errors for the minority class. 00667 Weights must be convertible to <tt>double</tt>, and the array of weights 00668 must contain as many entries as there are classes.<br> 00669 Default: do not use weights 00670 */ 00671 template <class WeightIterator> 00672 RandomForestOptionsDeprec & weights(WeightIterator weights, unsigned int classCount) 00673 { 00674 class_weights.clear(); 00675 if(weights != 0) 00676 class_weights.insert(weights, classCount); 00677 return *this; 00678 } 00679 00680 RandomForestOptionsDeprec & oobData(MultiArrayView<2, UInt8>& data) 00681 { 00682 oob_data =data; 00683 return *this; 00684 } 00685 00686 MultiArrayView<2, UInt8> oob_data; 00687 ArrayVector<double> class_weights; 00688 double training_set_proportion; 00689 unsigned int mtry, min_split_node_size, training_set_size; 00690 bool sample_with_replacement, sample_classes_individually; 00691 unsigned int treeCount; 00692 }; 00693 00694 /*****************************************************************/ 00695 /* */ 00696 /* RandomForestDeprec */ 00697 /* */ 00698 /*****************************************************************/ 00699 00700 template <class ClassLabelType> 00701 class RandomForestDeprec 00702 { 00703 public: 00704 ArrayVector<ClassLabelType> classes_; 00705 ArrayVector<detail::DecisionTreeDeprec> trees_; 00706 MultiArrayIndex columnCount_; 00707 RandomForestOptionsDeprec options_; 00708 00709 public: 00710 00711 //First two constructors are straight forward. 00712 //they take either the iterators to an Array of Classlabels or the values 00713 template<class ClassLabelIterator> 00714 RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend, 00715 unsigned int treeCount = 255, 00716 RandomForestOptionsDeprec const & options = RandomForestOptionsDeprec()) 00717 : classes_(cl, cend), 00718 trees_(treeCount, detail::DecisionTreeDeprec(classes_.size())), 00719 columnCount_(0), 00720 options_(options) 00721 { 00722 vigra_precondition(options.training_set_proportion == 0.0 || 00723 options.training_set_size == 0, 00724 "RandomForestOptionsDeprec: absolute and proportional training set sizes " 00725 "cannot be specified at the same time."); 00726 vigra_precondition(classes_.size() > 1, 00727 "RandomForestOptionsDeprec::weights(): need at least two classes."); 00728 vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(), 00729 "RandomForestOptionsDeprec::weights(): wrong number of classes."); 00730 } 00731 00732 RandomForestDeprec(ClassLabelType const & c1, ClassLabelType const & c2, 00733 unsigned int treeCount = 255, 00734 RandomForestOptionsDeprec const & options = RandomForestOptionsDeprec()) 00735 : classes_(2), 00736 trees_(treeCount, detail::DecisionTreeDeprec(2)), 00737 columnCount_(0), 00738 options_(options) 00739 { 00740 vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == 2, 00741 "RandomForestOptionsDeprec::weights(): wrong number of classes."); 00742 classes_[0] = c1; 00743 classes_[1] = c2; 00744 } 00745 //This is esp. For the CrosValidator Class 00746 template<class ClassLabelIterator> 00747 RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend, 00748 RandomForestOptionsDeprec const & options ) 00749 : classes_(cl, cend), 00750 trees_(options.treeCount , detail::DecisionTreeDeprec(classes_.size())), 00751 columnCount_(0), 00752 options_(options) 00753 { 00754 00755 vigra_precondition(options.training_set_proportion == 0.0 || 00756 options.training_set_size == 0, 00757 "RandomForestOptionsDeprec: absolute and proportional training set sizes " 00758 "cannot be specified at the same time."); 00759 vigra_precondition(classes_.size() > 1, 00760 "RandomForestOptionsDeprec::weights(): need at least two classes."); 00761 vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(), 00762 "RandomForestOptionsDeprec::weights(): wrong number of classes."); 00763 } 00764 00765 //Not understood yet 00766 //Does not use the options object but the columnCount object. 00767 template<class ClassLabelIterator, class TreeIterator, class WeightIterator> 00768 RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend, 00769 unsigned int treeCount, unsigned int columnCount, 00770 TreeIterator trees, WeightIterator weights) 00771 : classes_(cl, cend), 00772 trees_(treeCount, detail::DecisionTreeDeprec(classes_.size())), 00773 columnCount_(columnCount) 00774 { 00775 for(unsigned int k=0; k<treeCount; ++k, ++trees, ++weights) 00776 { 00777 trees_[k].tree_ = *trees; 00778 trees_[k].terminalWeights_ = *weights; 00779 } 00780 } 00781 00782 int featureCount() const 00783 { 00784 vigra_precondition(columnCount_ > 0, 00785 "RandomForestDeprec::featureCount(): Random forest has not been trained yet."); 00786 return columnCount_; 00787 } 00788 00789 int labelCount() const 00790 { 00791 return classes_.size(); 00792 } 00793 00794 int treeCount() const 00795 { 00796 return trees_.size(); 00797 } 00798 00799 // loss == 0.0 means unweighted random forest 00800 template <class U, class C, class Array, class Random> 00801 double learn(MultiArrayView<2, U, C> const & features, Array const & labels, 00802 Random const& random); 00803 00804 template <class U, class C, class Array> 00805 double learn(MultiArrayView<2, U, C> const & features, Array const & labels) 00806 { 00807 RandomNumberGenerator<> generator(RandomSeed); 00808 return learn(features, labels, generator); 00809 } 00810 00811 template <class U, class C> 00812 ClassLabelType predictLabel(MultiArrayView<2, U, C> const & features) const; 00813 00814 template <class U, class C1, class T, class C2> 00815 void predictLabels(MultiArrayView<2, U, C1> const & features, 00816 MultiArrayView<2, T, C2> & labels) const 00817 { 00818 vigra_precondition(features.shape(0) == labels.shape(0), 00819 "RandomForestDeprec::predictLabels(): Label array has wrong size."); 00820 for(int k=0; k<features.shape(0); ++k) 00821 labels(k,0) = predictLabel(rowVector(features, k)); 00822 } 00823 00824 template <class U, class C, class Iterator> 00825 ClassLabelType predictLabel(MultiArrayView<2, U, C> const & features, 00826 Iterator priors) const; 00827 00828 template <class U, class C1, class T, class C2> 00829 void predictProbabilities(MultiArrayView<2, U, C1> const & features, 00830 MultiArrayView<2, T, C2> & prob) const; 00831 00832 template <class U, class C1, class T, class C2> 00833 void predictNodes(MultiArrayView<2, U, C1> const & features, 00834 MultiArrayView<2, T, C2> & NodeIDs) const; 00835 }; 00836 00837 template <class ClassLabelType> 00838 template <class U, class C1, class Array, class Random> 00839 double 00840 RandomForestDeprec<ClassLabelType>::learn(MultiArrayView<2, U, C1> const & features, 00841 Array const & labels, 00842 Random const& random) 00843 { 00844 unsigned int classCount = classes_.size(); 00845 unsigned int m = rowCount(features); 00846 unsigned int n = columnCount(features); 00847 vigra_precondition((unsigned int)(m) == (unsigned int)labels.size(), 00848 "RandomForestDeprec::learn(): Label array has wrong size."); 00849 00850 vigra_precondition(options_.training_set_size <= m || options_.sample_with_replacement, 00851 "RandomForestDeprec::learn(): Requested training set size exceeds total number of examples."); 00852 00853 MultiArrayIndex mtry = (options_.mtry == 0) 00854 ? int(std::floor(std::sqrt(double(n)) + 0.5)) 00855 : options_.mtry; 00856 00857 vigra_precondition(mtry <= (MultiArrayIndex)n, 00858 "RandomForestDeprec::learn(): mtry must be less than number of features."); 00859 00860 MultiArrayIndex msamples = options_.training_set_size; 00861 if(options_.sample_classes_individually) 00862 msamples = int(std::ceil(double(msamples) / classCount)); 00863 00864 ArrayVector<int> intLabels(m), classExampleCounts(classCount); 00865 00866 // verify the input labels 00867 int minClassCount; 00868 { 00869 typedef std::map<ClassLabelType, int > LabelChecker; 00870 typedef typename LabelChecker::iterator LabelCheckerIterator; 00871 LabelChecker labelChecker; 00872 for(unsigned int k=0; k<classCount; ++k) 00873 labelChecker[classes_[k]] = k; 00874 00875 for(unsigned int k=0; k<m; ++k) 00876 { 00877 LabelCheckerIterator found = labelChecker.find(labels[k]); 00878 vigra_precondition(found != labelChecker.end(), 00879 "RandomForestDeprec::learn(): Unknown class label encountered."); 00880 intLabels[k] = found->second; 00881 ++classExampleCounts[intLabels[k]]; 00882 } 00883 minClassCount = *argMin(classExampleCounts.begin(), classExampleCounts.end()); 00884 vigra_precondition(minClassCount > 0, 00885 "RandomForestDeprec::learn(): At least one class is missing in the training set."); 00886 if(msamples > 0 && options_.sample_classes_individually && 00887 !options_.sample_with_replacement) 00888 { 00889 vigra_precondition(msamples <= minClassCount, 00890 "RandomForestDeprec::learn(): Too few examples in smallest class to reach " 00891 "requested training set size."); 00892 } 00893 } 00894 columnCount_ = n; 00895 ArrayVector<int> indices(m); 00896 for(unsigned int k=0; k<m; ++k) 00897 indices[k] = k; 00898 00899 if(options_.sample_classes_individually) 00900 { 00901 detail::RandomForestDeprecLabelSorter<ArrayVector<int> > sorter(intLabels); 00902 std::sort(indices.begin(), indices.end(), sorter); 00903 } 00904 00905 ArrayVector<int> usedIndices(m), oobCount(m), oobErrorCount(m); 00906 00907 UniformIntRandomFunctor<Random> randint(0, m-1, random); 00908 //std::cerr << "Learning a RF \n"; 00909 for(unsigned int k=0; k<trees_.size(); ++k) 00910 { 00911 //std::cerr << "Learning tree " << k << " ...\n"; 00912 00913 ArrayVector<int> trainingSet; 00914 usedIndices.init(0); 00915 00916 if(options_.sample_classes_individually) 00917 { 00918 int first = 0; 00919 for(unsigned int l=0; l<classCount; ++l) 00920 { 00921 int lc = classExampleCounts[l]; 00922 int lsamples = (msamples == 0) 00923 ? int(std::ceil(options_.training_set_proportion*lc)) 00924 : msamples; 00925 00926 if(options_.sample_with_replacement) 00927 { 00928 for(int ll=0; ll<lsamples; ++ll) 00929 { 00930 trainingSet.push_back(indices[first+randint(lc)]); 00931 ++usedIndices[trainingSet.back()]; 00932 } 00933 } 00934 else 00935 { 00936 for(int ll=0; ll<lsamples; ++ll) 00937 { 00938 std::swap(indices[first+ll], indices[first+ll+randint(lc-ll)]); 00939 trainingSet.push_back(indices[first+ll]); 00940 ++usedIndices[trainingSet.back()]; 00941 } 00942 //std::sort(indices.begin(), indices.begin()+lsamples); 00943 } 00944 first += lc; 00945 } 00946 } 00947 else 00948 { 00949 if(msamples == 0) 00950 msamples = int(std::ceil(options_.training_set_proportion*m)); 00951 00952 if(options_.sample_with_replacement) 00953 { 00954 for(int l=0; l<msamples; ++l) 00955 { 00956 trainingSet.push_back(indices[randint(m)]); 00957 ++usedIndices[trainingSet.back()]; 00958 } 00959 } 00960 else 00961 { 00962 for(int l=0; l<msamples; ++l) 00963 { 00964 std::swap(indices[l], indices[l+randint(m-l)/*oikas*/]); 00965 trainingSet.push_back(indices[l]); 00966 ++usedIndices[trainingSet.back()]; 00967 } 00968 00969 00970 } 00971 00972 } 00973 trees_[k].learn(features, intLabels, 00974 trainingSet.begin(), trainingSet.size(), 00975 options_.featuresPerNode(mtry), randint); 00976 // for(unsigned int l=0; l<m; ++l) 00977 // { 00978 // if(!usedIndices[l]) 00979 // { 00980 // ++oobCount[l]; 00981 // if(trees_[k].predictLabel(rowVector(features, l)) != intLabels[l]) 00982 // ++oobErrorCount[l]; 00983 // } 00984 // } 00985 00986 for(unsigned int l=0; l<m; ++l) 00987 { 00988 if(!usedIndices[l]) 00989 { 00990 ++oobCount[l]; 00991 if(trees_[k].predictLabel(rowVector(features, l)) != intLabels[l]) 00992 { 00993 ++oobErrorCount[l]; 00994 if(options_.oob_data.data() != 0) 00995 options_.oob_data(l, k) = 2; 00996 } 00997 else if(options_.oob_data.data() != 0) 00998 { 00999 options_.oob_data(l, k) = 1; 01000 } 01001 } 01002 } 01003 // TODO: default value for oob_data 01004 // TODO: implement variable importance 01005 //if(!options_.sample_with_replacement){ 01006 //std::cerr << "done\n"; 01007 //trees_[k].print(std::cerr); 01008 #ifdef VIGRA_RF_VERBOSE 01009 trees_[k].printStatistics(std::cerr); 01010 #endif 01011 } 01012 double oobError = 0.0; 01013 int totalOobCount = 0; 01014 for(unsigned int l=0; l<m; ++l) 01015 if(oobCount[l]) 01016 { 01017 oobError += double(oobErrorCount[l]) / oobCount[l]; 01018 ++totalOobCount; 01019 } 01020 return oobError / totalOobCount; 01021 } 01022 01023 template <class ClassLabelType> 01024 template <class U, class C> 01025 ClassLabelType 01026 RandomForestDeprec<ClassLabelType>::predictLabel(MultiArrayView<2, U, C> const & features) const 01027 { 01028 vigra_precondition(columnCount(features) >= featureCount(), 01029 "RandomForestDeprec::predictLabel(): Too few columns in feature matrix."); 01030 vigra_precondition(rowCount(features) == 1, 01031 "RandomForestDeprec::predictLabel(): Feature matrix must have a single row."); 01032 Matrix<double> prob(1, classes_.size()); 01033 predictProbabilities(features, prob); 01034 return classes_[argMax(prob)]; 01035 } 01036 01037 01038 //Same thing as above with priors for each label !!! 01039 template <class ClassLabelType> 01040 template <class U, class C, class Iterator> 01041 ClassLabelType 01042 RandomForestDeprec<ClassLabelType>::predictLabel(MultiArrayView<2, U, C> const & features, 01043 Iterator priors) const 01044 { 01045 using namespace functor; 01046 vigra_precondition(columnCount(features) >= featureCount(), 01047 "RandomForestDeprec::predictLabel(): Too few columns in feature matrix."); 01048 vigra_precondition(rowCount(features) == 1, 01049 "RandomForestDeprec::predictLabel(): Feature matrix must have a single row."); 01050 Matrix<double> prob(1,classes_.size()); 01051 predictProbabilities(features, prob); 01052 std::transform(prob.begin(), prob.end(), priors, prob.begin(), Arg1()*Arg2()); 01053 return classes_[argMax(prob)]; 01054 } 01055 01056 template <class ClassLabelType> 01057 template <class U, class C1, class T, class C2> 01058 void 01059 RandomForestDeprec<ClassLabelType>::predictProbabilities(MultiArrayView<2, U, C1> const & features, 01060 MultiArrayView<2, T, C2> & prob) const 01061 { 01062 01063 //Features are n xp 01064 //prob is n x NumOfLabel probability for each feature in each class 01065 01066 vigra_precondition(rowCount(features) == rowCount(prob), 01067 "RandomForestDeprec::predictProbabilities(): Feature matrix and probability matrix size mismatch."); 01068 01069 // num of features must be bigger than num of features in Random forest training 01070 // but why bigger? 01071 vigra_precondition(columnCount(features) >= featureCount(), 01072 "RandomForestDeprec::predictProbabilities(): Too few columns in feature matrix."); 01073 vigra_precondition(columnCount(prob) == (MultiArrayIndex)labelCount(), 01074 "RandomForestDeprec::predictProbabilities(): Probability matrix must have as many columns as there are classes."); 01075 01076 //Classify for each row. 01077 for(int row=0; row < rowCount(features); ++row) 01078 { 01079 //contains the weights returned by a single tree??? 01080 //thought that one tree has only one vote??? 01081 //Pruning??? 01082 ArrayVector<double>::const_iterator weights; 01083 01084 //totalWeight == totalVoteCount! 01085 double totalWeight = 0.0; 01086 01087 //Set each VoteCount = 0 - prob(row,l) contains vote counts until 01088 //further normalisation 01089 for(unsigned int l=0; l<classes_.size(); ++l) 01090 prob(row, l) = 0.0; 01091 01092 //Let each tree classify... 01093 for(unsigned int k=0; k<trees_.size(); ++k) 01094 { 01095 //get weights predicted by single tree 01096 weights = trees_[k].predict(rowVector(features, row)); 01097 01098 //update votecount. 01099 for(unsigned int l=0; l<classes_.size(); ++l) 01100 { 01101 prob(row, l) += detail::RequiresExplicitCast<T>::cast(weights[l]); 01102 //every weight in totalWeight. 01103 totalWeight += weights[l]; 01104 } 01105 } 01106 01107 //Normalise votes in each row by total VoteCount (totalWeight 01108 for(unsigned int l=0; l<classes_.size(); ++l) 01109 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight); 01110 } 01111 } 01112 01113 01114 template <class ClassLabelType> 01115 template <class U, class C1, class T, class C2> 01116 void 01117 RandomForestDeprec<ClassLabelType>::predictNodes(MultiArrayView<2, U, C1> const & features, 01118 MultiArrayView<2, T, C2> & NodeIDs) const 01119 { 01120 vigra_precondition(columnCount(features) >= featureCount(), 01121 "RandomForestDeprec::getNodesRF(): Too few columns in feature matrix."); 01122 vigra_precondition(rowCount(features) <= rowCount(NodeIDs), 01123 "RandomForestDeprec::getNodesRF(): Too few rows in NodeIds matrix"); 01124 vigra_precondition(columnCount(NodeIDs) >= treeCount(), 01125 "RandomForestDeprec::getNodesRF(): Too few columns in NodeIds matrix."); 01126 NodeIDs.init(0); 01127 for(unsigned int k=0; k<trees_.size(); ++k) 01128 { 01129 for(int row=0; row < rowCount(features); ++row) 01130 { 01131 NodeIDs(row,k) = trees_[k].leafID(rowVector(features, row)); 01132 } 01133 } 01134 } 01135 01136 //@} 01137 01138 } // namespace vigra 01139 01140 01141 #endif // VIGRA_RANDOM_FOREST_HXX 01142
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|