ATLAS Offline Software
Loading...
Searching...
No Matches
ForestLGBM.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
3*/
4
6#include "TTree.h"
7#include <iostream>
8#include <stdexcept>
9
10using namespace MVAUtils;
11
14 , m_max_var(0)
15{
16
17
18 // variables read from the TTree
19 std::vector<int> *vars = nullptr;
20 std::vector<float> *values = nullptr;
21
22 std::vector<NodeLGBMSimple> nodes;
23
24 tree->SetBranchAddress("vars", &vars);
25 tree->SetBranchAddress("values", &values);
26
27 int numEntries = tree->GetEntries();
28 for (int entry = 0; entry < numEntries; ++entry)
29 {
30 // each entry in the TTree is a decision tree
31 tree->GetEntry(entry);
32 if (!vars) {
33 throw std::runtime_error(
34 "vars pointer is null in ForestLGBMSimple constructor");
35 }
36 if (!values) {
37 throw std::runtime_error(
38 "values pointers is null in ForestLGBMSimple constructor");
39 }
40 if (vars->size() != values->size()) {
41 throw std::runtime_error("inconsistent size for vars and values in "
42 "ForestLGBMSimple constructor");
43 }
44
45 nodes.clear();
46
47 std::vector<MVAUtils::index_t> right = detail::computeRight(*vars);
48
49 for (size_t i = 0; i < vars->size(); ++i) {
50 nodes.emplace_back(vars->at(i), values->at(i), right[i]);
51 if (vars->at(i) > m_max_var) { m_max_var = vars->at(i); }
52 }
53 newTree(nodes);
54 } // end loop on TTree, all decision tree loaded
55 delete vars;
56 delete values;
57}
58
59
60TTree* ForestLGBMSimple::WriteTree(TString name) const
61{
62 TTree *tree = new TTree(name.Data(), "creator=lgbm;node_type=lgbm_simple");
63
64 std::vector<int> vars;
65 std::vector<float> values;
66
67 tree->Branch("vars", &vars);
68 tree->Branch("values", &values);
69
70 for (size_t itree = 0; itree < GetNTrees(); ++itree) {
71 vars.clear();
72 values.clear();
73 for(const auto& node : GetTree(itree)) {
74 vars.push_back(node.GetVar());
75 values.push_back(node.GetVal());
76 }
77 tree->Fill();
78 }
79 return tree;
80}
81
83{
84 std::cout << "***BDT LGBMSimple: Printing entire forest***" << std::endl;
86}
87
90 , m_max_var(0)
91{
92
93
94 // variables read from the TTree
95 std::vector<int> *vars = nullptr;
96 std::vector<float> *values = nullptr;
97 std::vector<bool> *default_left = nullptr;
98
99 std::vector<NodeLGBM> nodes;
100
101 tree->SetBranchAddress("vars", &vars);
102 tree->SetBranchAddress("values", &values);
103 tree->SetBranchAddress("default_left", &default_left);
104 int numEntries = tree->GetEntries();
105 for (int entry = 0; entry < numEntries; ++entry) {
106 // each entry in the TTree is a decision tree
107 tree->GetEntry(entry);
108 if (!vars) {
109 throw std::runtime_error(
110 "vars pointer is null in ForestLGBM constructor");
111 }
112 if (!values) {
113 throw std::runtime_error(
114 "values pointers is null in ForestLGBM constructor");
115 }
116 if (!default_left) {
117 throw std::runtime_error(
118 "default_left pointers is null in ForestLGBM constructor");
119 }
120 if (vars->size() != values->size()) {
121 throw std::runtime_error(
122 "inconsistent size for vars and values in ForestLGBM constructor");
123 }
124 if (default_left->size() != values->size()) {
125 throw std::runtime_error("inconsistent size for default_left and "
126 "values in ForestLGBM constructor");
127 }
128
129 nodes.clear();
130
131 std::vector<MVAUtils::index_t> right = detail::computeRight(*vars);
132
133 for (size_t i = 0; i < vars->size(); ++i) {
134 nodes.emplace_back(
135 vars->at(i), values->at(i), right[i], default_left->at(i));
136 if (vars->at(i) > m_max_var) {
137 m_max_var = vars->at(i);
138 }
139 }
140 newTree(nodes);
141 } // end loop on TTree, all decision tree loaded
142 delete vars;
143 delete values;
144 delete default_left;
145}
146
147
148TTree* ForestLGBM::WriteTree(TString name) const
149{
150 TTree *tree = new TTree(name.Data(), "creator=lgbm;node_type=lgbm");
151
152 std::vector<int> vars;
153 std::vector<float> values;
154 std::vector<bool> default_left;
155
156 tree->Branch("vars", &vars);
157 tree->Branch("values", &values);
158 tree->Branch("default_left", &default_left);
159
160 for (size_t itree = 0; itree < GetNTrees(); ++itree) {
161 vars.clear();
162 values.clear();
163 default_left.clear();
164 for(const auto& node : GetTree(itree)) {
165 vars.push_back(node.GetVar());
166 values.push_back(node.GetVal());
167 default_left.push_back(node.GetDefaultLeft());
168 }
169 tree->Fill();
170 }
171 return tree;
172}
173
175{
176 std::cout << "***BDT LGBM: Printing entire forest***" << std::endl;
178}
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
virtual void PrintForest() const override
virtual TTree * WriteTree(TString name) const override
Return a TTree representing the BDT.
virtual void PrintForest() const override
void newTree(const std::vector< NodeLGBMSimple > &nodes)
std::vector< NodeLGBMSimple > GetTree(unsigned int itree) const
virtual void PrintForest() const override
virtual unsigned int GetNTrees() const override final
Definition Forest.h:94
Node for LGBM without nan implementation.
Definition NodeImpl.h:92
Node for LGBM with nan implementation.
Definition NodeImpl.h:135
Definition node.h:24
std::vector< index_t > computeRight(const std::vector< int > &vars)
Compute the offsets between the nodes to their right children from a serialized representation of the...
TChain * tree