ATLAS Offline Software
Loading...
Searching...
No Matches
ForestXGBoost.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef MVAUtils_ForestXGBOOST_H
6#define MVAUtils_ForestXGBOOST_H
7
8#include "MVAUtils/Forest.h"
9#include <cmath>
10#include <algorithm>
11#include <numeric>
12#include <vector>
13
14
15namespace MVAUtils
16{
17 /*
18 * Support XGBoost processing of the forest response.
19 *
20 * User should use ForestXGBoost (for nan input support)
21 *
22 * Implement only the classification as: sigmoid(raw-reponse)
23 * Other methods are from Forest:
24 * Regression (GetResponse) as raw-response.
25 * Global bias 'base_score' is not included [default=0.5]
26 * */
27 template<typename Node_t>
28 class ForestXGBoostBase : public Forest<Node_t>
29 {
30 public:
31 using Forest<Node_t>::GetResponse;
32
33 virtual float GetClassification(const std::vector<float>& values) const final
34 {
35 return detail::sigmoid(GetResponse(values));
36 }
37 virtual float GetClassification(const std::vector<float*>& pointers) const final
38 {
40 }
41 };
42
44 class ForestXGBoost final : public ForestXGBoostBase<NodeXGBoost>
45 {
46 public:
47 explicit ForestXGBoost(TTree* tree);
48 ForestXGBoost() = default;
49 ForestXGBoost (const ForestXGBoost&) = default;
53 ~ForestXGBoost()=default;
54
55 virtual TTree* WriteTree(TString name) const override;
56 virtual void PrintForest() const override;
57 virtual int GetNVars() const override { return m_max_var + 1; }
58 private:
59 int m_max_var=0;
60 };
61}
62
63#endif
virtual float GetClassification(const std::vector< float * > &pointers) const final
virtual float GetClassification(const std::vector< float > &values) const final
Compute the prediction of a classification.
ForestXGBoost(ForestXGBoost &&)=default
virtual int GetNVars() const override
Get the number of input variable to be passed with std::vector to Get* methods.
ForestXGBoost & operator=(const ForestXGBoost &)=default
virtual void PrintForest() const override
ForestXGBoost(const ForestXGBoost &)=default
ForestXGBoost & operator=(ForestXGBoost &&)=default
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
Generic Forest base class.
Definition Forest.h:54
virtual float GetResponse(const std::vector< float > &values) const override
Compute the prediction for regression.
T sigmoid(T x)
Definition Forest.h:21
std::vector< T * > pointers(std::vector< T > &v)
Definition rmain.cxx:367
TChain * tree