ATLAS Offline Software
Loading...
Searching...
No Matches
MVAUtils::ForestTMVA Class Referencefinalabstract

#include <ForestTMVA.h>

Inheritance diagram for MVAUtils::ForestTMVA:
Collaboration diagram for MVAUtils::ForestTMVA:

Public Member Functions

 ForestTMVA (TTree *tree)
 ForestTMVA ()=default
 ForestTMVA (const ForestTMVA &)=default
ForestTMVAoperator= (const ForestTMVA &)=default
 ForestTMVA (ForestTMVA &&)=default
ForestTMVAoperator= (ForestTMVA &&)=default
 ~ForestTMVA ()=default
virtual TTree * WriteTree (TString name) const override
 Return a TTree representing the BDT.
virtual float GetResponse (const std::vector< float > &values) const override
 Compute the prediction for regression.
virtual float GetResponse (const std::vector< float * > &pointers) const override
virtual float GetClassification (const std::vector< float > &values) const override
 Compute the prediction of a classification.
virtual float GetClassification (const std::vector< float * > &pointers) const override
virtual void PrintForest () const override
virtual int GetNVars () const override
 Get the number of input variable to be passed with std::vector to Get* methods.
void setNVars (const int max_var)
float GetTreeResponseWeighted (const std::vector< float > &values, unsigned int itree) const
float GetWeightedResponse (const std::vector< float > &values) const
void newTree (const std::vector< NodeTMVA > &nodes, float weight)
float GetTreeWeight (unsigned int itree) const
float GetSumWeights () const
virtual float GetOffset () const override
 Return the offset of the forest.
virtual void PrintTree (unsigned int itree) const override
virtual float GetTreeResponse (const std::vector< float > &values, unsigned int itree) const override final
 Return the response of one tree Must pass the features in a std::vector<float> values and the index of the tree.
virtual float GetTreeResponse (const std::vector< float * > &pointers, unsigned int itree) const =0
virtual float GetRawResponse (const std::vector< float > &values) const override final
 Return the response of the whole Forest.
virtual float GetRawResponse (const std::vector< float * > &pointers) const =0
virtual std::vector< float > GetMultiResponse (const std::vector< float > &values, unsigned int numClasses) const override
 Compute the prediction for multiclassification (a score for each class).
virtual std::vector< float > GetMultiResponse (const std::vector< float * > &pointers, unsigned int numClasses) const =0
virtual unsigned int GetNTrees () const override final
std::vector< NodeTMVAGetTree (unsigned int itree) const
 Return the vector of nodes for the tree itree.

Protected Member Functions

void newTree (const std::vector< NodeTMVA > &nodes)
 append a new tree (defined by a vector of nodes serialized in preorder) to the forest
float GetTreeResponseFromNode (const std::vector< float > &values, index_t index) const
 Get the response of a tree.

Private Attributes

int m_max_var =0
std::vector< float > m_weights
 boost weights
float m_sumWeights
 the sumOfBoostWeights–no need to recompute each call
std::vector< index_tm_forest
 indices of the top-level nodes of each tree
std::vector< NodeTMVAm_nodes
 where the nodes of the forest are stored

Detailed Description

Definition at line 63 of file ForestTMVA.h.

Constructor & Destructor Documentation

◆ ForestTMVA() [1/4]

ForestTMVA::ForestTMVA ( TTree * tree)
explicit

Definition at line 12 of file ForestTMVA.cxx.

14 , m_max_var(0)
15{
16 // variables read from the TTree
17 std::vector<int> *vars = nullptr;
18 std::vector<float> *values = nullptr;
19 float offset; // the offset is the weight
20
21 std::vector<NodeTMVA> nodes;
22
23 tree->SetBranchAddress("vars", &vars);
24 tree->SetBranchAddress("values", &values);
25 tree->SetBranchAddress("offset", &offset);
26
27 int numEntries = tree->GetEntries();
28 for (int entry = 0; entry < numEntries; ++entry) {
29 // each entry in the TTree is a decision tree
30 tree->GetEntry(entry);
31 if (!vars) {
32 throw std::runtime_error(
33 "vars pointer is null in ForestTMVA constructor");
34 }
35 if (!values) {
36 throw std::runtime_error(
37 "values pointers is null in ForestTMVA constructor");
38 }
39 if (vars->size() != values->size()) {
40 throw std::runtime_error(
41 "inconsistent size for vars and values in ForestTMVA constructor");
42 }
43
44 nodes.clear();
45
46 std::vector<MVAUtils::index_t> right = detail::computeRight(*vars);
47
48 for (size_t i = 0; i < vars->size(); ++i) {
49 nodes.emplace_back(vars->at(i), values->at(i), right[i]);
50 if (vars->at(i) > m_max_var) {
51 m_max_var = vars->at(i);
52 }
53 }
54 newTree(nodes, offset);
55 } // end loop on TTree, all decision tree loaded
56 delete vars;
57 delete values;
58}
void newTree(const std::vector< NodeTMVA > &nodes, float weight)
std::vector< index_t > computeRight(const std::vector< int > &vars)
Compute the offsets between the nodes to their right children from a serialized representation of the...
TChain * tree

