[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest_deprec.hxx VIGRA

00001 /************************************************************************/
00002 /*                                                                      */
00003 /*                  Copyright 2008 by Ullrich Koethe                    */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 #ifndef VIGRA_RANDOM_FOREST_DEPREC_HXX
00037 #define VIGRA_RANDOM_FOREST_DEPREC_HXX
00038 
00039 #include <algorithm>
00040 #include <map>
00041 #include <numeric>
00042 #include <iostream>
00043 #include <ctime>
00044 #include <cstdlib>
00045 #include "vigra/mathutil.hxx"
00046 #include "vigra/array_vector.hxx"
00047 #include "vigra/sized_int.hxx"
00048 #include "vigra/matrix.hxx"
00049 #include "vigra/random.hxx"
00050 #include "vigra/functorexpression.hxx"
00051 
00052 
00053 namespace vigra
00054 {
00055 
00056 /** \addtogroup MachineLearning 
00057 **/
00058 //@{
00059 
00060 namespace detail
00061 {
00062 
00063 template<class DataMatrix>
00064 class RandomForestDeprecFeatureSorter
00065 {
00066     DataMatrix const & data_;
00067     MultiArrayIndex sortColumn_;
00068 
00069   public:
00070 
00071     RandomForestDeprecFeatureSorter(DataMatrix const & data, MultiArrayIndex sortColumn)
00072     : data_(data),
00073       sortColumn_(sortColumn)
00074     {}
00075 
00076     void setColumn(MultiArrayIndex sortColumn)
00077     {
00078         sortColumn_ = sortColumn;
00079     }
00080 
00081     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00082     {
00083         return data_(l, sortColumn_) < data_(r, sortColumn_);
00084     }
00085 };
00086 
00087 template<class LabelArray>
00088 class RandomForestDeprecLabelSorter
00089 {
00090     LabelArray const & labels_;
00091 
00092   public:
00093 
00094     RandomForestDeprecLabelSorter(LabelArray const & labels)
00095     : labels_(labels)
00096     {}
00097 
00098     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00099     {
00100         return labels_[l] < labels_[r];
00101     }
00102 };
00103 
00104 template <class CountArray>
00105 class RandomForestDeprecClassCounter
00106 {
00107     ArrayVector<int> const & labels_;
00108     CountArray & counts_;
00109 
00110   public:
00111 
00112     RandomForestDeprecClassCounter(ArrayVector<int> const & labels, CountArray & counts)
00113     : labels_(labels),
00114       counts_(counts)
00115     {
00116         reset();
00117     }
00118 
00119     void reset()
00120     {
00121         counts_.init(0);
00122     }
00123 
00124     void operator()(MultiArrayIndex l) const
00125     {
00126         ++counts_[labels_[l]];
00127     }
00128 };
00129 
00130 struct DecisionTreeDeprecCountNonzeroFunctor
00131 {
00132     double operator()(double old, double other) const
00133     {
00134         if(other != 0.0)
00135             ++old;
00136         return old;
00137     }
00138 };
00139 
00140 struct DecisionTreeDeprecNode
00141 {
00142     DecisionTreeDeprecNode(int t, MultiArrayIndex bestColumn)
00143     : thresholdIndex(t), splitColumn(bestColumn)
00144     {}
00145 
00146     int children[2];
00147     int thresholdIndex;
00148     Int32 splitColumn;
00149 };
00150 
00151 template <class INT>
00152 struct DecisionTreeDeprecNodeProxy
00153 {
00154     DecisionTreeDeprecNodeProxy(ArrayVector<INT> const & tree, INT n)
00155     : node(const_cast<ArrayVector<INT> &>(tree).begin()+n)
00156     {}
00157 
00158     INT & child(INT l) const
00159     {
00160         return node[l];
00161     }
00162 
00163     INT & decisionWeightsIndex() const
00164     {
00165         return node[2];
00166     }
00167 
00168     typename ArrayVector<INT>::iterator decisionColumns() const
00169     {
00170         return node+3;
00171     }
00172 
00173     mutable typename ArrayVector<INT>::iterator node;
00174 };
00175 
00176 struct DecisionTreeDeprecAxisSplitFunctor
00177 {
00178     ArrayVector<Int32> splitColumns;
00179     ArrayVector<double> classCounts, currentCounts[2], bestCounts[2], classWeights;
00180     double threshold;
00181     double totalCounts[2], bestTotalCounts[2];
00182     int mtry, classCount, bestSplitColumn;
00183     bool pure[2], isWeighted;
00184 
00185     void init(int mtry, int cols, int classCount, ArrayVector<double> const & weights)
00186     {
00187         this->mtry = mtry;
00188         splitColumns.resize(cols);
00189         for(int k=0; k<cols; ++k)
00190             splitColumns[k] = k;
00191 
00192         this->classCount = classCount;
00193         classCounts.resize(classCount);
00194         currentCounts[0].resize(classCount);
00195         currentCounts[1].resize(classCount);
00196         bestCounts[0].resize(classCount);
00197         bestCounts[1].resize(classCount);
00198 
00199         isWeighted = weights.size() > 0;
00200         if(isWeighted)
00201             classWeights = weights;
00202         else
00203             classWeights.resize(classCount, 1.0);
00204     }
00205 
00206     bool isPure(int k) const
00207     {
00208         return pure[k];
00209     }
00210 
00211     unsigned int totalCount(int k) const
00212     {
00213         return (unsigned int)bestTotalCounts[k];
00214     }
00215 
00216     int sizeofNode() const { return 4; }
00217 
00218     int writeSplitParameters(ArrayVector<Int32> & tree,
00219                                 ArrayVector<double> &terminalWeights)
00220     {
00221         int currentWeightIndex = terminalWeights.size();
00222         terminalWeights.push_back(threshold);
00223 
00224         int currentNodeIndex = tree.size();
00225         tree.push_back(-1);  // left child
00226         tree.push_back(-1);  // right child
00227         tree.push_back(currentWeightIndex);
00228         tree.push_back(bestSplitColumn);
00229 
00230         return currentNodeIndex;
00231     }
00232 
00233     void writeWeights(int l, ArrayVector<double> &terminalWeights)
00234     {
00235         for(int k=0; k<classCount; ++k)
00236             terminalWeights.push_back(isWeighted
00237                                            ? bestCounts[l][k]
00238                                            : bestCounts[l][k] / totalCount(l));
00239     }
00240 
00241     template <class U, class C, class AxesIterator, class WeightIterator>
00242     bool decideAtNode(MultiArrayView<2, U, C> const & features,
00243                       AxesIterator a, WeightIterator w) const
00244     {
00245         return (features(0, *a) < *w);
00246     }
00247 
00248     template <class U, class C, class IndexIterator, class Random>
00249     IndexIterator findBestSplit(MultiArrayView<2, U, C> const & features,
00250                                 ArrayVector<int> const & labels,
00251                                 IndexIterator indices, int exampleCount,
00252                                 Random & randint);
00253 
00254 };
00255 
00256 
00257 template <class U, class C, class IndexIterator, class Random>
00258 IndexIterator
00259 DecisionTreeDeprecAxisSplitFunctor::findBestSplit(MultiArrayView<2, U, C> const & features,
00260                                             ArrayVector<int> const & labels,
00261                                             IndexIterator indices, int exampleCount,
00262                                             Random & randint)
00263 {
00264     // select columns to be tried for split
00265     for(int k=0; k<mtry; ++k)
00266         std::swap(splitColumns[k], splitColumns[k+randint(columnCount(features)-k)]);
00267 
00268     RandomForestDeprecFeatureSorter<MultiArrayView<2, U, C> > sorter(features, 0);
00269     RandomForestDeprecClassCounter<ArrayVector<double> > counter(labels, classCounts);
00270     std::for_each(indices, indices+exampleCount, counter);
00271 
00272     // find the best gini index
00273     double minGini = NumericTraits<double>::max();
00274     IndexIterator bestSplit = indices;
00275     for(int k=0; k<mtry; ++k)
00276     {
00277         sorter.setColumn(splitColumns[k]);
00278         std::sort(indices, indices+exampleCount, sorter);
00279 
00280         currentCounts[0].init(0);
00281         std::transform(classCounts.begin(), classCounts.end(), classWeights.begin(),
00282                        currentCounts[1].begin(), std::multiplies<double>());
00283         totalCounts[0] = 0;
00284         totalCounts[1] = std::accumulate(currentCounts[1].begin(), currentCounts[1].end(), 0.0);
00285         for(int m = 0; m < exampleCount-1; ++m)
00286         {
00287             int label = labels[indices[m]];
00288             double w = classWeights[label];
00289             currentCounts[0][label] += w;
00290             totalCounts[0] += w;
00291             currentCounts[1][label] -= w;
00292             totalCounts[1] -= w;
00293 
00294             if (m < exampleCount-2 &&
00295                 features(indices[m], splitColumns[k]) == features(indices[m+1], splitColumns[k]))
00296                 continue ;
00297 
00298             double gini = 0.0;
00299             if(classCount == 2)
00300             {
00301                 gini = currentCounts[0][0]*currentCounts[0][1] / totalCounts[0] +
00302                        currentCounts[1][0]*currentCounts[1][1] / totalCounts[1];
00303             }
00304             else
00305             {
00306                 for(int l=0; l<classCount; ++l)
00307                     gini += currentCounts[0][l]*(1.0 - currentCounts[0][l] / totalCounts[0]) +
00308                             currentCounts[1][l]*(1.0 - currentCounts[1][l] / totalCounts[1]);
00309             }
00310             if(gini < minGini)
00311             {
00312                 minGini = gini;
00313                 bestSplit = indices+m;
00314                 bestSplitColumn = splitColumns[k];
00315                 bestCounts[0] = currentCounts[0];
00316                 bestCounts[1] = currentCounts[1];
00317             }
00318         }
00319 
00320 
00321 
00322     }
00323         //std::cerr << minGini << " " << bestSplitColumn << std::endl;
00324     // split using the best feature
00325     sorter.setColumn(bestSplitColumn);
00326     std::sort(indices, indices+exampleCount, sorter);
00327 
00328     for(int k=0; k<2; ++k)
00329     {
00330         bestTotalCounts[k] = std::accumulate(bestCounts[k].begin(), bestCounts[k].end(), 0.0);
00331     }
00332 
00333     threshold = (features(bestSplit[0], bestSplitColumn) + features(bestSplit[1], bestSplitColumn)) / 2.0;
00334     ++bestSplit;
00335 
00336     counter.reset();
00337     std::for_each(indices, bestSplit, counter);
00338     pure[0] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeDeprecCountNonzeroFunctor());
00339     counter.reset();
00340     std::for_each(bestSplit, indices+exampleCount, counter);
00341     pure[1] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeDeprecCountNonzeroFunctor());
00342 
00343     return bestSplit;
00344 }
00345 
00346 enum  { DecisionTreeDeprecNoParent = -1 };
00347 
00348 template <class Iterator>
00349 struct DecisionTreeDeprecStackEntry
00350 {
00351     DecisionTreeDeprecStackEntry(Iterator i, int c,
00352                            int lp = DecisionTreeDeprecNoParent, int rp = DecisionTreeDeprecNoParent)
00353     : indices(i), exampleCount(c),
00354       leftParent(lp), rightParent(rp)
00355     {}
00356 
00357     Iterator indices;
00358     int exampleCount, leftParent, rightParent;
00359 };
00360 
00361 class DecisionTreeDeprec
00362 {
00363   public:
00364     typedef Int32 TreeInt;
00365     ArrayVector<TreeInt>  tree_;
00366     ArrayVector<double> terminalWeights_;
00367     unsigned int classCount_;
00368     DecisionTreeDeprecAxisSplitFunctor split;
00369 
00370   public:
00371 
00372 
00373     DecisionTreeDeprec(unsigned int classCount)
00374     : classCount_(classCount)
00375     {}
00376 
00377     void reset(unsigned int classCount = 0)
00378     {
00379         if(classCount)
00380             classCount_ = classCount;
00381         tree_.clear();
00382         terminalWeights_.clear();
00383     }
00384 
00385     template <class U, class C, class Iterator, class Options, class Random>
00386     void learn(MultiArrayView<2, U, C> const & features,
00387                ArrayVector<int> const & labels,
00388                Iterator indices, int exampleCount,
00389                Options const & options,
00390                Random & randint);
00391 
00392     template <class U, class C>
00393     ArrayVector<double>::const_iterator
00394     predict(MultiArrayView<2, U, C> const & features) const
00395     {
00396         int nodeindex = 0;
00397         for(;;)
00398         {
00399             DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, nodeindex);
00400             nodeindex = split.decideAtNode(features, node.decisionColumns(),
00401                                        terminalWeights_.begin() + node.decisionWeightsIndex())
00402                                 ? node.child(0)
00403                                 : node.child(1);
00404             if(nodeindex <= 0)
00405                 return terminalWeights_.begin() + (-nodeindex);
00406         }
00407     }
00408 
00409     template <class U, class C>
00410     int
00411     predictLabel(MultiArrayView<2, U, C> const & features) const
00412     {
00413         ArrayVector<double>::const_iterator weights = predict(features);
00414         return argMax(weights, weights+classCount_) - weights;
00415     }
00416 
00417     template <class U, class C>
00418     int
00419     leafID(MultiArrayView<2, U, C> const & features) const
00420     {
00421         int nodeindex = 0;
00422         for(;;)
00423         {
00424             DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, nodeindex);
00425             nodeindex = split.decideAtNode(features, node.decisionColumns(),
00426                                        terminalWeights_.begin() + node.decisionWeightsIndex())
00427                                 ? node.child(0)
00428                                 : node.child(1);
00429             if(nodeindex <= 0)
00430                 return -nodeindex;
00431         }
00432     }
00433 
00434     void depth(int & maxDep, int & interiorCount, int & leafCount, int k = 0, int d = 1) const
00435     {
00436         DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, k);
00437         ++interiorCount;
00438         ++d;
00439         for(int l=0; l<2; ++l)
00440         {
00441             int child = node.child(l);
00442             if(child > 0)
00443                 depth(maxDep, interiorCount, leafCount, child, d);
00444             else
00445             {
00446                 ++leafCount;
00447                 if(maxDep < d)
00448                     maxDep = d;
00449             }
00450         }
00451     }
00452 
00453     void printStatistics(std::ostream & o) const
00454     {
00455         int maxDep = 0, interiorCount = 0, leafCount = 0;
00456         depth(maxDep, interiorCount, leafCount);
00457 
00458         o << "interior nodes: " << interiorCount <<
00459              ", terminal nodes: " << leafCount <<
00460              ", depth: " << maxDep << "\n";
00461     }
00462 
00463     void print(std::ostream & o, int k = 0, std::string s = "") const
00464     {
00465         DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, k);
00466         o << s << (*node.decisionColumns()) << " " << terminalWeights_[node.decisionWeightsIndex()] << "\n";
00467 
00468         for(int l=0; l<2; ++l)
00469         {
00470             int child = node.child(l);
00471             if(child <= 0)
00472                 o << s << " weights " << terminalWeights_[-child] << " "
00473                                       << terminalWeights_[-child+1] << "\n";
00474             else
00475                 print(o, child, s+" ");
00476         }
00477     }
00478 };
00479 
00480 
00481 template <class U, class C, class Iterator, class Options, class Random>
00482 void DecisionTreeDeprec::learn(MultiArrayView<2, U, C> const & features,
00483                           ArrayVector<int> const & labels,
00484                           Iterator indices, int exampleCount,
00485                           Options const & options,
00486                           Random & randint)
00487 {
00488     ArrayVector<double> const & classLoss = options.class_weights;
00489 
00490     vigra_precondition(classLoss.size() == 0 || classLoss.size() == classCount_,
00491         "DecisionTreeDeprec2::learn(): class weights array has wrong size.");
00492 
00493     reset();
00494 
00495     unsigned int mtry = options.mtry;
00496     MultiArrayIndex cols = columnCount(features);
00497 
00498     split.init(mtry, cols, classCount_, classLoss);
00499 
00500     typedef DecisionTreeDeprecStackEntry<Iterator> Entry;
00501     ArrayVector<Entry> stack;
00502     stack.push_back(Entry(indices, exampleCount));
00503 
00504     while(!stack.empty())
00505     {
00506 //        std::cerr << "*";
00507         indices = stack.back().indices;
00508         exampleCount = stack.back().exampleCount;
00509         int leftParent  = stack.back().leftParent,
00510             rightParent = stack.back().rightParent;
00511 
00512         stack.pop_back();
00513 
00514         Iterator bestSplit = split.findBestSplit(features, labels, indices, exampleCount, randint);
00515 
00516 
00517         int currentNode = split.writeSplitParameters(tree_, terminalWeights_);
00518 
00519         if(leftParent != DecisionTreeDeprecNoParent)
00520             DecisionTreeDeprecNodeProxy<TreeInt>(tree_, leftParent).child(0) = currentNode;
00521         if(rightParent != DecisionTreeDeprecNoParent)
00522             DecisionTreeDeprecNodeProxy<TreeInt>(tree_, rightParent).child(1) = currentNode;
00523         leftParent = currentNode;
00524         rightParent = DecisionTreeDeprecNoParent;
00525 
00526         for(int l=0; l<2; ++l)
00527         {
00528 
00529             if(!split.isPure(l) && split.totalCount(l) >= options.min_split_node_size)
00530             {
00531                 // sample is still large enough and not yet perfectly separated => split
00532                 stack.push_back(Entry(indices, split.totalCount(l), leftParent, rightParent));
00533             }
00534             else
00535             {
00536                 DecisionTreeDeprecNodeProxy<TreeInt>(tree_, currentNode).child(l) = -(TreeInt)terminalWeights_.size();
00537 
00538                 split.writeWeights(l, terminalWeights_);
00539             }
00540             std::swap(leftParent, rightParent);
00541             indices = bestSplit;
00542         }
00543     }
00544 //    std::cerr << "\n";
00545 }
00546 
00547 } // namespace detail
00548 
00549 class RandomForestOptionsDeprec
00550 {
00551   public:
00552         /** Initialize all options with default values.
00553         */
00554     RandomForestOptionsDeprec()
00555     : training_set_proportion(1.0),
00556       mtry(0),
00557       min_split_node_size(1),
00558       training_set_size(0),
00559       sample_with_replacement(true),
00560       sample_classes_individually(false),
00561       treeCount(255)
00562     {}
00563 
00564         /** Number of features considered in each node.
00565 
00566             If \a n is 0 (the default), the number of features tried in every node
00567             is determined by the square root of the total number of features.
00568             According to Breiman, this quantity should always be optimized by means
00569             of the out-of-bag error.<br>
00570             Default: 0 (use <tt>sqrt(columnCount(featureMatrix))</tt>)
00571         */
00572     RandomForestOptionsDeprec & featuresPerNode(unsigned int n)
00573     {
00574         mtry = n;
00575         return *this;
00576     }
00577 
00578         /** How to sample the subset of the training data for each tree.
00579 
00580             Each tree is only trained with a subset of the entire training data.
00581             If \a r is <tt>true</tt>, this subset is sampled from the entire training set with
00582             replacement.<br>
00583             Default: <tt>true</tt> (use sampling with replacement))
00584         */
00585     RandomForestOptionsDeprec & sampleWithReplacement(bool r)
00586     {
00587         sample_with_replacement = r;
00588         return *this;
00589     }
00590 
00591     RandomForestOptionsDeprec & setTreeCount(unsigned int cnt)
00592     {
00593         treeCount = cnt;
00594         return *this;
00595     }
00596         /** Proportion of training examples used for each tree.
00597 
00598             If \a p is 1.0 (the default), and samples are drawn with replacement,
00599             the training set of each tree will contain as many examples as the entire
00600             training set, but some are drawn multiply and others not at all. On average,
00601             each tree is actually trained on about 65% of the examples in the full
00602             training set. Changing the proportion makes mainly sense when
00603             sampleWithReplacement() is set to <tt>false</tt>. trainingSetSizeProportional() gets
00604             overridden by trainingSetSizeAbsolute().<br>
00605             Default: 1.0
00606         */
00607     RandomForestOptionsDeprec & trainingSetSizeProportional(double p)
00608     {
00609         vigra_precondition(p >= 0.0 && p <= 1.0,
00610             "RandomForestOptionsDeprec::trainingSetSizeProportional(): proportion must be in [0, 1].");
00611         if(training_set_size == 0) // otherwise, absolute size gets priority
00612             training_set_proportion = p;
00613         return *this;
00614     }
00615 
00616         /** Size of the training set for each tree.
00617 
00618             If this option is set, it overrides the proportion set by
00619             trainingSetSizeProportional(). When classes are sampled individually,
00620             the number of examples is divided by the number of classes (rounded upwards)
00621             to determine the number of examples drawn from every class.<br>
00622             Default: <tt>0</tt> (determine size by proportion)
00623         */
00624     RandomForestOptionsDeprec & trainingSetSizeAbsolute(unsigned int s)
00625     {
00626         training_set_size = s;
00627         if(s > 0)
00628             training_set_proportion = 0.0;
00629         return *this;
00630     }
00631 
00632         /** Are the classes sampled individually?
00633 
00634             If \a s is <tt>false</tt> (the default), the training set for each tree is sampled
00635             without considering class labels. Otherwise, samples are drawn from each
00636             class independently. The latter is especially useful in connection
00637             with the specification of an absolute training set size: then, the same number of
00638             examples is drawn from every class. This can be used as a counter-measure when the
00639             classes are very unbalanced in size.<br>
00640             Default: <tt>false</tt>
00641         */
00642     RandomForestOptionsDeprec & sampleClassesIndividually(bool s)
00643     {
00644         sample_classes_individually = s;
00645         return *this;
00646     }
00647 
00648         /** Number of examples required for a node to be split.
00649 
00650             When the number of examples in a node is below this number, the node is not
00651             split even if class separation is not yet perfect. Instead, the node returns
00652             the proportion of each class (among the remaining examples) during the
00653             prediction phase.<br>
00654             Default: 1 (complete growing)
00655         */
00656     RandomForestOptionsDeprec & minSplitNodeSize(unsigned int n)
00657     {
00658         if(n == 0)
00659             n = 1;
00660         min_split_node_size = n;
00661         return *this;
00662     }
00663 
00664         /** Use a weighted random forest.
00665 
00666             This is usually used to penalize the errors for the minority class.
00667             Weights must be convertible to <tt>double</tt>, and the array of weights
00668             must contain as many entries as there are classes.<br>
00669             Default: do not use weights
00670         */
00671     template <class WeightIterator>
00672     RandomForestOptionsDeprec & weights(WeightIterator weights, unsigned int classCount)
00673     {
00674         class_weights.clear();
00675         if(weights != 0)
00676             class_weights.insert(weights, classCount);
00677         return *this;
00678     }
00679 
00680     RandomForestOptionsDeprec & oobData(MultiArrayView<2, UInt8>& data)
00681     {
00682         oob_data =data;
00683         return *this;
00684     }
00685 
00686     MultiArrayView<2, UInt8> oob_data;
00687     ArrayVector<double> class_weights;
00688     double training_set_proportion;
00689     unsigned int mtry, min_split_node_size, training_set_size;
00690     bool sample_with_replacement, sample_classes_individually;
00691     unsigned int treeCount;
00692 };
00693 
00694 /*****************************************************************/
00695 /*                                                               */
00696 /*                          RandomForestDeprec                   */
00697 /*                                                               */
00698 /*****************************************************************/
00699 
00700 template <class ClassLabelType>
00701 class RandomForestDeprec
00702 {
00703   public:
00704     ArrayVector<ClassLabelType> classes_;
00705     ArrayVector<detail::DecisionTreeDeprec> trees_;
00706     MultiArrayIndex columnCount_;
00707     RandomForestOptionsDeprec options_;
00708 
00709   public:
00710 
00711     //First two constructors are straight forward.
00712     //they take either the iterators to an Array of Classlabels or the values
00713     template<class ClassLabelIterator>
00714     RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
00715                   unsigned int treeCount = 255,
00716                   RandomForestOptionsDeprec const & options = RandomForestOptionsDeprec())
00717     : classes_(cl, cend),
00718       trees_(treeCount, detail::DecisionTreeDeprec(classes_.size())),
00719       columnCount_(0),
00720       options_(options)
00721     {
00722         vigra_precondition(options.training_set_proportion == 0.0 ||
00723                            options.training_set_size == 0,
00724             "RandomForestOptionsDeprec: absolute and proportional training set sizes "
00725             "cannot be specified at the same time.");
00726         vigra_precondition(classes_.size() > 1,
00727             "RandomForestOptionsDeprec::weights(): need at least two classes.");
00728         vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(),
00729             "RandomForestOptionsDeprec::weights(): wrong number of classes.");
00730     }
00731 
00732     RandomForestDeprec(ClassLabelType const & c1, ClassLabelType const & c2,
00733                   unsigned int treeCount = 255,
00734                   RandomForestOptionsDeprec const & options = RandomForestOptionsDeprec())
00735     : classes_(2),
00736       trees_(treeCount, detail::DecisionTreeDeprec(2)),
00737       columnCount_(0),
00738       options_(options)
00739     {
00740         vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == 2,
00741             "RandomForestOptionsDeprec::weights(): wrong number of classes.");
00742         classes_[0] = c1;
00743         classes_[1] = c2;
00744     }
00745     //This is esp. For the CrosValidator Class
00746     template<class ClassLabelIterator>
00747     RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
00748                   RandomForestOptionsDeprec const & options )
00749     : classes_(cl, cend),
00750       trees_(options.treeCount , detail::DecisionTreeDeprec(classes_.size())),
00751       columnCount_(0),
00752       options_(options)
00753     {
00754 
00755         vigra_precondition(options.training_set_proportion == 0.0 ||
00756                            options.training_set_size == 0,
00757             "RandomForestOptionsDeprec: absolute and proportional training set sizes "
00758             "cannot be specified at the same time.");
00759         vigra_precondition(classes_.size() > 1,
00760             "RandomForestOptionsDeprec::weights(): need at least two classes.");
00761         vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(),
00762             "RandomForestOptionsDeprec::weights(): wrong number of classes.");
00763     }
00764 
00765     //Not understood yet
00766     //Does not use the options object but the columnCount object.
00767     template<class ClassLabelIterator, class TreeIterator, class WeightIterator>
00768     RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
00769                   unsigned int treeCount, unsigned int columnCount,
00770                   TreeIterator trees, WeightIterator weights)
00771     : classes_(cl, cend),
00772       trees_(treeCount, detail::DecisionTreeDeprec(classes_.size())),
00773       columnCount_(columnCount)
00774     {
00775         for(unsigned int k=0; k<treeCount; ++k, ++trees, ++weights)
00776         {
00777             trees_[k].tree_ = *trees;
00778             trees_[k].terminalWeights_ = *weights;
00779         }
00780     }
00781 
00782     int featureCount() const
00783     {
00784         vigra_precondition(columnCount_ > 0,
00785            "RandomForestDeprec::featureCount(): Random forest has not been trained yet.");
00786         return columnCount_;
00787     }
00788 
00789     int labelCount() const
00790     {
00791         return classes_.size();
00792     }
00793 
00794     int treeCount() const
00795     {
00796         return trees_.size();
00797     }
00798 
00799     // loss == 0.0 means unweighted random forest
00800     template <class U, class C, class Array, class Random>
00801     double learn(MultiArrayView<2, U, C> const & features, Array const & labels,
00802                Random const& random);
00803 
00804     template <class U, class C, class Array>
00805     double learn(MultiArrayView<2, U, C> const & features, Array const & labels)
00806     {
00807         RandomNumberGenerator<> generator(RandomSeed);
00808         return learn(features, labels, generator);
00809     }
00810 
00811     template <class U, class C>
00812     ClassLabelType predictLabel(MultiArrayView<2, U, C> const & features) const;
00813 
00814     template <class U, class C1, class T, class C2>
00815     void predictLabels(MultiArrayView<2, U, C1> const & features,
00816                        MultiArrayView<2, T, C2> & labels) const
00817     {
00818         vigra_precondition(features.shape(0) == labels.shape(0),
00819             "RandomForestDeprec::predictLabels(): Label array has wrong size.");
00820         for(int k=0; k<features.shape(0); ++k)
00821             labels(k,0) = predictLabel(rowVector(features, k));
00822     }
00823 
00824     template <class U, class C, class Iterator>
00825     ClassLabelType predictLabel(MultiArrayView<2, U, C> const & features,
00826                                 Iterator priors) const;
00827 
00828     template <class U, class C1, class T, class C2>
00829     void predictProbabilities(MultiArrayView<2, U, C1> const & features,
00830                               MultiArrayView<2, T, C2> & prob) const;
00831 
00832     template <class U, class C1, class T, class C2>
00833     void predictNodes(MultiArrayView<2, U, C1> const & features,
00834                                                    MultiArrayView<2, T, C2> & NodeIDs) const;
00835 };
00836 
00837 template <class ClassLabelType>
00838 template <class U, class C1, class Array, class Random>
00839 double
00840 RandomForestDeprec<ClassLabelType>::learn(MultiArrayView<2, U, C1> const & features,
00841                                              Array const & labels,
00842                                              Random const& random)
00843 {
00844     unsigned int classCount = classes_.size();
00845     unsigned int m = rowCount(features);
00846     unsigned int n = columnCount(features);
00847     vigra_precondition((unsigned int)(m) == (unsigned int)labels.size(),
00848       "RandomForestDeprec::learn(): Label array has wrong size.");
00849 
00850     vigra_precondition(options_.training_set_size <= m || options_.sample_with_replacement,
00851        "RandomForestDeprec::learn(): Requested training set size exceeds total number of examples.");
00852 
00853     MultiArrayIndex mtry = (options_.mtry == 0)
00854                                 ? int(std::floor(std::sqrt(double(n)) + 0.5))
00855                                 : options_.mtry;
00856 
00857     vigra_precondition(mtry <= (MultiArrayIndex)n,
00858        "RandomForestDeprec::learn(): mtry must be less than number of features.");
00859 
00860     MultiArrayIndex msamples = options_.training_set_size;
00861     if(options_.sample_classes_individually)
00862         msamples = int(std::ceil(double(msamples) / classCount));
00863 
00864     ArrayVector<int> intLabels(m), classExampleCounts(classCount);
00865 
00866     // verify the input labels
00867     int minClassCount;
00868     {
00869         typedef std::map<ClassLabelType, int > LabelChecker;
00870         typedef typename LabelChecker::iterator LabelCheckerIterator;
00871         LabelChecker labelChecker;
00872         for(unsigned int k=0; k<classCount; ++k)
00873             labelChecker[classes_[k]] = k;
00874 
00875         for(unsigned int k=0; k<m; ++k)
00876         {
00877             LabelCheckerIterator found = labelChecker.find(labels[k]);
00878             vigra_precondition(found != labelChecker.end(),
00879                 "RandomForestDeprec::learn(): Unknown class label encountered.");
00880             intLabels[k] = found->second;
00881             ++classExampleCounts[intLabels[k]];
00882         }
00883         minClassCount = *argMin(classExampleCounts.begin(), classExampleCounts.end());
00884         vigra_precondition(minClassCount > 0,
00885              "RandomForestDeprec::learn(): At least one class is missing in the training set.");
00886         if(msamples > 0 && options_.sample_classes_individually &&
00887                           !options_.sample_with_replacement)
00888         {
00889             vigra_precondition(msamples <= minClassCount,
00890                 "RandomForestDeprec::learn(): Too few examples in smallest class to reach "
00891                 "requested training set size.");
00892         }
00893     }
00894     columnCount_ = n;
00895     ArrayVector<int> indices(m);
00896     for(unsigned int k=0; k<m; ++k)
00897         indices[k] = k;
00898 
00899     if(options_.sample_classes_individually)
00900     {
00901         detail::RandomForestDeprecLabelSorter<ArrayVector<int> > sorter(intLabels);
00902         std::sort(indices.begin(), indices.end(), sorter);
00903     }
00904 
00905     ArrayVector<int> usedIndices(m), oobCount(m), oobErrorCount(m);
00906 
00907     UniformIntRandomFunctor<Random> randint(0, m-1, random);
00908     //std::cerr << "Learning a RF \n";
00909     for(unsigned int k=0; k<trees_.size(); ++k)
00910     {
00911        //std::cerr << "Learning tree " << k << " ...\n";
00912 
00913         ArrayVector<int> trainingSet;
00914         usedIndices.init(0);
00915 
00916         if(options_.sample_classes_individually)
00917         {
00918             int first = 0;
00919             for(unsigned int l=0; l<classCount; ++l)
00920             {
00921                 int lc = classExampleCounts[l];
00922                 int lsamples = (msamples == 0)
00923                                    ? int(std::ceil(options_.training_set_proportion*lc))
00924                                    : msamples;
00925 
00926                 if(options_.sample_with_replacement)
00927                 {
00928                     for(int ll=0; ll<lsamples; ++ll)
00929                     {
00930                         trainingSet.push_back(indices[first+randint(lc)]);
00931                         ++usedIndices[trainingSet.back()];
00932                     }
00933                 }
00934                 else
00935                 {
00936                     for(int ll=0; ll<lsamples; ++ll)
00937                     {
00938                         std::swap(indices[first+ll], indices[first+ll+randint(lc-ll)]);
00939                         trainingSet.push_back(indices[first+ll]);
00940                         ++usedIndices[trainingSet.back()];
00941                     }
00942                     //std::sort(indices.begin(), indices.begin()+lsamples);
00943                 }
00944                 first += lc;
00945             }
00946         }
00947         else
00948         {
00949             if(msamples == 0)
00950                 msamples = int(std::ceil(options_.training_set_proportion*m));
00951 
00952             if(options_.sample_with_replacement)
00953             {
00954                 for(int l=0; l<msamples; ++l)
00955                 {
00956                     trainingSet.push_back(indices[randint(m)]);
00957                     ++usedIndices[trainingSet.back()];
00958                 }
00959             }
00960             else
00961             {
00962                 for(int l=0; l<msamples; ++l)
00963                 {
00964                     std::swap(indices[l], indices[l+randint(m-l)/*oikas*/]);
00965                     trainingSet.push_back(indices[l]);
00966                     ++usedIndices[trainingSet.back()];
00967                 }
00968 
00969 
00970             }
00971 
00972         }
00973         trees_[k].learn(features, intLabels,
00974                         trainingSet.begin(), trainingSet.size(),
00975                         options_.featuresPerNode(mtry), randint);
00976 //        for(unsigned int l=0; l<m; ++l)
00977 //        {
00978 //            if(!usedIndices[l])
00979 //            {
00980 //                ++oobCount[l];
00981 //                if(trees_[k].predictLabel(rowVector(features, l)) != intLabels[l])
00982 //                    ++oobErrorCount[l];
00983 //            }
00984 //        }
00985 
00986         for(unsigned int l=0; l<m; ++l)
00987         {
00988             if(!usedIndices[l])
00989             {
00990                 ++oobCount[l];
00991                 if(trees_[k].predictLabel(rowVector(features, l)) != intLabels[l])
00992                 {
00993                     ++oobErrorCount[l];
00994                     if(options_.oob_data.data() != 0)
00995                         options_.oob_data(l, k) = 2;
00996                 }
00997                 else if(options_.oob_data.data() != 0)
00998                 {
00999                     options_.oob_data(l, k) = 1;
01000                 }
01001             }
01002         }
01003         // TODO: default value for oob_data
01004         // TODO: implement variable importance
01005         //if(!options_.sample_with_replacement){
01006         //std::cerr << "done\n";
01007         //trees_[k].print(std::cerr);
01008         #ifdef VIGRA_RF_VERBOSE
01009         trees_[k].printStatistics(std::cerr);
01010         #endif
01011     }
01012     double oobError = 0.0;
01013     int totalOobCount = 0;
01014     for(unsigned int l=0; l<m; ++l)
01015         if(oobCount[l])
01016         {
01017             oobError += double(oobErrorCount[l]) / oobCount[l];
01018             ++totalOobCount;
01019         }
01020     return oobError / totalOobCount;
01021 }
01022 
01023 template <class ClassLabelType>
01024 template <class U, class C>
01025 ClassLabelType
01026 RandomForestDeprec<ClassLabelType>::predictLabel(MultiArrayView<2, U, C> const & features) const
01027 {
01028     vigra_precondition(columnCount(features) >= featureCount(),
01029         "RandomForestDeprec::predictLabel(): Too few columns in feature matrix.");
01030     vigra_precondition(rowCount(features) == 1,
01031         "RandomForestDeprec::predictLabel(): Feature matrix must have a single row.");
01032     Matrix<double> prob(1, classes_.size());
01033     predictProbabilities(features, prob);
01034     return classes_[argMax(prob)];
01035 }
01036 
01037 
01038 //Same thing as above with priors for each label !!!
01039 template <class ClassLabelType>
01040 template <class U, class C, class Iterator>
01041 ClassLabelType
01042 RandomForestDeprec<ClassLabelType>::predictLabel(MultiArrayView<2, U, C> const & features,
01043                                            Iterator priors) const
01044 {
01045     using namespace functor;
01046     vigra_precondition(columnCount(features) >= featureCount(),
01047         "RandomForestDeprec::predictLabel(): Too few columns in feature matrix.");
01048     vigra_precondition(rowCount(features) == 1,
01049         "RandomForestDeprec::predictLabel(): Feature matrix must have a single row.");
01050     Matrix<double> prob(1,classes_.size());
01051     predictProbabilities(features, prob);
01052     std::transform(prob.begin(), prob.end(), priors, prob.begin(), Arg1()*Arg2());
01053     return classes_[argMax(prob)];
01054 }
01055 
01056 template <class ClassLabelType>
01057 template <class U, class C1, class T, class C2>
01058 void
01059 RandomForestDeprec<ClassLabelType>::predictProbabilities(MultiArrayView<2, U, C1> const & features,
01060                                                    MultiArrayView<2, T, C2> & prob) const
01061 {
01062 
01063     //Features are n xp
01064     //prob is n x NumOfLabel probability for each feature in each class
01065 
01066     vigra_precondition(rowCount(features) == rowCount(prob),
01067       "RandomForestDeprec::predictProbabilities(): Feature matrix and probability matrix size mismatch.");
01068 
01069     // num of features must be bigger than num of features in Random forest training
01070     // but why bigger?
01071     vigra_precondition(columnCount(features) >= featureCount(),
01072       "RandomForestDeprec::predictProbabilities(): Too few columns in feature matrix.");
01073     vigra_precondition(columnCount(prob) == (MultiArrayIndex)labelCount(),
01074       "RandomForestDeprec::predictProbabilities(): Probability matrix must have as many columns as there are classes.");
01075 
01076     //Classify for each row.
01077     for(int row=0; row < rowCount(features); ++row)
01078     {
01079     //contains the weights returned by a single tree???
01080     //thought that one tree has only one vote???
01081     //Pruning???
01082         ArrayVector<double>::const_iterator weights;
01083 
01084         //totalWeight == totalVoteCount!
01085     double totalWeight = 0.0;
01086 
01087     //Set each VoteCount = 0 - prob(row,l) contains vote counts until
01088     //further normalisation
01089         for(unsigned int l=0; l<classes_.size(); ++l)
01090             prob(row, l) = 0.0;
01091 
01092     //Let each tree classify...
01093         for(unsigned int k=0; k<trees_.size(); ++k)
01094         {
01095         //get weights predicted by single tree
01096             weights = trees_[k].predict(rowVector(features, row));
01097 
01098         //update votecount.
01099             for(unsigned int l=0; l<classes_.size(); ++l)
01100             {
01101                 prob(row, l) += detail::RequiresExplicitCast<T>::cast(weights[l]);
01102                 //every weight in totalWeight.
01103                 totalWeight += weights[l];
01104             }
01105         }
01106 
01107     //Normalise votes in each row by total VoteCount (totalWeight
01108         for(unsigned int l=0; l<classes_.size(); ++l)
01109                 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
01110     }
01111 }
01112 
01113 
01114 template <class ClassLabelType>
01115 template <class U, class C1, class T, class C2>
01116 void
01117 RandomForestDeprec<ClassLabelType>::predictNodes(MultiArrayView<2, U, C1> const & features,
01118                                                    MultiArrayView<2, T, C2> & NodeIDs) const
01119 {
01120     vigra_precondition(columnCount(features) >= featureCount(),
01121       "RandomForestDeprec::getNodesRF(): Too few columns in feature matrix.");
01122     vigra_precondition(rowCount(features) <= rowCount(NodeIDs),
01123       "RandomForestDeprec::getNodesRF(): Too few rows in NodeIds matrix");
01124     vigra_precondition(columnCount(NodeIDs) >= treeCount(),
01125       "RandomForestDeprec::getNodesRF(): Too few columns in NodeIds matrix.");
01126     NodeIDs.init(0);
01127     for(unsigned int k=0; k<trees_.size(); ++k)
01128     {
01129         for(int row=0; row < rowCount(features); ++row)
01130         {
01131             NodeIDs(row,k) = trees_[k].leafID(rowVector(features, row));
01132         }
01133     }
01134 }
01135 
01136 //@}
01137 
01138 } // namespace vigra
01139 
01140 
01141 #endif // VIGRA_RANDOM_FOREST_HXX
01142 

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.9.0 (Tue Nov 6 2012)