ATLAS Offline Software
ForestXGBoost.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #ifndef MVAUtils_ForestXGBOOST_H
6 #define MVAUtils_ForestXGBOOST_H
7 
8 #include "MVAUtils/Forest.h"
9 #include <cmath>
10 #include <algorithm>
11 #include <numeric>
12 #include <vector>
13 
14 
15 namespace MVAUtils
16 {
17  /*
18  * Support XGBoost processing of the forest response.
19  *
20  * User should use ForestXGBoost (for nan input support)
21  *
22  * Implement only the classification as: sigmoid(raw-reponse)
23  * Other methods are from Forest:
24  * Regression (GetResponse) as raw-response.
25  * Global bias 'base_score' is not included [default=0.5]
26  * */
27  template<typename Node_t>
28  class ForestXGBoostBase : public Forest<Node_t>
29  {
30  public:
32 
33  virtual float GetClassification(const std::vector<float>& values) const final
34  {
36  }
37  virtual float GetClassification(const std::vector<float*>& pointers) const final
38  {
40  }
41  };
42 
44  class ForestXGBoost final : public ForestXGBoostBase<NodeXGBoost>
45  {
46  public:
47  explicit ForestXGBoost(TTree* tree);
48  ForestXGBoost() = default;
49  ForestXGBoost (const ForestXGBoost&) = default;
53  ~ForestXGBoost()=default;
54 
55  virtual TTree* WriteTree(TString name) const override;
56  virtual void PrintForest() const override;
57  virtual int GetNVars() const override { return m_max_var + 1; }
58  private:
59  int m_max_var=0;
60  };
61 }
62 
63 #endif
MVAUtils::ForestXGBoost
Implement XGBoost with nan support.
Definition: ForestXGBoost.h:45
MVAUtils::detail::sigmoid
T sigmoid(T x)
Definition: Forest.h:21
MVAUtils
Definition: InDetTrkInJetType.h:48
MVAUtils::ForestXGBoost::operator=
ForestXGBoost & operator=(ForestXGBoost &&)=default
MVAUtils::Forest::GetResponse
virtual float GetResponse(const std::vector< float > &values) const override
Compute the prediction for regression.
MVAUtils::ForestXGBoost::operator=
ForestXGBoost & operator=(const ForestXGBoost &)=default
tree
TChain * tree
Definition: tile_monitor.h:30
MVAUtils::ForestXGBoost::~ForestXGBoost
~ForestXGBoost()=default
MVAUtils::ForestXGBoost::ForestXGBoost
ForestXGBoost()=default
python.Bindings.values
values
Definition: Control/AthenaPython/python/Bindings.py:805
MVAUtils::ForestXGBoostBase
Definition: ForestXGBoost.h:29
MVAUtils::ForestXGBoostBase::GetClassification
virtual float GetClassification(const std::vector< float > &values) const final
Compute the prediction of a classification.
Definition: ForestXGBoost.h:33
Forest.h
MVAUtils::ForestXGBoost::PrintForest
virtual void PrintForest() const override
Definition: ForestXGBoost.cxx:81
MVAUtils::Forest
Generic Forest base class.
Definition: Forest.h:54
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:221
MVAUtils::ForestXGBoost::WriteTree
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
Definition: ForestXGBoost.cxx:55
MVAUtils::ForestXGBoostBase::GetClassification
virtual float GetClassification(const std::vector< float * > &pointers) const final
Definition: ForestXGBoost.h:37
MVAUtils::ForestXGBoost::ForestXGBoost
ForestXGBoost(ForestXGBoost &&)=default
pointers
std::vector< T * > pointers(std::vector< T > &v)
Definition: rmain.cxx:366
MVAUtils::ForestXGBoost::ForestXGBoost
ForestXGBoost(const ForestXGBoost &)=default
MVAUtils::ForestXGBoost::m_max_var
int m_max_var
Definition: ForestXGBoost.h:59
MVAUtils::ForestXGBoost::GetNVars
virtual int GetNVars() const override
Get the number of input variable to be passed with std::vector to Get* methods.
Definition: ForestXGBoost.h:57