ATLAS Offline Software
Loading...
Searching...
No Matches
ForestTMVA.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
3*/
4
6#include "TTree.h"
7#include <iostream>
8#include <stdexcept>
9
10using namespace MVAUtils;
11
14 , m_max_var(0)
15{
16 // variables read from the TTree
17 std::vector<int> *vars = nullptr;
18 std::vector<float> *values = nullptr;
19 float offset; // the offset is the weight
20
21 std::vector<NodeTMVA> nodes;
22
23 tree->SetBranchAddress("vars", &vars);
24 tree->SetBranchAddress("values", &values);
25 tree->SetBranchAddress("offset", &offset);
26
27 int numEntries = tree->GetEntries();
28 for (int entry = 0; entry < numEntries; ++entry) {
29 // each entry in the TTree is a decision tree
30 tree->GetEntry(entry);
31 if (!vars) {
32 throw std::runtime_error(
33 "vars pointer is null in ForestTMVA constructor");
34 }
35 if (!values) {
36 throw std::runtime_error(
37 "values pointers is null in ForestTMVA constructor");
38 }
39 if (vars->size() != values->size()) {
40 throw std::runtime_error(
41 "inconsistent size for vars and values in ForestTMVA constructor");
42 }
43
44 nodes.clear();
45
46 std::vector<MVAUtils::index_t> right = detail::computeRight(*vars);
47
48 for (size_t i = 0; i < vars->size(); ++i) {
49 nodes.emplace_back(vars->at(i), values->at(i), right[i]);
50 if (vars->at(i) > m_max_var) {
51 m_max_var = vars->at(i);
52 }
53 }
54 newTree(nodes, offset);
55 } // end loop on TTree, all decision tree loaded
56 delete vars;
57 delete values;
58}
59
60
61TTree* ForestTMVA::WriteTree(TString name) const
62{
63 TTree *tree = new TTree(name.Data(), "creator=TMVA");
64
65 std::vector<int> vars;
66 std::vector<float> values;
67 float offset;
68
69 tree->Branch("offset", &offset);
70 tree->Branch("vars", &vars);
71 tree->Branch("values", &values);
72
73 for (size_t itree = 0; itree < GetNTrees(); ++itree) {
74 vars.clear();
75 values.clear();
76 for(const auto& node : GetTree(itree)) {
77 vars.push_back(node.GetVar());
78 values.push_back(node.GetVal());
79 }
80 offset = GetTreeWeight(itree);
81 tree->Fill();
82 }
83 return tree;
84}
85
87{
88 std::cout << "***BDT TMVA: Printing entire forest***" << std::endl;
90}
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
virtual void PrintForest() const override
float GetTreeWeight(unsigned int itree) const
Definition ForestTMVA.h:38
void newTree(const std::vector< NodeTMVA > &nodes, float weight)
std::vector< NodeTMVA > GetTree(unsigned int itree) const
virtual void PrintForest() const override
virtual unsigned int GetNTrees() const override final
Definition Forest.h:94
Node for TMVA implementation.
Definition NodeImpl.h:36
Definition node.h:24
std::vector< index_t > computeRight(const std::vector< int > &vars)
Compute the offsets between the nodes to their right children from a serialized representation of the...
TChain * tree