ATLAS Offline Software
Reconstruction/MVAUtils/MVAUtils/BDT.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef MVAUtils_BDT_H
6 #define MVAUtils_BDT_H
7 
9 #include "TString.h"
10 #include <vector>
11 #include <map>
12 #include <cassert>
13 #include <memory>
14 #include <cmath>
15 #include "ForestBase.h"
16 
17 class TTree;
18 namespace MVAUtils
19 {
20 
34  class BDT
35  {
36  public:
42  explicit BDT(TTree *tree);//ctor TTree
43 
44  explicit BDT(std::unique_ptr<IForest> forest):
45  m_forest(std::move(forest)){
46  }
47 
48  /* delete default ctor
49  * and default copy / assignment*/
50  BDT() = delete;
51  BDT (const BDT&) = delete;
52  BDT& operator=(const BDT&) = delete;
55  BDT (BDT&&) = default;
56  BDT& operator=(BDT&&) = default;
57  ~BDT()=default;
58 
59 
60 
62  unsigned int GetNTrees() const ;
64  int GetNVars() const ;
66  float GetOffset() const ;
68  float GetResponse(const std::vector<float>& values) const;
69  float GetResponse(const std::vector<float*>& pointers) const;
70  float GetResponse() const;
71 
73  float GetClassification(const std::vector<float>& values) const;
74  float GetClassification(const std::vector<float*>& pointers) const;
75  float GetClassification() const;
76 
77  // TMVA specific: return 2.0/(1.0+exp(-2.0*sum))-1, with no offset.
78  float GetGradBoostMVA(const std::vector<float>& values) const;
79  float GetGradBoostMVA(const std::vector<float*>& pointers) const;
80 
82  std::vector<float> GetMultiResponse(const std::vector<float>& values, unsigned int numClasses) const;
83  std::vector<float> GetMultiResponse(const std::vector<float*>& pointers, unsigned int numClasses) const;
84  std::vector<float> GetMultiResponse(unsigned int numClasses) const;
85 
87  std::vector<float> GetValues() const;
89  const std::vector<float*>& GetPointers() const;
91  void SetPointers(const std::vector<float*>& pointers);
92 
96  TTree* WriteTree(TString name = "BDT") const;
97 
99  void PrintForest() const;
100  void PrintTree(unsigned int itree) const;
101 
103  float GetTreeResponse(const std::vector<float>& values, MVAUtils::index_t index) const;
104  float GetTreeResponse(const std::vector<float*>& pointers, MVAUtils::index_t index) const;
105 
106  private:
107  std::unique_ptr<IForest> m_forest;
108  std::vector<float*> m_pointers;
109  };
110 }
111 
112 #include "MVAUtils/BDT.icc"
113 #endif
MVAUtils::BDT::GetOffset
float GetOffset() const
Get the offset to the whole forest.
MVAUtils::BDT::GetResponse
float GetResponse() const
MVAUtils::BDT::GetNTrees
unsigned int GetNTrees() const
Number of trees in the whole forest.
MVAUtils
Definition: InDetTrkInJetType.h:48
MVAUtils::BDT::GetValues
std::vector< float > GetValues() const
Return the values corresponding to m_pointers (or an empty vector)
MVAUtils::BDT::PrintTree
void PrintTree(unsigned int itree) const
Definition: BDT.cxx:93
MVAUtils::BDT::BDT
BDT(BDT &&)=default
default move ctor, move assignment and dtor
index
Definition: index.py:1
MVAUtils::BDT::BDT
BDT()=delete
tree
TChain * tree
Definition: tile_monitor.h:30
MVAUtils::BDT::SetPointers
void SetPointers(const std::vector< float * > &pointers)
Set the stored pointers so that one can use methods with no args.
MVAUtils::BDT::BDT
BDT(TTree *tree)
Constructor.
MVAUtils::BDT
Simplified Boosted Regression Tree, support TMVA, lgbm, and xgboost.
Definition: Reconstruction/MVAUtils/MVAUtils/BDT.h:35
MVAUtils::BDT::GetPointers
const std::vector< float * > & GetPointers() const
Return stored pointers (which are used by methods with no args)
MVAUtils::BDT::GetTreeResponse
float GetTreeResponse(const std::vector< float * > &pointers, MVAUtils::index_t index) const
MVAUtilsDefs.h
MVAUtils::BDT::GetTreeResponse
float GetTreeResponse(const std::vector< float > &values, MVAUtils::index_t index) const
for debugging, return the response of a sigle tree given the index of its top node
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:808
MVAUtils::BDT::operator=
BDT & operator=(const BDT &)=delete
MVAUtils::BDT::m_forest
std::unique_ptr< IForest > m_forest
the implementation of the forest, doing the hard work
Definition: Reconstruction/MVAUtils/MVAUtils/BDT.h:107
MVAUtils::BDT::GetResponse
float GetResponse(const std::vector< float * > &pointers) const
BDT.icc
MVAUtils::BDT::operator=
BDT & operator=(BDT &&)=default
MVAUtils::BDT::GetGradBoostMVA
float GetGradBoostMVA(const std::vector< float * > &pointers) const
MVAUtils::BDT::GetClassification
float GetClassification(const std::vector< float > &values) const
Get response of the forest, for classification.
MVAUtils::BDT::PrintForest
void PrintForest() const
for debugging, print out tree or forest to stdout
Definition: BDT.cxx:92
MVAUtils::BDT::BDT
BDT(const BDT &)=delete
MVAUtils::index_t
int32_t index_t
The index type of the node in the vector.
Definition: MVAUtilsDefs.h:12
MVAUtils::BDT::GetMultiResponse
std::vector< float > GetMultiResponse(const std::vector< float * > &pointers, unsigned int numClasses) const
MVAUtils::BDT::m_pointers
std::vector< float * > m_pointers
where vars to cut on can be set (but can also be passed)
Definition: Reconstruction/MVAUtils/MVAUtils/BDT.h:108
MVAUtils::BDT::WriteTree
TTree * WriteTree(TString name="BDT") const
Return a TTree representing the BDT: each entry is a binary tree, each element of the vectors is a no...
Definition: BDT.cxx:91
MVAUtils::BDT::GetClassification
float GetClassification(const std::vector< float * > &pointers) const
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
MVAUtils::BDT::GetGradBoostMVA
float GetGradBoostMVA(const std::vector< float > &values) const
MVAUtils::BDT::~BDT
~BDT()=default
MVAUtils::BDT::BDT
BDT(std::unique_ptr< IForest > forest)
Definition: Reconstruction/MVAUtils/MVAUtils/BDT.h:44
MVAUtils::BDT::GetNVars
int GetNVars() const
Number of variables expected in the inputs.
MVAUtils::BDT::GetMultiResponse
std::vector< float > GetMultiResponse(unsigned int numClasses) const
pointers
std::vector< T * > pointers(std::vector< T > &v)
Definition: rmain.cxx:367
MVAUtils::BDT::GetResponse
float GetResponse(const std::vector< float > &values) const
Get response of the forest, for regression.
ForestBase.h
MVAUtils::BDT::GetClassification
float GetClassification() const
MVAUtils::BDT::GetMultiResponse
std::vector< float > GetMultiResponse(const std::vector< float > &values, unsigned int numClasses) const
Get response of the forest, for multiclassification (e.g.