ATLAS Offline Software
Loading...
Searching...
No Matches
Forest.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef MVAUtils_Forest_H
6#define MVAUtils_Forest_H
7
9#include "NodeImpl.h"
10#include <stack>
11#include <cmath>
12#include <algorithm>
13#include <numeric>
14#include <iostream>
15#include <vector>
16
17
18namespace MVAUtils { namespace detail { // helpers
19
20 template<typename T>
21 inline T sigmoid(T x) { return 1. / (1. + exp(-x)); }
22
26 template<typename Container_t> void applySoftmax(Container_t& x);
27
34 inline std::vector<index_t> computeRight(const std::vector<int>& vars);
35 } }
36
37namespace MVAUtils
38{
52 template<typename Node_t>
53 class Forest : public IForest
54 {
55 public:
56 virtual float GetTreeResponse(const std::vector<float>& values,
57 unsigned int itree) const override final;
58 virtual float GetTreeResponse(const std::vector<float*>& pointers,
59 unsigned int itree) const override final;
60
63 virtual float GetOffset() const override { return 0.; }
64
67 virtual float GetRawResponse(
68 const std::vector<float>& values) const override final;
69 virtual float GetRawResponse(
70 const std::vector<float*>& pointers) const override final;
71
73 // In this class it is equal to the raw-reponse. Derived class should
74 // override this.
75 virtual float GetResponse(
76 const std::vector<float>& values) const override;
77 virtual float GetResponse(
78 const std::vector<float*>& pointers) const override;
79
84 // Since TMVA and lgbm are identical the common implementation is here:
85 // Return the softmax of the sub-forest raw-response
86 virtual std::vector<float> GetMultiResponse(
87 const std::vector<float>& values,
88 unsigned int numClasses) const override;
89
90 virtual std::vector<float> GetMultiResponse(
91 const std::vector<float*>& pointers,
92 unsigned int numClasses) const override;
93
94 virtual unsigned int GetNTrees() const override final
95 {
96 return m_forest.size();
97 }
98
99 virtual void PrintForest() const override;
100
101 virtual void PrintTree(unsigned int itree) const override;
102
104 std::vector<Node_t> GetTree(unsigned int itree) const;
105
106 protected:
110 float GetTreeResponseFromNode(const std::vector<float>& values, index_t index) const;
111 float GetTreeResponseFromNode(const std::vector<float*>& pointers, index_t index) const;
112
114 void newTree(const std::vector<Node_t>& nodes);
115
116 private:
117 std::vector<index_t> m_forest;
118 std::vector<Node_t> m_nodes;
119 };
120
121}
122
123#include "Forest.icc"
124
125#endif
#define x
Generic Forest base class.
Definition Forest.h:54
virtual float GetRawResponse(const std::vector< float * > &pointers) const override final
void newTree(const std::vector< Node_t > &nodes)
append a new tree (defined by a vector of nodes serialized in preorder) to the forest
std::vector< Node_t > GetTree(unsigned int itree) const
Return the vector of nodes for the tree itree.
virtual std::vector< float > GetMultiResponse(const std::vector< float > &values, unsigned int numClasses) const override
Compute the prediction for multiclassification (a score for each class).
virtual float GetOffset() const override
Return the offset of the forest.
Definition Forest.h:63
std::vector< index_t > m_forest
indices of the top-level nodes of each tree
Definition Forest.h:117
float GetTreeResponseFromNode(const std::vector< float * > &pointers, index_t index) const
virtual std::vector< float > GetMultiResponse(const std::vector< float * > &pointers, unsigned int numClasses) const override
virtual float GetTreeResponse(const std::vector< float * > &pointers, unsigned int itree) const override final
virtual float GetRawResponse(const std::vector< float > &values) const override final
Return the response of the whole Forest.
std::vector< Node_t > m_nodes
where the nodes of the forest are stored
Definition Forest.h:118
virtual void PrintForest() const override
virtual float GetResponse(const std::vector< float * > &pointers) const override
virtual void PrintTree(unsigned int itree) const override
virtual float GetTreeResponse(const std::vector< float > &values, unsigned int itree) const override final
Return the response of one tree Must pass the features in a std::vector<float> values and the index o...
virtual float GetResponse(const std::vector< float > &values) const override
Compute the prediction for regression.
float GetTreeResponseFromNode(const std::vector< float > &values, index_t index) const
Get the response of a tree.
virtual unsigned int GetNTrees() const override final
Definition Forest.h:94
Compute the response from the binary trees in the forest.
Definition ForestBase.h:23
void applySoftmax(Container_t &x)
apply softmax to the input: {exp[xi] / sum(exp[xj]) for xi in x}
T sigmoid(T x)
Definition Forest.h:21
std::vector< index_t > computeRight(const std::vector< int > &vars)
Compute the offsets between the nodes to their right children from a serialized representation of the...
int32_t index_t
The index type of the node in the vector.
Definition index.py:1
std::vector< T * > pointers(std::vector< T > &v)
Definition rmain.cxx:367