ATLAS Offline Software
Loading...
Searching...
No Matches
convertXGBoostToRootTree.py
Go to the documentation of this file.
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
4
5""" Convert XGBoost model to TTree to be used with MVAUtils. """
6
7__author__ = "Yuan-Tang Chou"
8
9
10try:
11 import xgboost as xgb
12except ImportError:
13 print("""cannot load xgboost. Try to install it with
14 pip install xgboost
15 """)
16import ROOT
17import time
18import json
19import logging
20logging.basicConfig(level=logging.DEBUG)
21
22
23class XBGoostTextNode(dict):
24 """
25 Adaptor from XGBoost dictionary to tree
26 * XGboost Yes is < and No is >=
27 """
28
30 if 'split' in self:
31 return self['split']
32 else:
33 return -1
34
35 def get_value(self):
36 if 'split_condition' in self:
37 return self['split_condition']
38 else:
39 return self['leaf']
40
41 def get_left(self):
42 if 'children' not in self:
43 return None
44 else: # XGBoost "YES" is left branch in MVAUtils
45 return XBGoostTextNode(self['children'][self.get_nodeid('yes')])
46
47 def get_right(self):
48 if 'children' not in self:
49 return None
50 else: # XGBoost "NO" is right branch in MVAUtils
51 return XBGoostTextNode(self['children'][self.get_nodeid('no')])
52
53 def get_nodeid(self, node_type):
54 for idx, children in enumerate(self['children']):
55 if children['nodeid'] == self[node_type]:
56 return idx
57
59 if 'children' not in self:
60 return True
61 elif self.get_nodeid('yes') == self.get_nodeid('missing'):
62 return True
63 else:
64 return False
65
66
67def dump_tree(tree_structure):
68 """
69 dump a single decision tree to arrays to be written into the TTree
70 """
71
72 split_values = []
73 split_features = []
74 default_left = []
75 top = XBGoostTextNode(tree_structure)
76
77 def preorder(node):
78 # visit root
79 split_features.append(node.get_split_feature())
80 split_values.append(node.get_value())
81 default_left.append(node.get_default_left())
82
83 # visit (yes)left
84 if node.get_left() is not None:
85 preorder(node.get_left())
86 # visit (no)right
87 if node.get_right() is not None:
88 preorder(node.get_right())
89
90 preorder(top)
91 return split_features, split_values, default_left
92
93def dump2ROOT(model, output_filename, output_treename='xgboost'):
94 model.dump_model('dump_model.json', dump_format='json')
95 with open('dump_model.json', 'r') as dump_json:
96 model_dump = dump_json.read()
97 trees = json.loads(model_dump)
98 fout = ROOT.TFile.Open(output_filename, 'recreate')
99
100 features_array = ROOT.std.vector('int')()
101 values_array = ROOT.std.vector('float')()
102 default_lefts_array = ROOT.std.vector('bool')()
103
104 title = 'creator=xgboost'
105 root_tree = ROOT.TTree(output_treename, title)
106 root_tree.Branch('vars', 'vector<int>', ROOT.AddressOf(features_array))
107 root_tree.Branch('values', 'vector<float>', ROOT.AddressOf(values_array))
108 root_tree.Branch('default_left', 'vector<bool>', ROOT.AddressOf(default_lefts_array))
109
110 logging.info("tree support nan: using XGBoost implementation")
111
112 for tree in trees:
113 tree_structure = tree
114 features, values, default_lefts = dump_tree(tree_structure)
115
116 features_array.clear()
117 values_array.clear()
118 default_lefts_array.clear()
119
120 for value in values:
121 values_array.push_back(value)
122 for feature in features:
123 features_array.push_back(feature)
124 for default_left in default_lefts:
125 default_lefts_array.push_back(default_left)
126
127 root_tree.Fill()
128
129 root_tree.Write()
130 fout.Close()
131 return output_treename
132
133def convertXGBoostToRootTree(model, output_filename, tree_name='xgboost'):
134 """
135 Model: - a string, in this case, it is the name of the input file containing the xgboost model
136 you can get this model with xgboost with `bst.save_model('my_model.model')
137 - directly a xgboost booster object
138 """
139 if type(model) is str:
140 bst = xgb.Booster()
141 bst.load_model(model)
142 return dump2ROOT(bst, output_filename, tree_name)
143 else:
144 return dump2ROOT(model, output_filename, tree_name)
145
146
147def test(model_file, tree_file, objective, tree_name='xgboost', ntests=10000, test_file=None):
148 bst = xgb.Booster()
149 bst.load_model(model_file)
150 f = ROOT.TFile.Open(tree_file)
151 tree = f.Get(tree_name)
152 try:
153 _ = ROOT.MVAUtils.BDT
154 except Exception:
155 print("cannot import MVAUtils")
156 return None
157
158 mva_utils = ROOT.MVAUtils.BDT(tree)
159
160 if 'binary' in objective:
161 logging.info("testing binary")
162 return test_binary(bst, mva_utils, objective, ntests, test_file)
163 elif 'multi' in objective:
164 logging.info("testing multi-class")
165 return test_multiclass(bst,mva_utils, objective, ntests, test_file)
166 else:
167 logging.info("testing regression")
168 return test_regression(bst, mva_utils, objective, ntests, test_file)
169
170def test_regression(booster, mva_utils, objective, ntests=10000, test_file=None):
171 import numpy as np
172 logging.info("Tesing input features with regression")
173
174 if test_file is not None:
175 data_input = np.load(test_file)
176 logging.info("using as input %s inputs from file %s", len(data_input), test_file)
177 else:
178 logging.error("Please provide an input test file for testing")
179
180 start = time.time()
181 dTest = xgb.DMatrix(data_input)
182 results_xgboost = booster.predict(dTest)
183 logging.info("xgboost (vectorized) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
184
185 input_values_vector = ROOT.std.vector("float")()
186 results_MVAUtils = []
187 start = time.time()
188 for input_values in data_input:
189 input_values_vector.clear()
190 for v in input_values:
191 input_values_vector.push_back(v)
192 output_MVAUtils = mva_utils.GetResponse(input_values_vector)
193 results_MVAUtils.append(output_MVAUtils)
194 logging.info("mvautils (not vectorized+overhead) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
195
196 for input_values, output_xgb, output_MVAUtils in zip(data_input, results_xgboost, results_MVAUtils):
197 if not np.allclose(output_xgb, output_MVAUtils, rtol=1E-4):
198 logging.info("output are different:"
199 "mvautils: %s\n"
200 "xgboost: %s\n"
201 "inputs: %s", output_MVAUtils, output_xgb, input_values)
202 return False
203 return True
204
205
206def test_binary(booster, mva_utils, objective, ntests=10000, test_file=None):
207 import numpy as np
208 logging.info("Testing input features with binary classification")
209 if test_file is not None:
210 data_input = np.load(test_file)
211 logging.info("using as input %s inputs from file %s", len(data_input), test_file)
212 else:
213 logging.error("Please provide an input test file for testing")
214
215 start = time.time()
216 dTest = xgb.DMatrix(data_input)
217 results_xgboost = booster.predict(dTest)
218 logging.info("xgboost (vectorized) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
219
220 input_values_vector = ROOT.std.vector("float")()
221 results_MVAUtils = []
222 start = time.time()
223 for input_values in data_input:
224 input_values_vector.clear()
225 for v in input_values:
226 input_values_vector.push_back(v)
227 output_MVAUtils = mva_utils.GetClassification(input_values_vector)
228 results_MVAUtils.append(output_MVAUtils)
229 logging.info("mvautils (not vectorized+overhead) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
230
231 for input_values, output_xgb, output_MVAUtils in zip(data_input, results_xgboost, results_MVAUtils):
232 if not np.allclose(output_xgb, output_MVAUtils):
233 logging.info("output are different:"
234 "mvautils: %s\n"
235 "xgboost: %s\n"
236 "inputs: %s", output_MVAUtils, output_xgb, input_values)
237 return False
238 return True
239
240def test_multiclass(booster, mva_utils, objective, ntests=10000, test_file=None):
241 import numpy as np
242 logging.info("using multiclass model")
243
244 if test_file is not None:
245 data_input = np.load(test_file)
246 logging.info("using as input %s inputs from file %s", len(data_input), test_file)
247 else:
248 logging.error("Please provide an input test file for testing")
249
250 start = time.time()
251 dTest = xgb.DMatrix(data_input)
252 results_xgboost = booster.predict(dTest)
253
254 nclasses = results_xgboost.shape[1]
255 logging.info("xgboost (vectorized) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
256
257 input_values_vector = ROOT.std.vector("float")()
258 results_MVAUtils = []
259 start = time.time()
260 for input_values in data_input:
261 input_values_vector.clear()
262 for v in input_values:
263 input_values_vector.push_back(v)
264 output_MVAUtils = mva_utils.GetMultiResponse(input_values_vector, nclasses)
265 results_MVAUtils.append(output_MVAUtils)
266
267 logging.info("mvautils (not vectorized+overhead) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
268
269 for input_values, output_xgb, output_MVAUtils in zip(data_input, results_xgboost, results_MVAUtils):
270 if not np.allclose(output_xgb, output_MVAUtils):
271 logging.info("output are different:"
272 "mvautils: %s\n"
273 "xgboost: %s\n"
274 "inputs: %s", output_MVAUtils, output_xgb, input_values)
275 return False
276 return True
277
278
280 f = ROOT.TFile.Open(fn)
281 keys = f.GetListOfKeys()
282 keys = list(keys)
283 if len(keys) != 1:
284 logging.info("file %s is empty", fn)
285 return False
286 tree = f.Get(keys[0].GetName())
287 if type(tree) is not ROOT.TTree:
288 logging.info("cannot find TTree in file %s", fn)
289 return False
290 if not tree.GetEntries():
291 logging.info("tree is empty")
292 return False
293 return True
294
295
296if __name__ == "__main__":
297 import argparse
298
299 parser = argparse.ArgumentParser(description=__doc__)
300 parser.add_argument('input', help='input xgboost model')
301 parser.add_argument('output', type=str, default ='xgboost_model.root', help='Output file name, it must end with .root')
302 parser.add_argument('--tree-name', type=str, default = 'xgboost', help="tree name in Output root file")
303 parser.add_argument('--no-test', action='store_true', help="don't run test (not suggested)")
304 parser.add_argument('--ntests', type=int, default=1000, help="number of random test, default=1000")
305 parser.add_argument('--test-file', type=str, help='numpy table')
306 parser.add_argument('--objective', type=str, help='Specify the learning task and the corresponding learning objective, currently support options: binary:logistic, reg:linear(squarederror), multi:softprob')
307
308 args = parser.parse_args()
309 logging.info("converting input file %s to root file %s", args.input, args.output)
310
311 # 'reg:linear'is been named as 'reg:squarederror' in newer version of xgboost (> 0.90)
312 supported_objective = ['binary:logistic', 'reg:linear', 'reg:squarederror','multi:softprob']
313
314 if args.objective not in supported_objective:
315 parser.error('''
316 Current version does NOT support this objective!!
317 Only the following objectives are supported and tested:
318 - binary:logistic
319 - reg:linear(or squarederror)
320 - multi:softprob
321 ''')
322
323 if not args.input:
324 parser.error('Model file name not given!')
325
326 if "root" not in args.output:
327 parser.error("The outputfile name must end with .root!!")
328
329 output_treename = convertXGBoostToRootTree(args.input, args.output, args.tree_name)
330
331 if args.no_test:
332 print("model has not been tested. Do not use it production!")
333 else:
334 logging.info("testing model")
335 if not args.test_file:
336 parser.error("Attempting to do test but no test file was provided, pass this with '--test-file <test_file> or use option '--no_test' ")
337 if not check_file(args.output):
338 print("problem when checking file")
339 result = test(args.input, args.output, args.objective, args.tree_name, args.ntests, args.test_file)
340 if not result:
341 print("some problems during test. Have you setup athena? Do not use this in production!")
342 else:
343 print(u"::: everything fine: XGBoost output == MVAUtils output :::")
344 objective = args.objective
345 import numpy as np
346 data = np.load(args.test_file)
347 if 'binary' in objective:
348 print('''In c++ use your BDT as:
349#include "MVAUtils/BDT.h"
350
351TFile* f = TFile::Open("%s");
352TTree* tree = nullptr;
353f->GetObject("%s", tree);
354MVAUtils::BDT my_bdt(tree);
355// ...
356// std::vector<float> input_values(%d, 0.);
357// fill the vector using the order as in the trainig
358// ...
359float output = my_bdt.GetClassification(input_values);
360 ''' % (args.output, output_treename, len(data[0])))
361 elif 'reg' in objective:
362 print('''In c++ use your BDT as:
363#include "MVAUtils/BDT.h"
364
365TFile* f = TFile::Open("%s");
366TTree* tree = nullptr;
367f->GetObject("%s", tree);
368MVAUtils::BDT my_bdt(tree);
369// ...
370// std::vector<float> input_values(%d, 0.);
371// fill the vector using the order as in the trainig
372// ...
373float output = my_bdt.Predict(input_values);
374 ''' % (args.output, output_treename, len(data[0])))
375 elif "multi" in objective:
376 print('''In c++ use your BDT as:
377#include "MVAUtils/BDT.h"
378
379TFile* f = TFile::Open("%s");
380TTree* tree = nullptr;
381f->GetObject("%s", tree);
382MVAUtils::BDT my_bdt(tree);
383// ...
384// std::vector<float> input_values(%d, 0.);
385// fill the vector using the order as in the trainig
386// ...
387float output = my_bdt.GetMultiResponse(input_values, nclasses);
388''' % (args.output, output_treename, len(data[0])))
void print(char *figname, TCanvas *c1)
test(model_file, tree_file, objective, tree_name='xgboost', ntests=10000, test_file=None)
test_binary(booster, mva_utils, objective, ntests=10000, test_file=None)
dump2ROOT(model, output_filename, output_treename='xgboost')
test_multiclass(booster, mva_utils, objective, ntests=10000, test_file=None)
test_regression(booster, mva_utils, objective, ntests=10000, test_file=None)