ATLAS Offline Software
ForestLGBM.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef MVAUtils_ForestLGBM_H
6 #define MVAUtils_ForestLGBM_H
7 
8 #include "MVAUtils/Forest.h"
9 #include <cmath>
10 #include <algorithm>
11 #include <numeric>
12 #include <vector>
13 namespace MVAUtils
14 {
15  /*
16  * Support LGBM processing of the forest response.
17  *
18  * User should use ForestLGBM (for nan input support) or ForestSimple
19  *
20  * Implement only the classification as: sigmoid(raw-reponse)
21  * Other methods are from Forest:
22  * Multiclassification as softmax(raw-response)
23  * Regression (GetResponse) as raw-response
24  * */
25  template<typename Node_t>
26  class ForestLGBMBase : public Forest<Node_t>
27  {
28  public:
30 
31  virtual float GetClassification(const std::vector<float>& values) const final
32  {
34  }
35  virtual float GetClassification(const std::vector<float*>& pointers) const final
36  {
38  }
39  };
40 
42  class ForestLGBMSimple final : public ForestLGBMBase<NodeLGBMSimple>
43  {
44  public:
45  explicit ForestLGBMSimple(TTree* tree);
46  ForestLGBMSimple() = default;
47  ForestLGBMSimple (const ForestLGBMSimple&) = default;
51  ~ForestLGBMSimple()=default;
52  virtual TTree* WriteTree(TString name) const override;
53  virtual void PrintForest() const override;
54  virtual int GetNVars() const override { return m_max_var + 1; }
55  private:
56  int m_max_var=0;
57 
58  };
59 
61  class ForestLGBM final : public ForestLGBMBase<NodeLGBM>
62  {
63  public:
64  explicit ForestLGBM(TTree* tree);
65  ForestLGBM() = default;
66  ForestLGBM (const ForestLGBM&) = default;
67  ForestLGBM& operator=(const ForestLGBM&) = default;
68  ForestLGBM (ForestLGBM&&) = default;
70  ~ForestLGBM()=default;
71  virtual TTree* WriteTree(TString name) const override;
72  virtual void PrintForest() const override;
73  virtual int GetNVars() const override { return m_max_var + 1; }
74  private:
75  int m_max_var=0;
76  };
77 }
78 
79 #endif
MVAUtils::ForestLGBM::ForestLGBM
ForestLGBM()=default
MVAUtils::ForestLGBMSimple::m_max_var
int m_max_var
Definition: ForestLGBM.h:56
MVAUtils::ForestLGBM::operator=
ForestLGBM & operator=(ForestLGBM &&)=default
MVAUtils::ForestLGBMSimple::operator=
ForestLGBMSimple & operator=(ForestLGBMSimple &&)=default
MVAUtils::detail::sigmoid
T sigmoid(T x)
Definition: Forest.h:21
MVAUtils
Definition: InDetTrkInJetType.h:48
MVAUtils::Forest::GetResponse
virtual float GetResponse(const std::vector< float > &values) const override
Compute the prediction for regression.
MVAUtils::ForestLGBMBase
Definition: ForestLGBM.h:27
tree
TChain * tree
Definition: tile_monitor.h:30
MVAUtils::ForestLGBMSimple::ForestLGBMSimple
ForestLGBMSimple()=default
MVAUtils::ForestLGBMBase::GetClassification
virtual float GetClassification(const std::vector< float > &values) const final
Compute the prediction of a classification.
Definition: ForestLGBM.h:31
MVAUtils::ForestLGBM::ForestLGBM
ForestLGBM(const ForestLGBM &)=default
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:805
MVAUtils::ForestLGBMBase::GetClassification
virtual float GetClassification(const std::vector< float * > &pointers) const final
Definition: ForestLGBM.h:35
MVAUtils::ForestLGBMSimple::GetNVars
virtual int GetNVars() const override
Get the number of input variable to be passed with std::vector to Get* methods.
Definition: ForestLGBM.h:54
Forest.h
MVAUtils::ForestLGBM::ForestLGBM
ForestLGBM(ForestLGBM &&)=default
MVAUtils::ForestLGBMSimple::operator=
ForestLGBMSimple & operator=(const ForestLGBMSimple &)=default
MVAUtils::ForestLGBM::GetNVars
virtual int GetNVars() const override
Get the number of input variable to be passed with std::vector to Get* methods.
Definition: ForestLGBM.h:73
MVAUtils::ForestLGBMSimple::WriteTree
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
Definition: ForestLGBM.cxx:60
MVAUtils::ForestLGBM::operator=
ForestLGBM & operator=(const ForestLGBM &)=default
MVAUtils::Forest
Generic Forest base class.
Definition: Forest.h:54
MVAUtils::ForestLGBM::WriteTree
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
Definition: ForestLGBM.cxx:148
MVAUtils::ForestLGBM
Implement LGBM with nan support.
Definition: ForestLGBM.h:62
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
MVAUtils::ForestLGBMSimple::PrintForest
virtual void PrintForest() const override
Definition: ForestLGBM.cxx:82
MVAUtils::ForestLGBM::m_max_var
int m_max_var
Definition: ForestLGBM.h:75
MVAUtils::ForestLGBMSimple::ForestLGBMSimple
ForestLGBMSimple(const ForestLGBMSimple &)=default
MVAUtils::ForestLGBMSimple::~ForestLGBMSimple
~ForestLGBMSimple()=default
MVAUtils::ForestLGBM::PrintForest
virtual void PrintForest() const override
Definition: ForestLGBM.cxx:174
pointers
std::vector< T * > pointers(std::vector< T > &v)
Definition: rmain.cxx:366
MVAUtils::ForestLGBMSimple
Implement LGBM Forest without nan support.
Definition: ForestLGBM.h:43
MVAUtils::ForestLGBMSimple::ForestLGBMSimple
ForestLGBMSimple(ForestLGBMSimple &&)=default
MVAUtils::ForestLGBM::~ForestLGBM
~ForestLGBM()=default