2 Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
5 template<typename Container_t>
6 inline void MVAUtils::detail::applySoftmax(Container_t& x)
8 using T = typename Container_t::value_type;
9 // subtract max to avoid overflow (softmax is invariant to shifts)
10 const T max_x = *std::max_element(x.begin(), x.end());
11 std::transform(x.begin(), x.end(), x.begin(), [max_x](T v){ return exp(v - max_x); });
12 const T sum = std::accumulate(x.begin(), x.end(), T{});
13 std::transform(x.begin(), x.end(), x.begin(), [sum](T v){ return v / sum; });
17 inline std::vector<MVAUtils::index_t> MVAUtils::detail::computeRight(const std::vector<int>& vars)
19 // parent index is relative to root of tree (and only used inside this function)
20 // right index is relative to the then processed node
21 // example: vars = 2 0 -1 -1 -1 returns: 4 2 0 0 0
22 std::vector<MVAUtils::index_t> right(vars.size());
23 std::stack<MVAUtils::index_t> parent; // not strictly parent if doing a right node
26 for (std::vector<int>::size_type i = 0; i < vars.size(); ++i)
28 if (vars.at(i) >= 0) { // not a leaf
32 const MVAUtils::index_t currParent = parent.top();
33 // if right has not been visited, next will be right
34 if (currParent >= 0) {
35 right[currParent] = i + 1 - currParent;
44 template<typename Node_t> void MVAUtils::Forest<Node_t>::PrintTree(unsigned int itree) const
46 index_t top_node_index = m_forest[itree];
47 std::stack<index_t> s;
48 s.push(top_node_index);
52 m_nodes.at(node).Print(node);
53 if (!m_nodes[node].IsLeaf()) {
54 s.push(m_nodes[node].GetRight(node));
55 s.push(m_nodes[node].GetLeft(node));
60 template<typename Node_t> void MVAUtils::Forest<Node_t>::PrintForest() const {
61 for (unsigned int itree = 0; itree != GetNTrees(); ++itree) {
62 std::cout << "Tree number: " << itree << std::endl;
67 template<typename Node_t>
68 std::vector<Node_t> MVAUtils::Forest<Node_t>::GetTree(unsigned int itree) const
70 index_t top_node_index = m_forest[itree];
71 index_t last_node_index = (itree < GetNTrees() - 1) ? m_forest[itree + 1] : m_nodes.size();
72 return std::vector<Node_t>(m_nodes.begin() + top_node_index, m_nodes.begin() + last_node_index);
75 template<typename Node_t>
76 float MVAUtils::Forest<Node_t>::GetTreeResponseFromNode(const std::vector<float>& values, index_t index) const
78 while (!m_nodes[index].IsLeaf())
80 index = m_nodes[index].GetNext(values[m_nodes[index].GetVar()], index);
82 return m_nodes[index].GetVal();
85 template<typename Node_t>
86 float MVAUtils::Forest<Node_t>::GetTreeResponseFromNode(const std::vector<float*>& pointers, index_t index) const
88 while (!m_nodes[index].IsLeaf())
90 index = m_nodes[index].GetNext(*(pointers[m_nodes[index].GetVar()]), index);
92 return m_nodes[index].GetVal();
95 template<typename Node_t>
96 float MVAUtils::Forest<Node_t>::GetTreeResponse(const std::vector<float>& values, unsigned int itree) const
98 index_t top_node_index = m_forest[itree];
99 return GetTreeResponseFromNode(values, top_node_index);
102 template<typename Node_t>
103 float MVAUtils::Forest<Node_t>::GetTreeResponse(const std::vector<float*>& pointers, unsigned int itree) const
105 index_t top_node_index = m_forest[itree];
106 return GetTreeResponseFromNode(pointers, top_node_index);
109 template<typename Node_t>
110 float MVAUtils::Forest<Node_t>::GetRawResponse(const std::vector<float>& values) const
113 // looping in the reverse order since usually the response of the trees
114 // is decreasing (first more important). So it better to start the sum
115 // from the smaller to avoid numerical precision issues.
116 for (int itree = GetNTrees() - 1; itree >= 0; --itree)
118 result += GetTreeResponse(values, itree);
123 template<typename Node_t>
124 float MVAUtils::Forest<Node_t>::GetRawResponse(const std::vector<float*>& pointers) const
127 for (int itree = GetNTrees() - 1; itree >= 0; --itree)
129 result += GetTreeResponse(pointers, itree);
134 template<typename Node_t>
135 float MVAUtils::Forest<Node_t>::GetResponse(const std::vector<float>& values) const
137 return GetRawResponse(values);
140 template<typename Node_t>
141 float MVAUtils::Forest<Node_t>::GetResponse(const std::vector<float*>& pointers) const
143 return GetRawResponse(pointers);
146 template<typename Node_t>
147 std::vector<float> MVAUtils::Forest<Node_t>::GetMultiResponse(const std::vector<float>& values,
148 unsigned int numClasses) const
150 // this implementation is common TMVA / LGBM
151 // In multiclass each class has a separate forest. Each forest is made by the same number
152 // of trees. Here all the nodes of all the trees of all the forests are stored interleaved,
153 // e.g. assume each forest is made by 10 trees
154 // class0: tree-0, tree-10, ...
155 // class1: tree-1, tree-11, ...
156 // this very same scheme is used internally by lgbm. The representation of the nodes will be:
157 // tree0(class0)-node0, tree0(class0)-node1, ...
159 std::vector<float> result;
160 if (numClasses > 0) {
161 result.resize(numClasses); // ignores the offset
162 // note that the loop is not the trees, not on the classes
163 // that would be equivalent, but better to read the vector in order
164 for (unsigned int itree = 0; itree < GetNTrees(); ++itree) {
165 result[itree % numClasses] += GetTreeResponse(values, itree);
168 detail::applySoftmax(result);
173 template<typename Node_t>
174 std::vector<float> MVAUtils::Forest<Node_t>::GetMultiResponse(const std::vector<float*>& pointers,
175 unsigned int numClasses) const
177 std::vector<float> result;
179 if (numClasses > 0) {
180 result.resize(numClasses); // ignores the offset
181 for (unsigned int itree = 0; itree < GetNTrees(); ++itree) {
182 result[itree % numClasses] += GetTreeResponse(pointers, itree);
184 detail::applySoftmax(result);
189 template<typename Node_t>
190 void MVAUtils::Forest<Node_t>::newTree(const std::vector<Node_t>& nodes)
192 m_forest.push_back(m_nodes.size());
193 m_nodes.insert(m_nodes.end(), nodes.begin(), nodes.end());