ATLAS Offline Software
BDT.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef MVAUtils_BDT_H
6 #define MVAUtils_BDT_H
7 
9 #include "TString.h"
10 #include <vector>
11 #include <map>
12 #include <cassert>
13 #include <memory>
14 #include "ForestBase.h"
15 
16 class TTree;
17 namespace MVAUtils
18 {
19 
33  class BDT
34  {
35  public:
41  explicit BDT(TTree *tree);//ctor TTree
42 
43  explicit BDT(std::unique_ptr<IForest> forest):
44  m_forest(std::move(forest)){
45  }
46 
47  /* delete default ctor
48  * and default copy / assignment*/
49  BDT() = delete;
50  BDT (const BDT&) = delete;
51  BDT& operator=(const BDT&) = delete;
54  BDT (BDT&&) = default;
55  BDT& operator=(BDT&&) = default;
56  ~BDT()=default;
57 
58 
59 
61  unsigned int GetNTrees() const ;
63  int GetNVars() const ;
65  float GetOffset() const ;
67  float GetResponse(const std::vector<float>& values) const;
68  float GetResponse(const std::vector<float*>& pointers) const;
69  float GetResponse() const;
70 
72  float GetClassification(const std::vector<float>& values) const;
73  float GetClassification(const std::vector<float*>& pointers) const;
74  float GetClassification() const;
75 
76  // TMVA specific: return 2.0/(1.0+exp(-2.0*sum))-1, with no offset.
77  float GetGradBoostMVA(const std::vector<float>& values) const;
78  float GetGradBoostMVA(const std::vector<float*>& pointers) const;
79 
81  std::vector<float> GetMultiResponse(const std::vector<float>& values, unsigned int numClasses) const;
82  std::vector<float> GetMultiResponse(const std::vector<float*>& pointers, unsigned int numClasses) const;
83  std::vector<float> GetMultiResponse(unsigned int numClasses) const;
84 
86  std::vector<float> GetValues() const;
88  const std::vector<float*>& GetPointers() const;
90  void SetPointers(const std::vector<float*>& pointers);
91 
95  TTree* WriteTree(TString name = "BDT") const;
96 
98  void PrintForest() const;
99  void PrintTree(unsigned int itree) const;
100 
102  float GetTreeResponse(const std::vector<float>& values, MVAUtils::index_t index) const;
103  float GetTreeResponse(const std::vector<float*>& pointers, MVAUtils::index_t index) const;
104 
105  private:
106  std::unique_ptr<IForest> m_forest;
107  std::vector<float*> m_pointers;
108  };
109 }
110 
111 #include "MVAUtils/BDT.icc"
112 #endif
MVAUtils::BDT::GetOffset
float GetOffset() const
Get the offset to the whole forest.
MVAUtils::BDT::GetResponse
float GetResponse() const
MVAUtils::BDT::GetNTrees
unsigned int GetNTrees() const
Number of trees in the whole forest.
MVAUtils
Definition: InDetTrkInJetType.h:47
MVAUtils::BDT::GetValues
std::vector< float > GetValues() const
Return the values corresponding to m_pointers (or an empty vector)
MVAUtils::BDT::PrintTree
void PrintTree(unsigned int itree) const
Definition: BDT.cxx:93
MVAUtils::BDT::BDT
BDT(BDT &&)=default
default move ctor, move assignment and dtor
index
Definition: index.py:1
MVAUtils::BDT::BDT
BDT()=delete
tree
TChain * tree
Definition: tile_monitor.h:30
MVAUtils::BDT::SetPointers
void SetPointers(const std::vector< float * > &pointers)
Set the stored pointers so that one can use methods with no args.
MVAUtils::BDT::BDT
BDT(TTree *tree)
Constructor.
MVAUtils::BDT
Simplified Boosted Regression Tree, support TMVA, lgbm, and xgboost.
Definition: BDT.h:34
MVAUtils::BDT::GetPointers
const std::vector< float * > & GetPointers() const
Return stored pointers (which are used by methods with no args)
MVAUtils::BDT::GetTreeResponse
float GetTreeResponse(const std::vector< float * > &pointers, MVAUtils::index_t index) const
MVAUtilsDefs.h
MVAUtils::BDT::GetTreeResponse
float GetTreeResponse(const std::vector< float > &values, MVAUtils::index_t index) const
for debugging, return the response of a sigle tree given the index of its top node
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:797
MVAUtils::BDT::operator=
BDT & operator=(const BDT &)=delete
MVAUtils::BDT::m_forest
std::unique_ptr< IForest > m_forest
the implementation of the forest, doing the hard work
Definition: BDT.h:106
MVAUtils::BDT::GetResponse
float GetResponse(const std::vector< float * > &pointers) const
BDT.icc
MVAUtils::BDT::operator=
BDT & operator=(BDT &&)=default
MVAUtils::BDT::GetGradBoostMVA
float GetGradBoostMVA(const std::vector< float * > &pointers) const
MVAUtils::BDT::GetClassification
float GetClassification(const std::vector< float > &values) const
Get response of the forest, for classification.
MVAUtils::BDT::PrintForest
void PrintForest() const
for debugging, print out tree or forest to stdout
Definition: BDT.cxx:92
MVAUtils::BDT::BDT
BDT(const BDT &)=delete
MVAUtils::index_t
int32_t index_t
The index type of the node in the vector.
Definition: MVAUtilsDefs.h:12
MVAUtils::BDT::GetMultiResponse
std::vector< float > GetMultiResponse(const std::vector< float * > &pointers, unsigned int numClasses) const
MVAUtils::BDT::m_pointers
std::vector< float * > m_pointers
where vars to cut on can be set (but can also be passed)
Definition: BDT.h:107
MVAUtils::BDT::WriteTree
TTree * WriteTree(TString name="BDT") const
Return a TTree representing the BDT: each entry is a binary tree, each element of the vectors is a no...
Definition: BDT.cxx:91
MVAUtils::BDT::GetClassification
float GetClassification(const std::vector< float * > &pointers) const
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:192
MVAUtils::BDT::GetGradBoostMVA
float GetGradBoostMVA(const std::vector< float > &values) const
MVAUtils::BDT::~BDT
~BDT()=default
MVAUtils::BDT::BDT
BDT(std::unique_ptr< IForest > forest)
Definition: BDT.h:43
MVAUtils::BDT::GetNVars
int GetNVars() const
Number of variables expected in the inputs.
MVAUtils::BDT::GetMultiResponse
std::vector< float > GetMultiResponse(unsigned int numClasses) const
pointers
std::vector< T * > pointers(std::vector< T > &v)
Definition: rmain.cxx:366
MVAUtils::BDT::GetResponse
float GetResponse(const std::vector< float > &values) const
Get response of the forest, for regression.
ForestBase.h
MVAUtils::BDT::GetClassification
float GetClassification() const
MVAUtils::BDT::GetMultiResponse
std::vector< float > GetMultiResponse(const std::vector< float > &values, unsigned int numClasses) const
Get response of the forest, for multiclassification (e.g.