ATLAS Offline Software
Loading...
Searching...
No Matches
MVAUtils::ForestLGBM Class Referencefinalabstract

Implement LGBM with nan support. More...

#include <ForestLGBM.h>

Inheritance diagram for MVAUtils::ForestLGBM:
Collaboration diagram for MVAUtils::ForestLGBM:

Public Member Functions

 ForestLGBM (TTree *tree)
 ForestLGBM ()=default
 ForestLGBM (const ForestLGBM &)=default
ForestLGBMoperator= (const ForestLGBM &)=default
 ForestLGBM (ForestLGBM &&)=default
ForestLGBMoperator= (ForestLGBM &&)=default
 ~ForestLGBM ()=default
virtual TTree * WriteTree (TString name) const override
 Return a TTree representing the BDT.
virtual void PrintForest () const override
virtual int GetNVars () const override
 Get the number of input variable to be passed with std::vector to Get* methods.
virtual float GetClassification (const std::vector< float > &values) const final
 Compute the prediction of a classification.
virtual float GetClassification (const std::vector< float * > &pointers) const =0
virtual float GetTreeResponse (const std::vector< float > &values, unsigned int itree) const override final
 Return the response of one tree Must pass the features in a std::vector<float> values and the index of the tree.
virtual float GetTreeResponse (const std::vector< float * > &pointers, unsigned int itree) const =0
virtual float GetOffset () const override
 Return the offset of the forest.
virtual float GetRawResponse (const std::vector< float > &values) const override final
 Return the response of the whole Forest.
virtual float GetRawResponse (const std::vector< float * > &pointers) const =0
virtual float GetResponse (const std::vector< float > &values) const override
 Compute the prediction for regression.
virtual float GetResponse (const std::vector< float * > &pointers) const =0
virtual std::vector< float > GetMultiResponse (const std::vector< float > &values, unsigned int numClasses) const override
 Compute the prediction for multiclassification (a score for each class).
virtual std::vector< float > GetMultiResponse (const std::vector< float * > &pointers, unsigned int numClasses) const =0
virtual unsigned int GetNTrees () const override final
virtual void PrintTree (unsigned int itree) const override
std::vector< NodeLGBMGetTree (unsigned int itree) const
 Return the vector of nodes for the tree itree.

Protected Member Functions

float GetTreeResponseFromNode (const std::vector< float > &values, index_t index) const
 Get the response of a tree.
void newTree (const std::vector< NodeLGBM > &nodes)
 append a new tree (defined by a vector of nodes serialized in preorder) to the forest

Private Attributes

int m_max_var =0
std::vector< index_tm_forest
 indices of the top-level nodes of each tree
std::vector< NodeLGBMm_nodes
 where the nodes of the forest are stored

Detailed Description

Implement LGBM with nan support.

Definition at line 61 of file ForestLGBM.h.

Constructor & Destructor Documentation

◆ ForestLGBM() [1/4]

ForestLGBM::ForestLGBM ( TTree * tree)
explicit

Definition at line 88 of file ForestLGBM.cxx.

89 : ForestLGBMBase<NodeLGBM>()
90 , m_max_var(0)
91{
92
93
94 // variables read from the TTree
95 std::vector<int> *vars = nullptr;
96 std::vector<float> *values = nullptr;
97 std::vector<bool> *default_left = nullptr;
98
99 std::vector<NodeLGBM> nodes;
100
101 tree->SetBranchAddress("vars", &vars);
102 tree->SetBranchAddress("values", &values);
103 tree->SetBranchAddress("default_left", &default_left);
104 int numEntries = tree->GetEntries();
105 for (int entry = 0; entry < numEntries; ++entry) {
106 // each entry in the TTree is a decision tree
107 tree->GetEntry(entry);
108 if (!vars) {
109 throw std::runtime_error(
110 "vars pointer is null in ForestLGBM constructor");
111 }
112 if (!values) {
113 throw std::runtime_error(
114 "values pointers is null in ForestLGBM constructor");
115 }
116 if (!default_left) {
117 throw std::runtime_error(
118 "default_left pointers is null in ForestLGBM constructor");
119 }
120 if (vars->size() != values->size()) {
121 throw std::runtime_error(
122 "inconsistent size for vars and values in ForestLGBM constructor");
123 }
124 if (default_left->size() != values->size()) {
125 throw std::runtime_error("inconsistent size for default_left and "
126 "values in ForestLGBM constructor");
127 }
128
129 nodes.clear();
130
131 std::vector<MVAUtils::index_t> right = detail::computeRight(*vars);
132
133 for (size_t i = 0; i < vars->size(); ++i) {
134 nodes.emplace_back(
135 vars->at(i), values->at(i), right[i], default_left->at(i));
136 if (vars->at(i) > m_max_var) {
137 m_max_var = vars->at(i);
138 }
139 }
140 newTree(nodes);
141 } // end loop on TTree, all decision tree loaded
142 delete vars;
143 delete values;
144 delete default_left;
145}
void newTree(const std::vector< NodeLGBM > &nodes)
std::vector< index_t > computeRight(const std::vector< int > &vars)
Compute the offsets between the nodes to their right children from a serialized representation of the...
TChain * tree

