ATLAS Offline Software
Functions
TMVAToMVAUtils Namespace Reference

Functions

void newTree (MVAUtils::ForestTMVA *forest, const TMVA::DecisionTreeNode *node, float weight, bool isRegression, bool useYesNoLeaf)
 
std::unique_ptr< MVAUtils::ForestTMVAcreateForestTMVA (TMVA::MethodBDT *bdt, bool isRegression, bool useYesNoLeaf)
 
std::unique_ptr< MVAUtils::BDTconvert (TMVA::MethodBDT *bdt, bool isRegression=true, bool useYesNoLeaf=false)
 

Function Documentation

◆ convert()

std::unique_ptr<MVAUtils::BDT> TMVAToMVAUtils::convert ( TMVA::MethodBDT *  bdt,
bool  isRegression = true,
bool  useYesNoLeaf = false 
)

Definition at line 114 of file TMVAToMVAUtils.h.

116  {
117 
118  std::unique_ptr<MVAUtils::IForest> forest= createForestTMVA(bdt,isRegression,useYesNoLeaf);
119  return std::make_unique<MVAUtils::BDT> (std::move(forest));
120 }

◆ createForestTMVA()

std::unique_ptr<MVAUtils::ForestTMVA> TMVAToMVAUtils::createForestTMVA ( TMVA::MethodBDT *  bdt,
bool  isRegression,
bool  useYesNoLeaf 
)

Definition at line 97 of file TMVAToMVAUtils.h.

99  {
100  auto forest=std::make_unique<MVAUtils::ForestTMVA>();
101  std::vector<TMVA::DecisionTree*>::const_iterator it;
102  for(it = bdt->GetForest().begin(); it != bdt->GetForest().end(); ++it) {
103  uint index = it - bdt->GetForest().begin();
104  float weight = 0.;
105  if(bdt->GetBoostWeights().size() > index) {
106  weight = bdt->GetBoostWeights()[index];
107  }
108 
109  newTree(forest.get(),(*it)->GetRoot(), weight, isRegression, useYesNoLeaf);
110  }
111  return forest;
112 }

◆ newTree()

void TMVAToMVAUtils::newTree ( MVAUtils::ForestTMVA forest,
const TMVA::DecisionTreeNode *  node,
float  weight,
bool  isRegression,
bool  useYesNoLeaf 
)

Definition at line 15 of file TMVAToMVAUtils.h.

20 {
21  int max_var=0;
22  // index is relative to the current node
23  std::vector<MVAUtils::index_t> right;
24  {
25  // not strictly parent if doing a right node
26  std::stack<const TMVA::DecisionTreeNode *> parent;
27  std::stack<MVAUtils::index_t> parentIndex;
28  parentIndex.push(-1);
29  parent.push(nullptr);
30  auto currNode = node;
31  int i = -1;
32  while (currNode) {
33  ++i;
34  right.push_back(-1);
35  if (!currNode->GetLeft()) {
36  // a leaf
37  auto currParent = parent.top();
38  auto currParentIndex = parentIndex.top();
39  // if right has not been visited, next will be right
40  if (currParentIndex >= 0) {
41  right[currParentIndex] = i + 1 - currParentIndex;
42  currNode = currParent->GetCutType() ? currParent->GetLeft() : currParent->GetRight();
43  } else {
44  currNode = nullptr;
45  }
46  parent.pop();
47  parentIndex.pop();
48  } else {
49  // not a leaf
50  parent.push(currNode);
51  parentIndex.push(i);
52  currNode = currNode->GetCutType() ? currNode->GetRight() : currNode->GetLeft();
53  }
54  }
55  }
56 
57  {
58  std::stack<const TMVA::DecisionTreeNode *> parent; // not strictly parent if doing a right node
59 
60  parent.push(nullptr);
61 
62  auto currNode = node;
63  int i = -1;
64  std::vector<MVAUtils::NodeTMVA> nodes;
65  while (currNode) {
66  ++i;
67  if (!currNode->GetLeft()){
68  // a leaf
69  nodes.emplace_back(-1,
70  isRegression ?
71  currNode->GetResponse() : useYesNoLeaf ? currNode->GetNodeType() : currNode->GetPurity(),
72  right[i]);
73  auto currParent = parent.top();
74  // if right has not been visited, next will be right
75  if (currParent) {
76  currNode = currParent->GetCutType() ? currParent->GetLeft() : currParent->GetRight();
77  } else {
78  currNode = nullptr;
79  }
80  parent.pop();
81  } else {
82  // not a leaf
83  parent.push(currNode);
84 
85  if (currNode->GetSelector() >max_var) { max_var = currNode->GetSelector(); }
86 
87  nodes.emplace_back(currNode->GetSelector(), currNode->GetCutValue(), right[i]);
88 
89  currNode = currNode->GetCutType() ? currNode->GetRight() : currNode->GetLeft();
90  }
91  }
92  forest->newTree(nodes, weight);
93  }
94  forest->setNVars(max_var);
95 }
MVAUtils::ForestWeighted::newTree
void newTree(const std::vector< Node_t > &nodes, float weight)
index
Definition: index.py:1
TMVAToMVAUtils::createForestTMVA
std::unique_ptr< MVAUtils::ForestTMVA > createForestTMVA(TMVA::MethodBDT *bdt, bool isRegression, bool useYesNoLeaf)
Definition: TMVAToMVAUtils.h:97
skel.it
it
Definition: skel.GENtoEVGEN.py:423
dqt_zlumi_pandas.weight
int weight
Definition: dqt_zlumi_pandas.py:200
uint
unsigned int uint
Definition: LArOFPhaseFill.cxx:20
lumiFormat.i
int i
Definition: lumiFormat.py:92
test_pyathena.parent
parent
Definition: test_pyathena.py:15
PyPoolBrowser.node
node
Definition: PyPoolBrowser.py:131
MVAUtils::ForestTMVA::setNVars
void setNVars(const int max_var)
Definition: ForestTMVA.h:82
DeMoScan.index
string index
Definition: DeMoScan.py:362
TMVAToMVAUtils::newTree
void newTree(MVAUtils::ForestTMVA *forest, const TMVA::DecisionTreeNode *node, float weight, bool isRegression, bool useYesNoLeaf)
Definition: TMVAToMVAUtils.h:15