◆ ForestTMVA() [2/4]

MVAUtils::ForestTMVA::ForestTMVA ( )
default

◆ ForestTMVA() [3/4]

MVAUtils::ForestTMVA::ForestTMVA ( const ForestTMVA & )
default

◆ ForestTMVA() [4/4]

MVAUtils::ForestTMVA::ForestTMVA ( ForestTMVA && )
default

◆ ~ForestTMVA()

MVAUtils::ForestTMVA::~ForestTMVA ( )
default

Member Function Documentation

◆ GetClassification() [1/2]

virtual float MVAUtils::ForestTMVA::GetClassification ( const std::vector< float * > & pointers) const
overridevirtual

Implements MVAUtils::IForest.

◆ GetClassification() [2/2]

virtual float MVAUtils::ForestTMVA::GetClassification ( const std::vector< float > & values) const
overridevirtual

Compute the prediction of a classification.

Implements MVAUtils::IForest.

◆ GetMultiResponse() [1/2]

virtual std::vector< float > MVAUtils::Forest< NodeTMVA >::GetMultiResponse ( const std::vector< float > & values,
unsigned int numClasses ) const
overridevirtualinherited

Compute the prediction for multiclassification (a score for each class).

In addition to the input values need to pass the number of classes

Implements MVAUtils::IForest.

◆ GetMultiResponse() [2/2]

virtual std::vector< float > MVAUtils::IForest::GetMultiResponse ( const std::vector< float * > & pointers,
unsigned int numClasses ) const
pure virtualinherited

Implemented in MVAUtils::Forest< Node_t >.

◆ GetNTrees()

virtual unsigned int MVAUtils::Forest< NodeTMVA >::GetNTrees ( ) const
inlinefinaloverridevirtualinherited

Implements MVAUtils::IForest.

Definition at line 94 of file Forest.h.

95 {
96 return m_forest.size();
97 }
Generic Forest base class.
Definition Forest.h:54

◆ GetNVars()

virtual int MVAUtils::ForestTMVA::GetNVars ( ) const
inlineoverridevirtual

Get the number of input variable to be passed with std::vector to Get* methods.

Implements MVAUtils::IForest.

Definition at line 81 of file ForestTMVA.h.

81{ return m_max_var + 1; }

◆ GetOffset()

virtual float MVAUtils::ForestWeighted< NodeTMVA >::GetOffset ( ) const
inlineoverridevirtualinherited

Return the offset of the forest.

Since by default there is no offset, return 0

Reimplemented from MVAUtils::Forest< NodeTMVA >.

Definition at line 41 of file ForestTMVA.h.

41{ return m_weights[0]; }
Implement a Forest with weighted nodes This a general Forest class which implement the strategy used ...
Definition ForestTMVA.h:24

◆ GetRawResponse() [1/2]

virtual float MVAUtils::Forest< NodeTMVA >::GetRawResponse ( const std::vector< float > & values) const
finaloverridevirtualinherited

Return the response of the whole Forest.

Raw is just the sum of all the trees

Implements MVAUtils::IForest.

◆ GetRawResponse() [2/2]

virtual float MVAUtils::IForest::GetRawResponse ( const std::vector< float * > & pointers) const
pure virtualinherited

Implemented in MVAUtils::Forest< Node_t >.

◆ GetResponse() [1/2]

virtual float MVAUtils::ForestTMVA::GetResponse ( const std::vector< float * > & pointers) const
overridevirtual

Implements MVAUtils::IForest.

◆ GetResponse() [2/2]

virtual float MVAUtils::ForestTMVA::GetResponse ( const std::vector< float > & values) const
overridevirtual

Compute the prediction for regression.

Reimplemented from MVAUtils::Forest< NodeTMVA >.

◆ GetSumWeights()

float MVAUtils::ForestWeighted< NodeTMVA >::GetSumWeights ( ) const
inlineinherited

Definition at line 39 of file ForestTMVA.h.

39{ return m_sumWeights; }

◆ GetTree()

std::vector< NodeTMVA > MVAUtils::Forest< NodeTMVA >::GetTree ( unsigned int itree) const
inherited

Return the vector of nodes for the tree itree.

◆ GetTreeResponse() [1/2]

virtual float MVAUtils::Forest< NodeTMVA >::GetTreeResponse ( const std::vector< float > & values,
unsigned int itree ) const
finaloverridevirtualinherited

Return the response of one tree Must pass the features in a std::vector<float> values and the index of the tree.

Implements MVAUtils::IForest.

◆ GetTreeResponse() [2/2]

virtual float MVAUtils::IForest::GetTreeResponse ( const std::vector< float * > & pointers,
unsigned int itree ) const
pure virtualinherited

