ATLAS Offline Software
Loading...
Searching...
No Matches
ForestTMVA.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef MVAUtils_ForestTMVA_H
6#define MVAUtils_ForestTMVA_H
7
8#include "MVAUtils/Forest.h"
9#include "TTree.h"
10#include <vector>
11
12namespace MVAUtils
13{
14
22 template<typename Node_t>
23 class ForestWeighted : public Forest<Node_t>
24 {
25 public:
27
28 using Forest<Node_t>::GetNTrees;
29 using Forest<Node_t>::newTree;
30
31 float GetTreeResponseWeighted(const std::vector<float>& values, unsigned int itree) const;
32 float GetTreeResponseWeighted(const std::vector<float*>& pointers, unsigned int itree) const;
33
34 float GetWeightedResponse(const std::vector<float>& values) const;
35 float GetWeightedResponse(const std::vector<float*>& pointers) const;
36
37 void newTree(const std::vector<Node_t>& nodes, float weight);
38 float GetTreeWeight(unsigned int itree) const { return m_weights[itree]; }
39 float GetSumWeights() const { return m_sumWeights; }
40
41 virtual float GetOffset() const override { return m_weights[0]; }
42
43 virtual void PrintTree(unsigned int itree) const override {
44 std::cout << "weight: " << m_weights[itree] << std::endl;
46 }
47
48 private:
49 std::vector<float> m_weights;
51 };
52
53
54 /*
55 * Support TMVA processing
56 *
57 * Forest implementing the TMVA forest.
58 *
59 * Regression (GetResponse): offset + raw-response
60 * Classification (GetClassification): weighted average of the nodes
61 * MultiClassification: softmax of the raw-response
62 */
63 class ForestTMVA final : public ForestWeighted<NodeTMVA>
64 {
65 public:
66
68 explicit ForestTMVA(TTree* tree);
69 ForestTMVA() = default;
70 ForestTMVA (const ForestTMVA&) = default;
72 ForestTMVA (ForestTMVA&&) = default;
74 ~ForestTMVA()=default;
75 virtual TTree* WriteTree(TString name) const override;
76 virtual float GetResponse(const std::vector<float>& values) const override ;
77 virtual float GetResponse(const std::vector<float*>& pointers) const override;
78 virtual float GetClassification(const std::vector<float>& values) const override;
79 virtual float GetClassification(const std::vector<float*>& pointers) const override ;
80 virtual void PrintForest() const override;
81 virtual int GetNVars() const override { return m_max_var + 1; }
82 void setNVars(const int max_var) {m_max_var=max_var;}
83 private:
84 int m_max_var=0;
85 };
86
87}
89#endif
ForestTMVA(ForestTMVA &&)=default
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
virtual float GetClassification(const std::vector< float * > &pointers) const override
void setNVars(const int max_var)
Definition ForestTMVA.h:82
virtual int GetNVars() const override
Get the number of input variable to be passed with std::vector to Get* methods.
Definition ForestTMVA.h:81
ForestTMVA & operator=(ForestTMVA &&)=default
virtual float GetClassification(const std::vector< float > &values) const override
Compute the prediction of a classification.
ForestTMVA(TTree *tree)
ForestTMVA & operator=(const ForestTMVA &)=default
virtual float GetResponse(const std::vector< float * > &pointers) const override
virtual void PrintForest() const override
ForestTMVA(const ForestTMVA &)=default
virtual float GetResponse(const std::vector< float > &values) const override
Compute the prediction for regression.
float GetTreeWeight(unsigned int itree) const
Definition ForestTMVA.h:38
float GetSumWeights() const
Definition ForestTMVA.h:39
float GetTreeResponseWeighted(const std::vector< float > &values, unsigned int itree) const
float m_sumWeights
the sumOfBoostWeights–no need to recompute each call
Definition ForestTMVA.h:50
float GetWeightedResponse(const std::vector< float > &values) const
float GetWeightedResponse(const std::vector< float * > &pointers) const
virtual void PrintTree(unsigned int itree) const override
Definition ForestTMVA.h:43
virtual float GetOffset() const override
Return the offset of the forest.
Definition ForestTMVA.h:41
float GetTreeResponseWeighted(const std::vector< float * > &pointers, unsigned int itree) const
void newTree(const std::vector< Node_t > &nodes, float weight)
std::vector< float > m_weights
boost weights
Definition ForestTMVA.h:49
Generic Forest base class.
Definition Forest.h:54
virtual void PrintTree(unsigned int itree) const override
virtual unsigned int GetNTrees() const override final
Definition Forest.h:94
Node for TMVA implementation.
Definition NodeImpl.h:36
std::vector< T * > pointers(std::vector< T > &v)
Definition rmain.cxx:367
TChain * tree