ATLAS Offline Software
ForestLGBM.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #include "MVAUtils/ForestLGBM.h"
6 #include "TTree.h"
7 #include <iostream>
8 #include <stdexcept>
9 
10 using namespace MVAUtils;
11 
14  , m_max_var(0)
15 {
16 
17 
18  // variables read from the TTree
19  std::vector<int> *vars = nullptr;
20  std::vector<float> *values = nullptr;
21 
22  std::vector<NodeLGBMSimple> nodes;
23 
24  tree->SetBranchAddress("vars", &vars);
25  tree->SetBranchAddress("values", &values);
26 
27  int numEntries = tree->GetEntries();
28  for (int entry = 0; entry < numEntries; ++entry)
29  {
30  // each entry in the TTree is a decision tree
31  tree->GetEntry(entry);
32  if (!vars) {
33  throw std::runtime_error(
34  "vars pointer is null in ForestLGBMSimple constructor");
35  }
36  if (!values) {
37  throw std::runtime_error(
38  "values pointers is null in ForestLGBMSimple constructor");
39  }
40  if (vars->size() != values->size()) {
41  throw std::runtime_error("inconsistent size for vars and values in "
42  "ForestLGBMSimple constructor");
43  }
44 
45  nodes.clear();
46 
47  std::vector<MVAUtils::index_t> right = detail::computeRight(*vars);
48 
49  for (size_t i = 0; i < vars->size(); ++i) {
50  nodes.emplace_back(vars->at(i), values->at(i), right[i]);
51  if (vars->at(i) > m_max_var) { m_max_var = vars->at(i); }
52  }
53  newTree(nodes);
54  } // end loop on TTree, all decision tree loaded
55  delete vars;
56  delete values;
57 }
58 
59 
60 TTree* ForestLGBMSimple::WriteTree(TString name) const
61 {
62  TTree *tree = new TTree(name.Data(), "creator=lgbm;node_type=lgbm_simple");
63 
64  std::vector<int> vars;
65  std::vector<float> values;
66 
67  tree->Branch("vars", &vars);
68  tree->Branch("values", &values);
69 
70  for (size_t itree = 0; itree < GetNTrees(); ++itree) {
71  vars.clear();
72  values.clear();
73  for(const auto& node : GetTree(itree)) {
74  vars.push_back(node.GetVar());
75  values.push_back(node.GetVal());
76  }
77  tree->Fill();
78  }
79  return tree;
80 }
81 
83 {
84  std::cout << "***BDT LGBMSimple: Printing entire forest***" << std::endl;
86 }
87 
90  , m_max_var(0)
91 {
92 
93 
94  // variables read from the TTree
95  std::vector<int> *vars = nullptr;
96  std::vector<float> *values = nullptr;
97  std::vector<bool> *default_left = nullptr;
98 
99  std::vector<NodeLGBM> nodes;
100 
101  tree->SetBranchAddress("vars", &vars);
102  tree->SetBranchAddress("values", &values);
103  tree->SetBranchAddress("default_left", &default_left);
104  int numEntries = tree->GetEntries();
105  for (int entry = 0; entry < numEntries; ++entry) {
106  // each entry in the TTree is a decision tree
107  tree->GetEntry(entry);
108  if (!vars) {
109  throw std::runtime_error(
110  "vars pointer is null in ForestLGBM constructor");
111  }
112  if (!values) {
113  throw std::runtime_error(
114  "values pointers is null in ForestLGBM constructor");
115  }
116  if (!default_left) {
117  throw std::runtime_error(
118  "default_left pointers is null in ForestLGBM constructor");
119  }
120  if (vars->size() != values->size()) {
121  throw std::runtime_error(
122  "inconsistent size for vars and values in ForestLGBM constructor");
123  }
124  if (default_left->size() != values->size()) {
125  throw std::runtime_error("inconsistent size for default_left and "
126  "values in ForestLGBM constructor");
127  }
128 
129  nodes.clear();
130 
131  std::vector<MVAUtils::index_t> right = detail::computeRight(*vars);
132 
133  for (size_t i = 0; i < vars->size(); ++i) {
134  nodes.emplace_back(
135  vars->at(i), values->at(i), right[i], default_left->at(i));
136  if (vars->at(i) > m_max_var) {
137  m_max_var = vars->at(i);
138  }
139  }
140  newTree(nodes);
141  } // end loop on TTree, all decision tree loaded
142  delete vars;
143  delete values;
144  delete default_left;
145 }
146 
147 
148 TTree* ForestLGBM::WriteTree(TString name) const
149 {
150  TTree *tree = new TTree(name.Data(), "creator=lgbm;node_type=lgbm");
151 
152  std::vector<int> vars;
153  std::vector<float> values;
154  std::vector<bool> default_left;
155 
156  tree->Branch("vars", &vars);
157  tree->Branch("values", &values);
158  tree->Branch("default_left", &default_left);
159 
160  for (size_t itree = 0; itree < GetNTrees(); ++itree) {
161  vars.clear();
162  values.clear();
163  default_left.clear();
164  for(const auto& node : GetTree(itree)) {
165  vars.push_back(node.GetVar());
166  values.push_back(node.GetVal());
167  default_left.push_back(node.GetDefaultLeft());
168  }
169  tree->Fill();
170  }
171  return tree;
172 }
173 
175 {
176  std::cout << "***BDT LGBM: Printing entire forest***" << std::endl;
178 }
MVAUtils::ForestLGBM::ForestLGBM
ForestLGBM()=default
MVAUtils::ForestLGBMSimple::m_max_var
int m_max_var
Definition: ForestLGBM.h:56
MVAUtils
Definition: InDetTrkInJetType.h:48
MVAUtils::ForestLGBMBase
Definition: ForestLGBM.h:27
tree
TChain * tree
Definition: tile_monitor.h:30
MVAUtils::ForestLGBMSimple::ForestLGBMSimple
ForestLGBMSimple()=default
MVAUtils::Forest< NodeLGBMSimple >::GetTree
std::vector< NodeLGBMSimple > GetTree(unsigned int itree) const
Return the vector of nodes for the tree itree.
MVAUtils::detail::computeRight
std::vector< index_t > computeRight(const std::vector< int > &vars)
Compute the offsets between the nodes to their right children from a serialized representation of the...
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:805
ForestLGBM.h
MVAUtils::Forest< NodeLGBMSimple >::GetNTrees
virtual unsigned int GetNTrees() const override final
Definition: Forest.h:94
lumiFormat.i
int i
Definition: lumiFormat.py:85
MVAUtils::Forest< NodeLGBMSimple >::newTree
void newTree(const std::vector< NodeLGBMSimple > &nodes)
append a new tree (defined by a vector of nodes serialized in preorder) to the forest
MVAUtils::NodeLGBMSimple
Node for LGBM without nan implementation.
Definition: NodeImpl.h:92
MVAUtils::ForestLGBMSimple::WriteTree
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
Definition: ForestLGBM.cxx:60
MVAUtils::ForestLGBM::WriteTree
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
Definition: ForestLGBM.cxx:148
GetAllXsec.entry
list entry
Definition: GetAllXsec.py:132
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
MVAUtils::ForestLGBMSimple::PrintForest
virtual void PrintForest() const override
Definition: ForestLGBM.cxx:82
MVAUtils::Forest::PrintForest
virtual void PrintForest() const override
MVAUtils::ForestLGBM::m_max_var
int m_max_var
Definition: ForestLGBM.h:75
MVAUtils::ForestLGBM::PrintForest
virtual void PrintForest() const override
Definition: ForestLGBM.cxx:174
MVAUtils::NodeLGBM
Node for LGBM with nan implementation.
Definition: NodeImpl.h:135
node
Definition: memory_hooks-stdcmalloc.h:74