ATLAS Offline Software
Loading...
Searching...
No Matches
ForestLGBM.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef MVAUtils_ForestLGBM_H
6#define MVAUtils_ForestLGBM_H
7
8#include "MVAUtils/Forest.h"
9#include <cmath>
10#include <algorithm>
11#include <numeric>
12#include <vector>
13namespace MVAUtils
14{
15 /*
16 * Support LGBM processing of the forest response.
17 *
18 * User should use ForestLGBM (for nan input support) or ForestSimple
19 *
20 * Implement only the classification as: sigmoid(raw-reponse)
21 * Other methods are from Forest:
22 * Multiclassification as softmax(raw-response)
23 * Regression (GetResponse) as raw-response
24 * */
25 template<typename Node_t>
26 class ForestLGBMBase : public Forest<Node_t>
27 {
28 public:
29 using Forest<Node_t>::GetResponse;
30
31 virtual float GetClassification(const std::vector<float>& values) const final
32 {
33 return detail::sigmoid(GetResponse(values));
34 }
35 virtual float GetClassification(const std::vector<float*>& pointers) const final
36 {
38 }
39 };
40
42 class ForestLGBMSimple final : public ForestLGBMBase<NodeLGBMSimple>
43 {
44 public:
45 explicit ForestLGBMSimple(TTree* tree);
46 ForestLGBMSimple() = default;
52 virtual TTree* WriteTree(TString name) const override;
53 virtual void PrintForest() const override;
54 virtual int GetNVars() const override { return m_max_var + 1; }
55 private:
56 int m_max_var=0;
57
58 };
59
61 class ForestLGBM final : public ForestLGBMBase<NodeLGBM>
62 {
63 public:
64 explicit ForestLGBM(TTree* tree);
65 ForestLGBM() = default;
66 ForestLGBM (const ForestLGBM&) = default;
67 ForestLGBM& operator=(const ForestLGBM&) = default;
68 ForestLGBM (ForestLGBM&&) = default;
70 ~ForestLGBM()=default;
71 virtual TTree* WriteTree(TString name) const override;
72 virtual void PrintForest() const override;
73 virtual int GetNVars() const override { return m_max_var + 1; }
74 private:
75 int m_max_var=0;
76 };
77}
78
79#endif
virtual float GetClassification(const std::vector< float > &values) const final
Compute the prediction of a classification.
Definition ForestLGBM.h:31
virtual float GetClassification(const std::vector< float * > &pointers) const final
Definition ForestLGBM.h:35
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
ForestLGBMSimple & operator=(ForestLGBMSimple &&)=default
virtual int GetNVars() const override
Get the number of input variable to be passed with std::vector to Get* methods.
Definition ForestLGBM.h:54
ForestLGBMSimple & operator=(const ForestLGBMSimple &)=default
virtual void PrintForest() const override
ForestLGBMSimple(ForestLGBMSimple &&)=default
ForestLGBMSimple(const ForestLGBMSimple &)=default
ForestLGBM(TTree *tree)
ForestLGBM(ForestLGBM &&)=default
ForestLGBM(const ForestLGBM &)=default
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
ForestLGBM & operator=(ForestLGBM &&)=default
virtual void PrintForest() const override
virtual int GetNVars() const override
Get the number of input variable to be passed with std::vector to Get* methods.
Definition ForestLGBM.h:73
ForestLGBM & operator=(const ForestLGBM &)=default
Generic Forest base class.
Definition Forest.h:54
virtual float GetResponse(const std::vector< float > &values) const override
Compute the prediction for regression.
T sigmoid(T x)
Definition Forest.h:21
std::vector< T * > pointers(std::vector< T > &v)
Definition rmain.cxx:367
TChain * tree