ATLAS Offline Software
NodeImpl.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef MVAUtils_NodeImpl_H
6 #define MVAUtils_NodeImpl_H
7 
9 #include <vector>
10 #include <cstdint>
11 #include <cmath>
12 
13 namespace TMVA{
14  class MethodBDT;
15  class DecisionTreeNode;
16 }
17 
18 namespace MVAUtils
19 {
20 
21  // classes do not inherit from a common note. Different Node classes
22  // are used as template argument for Forest classes. Forest classes will
23  // hold vector of Nodes, so it is important to avoid overhead when calling
24  // Node methods -> no polymoriphis
25 
26  // The main differences between Node classes is the logic to get the next
27  // node, depending on the input values
28 
35  class NodeTMVA
36  {
37  public:
42  NodeTMVA(const int ivar, const float val, const index_t right):
43  m_cut(val), m_right(right), m_var(ivar) { }
44 
46  void Print(index_t index) const;
47 
52  index_t GetNext(const float value, index_t index) const;
53 
54  bool IsLeaf() const { return m_var < 0; }
55 
57  var_t GetVar() const { return m_var; }
58 
60  float GetVal() const { return m_cut; }
61 
64  index_t GetLeft(index_t index) const { return index + 1; }
65 
68  index_t GetRight(index_t index) const { return index + m_right; }
69 
70  private:
71  // the order is important to have the optimal memory usage
72  float m_cut;
75  };
76 
77  // inline speedup ~15%
78  inline index_t NodeTMVA::GetNext(const float value, index_t index) const {
79  return (value >= m_cut) ? GetLeft(index) : GetRight(index);
80  }
81 
82 
92  {
93  public:
94  NodeLGBMSimple(const int ivar, const float val, const index_t right)
95  : m_cut(val), m_right(right), m_var(ivar) { }
96  void Print(index_t index) const;
97  index_t GetNext(const float value, index_t index) const;
98  bool IsLeaf() const { return m_var < 0; }
99 
101  var_t GetVar() const { return m_var; }
102 
104  float GetVal() const { return m_cut; }
105 
108  index_t GetLeft(index_t index) const { return index + 1; }
109 
112  index_t GetRight(index_t index) const { return index + m_right; }
113 
114  private:
115  float m_cut;
118  };
119 
120  inline index_t NodeLGBMSimple::GetNext(const float value, index_t index) const {
121  // note that this is different from TMVA (and not the opposite, e.g. cannot simply invert left/right)
122  return (value <= m_cut) ? GetLeft(index) : GetRight(index);
123  }
124 
134  class NodeLGBM
135  {
136  public:
137  NodeLGBM(const int ivar, const float val, const index_t right, const int8_t default_left)
138  : m_cut(val), m_right(right), m_var(ivar), m_default_left(default_left) { }
139  void Print(index_t index) const;
140  index_t GetNext(const float value, index_t index) const;
141  bool GetDefaultLeft() const { return m_default_left; }
142  bool IsLeaf() const { return m_var < 0; }
143 
144  var_t GetVar() const { return m_var; }
145  float GetVal() const { return m_cut; }
146  index_t GetLeft(index_t index) const { return index + 1; }
147  index_t GetRight(index_t index) const { return index + m_right; }
148 
149  private:
150  float m_cut;
153  int8_t m_default_left; //< if to go left in case of nan input
154  };
155 
156 
157  inline index_t NodeLGBM::GetNext(const float value, index_t index) const {
158  if (not std::isnan(value)) {
159  return (value <= m_cut) ? GetLeft(index) : GetRight(index);
160  }
161  else {
162  return (m_default_left) ? GetLeft(index) : GetRight(index);
163  }
164  }
165 
177  {
178  public:
179  NodeXGBoost(const int ivar, const float val, const index_t right, const int8_t default_left)
180  : m_cut(val), m_right(right), m_var(ivar), m_default_left(default_left) { }
181  void Print(index_t index) const;
182  index_t GetNext(const float value, index_t index) const;
183  bool GetDefaultLeft() const { return m_default_left; }
184  bool IsLeaf() const { return m_var < 0; }
185 
186  var_t GetVar() const { return m_var; }
187  float GetVal() const { return m_cut; }
188  index_t GetLeft(index_t index) const { return index + 1; }
189  index_t GetRight(index_t index) const { return index + m_right; }
190 
191  private:
192  float m_cut; // cut value for internal nodes or response for leaf nodes
193  int16_t m_right; // right relative index (to be added to current) (left is always current + 1)
194  var_t m_var; // index of the variable to cut for internal nodes, -1 for leaf nodes
195  int8_t m_default_left; //default if to go left in case of nan input
196  };
197 
198 
199  inline index_t NodeXGBoost::GetNext(const float value, index_t index) const {
200  if (not std::isnan(value)) {
201  return (value < m_cut) ? GetLeft(index) : GetRight(index);
202  }
203  else {
204  return (m_default_left) ? GetLeft(index) : GetRight(index);
205  }
206  }
207 
208 }
209 #endif
MVAUtils::NodeTMVA
Node for TMVA implementation.
Definition: NodeImpl.h:36
MVAUtils::NodeTMVA::GetVal
float GetVal() const
The value to cut on (if not leaf), or the response (if leaf).
Definition: NodeImpl.h:60
MVAUtils
Definition: InDetTrkInJetType.h:47
MVAUtils::NodeXGBoost::m_var
var_t m_var
Definition: NodeImpl.h:194
MVAUtils::NodeLGBMSimple::NodeLGBMSimple
NodeLGBMSimple(const int ivar, const float val, const index_t right)
Definition: NodeImpl.h:94
MVAUtils::NodeXGBoost::GetLeft
index_t GetLeft(index_t index) const
Definition: NodeImpl.h:188
MVAUtils::NodeLGBMSimple::IsLeaf
bool IsLeaf() const
is the current node a leaf node
Definition: NodeImpl.h:98
MVAUtils::NodeTMVA::Print
void Print(index_t index) const
For debugging only: print the node values.
Definition: NodeImpl.cxx:11
MVAUtils::NodeTMVA::GetVar
var_t GetVar() const
The variable index to cut on (or -1 if leaf, but use IsLeaf instead if checking for leaf)
Definition: NodeImpl.h:57
MVAUtils::NodeXGBoost::GetDefaultLeft
bool GetDefaultLeft() const
Definition: NodeImpl.h:183
MVAUtils::NodeLGBM::GetVar
var_t GetVar() const
Definition: NodeImpl.h:144
index
Definition: index.py:1
MVAUtils::NodeLGBMSimple::GetVal
float GetVal() const
The value to cut on (if not leaf), or the response (if leaf).
Definition: NodeImpl.h:104
MVAUtils::NodeLGBM::GetLeft
index_t GetLeft(index_t index) const
Definition: NodeImpl.h:146
athena.value
value
Definition: athena.py:122
MVAUtils::NodeLGBM::GetRight
index_t GetRight(index_t index) const
Definition: NodeImpl.h:147
MVAUtils::NodeLGBMSimple::m_cut
float m_cut
cut value for internal nodes or response for leaf nodes
Definition: NodeImpl.h:115
MVAUtils::var_t
int8_t var_t
The variable type (i.e., the index of the variable to cut)
Definition: MVAUtilsDefs.h:13
MVAUtils::NodeXGBoost
Node for XGBoost with nan implementation.
Definition: NodeImpl.h:177
MVAUtilsDefs.h
xAOD::int16_t
setScaleOne setStatusOne setSaturated int16_t
Definition: gFexGlobalRoI_v1.cxx:55
MVAUtils::NodeXGBoost::GetRight
index_t GetRight(index_t index) const
Definition: NodeImpl.h:189
MVAUtils::NodeTMVA::GetLeft
index_t GetLeft(index_t index) const
For debugging: returns the index of the left node; is passed the current node index.
Definition: NodeImpl.h:64
MVAUtils::NodeLGBMSimple::GetLeft
index_t GetLeft(index_t index) const
For debugging: returns the index of the left node; is passed the current node index.
Definition: NodeImpl.h:108
MVAUtils::NodeTMVA::m_cut
float m_cut
cut value for internal nodes or response for leaf nodes
Definition: NodeImpl.h:72
MVAUtils::NodeLGBM::m_default_left
int8_t m_default_left
Definition: NodeImpl.h:153
MVAUtils::NodeTMVA::IsLeaf
bool IsLeaf() const
is the current node a leaf node
Definition: NodeImpl.h:54
MVAUtils::NodeXGBoost::NodeXGBoost
NodeXGBoost(const int ivar, const float val, const index_t right, const int8_t default_left)
Definition: NodeImpl.h:179
MVAUtils::NodeLGBM::IsLeaf
bool IsLeaf() const
is the current node a leaf node
Definition: NodeImpl.h:142
MVAUtils::NodeTMVA::GetRight
index_t GetRight(index_t index) const
For debugging: returns the index of the right node; is passed the current node index.
Definition: NodeImpl.h:68
MVAUtils::NodeLGBMSimple::GetNext
index_t GetNext(const float value, index_t index) const
Definition: NodeImpl.h:120
plotBeamSpotCompare.ivar
int ivar
Definition: plotBeamSpotCompare.py:383
MVAUtils::NodeLGBM::NodeLGBM
NodeLGBM(const int ivar, const float val, const index_t right, const int8_t default_left)
Definition: NodeImpl.h:137
MVAUtils::NodeXGBoost::Print
void Print(index_t index) const
Definition: NodeImpl.cxx:31
MVAUtils::NodeLGBMSimple
Node for LGBM without nan implementation.
Definition: NodeImpl.h:92
MVAUtils::NodeLGBMSimple::Print
void Print(index_t index) const
Definition: NodeImpl.cxx:18
MVAUtils::NodeLGBM::m_cut
float m_cut
cut value for internal nodes or response for leaf nodes
Definition: NodeImpl.h:150
MVAUtils::NodeXGBoost::IsLeaf
bool IsLeaf() const
is the current node a leaf node
Definition: NodeImpl.h:184
MVAUtils::NodeXGBoost::m_cut
float m_cut
Definition: NodeImpl.h:192
MVAUtils::index_t
int32_t index_t
The index type of the node in the vector.
Definition: MVAUtilsDefs.h:12
MVAUtils::NodeXGBoost::m_right
int16_t m_right
Definition: NodeImpl.h:193
MVAUtils::NodeTMVA::m_right
int16_t m_right
right relative index (to be added to current) (left is always current + 1)
Definition: NodeImpl.h:73
MVAUtils::NodeXGBoost::GetVar
var_t GetVar() const
Definition: NodeImpl.h:186
MVAUtils::NodeXGBoost::GetVal
float GetVal() const
Definition: NodeImpl.h:187
MVAUtils::NodeLGBM::GetVal
float GetVal() const
Definition: NodeImpl.h:145
MVAUtils::NodeTMVA::m_var
var_t m_var
index of the variable to cut for internal nodes, -1 for leaf nodes
Definition: NodeImpl.h:74
MVAUtils::NodeLGBM::m_var
var_t m_var
index of the variable to cut for internal nodes, -1 for leaf nodes
Definition: NodeImpl.h:152
Pythia8_RapidityOrderMPI.val
val
Definition: Pythia8_RapidityOrderMPI.py:14
MVAUtils::NodeLGBM::m_right
int16_t m_right
right relative index (to be added to current) (left is always current + 1)
Definition: NodeImpl.h:151
MVAUtils::NodeLGBMSimple::GetVar
var_t GetVar() const
The variable index to cut on (or -1 if leaf, but use IsLeaf instead if checking for leaf)
Definition: NodeImpl.h:101
MVAUtils::NodeXGBoost::m_default_left
int8_t m_default_left
Definition: NodeImpl.h:195
MVAUtils::NodeLGBM::GetNext
index_t GetNext(const float value, index_t index) const
Definition: NodeImpl.h:157
MVAUtils::NodeLGBMSimple::m_right
int16_t m_right
right relative index (to be added to current) (left is always current + 1)
Definition: NodeImpl.h:116
MVAUtils::NodeLGBM
Node for LGBM with nan implementation.
Definition: NodeImpl.h:135
MVAUtils::NodeTMVA::NodeTMVA
NodeTMVA(const int ivar, const float val, const index_t right)
The constructor gets the index of the variable to cut on (-1 if leaf), the index of the right child (...
Definition: NodeImpl.h:42
MVAUtils::NodeLGBM::GetDefaultLeft
bool GetDefaultLeft() const
Definition: NodeImpl.h:141
MVAUtils::NodeLGBMSimple::m_var
var_t m_var
index of the variable to cut for internal nodes, -1 for leaf nodes
Definition: NodeImpl.h:117
MVAUtils::NodeXGBoost::GetNext
index_t GetNext(const float value, index_t index) const
Definition: NodeImpl.h:199
MVAUtils::NodeLGBM::Print
void Print(index_t index) const
Definition: NodeImpl.cxx:25
MVAUtils::NodeLGBMSimple::GetRight
index_t GetRight(index_t index) const
For debugging: returns the index of the right node; is passed the current node index.
Definition: NodeImpl.h:112
MVAUtils::NodeTMVA::GetNext
index_t GetNext(const float value, index_t index) const
Based on the value of the variable that's passed in, return the index of the appropriate child.
Definition: NodeImpl.h:78
TMVA
Definition: PhotonVertexSelectionTool.h:22