ATLAS Offline Software
TMVAToMVAUtils.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef MVAUtils_TMVATOMVAUtils_H
6 #define MVAUtils_TMVATOMVAUtils_H
7 
8 #include "MVAUtils/BDT.h"
9 #include "MVAUtils/ForestTMVA.h"
10 #include <stack>
11 #include "TMVA/Reader.h"
12 #include "TMVA/MethodBDT.h"
13 
14 namespace TMVAToMVAUtils{
16  const TMVA::DecisionTreeNode *node,
17  float weight,
18  bool isRegression,
19  bool useYesNoLeaf)
20 {
21  int max_var=0;
22  // index is relative to the current node
23  std::vector<MVAUtils::index_t> right;
24  {
25  // not strictly parent if doing a right node
26  std::stack<const TMVA::DecisionTreeNode *> parent;
27  std::stack<MVAUtils::index_t> parentIndex;
28  parentIndex.push(-1);
29  parent.push(nullptr);
30  auto currNode = node;
31  int i = -1;
32  while (currNode) {
33  ++i;
34  right.push_back(-1);
35  if (!currNode->GetLeft()) {
36  // a leaf
37  auto currParent = parent.top();
38  auto currParentIndex = parentIndex.top();
39  // if right has not been visited, next will be right
40  if (currParentIndex >= 0) {
41  right[currParentIndex] = i + 1 - currParentIndex;
42  currNode = currParent->GetCutType() ? currParent->GetLeft() : currParent->GetRight();
43  } else {
44  currNode = nullptr;
45  }
46  parent.pop();
47  parentIndex.pop();
48  } else {
49  // not a leaf
50  parent.push(currNode);
51  parentIndex.push(i);
52  currNode = currNode->GetCutType() ? currNode->GetRight() : currNode->GetLeft();
53  }
54  }
55  }
56 
57  {
58  std::stack<const TMVA::DecisionTreeNode *> parent; // not strictly parent if doing a right node
59 
60  parent.push(nullptr);
61 
62  auto currNode = node;
63  int i = -1;
64  std::vector<MVAUtils::NodeTMVA> nodes;
65  while (currNode) {
66  ++i;
67  if (!currNode->GetLeft()){
68  // a leaf
69  nodes.emplace_back(-1,
70  isRegression ?
71  currNode->GetResponse() : useYesNoLeaf ? currNode->GetNodeType() : currNode->GetPurity(),
72  right[i]);
73  auto currParent = parent.top();
74  // if right has not been visited, next will be right
75  if (currParent) {
76  currNode = currParent->GetCutType() ? currParent->GetLeft() : currParent->GetRight();
77  } else {
78  currNode = nullptr;
79  }
80  parent.pop();
81  } else {
82  // not a leaf
83  parent.push(currNode);
84 
85  if (currNode->GetSelector() >max_var) { max_var = currNode->GetSelector(); }
86 
87  nodes.emplace_back(currNode->GetSelector(), currNode->GetCutValue(), right[i]);
88 
89  currNode = currNode->GetCutType() ? currNode->GetRight() : currNode->GetLeft();
90  }
91  }
92  forest->newTree(nodes, weight);
93  }
94  forest->setNVars(max_var);
95 }
96 
97 std::unique_ptr<MVAUtils::ForestTMVA>createForestTMVA ( TMVA::MethodBDT* bdt,
98  bool isRegression ,
99  bool useYesNoLeaf){
100  auto forest=std::make_unique<MVAUtils::ForestTMVA>();
101  std::vector<TMVA::DecisionTree*>::const_iterator it;
102  for(it = bdt->GetForest().begin(); it != bdt->GetForest().end(); ++it) {
103  uint index = it - bdt->GetForest().begin();
104  float weight = 0.;
105  if(bdt->GetBoostWeights().size() > index) {
106  weight = bdt->GetBoostWeights()[index];
107  }
108 
109  newTree(forest.get(),(*it)->GetRoot(), weight, isRegression, useYesNoLeaf);
110  }
111  return forest;
112 }
113 
114 std::unique_ptr<MVAUtils::BDT> convert(TMVA::MethodBDT* bdt,
115  bool isRegression = true,
116  bool useYesNoLeaf = false){
117 
118  std::unique_ptr<MVAUtils::IForest> forest= createForestTMVA(bdt,isRegression,useYesNoLeaf);
119  return std::make_unique<MVAUtils::BDT> (std::move(forest));
120 }
121 }
122 #endif
TMVAToMVAUtils
Definition: TMVAToMVAUtils.h:14
MVAUtils::ForestWeighted::newTree
void newTree(const std::vector< Node_t > &nodes, float weight)
index
Definition: index.py:1
TMVAToMVAUtils::createForestTMVA
std::unique_ptr< MVAUtils::ForestTMVA > createForestTMVA(TMVA::MethodBDT *bdt, bool isRegression, bool useYesNoLeaf)
Definition: TMVAToMVAUtils.h:97
skel.it
it
Definition: skel.GENtoEVGEN.py:423
dqt_zlumi_pandas.weight
int weight
Definition: dqt_zlumi_pandas.py:200
uint
unsigned int uint
Definition: LArOFPhaseFill.cxx:20
ForestTMVA.h
lumiFormat.i
int i
Definition: lumiFormat.py:92
test_pyathena.parent
parent
Definition: test_pyathena.py:15
PyPoolBrowser.node
node
Definition: PyPoolBrowser.py:131
BDT.h
TMVAToMVAUtils::convert
std::unique_ptr< MVAUtils::BDT > convert(TMVA::MethodBDT *bdt, bool isRegression=true, bool useYesNoLeaf=false)
Definition: TMVAToMVAUtils.h:114
MVAUtils::ForestTMVA::setNVars
void setNVars(const int max_var)
Definition: ForestTMVA.h:82
DeMoScan.index
string index
Definition: DeMoScan.py:362
TMVAToMVAUtils::newTree
void newTree(MVAUtils::ForestTMVA *forest, const TMVA::DecisionTreeNode *node, float weight, bool isRegression, bool useYesNoLeaf)
Definition: TMVAToMVAUtils.h:15
MVAUtils::ForestTMVA
Definition: ForestTMVA.h:64
node
Definition: memory_hooks-stdcmalloc.h:74