ATLAS Offline Software
Forest.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef MVAUtils_Forest_H
6 #define MVAUtils_Forest_H
7 
8 #include "MVAUtils/ForestBase.h"
9 #include "NodeImpl.h"
10 #include <stack>
11 #include <cmath>
12 #include <algorithm>
13 #include <numeric>
14 #include <iostream>
15 #include <vector>
16 
17 
18 namespace MVAUtils { namespace detail { // helpers
19 
20  template<typename T>
21  inline T sigmoid(T x) { return 1. / (1. + exp(-x)); }
22 
26  template<typename Container_t> void applySoftmax(Container_t& x);
27 
34  inline std::vector<index_t> computeRight(const std::vector<int>& vars);
35  } }
36 
37 namespace MVAUtils
38 {
52  template<typename Node_t>
53  class Forest : public IForest
54  {
55  public:
56  virtual float GetTreeResponse(const std::vector<float>& values,
57  unsigned int itree) const override final;
58  virtual float GetTreeResponse(const std::vector<float*>& pointers,
59  unsigned int itree) const override final;
60 
63  virtual float GetOffset() const override { return 0.; }
64 
67  virtual float GetRawResponse(
68  const std::vector<float>& values) const override final;
69  virtual float GetRawResponse(
70  const std::vector<float*>& pointers) const override final;
71 
73  // In this class it is equal to the raw-reponse. Derived class should
74  // override this.
75  virtual float GetResponse(
76  const std::vector<float>& values) const override;
77  virtual float GetResponse(
78  const std::vector<float*>& pointers) const override;
79 
84  // Since TMVA and lgbm are identical the common implementation is here:
85  // Return the softmax of the sub-forest raw-response
86  virtual std::vector<float> GetMultiResponse(
87  const std::vector<float>& values,
88  unsigned int numClasses) const override;
89 
90  virtual std::vector<float> GetMultiResponse(
91  const std::vector<float*>& pointers,
92  unsigned int numClasses) const override;
93 
94  virtual unsigned int GetNTrees() const override final
95  {
96  return m_forest.size();
97  }
98 
99  virtual void PrintForest() const override;
100 
101  virtual void PrintTree(unsigned int itree) const override;
102 
104  std::vector<Node_t> GetTree(unsigned int itree) const;
105 
106  protected:
110  float GetTreeResponseFromNode(const std::vector<float>& values, index_t index) const;
111  float GetTreeResponseFromNode(const std::vector<float*>& pointers, index_t index) const;
112 
114  void newTree(const std::vector<Node_t>& nodes);
115 
116  private:
117  std::vector<index_t> m_forest;
118  std::vector<Node_t> m_nodes;
119  };
120 
121 }
122 
123 #include "Forest.icc"
124 
125 #endif
MVAUtils::Forest::GetTreeResponseFromNode
float GetTreeResponseFromNode(const std::vector< float > &values, index_t index) const
Get the response of a tree.
MVAUtils::Forest::GetRawResponse
virtual float GetRawResponse(const std::vector< float > &values) const override final
Return the response of the whole Forest.
MVAUtils::detail::sigmoid
T sigmoid(T x)
Definition: Forest.h:21
MVAUtils
Definition: InDetTrkInJetType.h:48
Forest.icc
index
Definition: index.py:1
MVAUtils::Forest::GetResponse
virtual float GetResponse(const std::vector< float > &values) const override
Compute the prediction for regression.
MVAUtils::Forest::GetOffset
virtual float GetOffset() const override
Return the offset of the forest.
Definition: Forest.h:63
detail
Definition: extract_histogram_tag.cxx:14
MVAUtils::Forest::GetTree
std::vector< Node_t > GetTree(unsigned int itree) const
Return the vector of nodes for the tree itree.
const
bool const RAWDATA *ch2 const
Definition: LArRodBlockPhysicsV0.cxx:560
drawFromPickle.exp
exp
Definition: drawFromPickle.py:36
x
#define x
MVAUtils::detail::computeRight
std::vector< index_t > computeRight(const std::vector< int > &vars)
Compute the offsets between the nodes to their right children from a serialized representation of the...
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:805
MVAUtils::Forest::GetTreeResponseFromNode
float GetTreeResponseFromNode(const std::vector< float * > &pointers, index_t index) const
MVAUtils::Forest::GetMultiResponse
virtual std::vector< float > GetMultiResponse(const std::vector< float * > &pointers, unsigned int numClasses) const override
MVAUtils::Forest::GetNTrees
virtual unsigned int GetNTrees() const override final
Definition: Forest.h:94
MVAUtils::Forest::m_forest
std::vector< index_t > m_forest
indices of the top-level nodes of each tree
Definition: Forest.h:117
MVAUtils::Forest::newTree
void newTree(const std::vector< Node_t > &nodes)
append a new tree (defined by a vector of nodes serialized in preorder) to the forest
MVAUtils::Forest::GetTreeResponse
virtual float GetTreeResponse(const std::vector< float * > &pointers, unsigned int itree) const override final
MVAUtils::Forest
Generic Forest base class.
Definition: Forest.h:54
MVAUtils::index_t
int32_t index_t
The index type of the node in the vector.
Definition: MVAUtilsDefs.h:12
MVAUtils::Forest::PrintForest
virtual void PrintForest() const override
MVAUtils::detail::applySoftmax
void applySoftmax(Container_t &x)
apply softmax to the input: {exp[xi] / sum(exp[xj]) for xi in x}
MVAUtils::Forest::GetMultiResponse
virtual std::vector< float > GetMultiResponse(const std::vector< float > &values, unsigned int numClasses) const override
Compute the prediction for multiclassification (a score for each class).
MVAUtils::IForest
Compute the response from the binary trees in the forest.
Definition: ForestBase.h:23
MVAUtils::Forest::GetRawResponse
virtual float GetRawResponse(const std::vector< float * > &pointers) const override final
pointers
std::vector< T * > pointers(std::vector< T > &v)
Definition: rmain.cxx:366
ForestBase.h
NodeImpl.h
MVAUtils::Forest::m_nodes
std::vector< Node_t > m_nodes
where the nodes of the forest are stored
Definition: Forest.h:118
MVAUtils::Forest::GetResponse
virtual float GetResponse(const std::vector< float * > &pointers) const override
MVAUtils::Forest::GetTreeResponse
virtual float GetTreeResponse(const std::vector< float > &values, unsigned int itree) const override final
Return the response of one tree Must pass the features in a std::vector<float> values and the index o...
MVAUtils::Forest::PrintTree
virtual void PrintTree(unsigned int itree) const override