Implemented in MVAUtils::Forest< Node_t >.

◆ GetTreeResponseFromNode()

float MVAUtils::Forest< NodeTMVA >::GetTreeResponseFromNode ( const std::vector< float > & values,
index_t index ) const
protectedinherited

Get the response of a tree.

Instead of specifying the index of the tree (as in GetTreeResponse) the index of the top node of the tree should be specified

◆ GetTreeResponseWeighted()

float MVAUtils::ForestWeighted< NodeTMVA >::GetTreeResponseWeighted ( const std::vector< float > & values,
unsigned int itree ) const
inherited

◆ GetTreeWeight()

float MVAUtils::ForestWeighted< NodeTMVA >::GetTreeWeight ( unsigned int itree) const
inlineinherited

Definition at line 38 of file ForestTMVA.h.

38{ return m_weights[itree]; }

◆ GetWeightedResponse()

float MVAUtils::ForestWeighted< NodeTMVA >::GetWeightedResponse ( const std::vector< float > & values) const
inherited

◆ newTree() [1/2]

void MVAUtils::Forest< NodeTMVA >::newTree ( const std::vector< NodeTMVA > & nodes)
protectedinherited

append a new tree (defined by a vector of nodes serialized in preorder) to the forest

◆ newTree() [2/2]

void MVAUtils::ForestWeighted< NodeTMVA >::newTree ( const std::vector< NodeTMVA > & nodes,
float weight )
inherited

◆ operator=() [1/2]

ForestTMVA & MVAUtils::ForestTMVA::operator= ( const ForestTMVA & )
default

◆ operator=() [2/2]

ForestTMVA & MVAUtils::ForestTMVA::operator= ( ForestTMVA && )
default

◆ PrintForest()

void ForestTMVA::PrintForest ( ) const
overridevirtual

Reimplemented from MVAUtils::Forest< NodeTMVA >.

Definition at line 86 of file ForestTMVA.cxx.

87{
88 std::cout << "***BDT TMVA: Printing entire forest***" << std::endl;
90}
virtual void PrintForest() const override

◆ PrintTree()

virtual void MVAUtils::ForestWeighted< NodeTMVA >::PrintTree ( unsigned int itree) const
inlineoverridevirtualinherited

Reimplemented from MVAUtils::Forest< NodeTMVA >.

Definition at line 43 of file ForestTMVA.h.

43 {
44 std::cout << "weight: " << m_weights[itree] << std::endl;
46 }
virtual void PrintTree(unsigned int itree) const override

◆ setNVars()

void MVAUtils::ForestTMVA::setNVars ( const int max_var)
inline

Definition at line 82 of file ForestTMVA.h.

82{m_max_var=max_var;}

◆ WriteTree()

TTree * ForestTMVA::WriteTree ( TString ) const
overridevirtual

Return a TTree representing the BDT.

The called is the owner of the returned TTree

Implements MVAUtils::IForest.

Definition at line 61 of file ForestTMVA.cxx.

62{
63 TTree *tree = new TTree(name.Data(), "creator=TMVA");
64
65 std::vector<int> vars;
66 std::vector<float> values;
67 float offset;
68
69 tree->Branch("offset", &offset);
70 tree->Branch("vars", &vars);
71 tree->Branch("values", &values);
72
73 for (size_t itree = 0; itree < GetNTrees(); ++itree) {
74 vars.clear();
75 values.clear();
76 for(const auto& node : GetTree(itree)) {
77 vars.push_back(node.GetVar());
78 values.push_back(node.GetVal());
79 }
80 offset = GetTreeWeight(itree);
81 tree->Fill();
82 }
83 return tree;
84}
float GetTreeWeight(unsigned int itree) const
Definition ForestTMVA.h:38
std::vector< NodeTMVA > GetTree(unsigned int itree) const
virtual unsigned int GetNTrees() const override final
Definition Forest.h:94

Member Data Documentation

◆ m_forest

std::vector<index_t> MVAUtils::Forest< NodeTMVA >::m_forest
privateinherited

indices of the top-level nodes of each tree

Definition at line 117 of file Forest.h.

◆ m_max_var

int MVAUtils::ForestTMVA::m_max_var =0
private

Definition at line 84 of file ForestTMVA.h.

◆ m_nodes

std::vector<NodeTMVA> MVAUtils::Forest< NodeTMVA >::m_nodes
privateinherited

where the nodes of the forest are stored

Definition at line 118 of file Forest.h.

◆ m_sumWeights

float MVAUtils::ForestWeighted< NodeTMVA >::m_sumWeights
privateinherited

the sumOfBoostWeights–no need to recompute each call

Definition at line 50 of file ForestTMVA.h.

◆ m_weights

std::vector<float> MVAUtils::ForestWeighted< NodeTMVA >::m_weights
privateinherited

boost weights

Definition at line 49 of file ForestTMVA.h.


The documentation for this class was generated from the following files: