ATLAS Offline Software
Public Member Functions | Protected Member Functions | Private Attributes | List of all members
MVAUtils::ForestWeighted< Node_t > Class Template Referenceabstract

Implement a Forest with weighted nodes This a general Forest class which implement the strategy used by TMVA in some cases. More...

#include <ForestTMVA.h>

Inheritance diagram for MVAUtils::ForestWeighted< Node_t >:
Collaboration diagram for MVAUtils::ForestWeighted< Node_t >:

Public Member Functions

 ForestWeighted ()
 
float GetTreeResponseWeighted (const std::vector< float > &values, unsigned int itree) const
 
float GetTreeResponseWeighted (const std::vector< float * > &pointers, unsigned int itree) const
 
float GetWeightedResponse (const std::vector< float > &values) const
 
float GetWeightedResponse (const std::vector< float * > &pointers) const
 
void newTree (const std::vector< Node_t > &nodes, float weight)
 
float GetTreeWeight (unsigned int itree) const
 
float GetSumWeights () const
 
virtual float GetOffset () const override
 Return the offset of the forest. More...
 
virtual void PrintTree (unsigned int itree) const override
 
virtual float GetTreeResponse (const std::vector< float > &values, unsigned int itree) const override final
 Return the response of one tree Must pass the features in a std::vector<float> values and the index of the tree. More...
 
virtual float GetTreeResponse (const std::vector< float * > &pointers, unsigned int itree) const override final
 
virtual float GetRawResponse (const std::vector< float > &values) const override final
 Return the response of the whole Forest. More...
 
virtual float GetRawResponse (const std::vector< float * > &pointers) const override final
 
virtual float GetResponse (const std::vector< float > &values) const override
 Compute the prediction for regression. More...
 
virtual float GetResponse (const std::vector< float * > &pointers) const override
 
virtual std::vector< float > GetMultiResponse (const std::vector< float > &values, unsigned int numClasses) const override
 Compute the prediction for multiclassification (a score for each class). More...
 
virtual std::vector< float > GetMultiResponse (const std::vector< float * > &pointers, unsigned int numClasses) const override
 
virtual unsigned int GetNTrees () const override final
 
virtual void PrintForest () const override
 
std::vector< Node_t > GetTree (unsigned int itree) const
 Return the vector of nodes for the tree itree. More...
 
virtual float GetClassification (const std::vector< float > &values) const =0
 Compute the prediction of a classification. More...
 
virtual float GetClassification (const std::vector< float * > &pointers) const =0
 
virtual TTree * WriteTree (TString) const =0
 Return a TTree representing the BDT. More...
 
virtual int GetNVars () const =0
 Get the number of input variable to be passed with std::vector to Get* methods. More...
 

Protected Member Functions

float GetTreeResponseFromNode (const std::vector< float > &values, index_t index) const
 Get the response of a tree. More...
 
float GetTreeResponseFromNode (const std::vector< float * > &pointers, index_t index) const
 
void newTree (const std::vector< Node_t > &nodes)
 append a new tree (defined by a vector of nodes serialized in preorder) to the forest More...
 

Private Attributes

std::vector< float > m_weights
 boost weights More...
 
float m_sumWeights
 the sumOfBoostWeights–no need to recompute each call More...
 
std::vector< index_tm_forest
 indices of the top-level nodes of each tree More...
 
std::vector< Node_t > m_nodes
 where the nodes of the forest are stored More...
 

Detailed Description

template<typename Node_t>
class MVAUtils::ForestWeighted< Node_t >

Implement a Forest with weighted nodes This a general Forest class which implement the strategy used by TMVA in some cases.

Each node has a weight that can be used to compute GetTreeResponseWeighted In some cases an offset is needed, which is just the first weight (actually in TMVA all the weights are identical when the offset is used).

Definition at line 23 of file ForestTMVA.h.

Constructor & Destructor Documentation

◆ ForestWeighted()

template<typename Node_t >
MVAUtils::ForestWeighted< Node_t >::ForestWeighted ( )
inline

Definition at line 26 of file ForestTMVA.h.

26 : m_sumWeights(0.) { }

Member Function Documentation

◆ GetClassification() [1/2]

virtual float MVAUtils::IForest::GetClassification ( const std::vector< float * > &  pointers) const
pure virtualinherited

◆ GetClassification() [2/2]

