[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest/rf_online_prediction_set.hxx | ![]() |
00001 #include "../multi_array.hxx" 00002 #include <set> 00003 #include <vector> 00004 00005 namespace vigra 00006 { 00007 00008 template<class T> 00009 struct SampleRange 00010 { 00011 SampleRange(int start,int end,int num_features) 00012 { 00013 this->start=start; 00014 this->end=end; 00015 this->min_boundaries.resize(num_features,-FLT_MAX); 00016 this->max_boundaries.resize(num_features,FLT_MAX); 00017 } 00018 00019 int start; 00020 mutable int end; 00021 mutable std::vector<T> max_boundaries; 00022 mutable std::vector<T> min_boundaries; 00023 00024 bool operator<(const SampleRange& o) const 00025 { 00026 return o.start<start; 00027 } 00028 }; 00029 00030 template<class T> 00031 class OnlinePredictionSet 00032 { 00033 public: 00034 template<class U> 00035 OnlinePredictionSet(MultiArrayView<2,T,U>& features,int num_sets) 00036 { 00037 this->features=features; 00038 std::vector<int> init(features.shape(0)); 00039 for(unsigned int i=0;i<init.size();++i) 00040 init[i]=i; 00041 indices.resize(num_sets,init); 00042 std::set<SampleRange<T> > set_init; 00043 set_init.insert(SampleRange<T>(0,init.size(),features.shape(1))); 00044 ranges.resize(num_sets,set_init); 00045 cumulativePredTime.resize(num_sets,0); 00046 } 00047 00048 int get_worsed_tree() 00049 { 00050 int result=0; 00051 for(unsigned int i=0;i<cumulativePredTime.size();++i) 00052 { 00053 if(cumulativePredTime[i]>cumulativePredTime[result]) 00054 { 00055 result=i; 00056 } 00057 } 00058 return result; 00059 } 00060 00061 void reset_tree(int index) 00062 { 00063 index=index % ranges.size(); 00064 std::set<SampleRange<T> > set_init; 00065 set_init.insert(SampleRange<T>(0,features.shape(0),features.shape(1))); 00066 ranges[index]=set_init; 00067 cumulativePredTime[index]=0; 00068 } 00069 00070 std::vector<std::set<SampleRange<T> > > ranges; 00071 std::vector<std::vector<int> > indices; 00072 std::vector<int> cumulativePredTime; 00073 MultiArray<2,T> features; 00074 }; 00075 00076 } 00077
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|