ATLAS Offline Software
ForestXGBoost.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 #include "TTree.h"
7 #include <iostream>
8 #include <stdexcept>
9 
10 using namespace MVAUtils;
11 
14  , m_max_var(0)
15 {
16 
17 
18  // variables read from the TTree
19  std::vector<int> *vars = nullptr;
20  std::vector<float> *values = nullptr;
21  std::vector<bool> *default_left = nullptr;
22 
23  std::vector<NodeXGBoost> nodes;
24 
25  tree->SetBranchAddress("vars", &vars);
26  tree->SetBranchAddress("values", &values);
27  tree->SetBranchAddress("default_left", &default_left);
28 
29  for (int i = 0; i < tree->GetEntries(); ++i)
30  {
31  // each entry in the TTree is a decision tree
32  tree->GetEntry(i);
33  if (!vars) { throw std::runtime_error("vars pointer is null in ForestXGBoost constructor"); }
34  if (!values) { throw std::runtime_error("values pointers is null in ForestXGBoost constructor"); }
35  if (!default_left) { throw std::runtime_error("default_left pointers is null in ForestXGBoost constructor"); }
36  if (vars->size() != values->size()) { throw std::runtime_error("inconsistent size for vars and values in ForestXGBoost constructor"); }
37  if (default_left->size() != values->size()) { throw std::runtime_error("inconsistent size for default_left and values in ForestXGBoost constructor"); }
38 
39  nodes.clear();
40 
41  std::vector<MVAUtils::index_t> right = detail::computeRight(*vars);
42 
43  for (size_t i = 0; i < vars->size(); ++i) {
44  nodes.emplace_back(vars->at(i), values->at(i), right[i], default_left->at(i));
45  if (vars->at(i) > m_max_var) { m_max_var = vars->at(i); }
46  }
47  newTree(nodes);
48  } // end loop on TTree, all decision tree loaded
49  delete vars;
50  delete values;
51  delete default_left;
52 }
53 
54 
55 TTree* ForestXGBoost::WriteTree(TString name) const
56 {
57  TTree *tree = new TTree(name.Data(), "creator=xgboost");
58 
59  std::vector<int> vars;
60  std::vector<float> values;
61  std::vector<bool> default_left;
62 
63  tree->Branch("vars", &vars);
64  tree->Branch("values", &values);
65  tree->Branch("default_left", &default_left);
66 
67  for (size_t itree = 0; itree < GetNTrees(); ++itree) {
68  vars.clear();
69  values.clear();
70  default_left.clear();
71  for(const auto& node : GetTree(itree)) {
72  vars.push_back(node.GetVar());
73  values.push_back(node.GetVal());
74  default_left.push_back(node.GetDefaultLeft());
75  }
76  tree->Fill();
77  }
78  return tree;
79 }
80 
82 {
83  std::cout << "***BDT XGBoost: Printing entire forest***" << std::endl;
85 }
MVAUtils
Definition: InDetTrkInJetType.h:48
tree
TChain * tree
Definition: tile_monitor.h:30
MVAUtils::Forest< NodeXGBoost >::GetTree
std::vector< NodeXGBoost > GetTree(unsigned int itree) const
Return the vector of nodes for the tree itree.
MVAUtils::ForestXGBoost::ForestXGBoost
ForestXGBoost()=default
MVAUtils::NodeXGBoost
Node for XGBoost with nan implementation.
Definition: NodeImpl.h:177
MVAUtils::detail::computeRight
std::vector< index_t > computeRight(const std::vector< int > &vars)
Compute the offsets between the nodes to their right children from a serialized representation of the...
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:805
MVAUtils::ForestXGBoostBase
Definition: ForestXGBoost.h:29
MVAUtils::Forest< NodeXGBoost >::GetNTrees
virtual unsigned int GetNTrees() const override final
Definition: Forest.h:94
lumiFormat.i
int i
Definition: lumiFormat.py:85
MVAUtils::Forest< NodeXGBoost >::newTree
void newTree(const std::vector< NodeXGBoost > &nodes)
append a new tree (defined by a vector of nodes serialized in preorder) to the forest
MVAUtils::ForestXGBoost::PrintForest
virtual void PrintForest() const override
Definition: ForestXGBoost.cxx:81
ForestXGBoost.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
MVAUtils::Forest::PrintForest
virtual void PrintForest() const override
MVAUtils::ForestXGBoost::WriteTree
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
Definition: ForestXGBoost.cxx:55
MVAUtils::ForestXGBoost::m_max_var
int m_max_var
Definition: ForestXGBoost.h:59
node
Definition: memory_hooks-stdcmalloc.h:74