ATLAS Offline Software
Forest.icc
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 template<typename Container_t>
6 inline void MVAUtils::detail::applySoftmax(Container_t& x)
7 {
8  using T = typename Container_t::value_type;
9  // subtract max to avoid overflow (softmax is invariant to shifts)
10  const T max_x = *std::max_element(x.begin(), x.end());
11  std::transform(x.begin(), x.end(), x.begin(), [max_x](T v){ return exp(v - max_x); });
12  const T sum = std::accumulate(x.begin(), x.end(), T{});
13  std::transform(x.begin(), x.end(), x.begin(), [sum](T v){ return v / sum; });
14 }
15 
16 
17 inline std::vector<MVAUtils::index_t> MVAUtils::detail::computeRight(const std::vector<int>& vars)
18 {
19  // parent index is relative to root of tree (and only used inside this function)
20  // right index is relative to the then processed node
21  // example: vars = 2 0 -1 -1 -1 returns: 4 2 0 0 0
22  std::vector<MVAUtils::index_t> right(vars.size());
23  std::stack<MVAUtils::index_t> parent; // not strictly parent if doing a right node
24 
25  parent.push(-1);
26  for (std::vector<int>::size_type i = 0; i < vars.size(); ++i)
27  {
28  if (vars.at(i) >= 0) { // not a leaf
29  parent.push(i);
30  } else {
31  // a leaf
32  const MVAUtils::index_t currParent = parent.top();
33  // if right has not been visited, next will be right
34  if (currParent >= 0) {
35  right[currParent] = i + 1 - currParent;
36  }
37  parent.pop();
38  }
39  }
40  return right;
41 }
42 
43 
44 template<typename Node_t> void MVAUtils::Forest<Node_t>::PrintTree(unsigned int itree) const
45 {
46  index_t top_node_index = m_forest[itree];
47  std::stack<index_t> s;
48  s.push(top_node_index);
49  while (!s.empty()) {
50  auto node = s.top();
51  s.pop();
52  m_nodes.at(node).Print(node);
53  if (!m_nodes[node].IsLeaf()) {
54  s.push(m_nodes[node].GetRight(node));
55  s.push(m_nodes[node].GetLeft(node));
56  }
57  }
58 }
59 
60 template<typename Node_t> void MVAUtils::Forest<Node_t>::PrintForest() const {
61  for (unsigned int itree = 0; itree != GetNTrees(); ++itree) {
62  std::cout << "Tree number: " << itree << std::endl;
63  PrintTree(itree);
64  }
65 }
66 
67 template<typename Node_t>
68 std::vector<Node_t> MVAUtils::Forest<Node_t>::GetTree(unsigned int itree) const
69 {
70  index_t top_node_index = m_forest[itree];
71  index_t last_node_index = (itree < GetNTrees() - 1) ? m_forest[itree + 1] : m_nodes.size();
72  return std::vector<Node_t>(m_nodes.begin() + top_node_index, m_nodes.begin() + last_node_index);
73 }
74 
75 template<typename Node_t>
76 float MVAUtils::Forest<Node_t>::GetTreeResponseFromNode(const std::vector<float>& values, index_t index) const
77 {
78  while (!m_nodes[index].IsLeaf())
79  {
80  index = m_nodes[index].GetNext(values[m_nodes[index].GetVar()], index);
81  }
82  return m_nodes[index].GetVal();
83 }
84 
85 template<typename Node_t>
86 float MVAUtils::Forest<Node_t>::GetTreeResponseFromNode(const std::vector<float*>& pointers, index_t index) const
87 {
88  while (!m_nodes[index].IsLeaf())
89  {
90  index = m_nodes[index].GetNext(*(pointers[m_nodes[index].GetVar()]), index);
91  }
92  return m_nodes[index].GetVal();
93 }
94 
95 template<typename Node_t>
96 float MVAUtils::Forest<Node_t>::GetTreeResponse(const std::vector<float>& values, unsigned int itree) const
97 {
98  index_t top_node_index = m_forest[itree];
99  return GetTreeResponseFromNode(values, top_node_index);
100 }
101 
102 template<typename Node_t>
103 float MVAUtils::Forest<Node_t>::GetTreeResponse(const std::vector<float*>& pointers, unsigned int itree) const
104 {
105  index_t top_node_index = m_forest[itree];
106  return GetTreeResponseFromNode(pointers, top_node_index);
107 }
108 
109 template<typename Node_t>
110 float MVAUtils::Forest<Node_t>::GetRawResponse(const std::vector<float>& values) const
111 {
112  float result = 0.;
113  // looping in the reverse order since usually the response of the trees
114  // is decreasing (first more important). So it better to start the sum
115  // from the smaller to avoid numerical precision issues.
116  for (int itree = GetNTrees() - 1; itree >= 0; --itree)
117  {
118  result += GetTreeResponse(values, itree);
119  }
120  return result;
121 }
122 
123 template<typename Node_t>
124 float MVAUtils::Forest<Node_t>::GetRawResponse(const std::vector<float*>& pointers) const
125 {
126  float result = 0.;
127  for (int itree = GetNTrees() - 1; itree >= 0; --itree)
128  {
129  result += GetTreeResponse(pointers, itree);
130  }
131  return result;
132 }
133 
134 template<typename Node_t>
135 float MVAUtils::Forest<Node_t>::GetResponse(const std::vector<float>& values) const
136 {
137  return GetRawResponse(values);
138 }
139 
140 template<typename Node_t>
141 float MVAUtils::Forest<Node_t>::GetResponse(const std::vector<float*>& pointers) const
142 {
143  return GetRawResponse(pointers);
144 }
145 
146 template<typename Node_t>
147 std::vector<float> MVAUtils::Forest<Node_t>::GetMultiResponse(const std::vector<float>& values,
148  unsigned int numClasses) const
149 {
150  // this implementation is common TMVA / LGBM
151  // In multiclass each class has a separate forest. Each forest is made by the same number
152  // of trees. Here all the nodes of all the trees of all the forests are stored interleaved,
153  // e.g. assume each forest is made by 10 trees
154  // class0: tree-0, tree-10, ...
155  // class1: tree-1, tree-11, ...
156  // this very same scheme is used internally by lgbm. The representation of the nodes will be:
157  // tree0(class0)-node0, tree0(class0)-node1, ...
158 
159  std::vector<float> result;
160  if (numClasses > 0) {
161  result.resize(numClasses); // ignores the offset
162  // note that the loop is not the trees, not on the classes
163  // that would be equivalent, but better to read the vector in order
164  for (unsigned int itree = 0; itree < GetNTrees(); ++itree) {
165  result[itree % numClasses] += GetTreeResponse(values, itree);
166  }
167 
168  detail::applySoftmax(result);
169  }
170  return result;
171 }
172 
173 template<typename Node_t>
174 std::vector<float> MVAUtils::Forest<Node_t>::GetMultiResponse(const std::vector<float*>& pointers,
175  unsigned int numClasses) const
176 {
177  std::vector<float> result;
178 
179  if (numClasses > 0) {
180  result.resize(numClasses); // ignores the offset
181  for (unsigned int itree = 0; itree < GetNTrees(); ++itree) {
182  result[itree % numClasses] += GetTreeResponse(pointers, itree);
183  }
184  detail::applySoftmax(result);
185  }
186  return result;
187 }
188 
189 template<typename Node_t>
190 void MVAUtils::Forest<Node_t>::newTree(const std::vector<Node_t>& nodes)
191 {
192  m_forest.push_back(m_nodes.size());
193  m_nodes.insert(m_nodes.end(), nodes.begin(), nodes.end());
194 }