virtual float MVAUtils::IForest::GetClassification ( const std::vector< float > &  values) const
pure virtualinherited

◆ GetMultiResponse() [1/2]

template<typename Node_t >
virtual std::vector<float> MVAUtils::Forest< Node_t >::GetMultiResponse ( const std::vector< float * > &  pointers,
unsigned int  numClasses 
) const
overridevirtualinherited

Implements MVAUtils::IForest.

◆ GetMultiResponse() [2/2]

template<typename Node_t >
virtual std::vector<float> MVAUtils::Forest< Node_t >::GetMultiResponse ( const std::vector< float > &  values,
unsigned int  numClasses 
) const
overridevirtualinherited

Compute the prediction for multiclassification (a score for each class).

In addition to the input values need to pass the number of classes

Implements MVAUtils::IForest.

◆ GetNTrees()

template<typename Node_t >
virtual unsigned int MVAUtils::Forest< Node_t >::GetNTrees ( ) const
inlinefinaloverridevirtualinherited

Implements MVAUtils::IForest.

Definition at line 94 of file Forest.h.

95  {
96  return m_forest.size();
97  }

◆ GetNVars()

virtual int MVAUtils::IForest::GetNVars ( ) const
pure virtualinherited

Get the number of input variable to be passed with std::vector to Get* methods.

Implemented in MVAUtils::ForestTMVA, MVAUtils::ForestLGBM, MVAUtils::ForestXGBoost, and MVAUtils::ForestLGBMSimple.

◆ GetOffset()

template<typename Node_t >
virtual float MVAUtils::ForestWeighted< Node_t >::GetOffset ( ) const
inlineoverridevirtual

Return the offset of the forest.

Since by default there is no offset, return 0

Reimplemented from MVAUtils::Forest< Node_t >.

Definition at line 41 of file ForestTMVA.h.

41 { return m_weights[0]; }

◆ GetRawResponse() [1/2]

template<typename Node_t >
virtual float MVAUtils::Forest< Node_t >::GetRawResponse ( const std::vector< float * > &  pointers) const
finaloverridevirtualinherited

Implements MVAUtils::IForest.

◆ GetRawResponse() [2/2]

template<typename Node_t >
virtual float MVAUtils::Forest< Node_t >::GetRawResponse ( const std::vector< float > &  values) const
finaloverridevirtualinherited

Return the response of the whole Forest.

Raw is just the sum of all the trees

Implements MVAUtils::IForest.

◆ GetResponse() [1/2]

template<typename Node_t >
virtual float MVAUtils::Forest< Node_t >::GetResponse ( const std::vector< float * > &  pointers) const
overridevirtualinherited

Implements MVAUtils::IForest.

Reimplemented in MVAUtils::ForestTMVA.

◆ GetResponse() [2/2]

template<typename Node_t >
virtual float MVAUtils::Forest< Node_t >::GetResponse ( const std::vector< float > &  values) const
overridevirtualinherited

Compute the prediction for regression.

Implements MVAUtils::IForest.

Reimplemented in MVAUtils::ForestTMVA.

◆ GetSumWeights()

template<typename Node_t >
float MVAUtils::ForestWeighted< Node_t >::GetSumWeights ( ) const
inline

Definition at line 39 of file ForestTMVA.h.

39 { return m_sumWeights; }

◆ GetTree()

template<typename Node_t >
std::vector<Node_t> MVAUtils::Forest< Node_t >::GetTree ( unsigned int  itree) const
inherited

Return the vector of nodes for the tree itree.

◆ GetTreeResponse() [1/2]

template<typename Node_t >
virtual float MVAUtils::Forest< Node_t >::GetTreeResponse ( const std::vector< float * > &  pointers,
unsigned int  itree 
) const
finaloverridevirtualinherited

Implements MVAUtils::IForest.

◆ GetTreeResponse() [2/2]

template<typename Node_t >
virtual float MVAUtils::Forest< Node_t >::GetTreeResponse ( const std::vector< float > &  values,
unsigned int  itree 
) const
finaloverridevirtualinherited

Return the response of one tree Must pass the features in a std::vector<float> values and the index of the tree.

Implements MVAUtils::IForest.

◆ GetTreeResponseFromNode() [1/2]

template<typename Node_t >
float MVAUtils::Forest< Node_t >::GetTreeResponseFromNode ( const std::vector< float * > &  pointers,
index_t  index 
) const
protectedinherited

