ATLAS Offline Software
Loading...
Searching...
No Matches
Forest.icc
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
3*/
4
5template<typename Container_t>
6inline void MVAUtils::detail::applySoftmax(Container_t& x)
7{
8 using T = typename Container_t::value_type;
9 // subtract max to avoid overflow (softmax is invariant to shifts)
10 const T max_x = *std::max_element(x.begin(), x.end());
11 std::transform(x.begin(), x.end(), x.begin(), [max_x](T v){ return exp(v - max_x); });
12 const T sum = std::accumulate(x.begin(), x.end(), T{});
13 std::transform(x.begin(), x.end(), x.begin(), [sum](T v){ return v / sum; });
14}
15
16
17inline std::vector<MVAUtils::index_t> MVAUtils::detail::computeRight(const std::vector<int>& vars)
18{
19 // parent index is relative to root of tree (and only used inside this function)
20 // right index is relative to the then processed node
21 // example: vars = 2 0 -1 -1 -1 returns: 4 2 0 0 0
22 std::vector<MVAUtils::index_t> right(vars.size());
23 std::stack<MVAUtils::index_t> parent; // not strictly parent if doing a right node
24
25 parent.push(-1);
26 for (std::vector<int>::size_type i = 0; i < vars.size(); ++i)
27 {
28 if (vars.at(i) >= 0) { // not a leaf
29 parent.push(i);
30 } else {
31 // a leaf
32 const MVAUtils::index_t currParent = parent.top();
33 // if right has not been visited, next will be right
34 if (currParent >= 0) {
35 right[currParent] = i + 1 - currParent;
36 }
37 parent.pop();
38 }
39 }
40 return right;
41}
42
43
44template<typename Node_t> void MVAUtils::Forest<Node_t>::PrintTree(unsigned int itree) const
45{
46 index_t top_node_index = m_forest[itree];
47 std::stack<index_t> s;
48 s.push(top_node_index);
49 while (!s.empty()) {
50 auto node = s.top();
51 s.pop();
52 m_nodes.at(node).Print(node);
53 if (!m_nodes[node].IsLeaf()) {
54 s.push(m_nodes[node].GetRight(node));
55 s.push(m_nodes[node].GetLeft(node));
56 }
57 }
58}
59
60template<typename Node_t> void MVAUtils::Forest<Node_t>::PrintForest() const {
61 for (unsigned int itree = 0; itree != GetNTrees(); ++itree) {
62 std::cout << "Tree number: " << itree << std::endl;
63 PrintTree(itree);
64 }
65}
66
67template<typename Node_t>
68std::vector<Node_t> MVAUtils::Forest<Node_t>::GetTree(unsigned int itree) const
69{
70 index_t top_node_index = m_forest[itree];
71 index_t last_node_index = (itree < GetNTrees() - 1) ? m_forest[itree + 1] : m_nodes.size();
72 return std::vector<Node_t>(m_nodes.begin() + top_node_index, m_nodes.begin() + last_node_index);
73}
74
75template<typename Node_t>
76float MVAUtils::Forest<Node_t>::GetTreeResponseFromNode(const std::vector<float>& values, index_t index) const
77{
78 while (!m_nodes[index].IsLeaf())
79 {
80 index = m_nodes[index].GetNext(values[m_nodes[index].GetVar()], index);
81 }
82 return m_nodes[index].GetVal();
83}
84
85template<typename Node_t>
86float MVAUtils::Forest<Node_t>::GetTreeResponseFromNode(const std::vector<float*>& pointers, index_t index) const
87{
88 while (!m_nodes[index].IsLeaf())
89 {
90 index = m_nodes[index].GetNext(*(pointers[m_nodes[index].GetVar()]), index);
91 }
92 return m_nodes[index].GetVal();
93}
94
95template<typename Node_t>
96float MVAUtils::Forest<Node_t>::GetTreeResponse(const std::vector<float>& values, unsigned int itree) const
97{
98 index_t top_node_index = m_forest[itree];
99 return GetTreeResponseFromNode(values, top_node_index);
100}
101
102template<typename Node_t>
103float MVAUtils::Forest<Node_t>::GetTreeResponse(const std::vector<float*>& pointers, unsigned int itree) const
104{
105 index_t top_node_index = m_forest[itree];
106 return GetTreeResponseFromNode(pointers, top_node_index);
107}
108
109template<typename Node_t>
110float MVAUtils::Forest<Node_t>::GetRawResponse(const std::vector<float>& values) const
111{
112 float result = 0.;
113 // looping in the reverse order since usually the response of the trees
114 // is decreasing (first more important). So it better to start the sum
115 // from the smaller to avoid numerical precision issues.
116 for (int itree = GetNTrees() - 1; itree >= 0; --itree)
117 {
118 result += GetTreeResponse(values, itree);
119 }
120 return result;
121}
122
123template<typename Node_t>
124float MVAUtils::Forest<Node_t>::GetRawResponse(const std::vector<float*>& pointers) const
125{
126 float result = 0.;
127 for (int itree = GetNTrees() - 1; itree >= 0; --itree)
128 {
129 result += GetTreeResponse(pointers, itree);
130 }
131 return result;
132}
133
134template<typename Node_t>
135float MVAUtils::Forest<Node_t>::GetResponse(const std::vector<float>& values) const
136{
137 return GetRawResponse(values);
138}
139
140template<typename Node_t>
141float MVAUtils::Forest<Node_t>::GetResponse(const std::vector<float*>& pointers) const
142{
143 return GetRawResponse(pointers);
144}
145
146template<typename Node_t>
147std::vector<float> MVAUtils::Forest<Node_t>::GetMultiResponse(const std::vector<float>& values,
148 unsigned int numClasses) const
149{
150 // this implementation is common TMVA / LGBM
151 // In multiclass each class has a separate forest. Each forest is made by the same number
152 // of trees. Here all the nodes of all the trees of all the forests are stored interleaved,
153 // e.g. assume each forest is made by 10 trees
154 // class0: tree-0, tree-10, ...
155 // class1: tree-1, tree-11, ...
156 // this very same scheme is used internally by lgbm. The representation of the nodes will be:
157 // tree0(class0)-node0, tree0(class0)-node1, ...
158
159 std::vector<float> result;
160 if (numClasses > 0) {
161 result.resize(numClasses); // ignores the offset
162 // note that the loop is not the trees, not on the classes
163 // that would be equivalent, but better to read the vector in order
164 for (unsigned int itree = 0; itree < GetNTrees(); ++itree) {
165 result[itree % numClasses] += GetTreeResponse(values, itree);
166 }
167
168 detail::applySoftmax(result);
169 }
170 return result;
171}
172
173template<typename Node_t>
174std::vector<float> MVAUtils::Forest<Node_t>::GetMultiResponse(const std::vector<float*>& pointers,
175 unsigned int numClasses) const
176{
177 std::vector<float> result;
178
179 if (numClasses > 0) {
180 result.resize(numClasses); // ignores the offset
181 for (unsigned int itree = 0; itree < GetNTrees(); ++itree) {
182 result[itree % numClasses] += GetTreeResponse(pointers, itree);
183 }
184 detail::applySoftmax(result);
185 }
186 return result;
187}
188
189template<typename Node_t>
190void MVAUtils::Forest<Node_t>::newTree(const std::vector<Node_t>& nodes)
191{
192 m_forest.push_back(m_nodes.size());
193 m_nodes.insert(m_nodes.end(), nodes.begin(), nodes.end());
194}