◆ ForestLGBM() [2/4]

MVAUtils::ForestLGBM::ForestLGBM ( )
default

◆ ForestLGBM() [3/4]

MVAUtils::ForestLGBM::ForestLGBM ( const ForestLGBM & )
default

◆ ForestLGBM() [4/4]

MVAUtils::ForestLGBM::ForestLGBM ( ForestLGBM && )
default

◆ ~ForestLGBM()

MVAUtils::ForestLGBM::~ForestLGBM ( )
default

Member Function Documentation

◆ GetClassification() [1/2]

virtual float MVAUtils::ForestLGBMBase< NodeLGBM >::GetClassification ( const std::vector< float > & values) const
inlinefinalvirtualinherited

Compute the prediction of a classification.

Implements MVAUtils::IForest.

Definition at line 31 of file ForestLGBM.h.

32 {
34 }
virtual float GetResponse(const std::vector< float > &values) const override
T sigmoid(T x)
Definition Forest.h:21

◆ GetClassification() [2/2]

virtual float MVAUtils::IForest::GetClassification ( const std::vector< float * > & pointers) const
pure virtualinherited

◆ GetMultiResponse() [1/2]

virtual std::vector< float > MVAUtils::Forest< NodeLGBM >::GetMultiResponse ( const std::vector< float > & values,
unsigned int numClasses ) const
overridevirtualinherited

Compute the prediction for multiclassification (a score for each class).

In addition to the input values need to pass the number of classes

Implements MVAUtils::IForest.

◆ GetMultiResponse() [2/2]

virtual std::vector< float > MVAUtils::IForest::GetMultiResponse ( const std::vector< float * > & pointers,
unsigned int numClasses ) const
pure virtualinherited

Implemented in MVAUtils::Forest< Node_t >.

◆ GetNTrees()

virtual unsigned int MVAUtils::Forest< NodeLGBM >::GetNTrees ( ) const
inlinefinaloverridevirtualinherited

Implements MVAUtils::IForest.

Definition at line 94 of file Forest.h.

95 {
96 return m_forest.size();
97 }
Generic Forest base class.
Definition Forest.h:54

◆ GetNVars()

virtual int MVAUtils::ForestLGBM::GetNVars ( ) const
inlineoverridevirtual

Get the number of input variable to be passed with std::vector to Get* methods.

Implements MVAUtils::IForest.

Definition at line 73 of file ForestLGBM.h.

73{ return m_max_var + 1; }

◆ GetOffset()

virtual float MVAUtils::Forest< NodeLGBM >::GetOffset ( ) const
inlineoverridevirtualinherited

Return the offset of the forest.

Since by default there is no offset, return 0

Implements MVAUtils::IForest.

Definition at line 63 of file Forest.h.

63{ return 0.; }

◆ GetRawResponse() [1/2]

virtual float MVAUtils::Forest< NodeLGBM >::GetRawResponse ( const std::vector< float > & values) const
finaloverridevirtualinherited

Return the response of the whole Forest.

Raw is just the sum of all the trees

Implements MVAUtils::IForest.

◆ GetRawResponse() [2/2]

virtual float MVAUtils::IForest::GetRawResponse ( const std::vector< float * > & pointers) const
pure virtualinherited

Implemented in MVAUtils::Forest< Node_t >.

◆ GetResponse() [1/2]

