ATLAS Offline Software
Reconstruction
MVAUtils
MVAUtils
ForestXGBoost.h
Go to the documentation of this file.
1
/*
2
Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3
*/
4
5
#ifndef MVAUtils_ForestXGBOOST_H
6
#define MVAUtils_ForestXGBOOST_H
7
8
#include "
MVAUtils/Forest.h
"
9
#include <cmath>
10
#include <algorithm>
11
#include <numeric>
12
#include <vector>
13
14
15
namespace
MVAUtils
16
{
17
/*
18
* Support XGBoost processing of the forest response.
19
*
20
* User should use ForestXGBoost (for nan input support)
21
*
22
* Implement only the classification as: sigmoid(raw-reponse)
23
* Other methods are from Forest:
24
* Regression (GetResponse) as raw-response.
25
* Global bias 'base_score' is not included [default=0.5]
26
* */
27
template
<
typename
Node_t>
28
class
ForestXGBoostBase
:
public
Forest
<Node_t>
29
{
30
public
:
31
using
Forest<Node_t>::GetResponse
;
32
33
virtual
float
GetClassification
(
const
std::vector<float>&
values
)
const
final
34
{
35
return
detail::sigmoid
(
GetResponse
(
values
));
36
}
37
virtual
float
GetClassification
(
const
std::vector<float*>&
pointers
)
const
final
38
{
39
return
detail::sigmoid
(
GetResponse
(
pointers
));
40
}
41
};
42
44
class
ForestXGBoost
final :
public
ForestXGBoostBase
<NodeXGBoost>
45
{
46
public
:
47
explicit
ForestXGBoost
(TTree*
tree
);
48
ForestXGBoost
() =
default
;
49
ForestXGBoost
(
const
ForestXGBoost
&) =
default
;
50
ForestXGBoost
&
operator=
(
const
ForestXGBoost
&) =
default
;
51
ForestXGBoost
(
ForestXGBoost
&&) =
default
;
52
ForestXGBoost
&
operator=
(
ForestXGBoost
&&) =
default
;
53
~ForestXGBoost
()=
default
;
54
55
virtual
TTree*
WriteTree
(TString
name
)
const override
;
56
virtual
void
PrintForest
()
const override
;
57
virtual
int
GetNVars
()
const override
{
return
m_max_var
+ 1; }
58
private
:
59
int
m_max_var
=0;
60
};
61
}
62
63
#endif
MVAUtils::ForestXGBoost
Implement XGBoost with nan support.
Definition:
ForestXGBoost.h:45
MVAUtils::detail::sigmoid
T sigmoid(T x)
Definition:
Forest.h:21
MVAUtils
Definition:
InDetTrkInJetType.h:48
MVAUtils::ForestXGBoost::operator=
ForestXGBoost & operator=(ForestXGBoost &&)=default
MVAUtils::Forest::GetResponse
virtual float GetResponse(const std::vector< float > &values) const override
Compute the prediction for regression.
MVAUtils::ForestXGBoost::operator=
ForestXGBoost & operator=(const ForestXGBoost &)=default
tree
TChain * tree
Definition:
tile_monitor.h:30
MVAUtils::ForestXGBoost::~ForestXGBoost
~ForestXGBoost()=default
MVAUtils::ForestXGBoost::ForestXGBoost
ForestXGBoost()=default
python.Bindings.values
values
Definition:
Control/AthenaPython/python/Bindings.py:805
MVAUtils::ForestXGBoostBase
Definition:
ForestXGBoost.h:29
MVAUtils::ForestXGBoostBase::GetClassification
virtual float GetClassification(const std::vector< float > &values) const final
Compute the prediction of a classification.
Definition:
ForestXGBoost.h:33
Forest.h
MVAUtils::ForestXGBoost::PrintForest
virtual void PrintForest() const override
Definition:
ForestXGBoost.cxx:81
MVAUtils::Forest
Generic Forest base class.
Definition:
Forest.h:54
name
std::string name
Definition:
Control/AthContainers/Root/debug.cxx:228
MVAUtils::ForestXGBoost::WriteTree
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
Definition:
ForestXGBoost.cxx:55
MVAUtils::ForestXGBoostBase::GetClassification
virtual float GetClassification(const std::vector< float * > &pointers) const final
Definition:
ForestXGBoost.h:37
MVAUtils::ForestXGBoost::ForestXGBoost
ForestXGBoost(ForestXGBoost &&)=default
pointers
std::vector< T * > pointers(std::vector< T > &v)
Definition:
rmain.cxx:366
MVAUtils::ForestXGBoost::ForestXGBoost
ForestXGBoost(const ForestXGBoost &)=default
MVAUtils::ForestXGBoost::m_max_var
int m_max_var
Definition:
ForestXGBoost.h:59
MVAUtils::ForestXGBoost::GetNVars
virtual int GetNVars() const override
Get the number of input variable to be passed with std::vector to Get* methods.
Definition:
ForestXGBoost.h:57
Generated on Mon Dec 23 2024 21:10:35 for ATLAS Offline Software by
1.8.18