35#ifndef VIGRA_RF3_VISITORS_HXX
36#define VIGRA_RF3_VISITORS_HXX
40#include "../multi_array.hxx"
41#include "../multi_shape.hxx"
89 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
98 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
105 template <
typename RF,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
115 template <
typename TREE,
179 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
186 double const EPS = 1e-20;
190 is_in_bag_.resize(weights.size(),
true);
191 for (
size_t i = 0; i < weights.size(); ++i)
193 if (std::abs(weights[i]) < EPS)
195 is_in_bag_[i] =
false;
201 throw std::runtime_error(
"OOBError::visit_before_tree(): The tree has no out-of-bags.");
207 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
211 const FEATURES & features,
212 const LABELS & labels
215 vigra_precondition(rf.num_trees() > 0,
"OOBError::visit_after_training(): Number of trees must be greater than zero after training.");
216 vigra_precondition(visitors.size() == rf.num_trees(),
"OOBError::visit_after_training(): Number of visitors must be equal to number of trees.");
217 size_t const num_instances = features.shape()[0];
218 auto const num_features = features.shape()[1];
219 for (
auto vptr : visitors)
220 vigra_precondition(vptr->is_in_bag_.size() == num_instances,
"OOBError::visit_after_training(): Some visitors have the wrong number of data points.");
223 typedef typename std::remove_const<LABELS>::type Labels;
224 Labels pred(Shape1(1));
226 for (
size_t i = 0; i < (size_t)num_instances; ++i)
229 std::vector<size_t> tree_indices;
230 for (
size_t k = 0; k < visitors.size(); ++k)
231 if (!visitors[k]->is_in_bag_[i])
232 tree_indices.push_back(k);
235 auto const sub_features = features.subarray(Shape2(i, 0), Shape2(i+1, num_features));
236 rf.predict(sub_features, pred, 1, tree_indices);
237 if (pred(0) != labels(i))
249 std::vector<bool> is_in_bag_;
257class VariableImportance :
public RFVisitorBase
261 VariableImportance(
size_t repetition_count = 10)
269 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
279 auto const num_features = features.shape()[1];
283 double const EPS = 1e-20;
285 is_in_bag_.resize(weights.size(),
true);
286 for (
size_t i = 0; i < weights.size(); ++i)
288 if (std::abs(weights[i]) < EPS)
290 is_in_bag_[i] =
false;
295 throw std::runtime_error(
"VariableImportance::visit_before_tree(): The tree has no out-of-bags.");
301 template <
typename TREE,
317 typename SCORER::Functor functor;
318 auto const region_impurity = functor.region_score(labels, weights, begin, end);
319 auto const split_impurity = scorer.best_score_;
320 variable_importance_(scorer.best_dim_, tree.num_classes()+1) += region_impurity - split_impurity;
326 template <
typename RF,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
328 const FEATURES & features,
329 const LABELS & labels,
333 typedef typename std::remove_const<FEATURES>::type Features;
334 typedef typename std::remove_const<LABELS>::type Labels;
336 typedef typename Features::value_type FeatureType;
338 auto const num_features = features.shape()[1];
345 copy_out_of_bags(features, labels, feats, labs);
346 auto const num_oobs = feats.shape()[0];
351 rf.predict(feats, pred, 1);
352 for (
size_t i = 0; i < (size_t)labs.size(); ++i)
354 if (labs(i) == pred(i))
356 oob_right(labs(i)) += 1.0;
357 oob_right(rf.num_classes()) += 1.0;
363 for (
size_t j = 0; j < (size_t)num_features; ++j)
366 backup = feats.template bind<1>(j);
372 for (
int ii = num_oobs-1; ii >= 1; --ii)
373 std::swap(feats(ii, j), feats(randint(ii+1), j));
376 rf.predict(feats, pred, 1);
377 for (
size_t i = 0; i < (size_t)labs.size(); ++i)
379 if (labs(i) == pred(i))
381 perm_oob_right(0, labs(i)) += 1.0;
382 perm_oob_right(0, rf.num_classes()) += 1.0;
389 perm_oob_right.bind<0>(0) -= oob_right;
390 perm_oob_right *= -1;
391 perm_oob_right /= num_oobs;
395 feats.template bind<1>(j) = backup;
402 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
406 const FEATURES & features,
409 vigra_precondition(rf.num_trees() > 0,
"VariableImportance::visit_after_training(): Number of trees must be greater than zero after training.");
410 vigra_precondition(visitors.size() == rf.num_trees(),
"VariableImportance::visit_after_training(): Number of visitors must be equal to number of trees.");
413 auto const num_features = features.shape()[1];
415 for (
auto vptr : visitors)
418 "VariableImportance::visit_after_training(): Shape mismatch.");
464 template <
typename F0,
typename L0,
typename F1,
typename L1>
465 void copy_out_of_bags(
466 F0
const & features_in,
467 L0
const & labels_in,
471 auto const num_instances = features_in.shape()[0];
472 auto const num_features = features_in.shape()[1];
476 for (
auto x : is_in_bag_)
481 features_out.reshape(Shape2(num_oobs, num_features));
482 labels_out.reshape(Shape1(num_oobs));
484 for (
size_t i = 0; i < (size_t)num_instances; ++i)
488 auto const src = features_in.template bind<0>(i);
489 auto out = features_out.template bind<0>(current);
491 labels_out(current) = labels_in(i);
497 std::vector<bool> is_in_bag_;
518template <
typename VISITOR,
typename NEXT = RFStopVisiting,
bool CPY = false>
523 typedef VISITOR Visitor;
526 typename std::conditional<CPY, Visitor, Visitor &>::type visitor_;
529 RFVisitorNode(Visitor & visitor, Next next)
535 explicit RFVisitorNode(Visitor & visitor)
541 explicit RFVisitorNode(RFVisitorNode<Visitor, Next, !CPY> & other)
543 visitor_(other.visitor_),
547 explicit RFVisitorNode(RFVisitorNode<Visitor, Next, !CPY>
const & other)
549 visitor_(other.visitor_),
553 void visit_before_training()
555 if (visitor_.is_active())
556 visitor_.visit_before_training();
557 next_.visit_before_training();
560 template <
typename VISITORS,
typename RF,
typename FEATURES,
typename LABELS>
561 void visit_after_training(VISITORS & v, RF & rf,
const FEATURES & features,
const LABELS & labels)
563 typedef typename VISITORS::value_type VisitorNodeType;
564 typedef typename VisitorNodeType::Visitor VisitorType;
565 typedef typename VisitorNodeType::Next NextType;
571 if (visitor_.is_active())
573 std::vector<VisitorType*> visitors;
575 visitors.push_back(&x.visitor_);
576 visitor_.visit_after_training(visitors, rf, features, labels);
580 std::vector<NextType> nexts;
582 nexts.push_back(x.next_);
585 next_.visit_after_training(nexts, rf, features, labels);
588 template <
typename TREE,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
589 void visit_before_tree(TREE & tree, FEATURES & features, LABELS & labels, WEIGHTS & weights)
591 if (visitor_.is_active())
592 visitor_.visit_before_tree(tree, features, labels, weights);
593 next_.visit_before_tree(tree, features, labels, weights);
596 template <
typename RF,
typename FEATURES,
typename LABELS,
typename WEIGHTS>
597 void visit_after_tree(RF & rf,
602 if (visitor_.is_active())
603 visitor_.visit_after_tree(rf, features, labels, weights);
604 next_.visit_after_tree(rf, features, labels, weights);
607 template <
typename TREE,
613 void visit_after_split(TREE & tree,
622 if (visitor_.is_active())
623 visitor_.visit_after_split(tree, features, labels, weights, scorer, begin, split, end);
624 next_.visit_after_split(tree, features, labels, weights, scorer, begin, split, end);
634template <
typename VISITOR>
654detail::RFVisitorNode<A>
657 typedef detail::RFVisitorNode<A> _0_t;
662template<
typename A,
typename B>
663detail::RFVisitorNode<A, detail::RFVisitorNode<B> >
664create_visitor(A & a, B & b)
666 typedef detail::RFVisitorNode<B> _1_t;
668 typedef detail::RFVisitorNode<A, _1_t> _0_t;
673template<
typename A,
typename B,
typename C>
675create_visitor(A & a, B & b, C & c)
686template<
typename A,
typename B,
typename C,
typename D>
689create_visitor(A & a, B & b, C & c, D & d)
702template<
typename A,
typename B,
typename C,
typename D,
typename E>
705create_visitor(A & a, B & b, C & c, D & d, E & e)
720template<
typename A,
typename B,
typename C,
typename D,
typename E,
724create_visitor(A & a, B & b, C & c, D & d, E & e, F & f)
741template<
typename A,
typename B,
typename C,
typename D,
typename E,
742 typename F,
typename G>
746create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g)
765template<
typename A,
typename B,
typename C,
typename D,
typename E,
766 typename F,
typename G,
typename H>
770create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h)
791template<
typename A,
typename B,
typename C,
typename D,
typename E,
792 typename F,
typename G,
typename H,
typename I>
796create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h, I & i)
819template<
typename A,
typename B,
typename C,
typename D,
typename E,
820 typename F,
typename G,
typename H,
typename I,
typename J>
825create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h, I & i,
Main MultiArray class containing the memory management.
Definition multi_array.hxx:2479
Compute the out of bag error.
Definition random_forest_visitors.hxx:173
double oob_err_
Definition random_forest_visitors.hxx:246
void visit_before_tree(TREE &, FEATURES &, LABELS &, WEIGHTS &weights)
Definition random_forest_visitors.hxx:180
void visit_after_training(VISITORS &visitors, RF &rf, const FEATURES &features, const LABELS &labels)
Definition random_forest_visitors.hxx:208
The default visitor node (= "do nothing").
Definition random_forest_visitors.hxx:510
void visit_before_training()
Do something before training starts.
Definition random_forest_visitors.hxx:80
void visit_before_tree(TREE &, FEATURES &, LABELS &, WEIGHTS &)
Do something before a tree has been learned.
Definition random_forest_visitors.hxx:99
void activate()
Activate the visitor.
Definition random_forest_visitors.hxx:142
void deactivate()
Deactivate the visitor.
Definition random_forest_visitors.hxx:150
void visit_after_tree(RF &, FEATURES &, LABELS &, WEIGHTS &)
Do something after a tree has been learned.
Definition random_forest_visitors.hxx:106
void visit_after_split(TREE &, FEATURES &, LABELS &, WEIGHTS &, SCORER &, ITER, ITER, ITER)
Do something after the split was made.
Definition random_forest_visitors.hxx:121
bool is_active() const
Return whether the visitor is active or not.
Definition random_forest_visitors.hxx:134
void visit_after_training(VISITORS &, RF &, const FEATURES &, const LABELS &)
Do something after all trees have been learned.
Definition random_forest_visitors.hxx:90
void visit_after_split(TREE &tree, FEATURES &, LABELS &labels, WEIGHTS &weights, SCORER &scorer, ITER begin, ITER, ITER end)
Definition random_forest_visitors.hxx:307
void visit_after_tree(RF &rf, const FEATURES &features, const LABELS &labels, WEIGHTS &)
Definition random_forest_visitors.hxx:327
void visit_before_tree(TREE &tree, FEATURES &features, LABELS &, WEIGHTS &weights)
Definition random_forest_visitors.hxx:270
size_t repetition_count_
Definition random_forest_visitors.hxx:457
void visit_after_training(VISITORS &visitors, RF &rf, const FEATURES &features, const LABELS &)
Definition random_forest_visitors.hxx:403
MultiArray< 2, double > variable_importance_
Definition random_forest_visitors.hxx:452
Container elements of the statically linked visitor list. Use the create_visitor() functions to creat...
Definition random_forest_visitors.hxx:520
Random forest version 3.
Definition random_forest_3.hxx:66
Definition random_forest_visitors.hxx:636