ATLAS Offline Software
Public Member Functions | Private Attributes | List of all members
MVAUtils::NodeXGBoost Class Reference

Node for XGBoost with nan implementation. More...

#include <NodeImpl.h>

Collaboration diagram for MVAUtils::NodeXGBoost:

Public Member Functions

 NodeXGBoost (const int ivar, const float val, const index_t right, const int8_t default_left)
 
void Print (index_t index) const
 
index_t GetNext (const float value, index_t index) const
 
bool GetDefaultLeft () const
 
bool IsLeaf () const
 is the current node a leaf node More...
 
var_t GetVar () const
 
float GetVal () const
 
index_t GetLeft (index_t index) const
 
index_t GetRight (index_t index) const
 

Private Attributes

float m_cut
 
int16_t m_right
 
var_t m_var
 
int8_t m_default_left
 

Detailed Description

Node for XGBoost with nan implementation.

This follow the implementation in XGBoost next = value != nan ? (value < cut ? left : right) : (default_left ? left : right) left are assigned to be "YES" right are assigned to be "NO" in XGBoost default_left is stored for each node (can be different) Does not support categorical inputs.

Definition at line 176 of file NodeImpl.h.

Constructor & Destructor Documentation

◆ NodeXGBoost()

MVAUtils::NodeXGBoost::NodeXGBoost ( const int  ivar,
const float  val,
const index_t  right,
const int8_t  default_left 
)
inline

Definition at line 179 of file NodeImpl.h.

180  : m_cut(val), m_right(right), m_var(ivar), m_default_left(default_left) { }

Member Function Documentation

◆ GetDefaultLeft()

bool MVAUtils::NodeXGBoost::GetDefaultLeft ( ) const
inline

Definition at line 183 of file NodeImpl.h.

183 { return m_default_left; }

◆ GetLeft()

index_t MVAUtils::NodeXGBoost::GetLeft ( index_t  index) const
inline

Definition at line 188 of file NodeImpl.h.

188 { return index + 1; }

◆ GetNext()

index_t MVAUtils::NodeXGBoost::GetNext ( const float  value,
index_t  index 
) const
inline

Definition at line 199 of file NodeImpl.h.

199  {
200  if (not std::isnan(value)) {
201  return (value < m_cut) ? GetLeft(index) : GetRight(index);
202  }
203  else {
204  return (m_default_left) ? GetLeft(index) : GetRight(index);
205  }
206  }

◆ GetRight()

index_t MVAUtils::NodeXGBoost::GetRight ( index_t  index) const
inline

Definition at line 189 of file NodeImpl.h.

189 { return index + m_right; }

◆ GetVal()

float MVAUtils::NodeXGBoost::GetVal ( ) const
inline

Definition at line 187 of file NodeImpl.h.

187 { return m_cut; }

◆ GetVar()

var_t MVAUtils::NodeXGBoost::GetVar ( ) const
inline

Definition at line 186 of file NodeImpl.h.

186 { return m_var; }

◆ IsLeaf()

bool MVAUtils::NodeXGBoost::IsLeaf ( ) const
inline

is the current node a leaf node

Definition at line 184 of file NodeImpl.h.

◆ Print()

void NodeXGBoost::Print ( index_t  index) const

Definition at line 31 of file NodeImpl.cxx.

32 {
33  std::cout << " Variable: " << int(m_var) << ", Cut: " << m_cut << ", DefaultLeft: " << (int)m_default_left
34  << " (index = " << index << ")" << std::endl;
35 }

Member Data Documentation

◆ m_cut

float MVAUtils::NodeXGBoost::m_cut
private

Definition at line 192 of file NodeImpl.h.

◆ m_default_left

int8_t MVAUtils::NodeXGBoost::m_default_left
private

Definition at line 195 of file NodeImpl.h.

◆ m_right

int16_t MVAUtils::NodeXGBoost::m_right
private

Definition at line 193 of file NodeImpl.h.

◆ m_var

var_t MVAUtils::NodeXGBoost::m_var
private

Definition at line 194 of file NodeImpl.h.


The documentation for this class was generated from the following files:
MVAUtils::NodeXGBoost::m_var
var_t m_var
Definition: NodeImpl.h:194
MVAUtils::NodeXGBoost::GetLeft
index_t GetLeft(index_t index) const
Definition: NodeImpl.h:188
CaloCellPos2Ntuple.int
int
Definition: CaloCellPos2Ntuple.py:24
index
Definition: index.py:1
athena.value
value
Definition: athena.py:124
MVAUtils::NodeXGBoost::GetRight
index_t GetRight(index_t index) const
Definition: NodeImpl.h:189
plotBeamSpotCompare.ivar
int ivar
Definition: plotBeamSpotCompare.py:383
MVAUtils::NodeXGBoost::m_cut
float m_cut
Definition: NodeImpl.h:192
MVAUtils::NodeXGBoost::m_right
int16_t m_right
Definition: NodeImpl.h:193
Pythia8_RapidityOrderMPI.val
val
Definition: Pythia8_RapidityOrderMPI.py:14
MVAUtils::NodeXGBoost::m_default_left
int8_t m_default_left
Definition: NodeImpl.h:195