ATLAS Offline Software
Loading...
Searching...
No Matches
Reconstruction
MVAUtils
MVAUtils
ForestXGBoost.h
Go to the documentation of this file.
1
/*
2
Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3
*/
4
5
#ifndef MVAUtils_ForestXGBOOST_H
6
#define MVAUtils_ForestXGBOOST_H
7
8
#include "
MVAUtils/Forest.h
"
9
#include <cmath>
10
#include <algorithm>
11
#include <numeric>
12
#include <vector>
13
14
15
namespace
MVAUtils
16
{
17
/*
18
* Support XGBoost processing of the forest response.
19
*
20
* User should use ForestXGBoost (for nan input support)
21
*
22
* Implement only the classification as: sigmoid(raw-reponse)
23
* Other methods are from Forest:
24
* Regression (GetResponse) as raw-response.
25
* Global bias 'base_score' is not included [default=0.5]
26
* */
27
template
<
typename
Node_t>
28
class
ForestXGBoostBase
:
public
Forest
<Node_t>
29
{
30
public
:
31
using
Forest
<Node_t>
::GetResponse
;
32
33
virtual
float
GetClassification
(
const
std::vector<float>& values)
const
final
34
{
35
return
detail::sigmoid
(
GetResponse
(values));
36
}
37
virtual
float
GetClassification
(
const
std::vector<float*>&
pointers
)
const
final
38
{
39
return
detail::sigmoid
(
GetResponse
(
pointers
));
40
}
41
};
42
44
class
ForestXGBoost
final :
public
ForestXGBoostBase
<NodeXGBoost>
45
{
46
public
:
47
explicit
ForestXGBoost
(TTree*
tree
);
48
ForestXGBoost
() =
default
;
49
ForestXGBoost
(
const
ForestXGBoost
&) =
default
;
50
ForestXGBoost
&
operator=
(
const
ForestXGBoost
&) =
default
;
51
ForestXGBoost
(
ForestXGBoost
&&) =
default
;
52
ForestXGBoost
&
operator=
(
ForestXGBoost
&&) =
default
;
53
~ForestXGBoost
()=
default
;
54
55
virtual
TTree*
WriteTree
(TString name)
const override
;
56
virtual
void
PrintForest
()
const override
;
57
virtual
int
GetNVars
()
const override
{
return
m_max_var
+ 1; }
58
private
:
59
int
m_max_var
=0;
60
};
61
}
62
63
#endif
Forest.h
MVAUtils::ForestXGBoostBase
Definition
ForestXGBoost.h:29
MVAUtils::ForestXGBoostBase::GetClassification
virtual float GetClassification(const std::vector< float * > &pointers) const final
Definition
ForestXGBoost.h:37
MVAUtils::ForestXGBoostBase::GetClassification
virtual float GetClassification(const std::vector< float > &values) const final
Compute the prediction of a classification.
Definition
ForestXGBoost.h:33
MVAUtils::ForestXGBoost::ForestXGBoost
ForestXGBoost(ForestXGBoost &&)=default
MVAUtils::ForestXGBoost::GetNVars
virtual int GetNVars() const override
Get the number of input variable to be passed with std::vector to Get* methods.
Definition
ForestXGBoost.h:57
MVAUtils::ForestXGBoost::~ForestXGBoost
~ForestXGBoost()=default
MVAUtils::ForestXGBoost::operator=
ForestXGBoost & operator=(const ForestXGBoost &)=default
MVAUtils::ForestXGBoost::m_max_var
int m_max_var
Definition
ForestXGBoost.h:59
MVAUtils::ForestXGBoost::ForestXGBoost
ForestXGBoost(TTree *tree)
Definition
ForestXGBoost.cxx:12
MVAUtils::ForestXGBoost::PrintForest
virtual void PrintForest() const override
Definition
ForestXGBoost.cxx:81
MVAUtils::ForestXGBoost::ForestXGBoost
ForestXGBoost(const ForestXGBoost &)=default
MVAUtils::ForestXGBoost::ForestXGBoost
ForestXGBoost()=default
MVAUtils::ForestXGBoost::operator=
ForestXGBoost & operator=(ForestXGBoost &&)=default
MVAUtils::ForestXGBoost::WriteTree
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
Definition
ForestXGBoost.cxx:55
MVAUtils::Forest
Generic Forest base class.
Definition
Forest.h:54
MVAUtils::Forest::GetResponse
virtual float GetResponse(const std::vector< float > &values) const override
Compute the prediction for regression.
MVAUtils::detail::sigmoid
T sigmoid(T x)
Definition
Forest.h:21
MVAUtils
Definition
InDetTrkInJetType.h:48
pointers
std::vector< T * > pointers(std::vector< T > &v)
Definition
rmain.cxx:367
tree
TChain * tree
Definition
tile_monitor.h:30
Generated on
for ATLAS Offline Software by
1.14.0