ATLAS Offline Software
Loading...
Searching...
No Matches
TMVAToMVAUtils Namespace Reference

Functions

void newTree (MVAUtils::ForestTMVA *forest, const TMVA::DecisionTreeNode *node, float weight, bool isRegression, bool useYesNoLeaf)
std::unique_ptr< MVAUtils::ForestTMVAcreateForestTMVA (TMVA::MethodBDT *bdt, bool isRegression, bool useYesNoLeaf)
std::unique_ptr< MVAUtils::BDTconvert (TMVA::MethodBDT *bdt, bool isRegression=true, bool useYesNoLeaf=false)

Function Documentation

◆ convert()

std::unique_ptr< MVAUtils::BDT > TMVAToMVAUtils::convert ( TMVA::MethodBDT * bdt,
bool isRegression = true,
bool useYesNoLeaf = false )

Definition at line 114 of file TMVAToMVAUtils.h.

116 {
117
118 std::unique_ptr<MVAUtils::IForest> forest= createForestTMVA(bdt,isRegression,useYesNoLeaf);
119 return std::make_unique<MVAUtils::BDT> (std::move(forest));
120}
std::unique_ptr< MVAUtils::ForestTMVA > createForestTMVA(TMVA::MethodBDT *bdt, bool isRegression, bool useYesNoLeaf)

◆ createForestTMVA()

std::unique_ptr< MVAUtils::ForestTMVA > TMVAToMVAUtils::createForestTMVA ( TMVA::MethodBDT * bdt,
bool isRegression,
bool useYesNoLeaf )

Definition at line 97 of file TMVAToMVAUtils.h.

99 {
100 auto forest=std::make_unique<MVAUtils::ForestTMVA>();
101 std::vector<TMVA::DecisionTree*>::const_iterator it;
102 for(it = bdt->GetForest().begin(); it != bdt->GetForest().end(); ++it) {
103 uint index = it - bdt->GetForest().begin();
104 float weight = 0.;
105 if(bdt->GetBoostWeights().size() > index) {
106 weight = bdt->GetBoostWeights()[index];
107 }
108
109 newTree(forest.get(),(*it)->GetRoot(), weight, isRegression, useYesNoLeaf);
110 }
111 return forest;
112}
unsigned int uint
void newTree(MVAUtils::ForestTMVA *forest, const TMVA::DecisionTreeNode *node, float weight, bool isRegression, bool useYesNoLeaf)
Definition index.py:1

◆ newTree()

void TMVAToMVAUtils::newTree ( MVAUtils::ForestTMVA * forest,
const TMVA::DecisionTreeNode * node,
float weight,
bool isRegression,
bool useYesNoLeaf )

Definition at line 15 of file TMVAToMVAUtils.h.

20{
21 int max_var=0;
22 // index is relative to the current node
23 std::vector<MVAUtils::index_t> right;
24 {
25 // not strictly parent if doing a right node
26 std::stack<const TMVA::DecisionTreeNode *> parent;
27 std::stack<MVAUtils::index_t> parentIndex;
28 parentIndex.push(-1);
29 parent.push(nullptr);
30 auto currNode = node;
31 int i = -1;
32 while (currNode) {
33 ++i;
34 right.push_back(-1);
35 if (!currNode->GetLeft()) {
36 // a leaf
37 auto currParent = parent.top();
38 auto currParentIndex = parentIndex.top();
39 // if right has not been visited, next will be right
40 if (currParentIndex >= 0) {
41 right[currParentIndex] = i + 1 - currParentIndex;
42 currNode = currParent->GetCutType() ? currParent->GetLeft() : currParent->GetRight();
43 } else {
44 currNode = nullptr;
45 }
46 parent.pop();
47 parentIndex.pop();
48 } else {
49 // not a leaf
50 parent.push(currNode);
51 parentIndex.push(i);
52 currNode = currNode->GetCutType() ? currNode->GetRight() : currNode->GetLeft();
53 }
54 }
55 }
56
57 {
58 std::stack<const TMVA::DecisionTreeNode *> parent; // not strictly parent if doing a right node
59
60 parent.push(nullptr);
61
62 auto currNode = node;
63 int i = -1;
64 std::vector<MVAUtils::NodeTMVA> nodes;
65 while (currNode) {
66 ++i;
67 if (!currNode->GetLeft()){
68 // a leaf
69 nodes.emplace_back(-1,
70 isRegression ?
71 currNode->GetResponse() : useYesNoLeaf ? currNode->GetNodeType() : currNode->GetPurity(),
72 right[i]);
73 auto currParent = parent.top();
74 // if right has not been visited, next will be right
75 if (currParent) {
76 currNode = currParent->GetCutType() ? currParent->GetLeft() : currParent->GetRight();
77 } else {
78 currNode = nullptr;
79 }
80 parent.pop();
81 } else {
82 // not a leaf
83 parent.push(currNode);
84
85 if (currNode->GetSelector() >max_var) { max_var = currNode->GetSelector(); }
86
87 nodes.emplace_back(currNode->GetSelector(), currNode->GetCutValue(), right[i]);
88
89 currNode = currNode->GetCutType() ? currNode->GetRight() : currNode->GetLeft();
90 }
91 }
92 forest->newTree(nodes, weight);
93 }
94 forest->setNVars(max_var);
95}
void setNVars(const int max_var)
Definition ForestTMVA.h:82
void newTree(const std::vector< Node_t > &nodes, float weight)
Definition node.h:24