ATLAS Offline Software
Reconstruction
MVAUtils
MVAUtils
ForestLGBM.h
Go to the documentation of this file.
1
/*
2
Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
3
*/
4
5
#ifndef MVAUtils_ForestLGBM_H
6
#define MVAUtils_ForestLGBM_H
7
8
#include "
MVAUtils/Forest.h
"
9
#include <cmath>
10
#include <algorithm>
11
#include <numeric>
12
#include <vector>
13
namespace
MVAUtils
14
{
15
/*
16
* Support LGBM processing of the forest response.
17
*
18
* User should use ForestLGBM (for nan input support) or ForestSimple
19
*
20
* Implement only the classification as: sigmoid(raw-reponse)
21
* Other methods are from Forest:
22
* Multiclassification as softmax(raw-response)
23
* Regression (GetResponse) as raw-response
24
* */
25
template
<
typename
Node_t>
26
class
ForestLGBMBase
:
public
Forest
<Node_t>
27
{
28
public
:
29
using
Forest<Node_t>::GetResponse
;
30
31
virtual
float
GetClassification
(
const
std::vector<float>&
values
)
const
final
32
{
33
return
detail::sigmoid
(
GetResponse
(
values
));
34
}
35
virtual
float
GetClassification
(
const
std::vector<float*>&
pointers
)
const
final
36
{
37
return
detail::sigmoid
(
GetResponse
(
pointers
));
38
}
39
};
40
42
class
ForestLGBMSimple
final :
public
ForestLGBMBase
<NodeLGBMSimple>
43
{
44
public
:
45
explicit
ForestLGBMSimple
(TTree*
tree
);
46
ForestLGBMSimple
() =
default
;
47
ForestLGBMSimple
(
const
ForestLGBMSimple
&) =
default
;
48
ForestLGBMSimple
&
operator=
(
const
ForestLGBMSimple
&) =
default
;
49
ForestLGBMSimple
(
ForestLGBMSimple
&&) =
default
;
50
ForestLGBMSimple
&
operator=
(
ForestLGBMSimple
&&) =
default
;
51
~ForestLGBMSimple
()=
default
;
52
virtual
TTree*
WriteTree
(TString
name
)
const override
;
53
virtual
void
PrintForest
()
const override
;
54
virtual
int
GetNVars
()
const override
{
return
m_max_var
+ 1; }
55
private
:
56
int
m_max_var
=0;
57
58
};
59
61
class
ForestLGBM
final :
public
ForestLGBMBase
<NodeLGBM>
62
{
63
public
:
64
explicit
ForestLGBM
(TTree*
tree
);
65
ForestLGBM
() =
default
;
66
ForestLGBM
(
const
ForestLGBM
&) =
default
;
67
ForestLGBM
&
operator=
(
const
ForestLGBM
&) =
default
;
68
ForestLGBM
(
ForestLGBM
&&) =
default
;
69
ForestLGBM
&
operator=
(
ForestLGBM
&&) =
default
;
70
~ForestLGBM
()=
default
;
71
virtual
TTree*
WriteTree
(TString
name
)
const override
;
72
virtual
void
PrintForest
()
const override
;
73
virtual
int
GetNVars
()
const override
{
return
m_max_var
+ 1; }
74
private
:
75
int
m_max_var
=0;
76
};
77
}
78
79
#endif
MVAUtils::ForestLGBM::ForestLGBM
ForestLGBM()=default
MVAUtils::ForestLGBMSimple::m_max_var
int m_max_var
Definition:
ForestLGBM.h:56
MVAUtils::ForestLGBM::operator=
ForestLGBM & operator=(ForestLGBM &&)=default
MVAUtils::ForestLGBMSimple::operator=
ForestLGBMSimple & operator=(ForestLGBMSimple &&)=default
MVAUtils::detail::sigmoid
T sigmoid(T x)
Definition:
Forest.h:21
MVAUtils
Definition:
InDetTrkInJetType.h:48
MVAUtils::Forest::GetResponse
virtual float GetResponse(const std::vector< float > &values) const override
Compute the prediction for regression.
MVAUtils::ForestLGBMBase
Definition:
ForestLGBM.h:27
tree
TChain * tree
Definition:
tile_monitor.h:30
MVAUtils::ForestLGBMSimple::ForestLGBMSimple
ForestLGBMSimple()=default
MVAUtils::ForestLGBMBase::GetClassification
virtual float GetClassification(const std::vector< float > &values) const final
Compute the prediction of a classification.
Definition:
ForestLGBM.h:31
MVAUtils::ForestLGBM::ForestLGBM
ForestLGBM(const ForestLGBM &)=default
python.Bindings.values
values
Definition:
Control/AthenaPython/python/Bindings.py:805
MVAUtils::ForestLGBMBase::GetClassification
virtual float GetClassification(const std::vector< float * > &pointers) const final
Definition:
ForestLGBM.h:35
MVAUtils::ForestLGBMSimple::GetNVars
virtual int GetNVars() const override
Get the number of input variable to be passed with std::vector to Get* methods.
Definition:
ForestLGBM.h:54
Forest.h
MVAUtils::ForestLGBM::ForestLGBM
ForestLGBM(ForestLGBM &&)=default
MVAUtils::ForestLGBMSimple::operator=
ForestLGBMSimple & operator=(const ForestLGBMSimple &)=default
MVAUtils::ForestLGBM::GetNVars
virtual int GetNVars() const override
Get the number of input variable to be passed with std::vector to Get* methods.
Definition:
ForestLGBM.h:73
MVAUtils::ForestLGBMSimple::WriteTree
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
Definition:
ForestLGBM.cxx:60
MVAUtils::ForestLGBM::operator=
ForestLGBM & operator=(const ForestLGBM &)=default
MVAUtils::Forest
Generic Forest base class.
Definition:
Forest.h:54
MVAUtils::ForestLGBM::WriteTree
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
Definition:
ForestLGBM.cxx:148
MVAUtils::ForestLGBM
Implement LGBM with nan support.
Definition:
ForestLGBM.h:62
name
std::string name
Definition:
Control/AthContainers/Root/debug.cxx:228
MVAUtils::ForestLGBMSimple::PrintForest
virtual void PrintForest() const override
Definition:
ForestLGBM.cxx:82
MVAUtils::ForestLGBM::m_max_var
int m_max_var
Definition:
ForestLGBM.h:75
MVAUtils::ForestLGBMSimple::ForestLGBMSimple
ForestLGBMSimple(const ForestLGBMSimple &)=default
MVAUtils::ForestLGBMSimple::~ForestLGBMSimple
~ForestLGBMSimple()=default
MVAUtils::ForestLGBM::PrintForest
virtual void PrintForest() const override
Definition:
ForestLGBM.cxx:174
pointers
std::vector< T * > pointers(std::vector< T > &v)
Definition:
rmain.cxx:366
MVAUtils::ForestLGBMSimple
Implement LGBM Forest without nan support.
Definition:
ForestLGBM.h:43
MVAUtils::ForestLGBMSimple::ForestLGBMSimple
ForestLGBMSimple(ForestLGBMSimple &&)=default
MVAUtils::ForestLGBM::~ForestLGBM
~ForestLGBM()=default
Generated on Sun Dec 22 2024 21:10:33 for ATLAS Offline Software by
1.8.18