virtual float MVAUtils::Forest< NodeLGBM >::GetResponse ( const std::vector< float > & values) const
overridevirtualinherited

Compute the prediction for regression.

Implements MVAUtils::IForest.

◆ GetResponse() [2/2]

virtual float MVAUtils::IForest::GetResponse ( const std::vector< float * > & pointers) const
pure virtualinherited

◆ GetTree()

std::vector< NodeLGBM > MVAUtils::Forest< NodeLGBM >::GetTree ( unsigned int itree) const
inherited

Return the vector of nodes for the tree itree.

◆ GetTreeResponse() [1/2]

virtual float MVAUtils::Forest< NodeLGBM >::GetTreeResponse ( const std::vector< float > & values,
unsigned int itree ) const
finaloverridevirtualinherited

Return the response of one tree Must pass the features in a std::vector<float> values and the index of the tree.

Implements MVAUtils::IForest.

◆ GetTreeResponse() [2/2]

virtual float MVAUtils::IForest::GetTreeResponse ( const std::vector< float * > & pointers,
unsigned int itree ) const
pure virtualinherited

Implemented in MVAUtils::Forest< Node_t >.

◆ GetTreeResponseFromNode()

float MVAUtils::Forest< NodeLGBM >::GetTreeResponseFromNode ( const std::vector< float > & values,
index_t index ) const
protectedinherited

Get the response of a tree.

Instead of specifying the index of the tree (as in GetTreeResponse) the index of the top node of the tree should be specified

◆ newTree()

void MVAUtils::Forest< NodeLGBM >::newTree ( const std::vector< NodeLGBM > & nodes)
protectedinherited

append a new tree (defined by a vector of nodes serialized in preorder) to the forest

◆ operator=() [1/2]

ForestLGBM & MVAUtils::ForestLGBM::operator= ( const ForestLGBM & )
default

◆ operator=() [2/2]

ForestLGBM & MVAUtils::ForestLGBM::operator= ( ForestLGBM && )
default

◆ PrintForest()

void ForestLGBM::PrintForest ( ) const
overridevirtual

Reimplemented from MVAUtils::Forest< NodeLGBM >.

Definition at line 174 of file ForestLGBM.cxx.

175{
176 std::cout << "***BDT LGBM: Printing entire forest***" << std::endl;
178}
virtual void PrintForest() const override

◆ PrintTree()

virtual void MVAUtils::Forest< NodeLGBM >::PrintTree ( unsigned int itree) const
overridevirtualinherited

Implements MVAUtils::IForest.

◆ WriteTree()

TTree * ForestLGBM::WriteTree ( TString ) const
overridevirtual

Return a TTree representing the BDT.

The called is the owner of the returned TTree

Implements MVAUtils::IForest.

Definition at line 148 of file ForestLGBM.cxx.

149{
150 TTree *tree = new TTree(name.Data(), "creator=lgbm;node_type=lgbm");
151
152 std::vector<int> vars;
153 std::vector<float> values;
154 std::vector<bool> default_left;
155
156 tree->Branch("vars", &vars);
157 tree->Branch("values", &values);
158 tree->Branch("default_left", &default_left);
159
160 for (size_t itree = 0; itree < GetNTrees(); ++itree) {
161 vars.clear();
162 values.clear();
163 default_left.clear();
164 for(const auto& node : GetTree(itree)) {
165 vars.push_back(node.GetVar());
166 values.push_back(node.GetVal());
167 default_left.push_back(node.GetDefaultLeft());
168 }
169 tree->Fill();
170 }
171 return tree;
172}
std::vector< NodeLGBM > GetTree(unsigned int itree) const
virtual unsigned int GetNTrees() const override final
Definition Forest.h:94

Member Data Documentation

◆ m_forest

std::vector<index_t> MVAUtils::Forest< NodeLGBM >::m_forest
privateinherited

indices of the top-level nodes of each tree

Definition at line 117 of file Forest.h.

◆ m_max_var

int MVAUtils::ForestLGBM::m_max_var =0
private

Definition at line 75 of file ForestLGBM.h.

◆ m_nodes

std::vector<NodeLGBM> MVAUtils::Forest< NodeLGBM >::m_nodes
privateinherited

where the nodes of the forest are stored

Definition at line 118 of file Forest.h.


The documentation for this class was generated from the following files: