ATLAS Offline Software
Loading...
Searching...
No Matches
ForestXGBoost.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
3*/
4
6#include "TTree.h"
7#include <iostream>
8#include <stdexcept>
9
10using namespace MVAUtils;
11
14 , m_max_var(0)
15{
16
17
18 // variables read from the TTree
19 std::vector<int> *vars = nullptr;
20 std::vector<float> *values = nullptr;
21 std::vector<bool> *default_left = nullptr;
22
23 std::vector<NodeXGBoost> nodes;
24
25 tree->SetBranchAddress("vars", &vars);
26 tree->SetBranchAddress("values", &values);
27 tree->SetBranchAddress("default_left", &default_left);
28
29 for (int i = 0; i < tree->GetEntries(); ++i)
30 {
31 // each entry in the TTree is a decision tree
32 tree->GetEntry(i);
33 if (!vars) { throw std::runtime_error("vars pointer is null in ForestXGBoost constructor"); }
34 if (!values) { throw std::runtime_error("values pointers is null in ForestXGBoost constructor"); }
35 if (!default_left) { throw std::runtime_error("default_left pointers is null in ForestXGBoost constructor"); }
36 if (vars->size() != values->size()) { throw std::runtime_error("inconsistent size for vars and values in ForestXGBoost constructor"); }
37 if (default_left->size() != values->size()) { throw std::runtime_error("inconsistent size for default_left and values in ForestXGBoost constructor"); }
38
39 nodes.clear();
40
41 std::vector<MVAUtils::index_t> right = detail::computeRight(*vars);
42
43 for (size_t i = 0; i < vars->size(); ++i) {
44 nodes.emplace_back(vars->at(i), values->at(i), right[i], default_left->at(i));
45 if (vars->at(i) > m_max_var) { m_max_var = vars->at(i); }
46 }
47 newTree(nodes);
48 } // end loop on TTree, all decision tree loaded
49 delete vars;
50 delete values;
51 delete default_left;
52}
53
54
55TTree* ForestXGBoost::WriteTree(TString name) const
56{
57 TTree *tree = new TTree(name.Data(), "creator=xgboost");
58
59 std::vector<int> vars;
60 std::vector<float> values;
61 std::vector<bool> default_left;
62
63 tree->Branch("vars", &vars);
64 tree->Branch("values", &values);
65 tree->Branch("default_left", &default_left);
66
67 for (size_t itree = 0; itree < GetNTrees(); ++itree) {
68 vars.clear();
69 values.clear();
70 default_left.clear();
71 for(const auto& node : GetTree(itree)) {
72 vars.push_back(node.GetVar());
73 values.push_back(node.GetVal());
74 default_left.push_back(node.GetDefaultLeft());
75 }
76 tree->Fill();
77 }
78 return tree;
79}
80
82{
83 std::cout << "***BDT XGBoost: Printing entire forest***" << std::endl;
85}
virtual void PrintForest() const override
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
void newTree(const std::vector< NodeXGBoost > &nodes)
std::vector< NodeXGBoost > GetTree(unsigned int itree) const
virtual void PrintForest() const override
virtual unsigned int GetNTrees() const override final
Definition Forest.h:94
Node for XGBoost with nan implementation.
Definition NodeImpl.h:177
Definition node.h:24
std::vector< index_t > computeRight(const std::vector< int > &vars)
Compute the offsets between the nodes to their right children from a serialized representation of the...
TChain * tree