ATLAS Offline Software
Loading...
Searching...
No Matches
ForestTMVA.icc
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration
3*/
4namespace MVAUtils {
5
6template<typename Node_t>
7float
8ForestWeighted<Node_t>::GetTreeResponseWeighted(
9 const std::vector<float>& values,
10 unsigned int itree) const
11{
12 return Forest<Node_t>::GetTreeResponse(values, itree) * m_weights[itree];
13}
14template<typename Node_t>
15float
16ForestWeighted<Node_t>::GetTreeResponseWeighted(
17 const std::vector<float*>& pointers,
18 unsigned int itree) const
19{
20 return Forest<Node_t>::GetTreeResponse(pointers, itree) * m_weights[itree];
21}
22
23template<typename Node_t>
24float
25ForestWeighted<Node_t>::GetWeightedResponse(
26 const std::vector<float>& values) const
27{
28 float result = 0.;
29 for (unsigned int itree = 0; itree != GetNTrees(); ++itree) {
30 result += GetTreeResponseWeighted(values, itree);
31 }
32 return result;
33}
34
35template<typename Node_t>
36float
37ForestWeighted<Node_t>::GetWeightedResponse(
38 const std::vector<float*>& pointers) const
39{
40 float result = 0.;
41 for (unsigned int itree = 0; itree != GetNTrees(); ++itree) {
42 result += GetTreeResponseWeighted(pointers, itree);
43 }
44 return result;
45}
46
47template<typename Node_t>
48void
49ForestWeighted<Node_t>::newTree(const std::vector<Node_t>& nodes, float weight)
50{
51 newTree(nodes);
52 m_weights.push_back(weight);
53 m_sumWeights += weight;
54}
55
56inline float
57ForestTMVA::GetResponse(const std::vector<float>& values) const
58{
59 return GetRawResponse(values) + GetOffset();
60}
61
62inline float
63ForestTMVA::GetResponse(const std::vector<float*>& pointers) const
64{
65 return GetRawResponse(pointers) + GetOffset();
66}
67
68inline float
69ForestTMVA::GetClassification(const std::vector<float>& values) const
70{
71 float result = GetWeightedResponse(values);
72 return result / GetSumWeights();
73}
74
75inline float
76ForestTMVA::GetClassification(const std::vector<float*>& pointers) const
77{
78 float result = GetWeightedResponse(pointers);
79 return result / GetSumWeights();
80}
81}