ATLAS Offline Software
Loading...
Searching...
No Matches
NodeImpl.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef MVAUtils_NodeImpl_H
6#define MVAUtils_NodeImpl_H
7
9#include <vector>
10#include <cstdint>
11#include <cmath>
12
13namespace TMVA{
14 class MethodBDT;
15 class DecisionTreeNode;
16}
17
18namespace MVAUtils
19{
20
21 // classes do not inherit from a common note. Different Node classes
22 // are used as template argument for Forest classes. Forest classes will
23 // hold vector of Nodes, so it is important to avoid overhead when calling
24 // Node methods -> no polymoriphis
25
26 // The main differences between Node classes is the logic to get the next
27 // node, depending on the input values
28
36 {
37 public:
42 NodeTMVA(const int ivar, const float val, const index_t right):
43 m_cut(val), m_right(right), m_var(ivar) { }
44
46 void Print(index_t index) const;
47
52 index_t GetNext(const float value, index_t index) const;
53
54 bool IsLeaf() const { return m_var < 0; }
55
57 var_t GetVar() const { return m_var; }
58
60 float GetVal() const { return m_cut; }
61
64 index_t GetLeft(index_t index) const { return index + 1; }
65
69
70 private:
71 // the order is important to have the optimal memory usage
72 float m_cut;
73 int16_t m_right;
75 };
76
77 // inline speedup ~15%
78 inline index_t NodeTMVA::GetNext(const float value, index_t index) const {
79 return (value >= m_cut) ? GetLeft(index) : GetRight(index);
80 }
81
82
92 {
93 public:
94 NodeLGBMSimple(const int ivar, const float val, const index_t right)
95 : m_cut(val), m_right(right), m_var(ivar) { }
96 void Print(index_t index) const;
97 index_t GetNext(const float value, index_t index) const;
98 bool IsLeaf() const { return m_var < 0; }
99
101 var_t GetVar() const { return m_var; }
102
104 float GetVal() const { return m_cut; }
105
108 index_t GetLeft(index_t index) const { return index + 1; }
109
113
114 private:
115 float m_cut;
116 int16_t m_right;
118 };
119
120 inline index_t NodeLGBMSimple::GetNext(const float value, index_t index) const {
121 // note that this is different from TMVA (and not the opposite, e.g. cannot simply invert left/right)
122 return (value <= m_cut) ? GetLeft(index) : GetRight(index);
123 }
124
135 {
136 public:
137 NodeLGBM(const int ivar, const float val, const index_t right, const int8_t default_left)
138 : m_cut(val), m_right(right), m_var(ivar), m_default_left(default_left) { }
139 void Print(index_t index) const;
140 index_t GetNext(const float value, index_t index) const;
141 bool GetDefaultLeft() const { return m_default_left; }
142 bool IsLeaf() const { return m_var < 0; }
143
144 var_t GetVar() const { return m_var; }
145 float GetVal() const { return m_cut; }
146 index_t GetLeft(index_t index) const { return index + 1; }
148
149 private:
150 float m_cut;
151 int16_t m_right;
153 int8_t m_default_left; //< if to go left in case of nan input
154 };
155
156
157 inline index_t NodeLGBM::GetNext(const float value, index_t index) const {
158 if (not std::isnan(value)) {
159 return (value <= m_cut) ? GetLeft(index) : GetRight(index);
160 }
161 else {
163 }
164 }
165
177 {
178 public:
179 NodeXGBoost(const int ivar, const float val, const index_t right, const int8_t default_left)
180 : m_cut(val), m_right(right), m_var(ivar), m_default_left(default_left) { }
181 void Print(index_t index) const;
182 index_t GetNext(const float value, index_t index) const;
183 bool GetDefaultLeft() const { return m_default_left; }
184 bool IsLeaf() const { return m_var < 0; }
185
186 var_t GetVar() const { return m_var; }
187 float GetVal() const { return m_cut; }
188 index_t GetLeft(index_t index) const { return index + 1; }
190
191 private:
192 float m_cut; // cut value for internal nodes or response for leaf nodes
193 int16_t m_right; // right relative index (to be added to current) (left is always current + 1)
194 var_t m_var; // index of the variable to cut for internal nodes, -1 for leaf nodes
195 int8_t m_default_left; //default if to go left in case of nan input
196 };
197
198
199 inline index_t NodeXGBoost::GetNext(const float value, index_t index) const {
200 if (not std::isnan(value)) {
201 return (value < m_cut) ? GetLeft(index) : GetRight(index);
202 }
203 else {
205 }
206 }
207
208}
209#endif
void Print(index_t index) const
Definition NodeImpl.cxx:18
float GetVal() const
The value to cut on (if not leaf), or the response (if leaf).
Definition NodeImpl.h:104
int16_t m_right
right relative index (to be added to current) (left is always current + 1)
Definition NodeImpl.h:116
index_t GetNext(const float value, index_t index) const
Definition NodeImpl.h:120
index_t GetRight(index_t index) const
For debugging: returns the index of the right node; is passed the current node index.
Definition NodeImpl.h:112
NodeLGBMSimple(const int ivar, const float val, const index_t right)
Definition NodeImpl.h:94
float m_cut
cut value for internal nodes or response for leaf nodes
Definition NodeImpl.h:115
var_t GetVar() const
The variable index to cut on (or -1 if leaf, but use IsLeaf instead if checking for leaf)
Definition NodeImpl.h:101
var_t m_var
index of the variable to cut for internal nodes, -1 for leaf nodes
Definition NodeImpl.h:117
index_t GetLeft(index_t index) const
For debugging: returns the index of the left node; is passed the current node index.
Definition NodeImpl.h:108
bool IsLeaf() const
is the current node a leaf node
Definition NodeImpl.h:98
int8_t m_default_left
Definition NodeImpl.h:153
var_t m_var
index of the variable to cut for internal nodes, -1 for leaf nodes
Definition NodeImpl.h:152
index_t GetRight(index_t index) const
Definition NodeImpl.h:147
float GetVal() const
Definition NodeImpl.h:145
int16_t m_right
right relative index (to be added to current) (left is always current + 1)
Definition NodeImpl.h:151
index_t GetNext(const float value, index_t index) const
Definition NodeImpl.h:157
void Print(index_t index) const
Definition NodeImpl.cxx:25
NodeLGBM(const int ivar, const float val, const index_t right, const int8_t default_left)
Definition NodeImpl.h:137
bool GetDefaultLeft() const
Definition NodeImpl.h:141
bool IsLeaf() const
is the current node a leaf node
Definition NodeImpl.h:142
index_t GetLeft(index_t index) const
Definition NodeImpl.h:146
float m_cut
cut value for internal nodes or response for leaf nodes
Definition NodeImpl.h:150
var_t GetVar() const
Definition NodeImpl.h:144
var_t m_var
index of the variable to cut for internal nodes, -1 for leaf nodes
Definition NodeImpl.h:74
float m_cut
cut value for internal nodes or response for leaf nodes
Definition NodeImpl.h:72
index_t GetRight(index_t index) const
For debugging: returns the index of the right node; is passed the current node index.
Definition NodeImpl.h:68
NodeTMVA(const int ivar, const float val, const index_t right)
The constructor gets the index of the variable to cut on (-1 if leaf), the index of the right child (...
Definition NodeImpl.h:42
void Print(index_t index) const
For debugging only: print the node values.
Definition NodeImpl.cxx:11
int16_t m_right
right relative index (to be added to current) (left is always current + 1)
Definition NodeImpl.h:73
var_t GetVar() const
The variable index to cut on (or -1 if leaf, but use IsLeaf instead if checking for leaf)
Definition NodeImpl.h:57
bool IsLeaf() const
is the current node a leaf node
Definition NodeImpl.h:54
index_t GetLeft(index_t index) const
For debugging: returns the index of the left node; is passed the current node index.
Definition NodeImpl.h:64
index_t GetNext(const float value, index_t index) const
Based on the value of the variable that's passed in, return the index of the appropriate child.
Definition NodeImpl.h:78
float GetVal() const
The value to cut on (if not leaf), or the response (if leaf).
Definition NodeImpl.h:60
index_t GetLeft(index_t index) const
Definition NodeImpl.h:188
index_t GetRight(index_t index) const
Definition NodeImpl.h:189
float GetVal() const
Definition NodeImpl.h:187
bool GetDefaultLeft() const
Definition NodeImpl.h:183
NodeXGBoost(const int ivar, const float val, const index_t right, const int8_t default_left)
Definition NodeImpl.h:179
var_t GetVar() const
Definition NodeImpl.h:186
void Print(index_t index) const
Definition NodeImpl.cxx:31
bool IsLeaf() const
is the current node a leaf node
Definition NodeImpl.h:184
index_t GetNext(const float value, index_t index) const
Definition NodeImpl.h:199
int8_t var_t
The variable type (i.e., the index of the variable to cut)
int32_t index_t
The index type of the node in the vector.
Definition index.py:1