◆ GetTreeResponseFromNode() [2/2]

template<typename Node_t >
float MVAUtils::Forest< Node_t >::GetTreeResponseFromNode ( const std::vector< float > &  values,
index_t  index 
) const
protectedinherited

Get the response of a tree.

Instead of specifying the index of the tree (as in GetTreeResponse) the index of the top node of the tree should be specified

◆ GetTreeResponseWeighted() [1/2]

template<typename Node_t >
float MVAUtils::ForestWeighted< Node_t >::GetTreeResponseWeighted ( const std::vector< float * > &  pointers,
unsigned int  itree 
) const

◆ GetTreeResponseWeighted() [2/2]

template<typename Node_t >
float MVAUtils::ForestWeighted< Node_t >::GetTreeResponseWeighted ( const std::vector< float > &  values,
unsigned int  itree 
) const

◆ GetTreeWeight()

template<typename Node_t >
float MVAUtils::ForestWeighted< Node_t >::GetTreeWeight ( unsigned int  itree) const
inline

Definition at line 38 of file ForestTMVA.h.

38 { return m_weights[itree]; }

◆ GetWeightedResponse() [1/2]

template<typename Node_t >
float MVAUtils::ForestWeighted< Node_t >::GetWeightedResponse ( const std::vector< float * > &  pointers) const

◆ GetWeightedResponse() [2/2]

template<typename Node_t >
float MVAUtils::ForestWeighted< Node_t >::GetWeightedResponse ( const std::vector< float > &  values) const

◆ newTree() [1/2]

template<typename Node_t >
void MVAUtils::Forest< Node_t >::newTree ( const std::vector< Node_t > &  nodes)
protectedinherited

append a new tree (defined by a vector of nodes serialized in preorder) to the forest

◆ newTree() [2/2]

template<typename Node_t >
void MVAUtils::ForestWeighted< Node_t >::newTree ( const std::vector< Node_t > &  nodes,
float  weight 
)

◆ PrintForest()

template<typename Node_t >
virtual void MVAUtils::Forest< Node_t >::PrintForest ( ) const
overridevirtualinherited

Implements MVAUtils::IForest.

◆ PrintTree()

template<typename Node_t >
virtual void MVAUtils::ForestWeighted< Node_t >::PrintTree ( unsigned int  itree) const
inlineoverridevirtual

Reimplemented from MVAUtils::Forest< Node_t >.

Definition at line 43 of file ForestTMVA.h.

43  {
44  std::cout << "weight: " << m_weights[itree] << std::endl;
46  }

◆ WriteTree()

virtual TTree* MVAUtils::IForest::WriteTree ( TString  ) const
pure virtualinherited

Return a TTree representing the BDT.

The called is the owner of the returned TTree

Implemented in MVAUtils::ForestTMVA, MVAUtils::ForestLGBM, MVAUtils::ForestXGBoost, and MVAUtils::ForestLGBMSimple.

Member Data Documentation

◆ m_forest

template<typename Node_t >
std::vector<index_t> MVAUtils::Forest< Node_t >::m_forest
privateinherited

indices of the top-level nodes of each tree

Definition at line 117 of file Forest.h.

◆ m_nodes

template<typename Node_t >
std::vector<Node_t> MVAUtils::Forest< Node_t >::m_nodes
privateinherited

where the nodes of the forest are stored

Definition at line 118 of file Forest.h.

◆ m_sumWeights

template<typename Node_t >
float MVAUtils::ForestWeighted< Node_t >::m_sumWeights
private

the sumOfBoostWeights–no need to recompute each call

Definition at line 50 of file ForestTMVA.h.

◆ m_weights

template<typename Node_t >
std::vector<float> MVAUtils::ForestWeighted< Node_t >::m_weights
private

boost weights

Definition at line 49 of file ForestTMVA.h.


The documentation for this class was generated from the following file:
MVAUtils::ForestWeighted::m_weights
std::vector< float > m_weights
boost weights
Definition: ForestTMVA.h:49
MVAUtils::Forest::m_forest
std::vector< index_t > m_forest
indices of the top-level nodes of each tree
Definition: Forest.h:117
MVAUtils::ForestWeighted::m_sumWeights
float m_sumWeights
the sumOfBoostWeights–no need to recompute each call
Definition: ForestTMVA.h:50
MVAUtils::Forest::PrintTree
virtual void PrintTree(unsigned int itree) const override