ATLAS Offline Software
ForestTMVA.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef MVAUtils_ForestTMVA_H
6 #define MVAUtils_ForestTMVA_H
7 
8 #include "MVAUtils/Forest.h"
9 #include "TTree.h"
10 #include <vector>
11 
12 namespace MVAUtils
13 {
14 
22  template<typename Node_t>
23  class ForestWeighted : public Forest<Node_t>
24  {
25  public:
27 
30 
31  float GetTreeResponseWeighted(const std::vector<float>& values, unsigned int itree) const;
32  float GetTreeResponseWeighted(const std::vector<float*>& pointers, unsigned int itree) const;
33 
34  float GetWeightedResponse(const std::vector<float>& values) const;
35  float GetWeightedResponse(const std::vector<float*>& pointers) const;
36 
37  void newTree(const std::vector<Node_t>& nodes, float weight);
38  float GetTreeWeight(unsigned int itree) const { return m_weights[itree]; }
39  float GetSumWeights() const { return m_sumWeights; }
40 
41  virtual float GetOffset() const override { return m_weights[0]; }
42 
43  virtual void PrintTree(unsigned int itree) const override {
44  std::cout << "weight: " << m_weights[itree] << std::endl;
46  }
47 
48  private:
49  std::vector<float> m_weights;
50  float m_sumWeights;
51  };
52 
53 
54  /*
55  * Support TMVA processing
56  *
57  * Forest implementing the TMVA forest.
58  *
59  * Regression (GetResponse): offset + raw-response
60  * Classification (GetClassification): weighted average of the nodes
61  * MultiClassification: softmax of the raw-response
62  */
63  class ForestTMVA final : public ForestWeighted<NodeTMVA>
64  {
65  public:
66 
68  explicit ForestTMVA(TTree* tree);
69  ForestTMVA() = default;
70  ForestTMVA (const ForestTMVA&) = default;
71  ForestTMVA& operator=(const ForestTMVA&)=default;
72  ForestTMVA (ForestTMVA&&) = default;
74  ~ForestTMVA()=default;
75  virtual TTree* WriteTree(TString name) const override;
76  virtual float GetResponse(const std::vector<float>& values) const override ;
77  virtual float GetResponse(const std::vector<float*>& pointers) const override;
78  virtual float GetClassification(const std::vector<float>& values) const override;
79  virtual float GetClassification(const std::vector<float*>& pointers) const override ;
80  virtual void PrintForest() const override;
81  virtual int GetNVars() const override { return m_max_var + 1; }
82  void setNVars(const int max_var) {m_max_var=max_var;}
83  private:
84  int m_max_var=0;
85  };
86 
87 }
88 #include "MVAUtils/ForestTMVA.icc"
89 #endif
MVAUtils::ForestTMVA::WriteTree
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
Definition: ForestTMVA.cxx:61
MVAUtils::ForestWeighted::GetWeightedResponse
float GetWeightedResponse(const std::vector< float * > &pointers) const
MVAUtils::ForestTMVA::m_max_var
int m_max_var
Definition: ForestTMVA.h:84
MVAUtils
Definition: InDetTrkInJetType.h:48
MVAUtils::ForestWeighted::newTree
void newTree(const std::vector< Node_t > &nodes, float weight)
ForestTMVA.icc
MVAUtils::ForestWeighted::GetTreeResponseWeighted
float GetTreeResponseWeighted(const std::vector< float > &values, unsigned int itree) const
MVAUtils::ForestTMVA::GetClassification
virtual float GetClassification(const std::vector< float > &values) const override
Compute the prediction of a classification.
tree
TChain * tree
Definition: tile_monitor.h:30
MVAUtils::ForestWeighted::m_weights
std::vector< float > m_weights
boost weights
Definition: ForestTMVA.h:49
MVAUtils::ForestWeighted::ForestWeighted
ForestWeighted()
Definition: ForestTMVA.h:26
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:805
MVAUtils::ForestTMVA::GetResponse
virtual float GetResponse(const std::vector< float * > &pointers) const override
dqt_zlumi_pandas.weight
int weight
Definition: dqt_zlumi_pandas.py:189
MVAUtils::ForestTMVA::GetNVars
virtual int GetNVars() const override
Get the number of input variable to be passed with std::vector to Get* methods.
Definition: ForestTMVA.h:81
MVAUtils::ForestWeighted::GetOffset
virtual float GetOffset() const override
Return the offset of the forest.
Definition: ForestTMVA.h:41
Forest.h
MVAUtils::ForestTMVA::operator=
ForestTMVA & operator=(ForestTMVA &&)=default
MVAUtils::Forest
Generic Forest base class.
Definition: Forest.h:54
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
MVAUtils::ForestWeighted::GetTreeResponseWeighted
float GetTreeResponseWeighted(const std::vector< float * > &pointers, unsigned int itree) const
MVAUtils::ForestTMVA::ForestTMVA
ForestTMVA()=default
MVAUtils::ForestTMVA::setNVars
void setNVars(const int max_var)
Definition: ForestTMVA.h:82
MVAUtils::ForestWeighted::m_sumWeights
float m_sumWeights
the sumOfBoostWeights–no need to recompute each call
Definition: ForestTMVA.h:50
MVAUtils::ForestWeighted::GetWeightedResponse
float GetWeightedResponse(const std::vector< float > &values) const
MVAUtils::ForestTMVA::ForestTMVA
ForestTMVA(ForestTMVA &&)=default
pointers
std::vector< T * > pointers(std::vector< T > &v)
Definition: rmain.cxx:366
MVAUtils::ForestTMVA::~ForestTMVA
~ForestTMVA()=default
MVAUtils::ForestWeighted::GetTreeWeight
float GetTreeWeight(unsigned int itree) const
Definition: ForestTMVA.h:38
MVAUtils::ForestTMVA::GetResponse
virtual float GetResponse(const std::vector< float > &values) const override
Compute the prediction for regression.
MVAUtils::ForestTMVA
Definition: ForestTMVA.h:64
MVAUtils::ForestWeighted::GetSumWeights
float GetSumWeights() const
Definition: ForestTMVA.h:39
MVAUtils::ForestTMVA::PrintForest
virtual void PrintForest() const override
Definition: ForestTMVA.cxx:86
MVAUtils::ForestTMVA::GetClassification
virtual float GetClassification(const std::vector< float * > &pointers) const override
MVAUtils::ForestWeighted::PrintTree
virtual void PrintTree(unsigned int itree) const override
Definition: ForestTMVA.h:43
MVAUtils::ForestTMVA::ForestTMVA
ForestTMVA(const ForestTMVA &)=default
MVAUtils::ForestTMVA::operator=
ForestTMVA & operator=(const ForestTMVA &)=default
MVAUtils::Forest::PrintTree
virtual void PrintTree(unsigned int itree) const override
MVAUtils::ForestWeighted
Implement a Forest with weighted nodes This a general Forest class which implement the strategy used ...
Definition: ForestTMVA.h:24