ATLAS Offline Software
ForestTMVA.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #include "MVAUtils/ForestTMVA.h"
6 #include "TTree.h"
7 #include <iostream>
8 #include <stdexcept>
9 
10 using namespace MVAUtils;
11 
14  , m_max_var(0)
15 {
16  // variables read from the TTree
17  std::vector<int> *vars = nullptr;
18  std::vector<float> *values = nullptr;
19  float offset; // the offset is the weight
20 
21  std::vector<NodeTMVA> nodes;
22 
23  tree->SetBranchAddress("vars", &vars);
24  tree->SetBranchAddress("values", &values);
25  tree->SetBranchAddress("offset", &offset);
26 
27  int numEntries = tree->GetEntries();
28  for (int entry = 0; entry < numEntries; ++entry) {
29  // each entry in the TTree is a decision tree
30  tree->GetEntry(entry);
31  if (!vars) {
32  throw std::runtime_error(
33  "vars pointer is null in ForestTMVA constructor");
34  }
35  if (!values) {
36  throw std::runtime_error(
37  "values pointers is null in ForestTMVA constructor");
38  }
39  if (vars->size() != values->size()) {
40  throw std::runtime_error(
41  "inconsistent size for vars and values in ForestTMVA constructor");
42  }
43 
44  nodes.clear();
45 
46  std::vector<MVAUtils::index_t> right = detail::computeRight(*vars);
47 
48  for (size_t i = 0; i < vars->size(); ++i) {
49  nodes.emplace_back(vars->at(i), values->at(i), right[i]);
50  if (vars->at(i) > m_max_var) {
51  m_max_var = vars->at(i);
52  }
53  }
54  newTree(nodes, offset);
55  } // end loop on TTree, all decision tree loaded
56  delete vars;
57  delete values;
58 }
59 
60 
61 TTree* ForestTMVA::WriteTree(TString name) const
62 {
63  TTree *tree = new TTree(name.Data(), "creator=TMVA");
64 
65  std::vector<int> vars;
66  std::vector<float> values;
67  float offset;
68 
69  tree->Branch("offset", &offset);
70  tree->Branch("vars", &vars);
71  tree->Branch("values", &values);
72 
73  for (size_t itree = 0; itree < GetNTrees(); ++itree) {
74  vars.clear();
75  values.clear();
76  for(const auto& node : GetTree(itree)) {
77  vars.push_back(node.GetVar());
78  values.push_back(node.GetVal());
79  }
80  offset = GetTreeWeight(itree);
81  tree->Fill();
82  }
83  return tree;
84 }
85 
87 {
88  std::cout << "***BDT TMVA: Printing entire forest***" << std::endl;
90 }
MVAUtils::ForestTMVA::WriteTree
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
Definition: ForestTMVA.cxx:61
MVAUtils::NodeTMVA
Node for TMVA implementation.
Definition: NodeImpl.h:36
MVAUtils::ForestTMVA::m_max_var
int m_max_var
Definition: ForestTMVA.h:84
MVAUtils
Definition: InDetTrkInJetType.h:47
MVAUtils::ForestWeighted< NodeTMVA >::newTree
void newTree(const std::vector< NodeTMVA > &nodes, float weight)
tree
TChain * tree
Definition: tile_monitor.h:30
MVAUtils::Forest< NodeTMVA >::GetTree
std::vector< NodeTMVA > GetTree(unsigned int itree) const
Return the vector of nodes for the tree itree.
MVAUtils::detail::computeRight
std::vector< index_t > computeRight(const std::vector< int > &vars)
Compute the offsets between the nodes to their right children from a serialized representation of the...
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:797
MVAUtils::Forest< NodeTMVA >::GetNTrees
virtual unsigned int GetNTrees() const override final
Definition: Forest.h:94
ForestTMVA.h
lumiFormat.i
int i
Definition: lumiFormat.py:92
GetAllXsec.entry
list entry
Definition: GetAllXsec.py:132
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
MVAUtils::Forest::PrintForest
virtual void PrintForest() const override
MVAUtils::ForestTMVA::ForestTMVA
ForestTMVA()=default
convertTimingResiduals.offset
offset
Definition: convertTimingResiduals.py:71
MVAUtils::ForestWeighted< NodeTMVA >::GetTreeWeight
float GetTreeWeight(unsigned int itree) const
Definition: ForestTMVA.h:38
MVAUtils::ForestTMVA::PrintForest
virtual void PrintForest() const override
Definition: ForestTMVA.cxx:86
node
Definition: memory_hooks-stdcmalloc.h:74
MVAUtils::ForestWeighted
Implement a Forest with weighted nodes This a general Forest class which implement the strategy used ...
Definition: ForestTMVA.h:24