ATLAS Offline Software
convertXGBoostToRootTree.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
4 
5 __doc__ = "Convert XGBoost model to TTree to be used with MVAUtils."
6 __author__ = "Yuan-Tang Chou"
7 
8 
9 try:
10  import xgboost as xgb
11 except ImportError:
12  print("""cannot load xgboost. Try to install it with
13  pip install xgboost
14  """)
15 import ROOT
16 import time
17 import json
18 import logging
19 logging.basicConfig(level=logging.DEBUG)
20 
21 
22 class XBGoostTextNode(dict):
23  """
24  Adaptor from XGBoost dictionary to tree
25  * XGboost Yes is < and No is >=
26  """
27 
28  def get_split_feature(self):
29  if 'split' in self:
30  return self['split']
31  else:
32  return -1
33 
34  def get_value(self):
35  if 'split_condition' in self:
36  return self['split_condition']
37  else:
38  return self['leaf']
39 
40  def get_left(self):
41  if 'children' not in self:
42  return None
43  else: # XGBoost "YES" is left branch in MVAUtils
44  return XBGoostTextNode(self['children'][self.get_nodeid('yes')])
45 
46  def get_right(self):
47  if 'children' not in self:
48  return None
49  else: # XGBoost "NO" is right branch in MVAUtils
50  return XBGoostTextNode(self['children'][self.get_nodeid('no')])
51 
52  def get_nodeid(self, node_type):
53  for idx, children in enumerate(self['children']):
54  if children['nodeid'] == self[node_type]:
55  return idx
56 
57  def get_default_left(self):
58  if 'children' not in self:
59  return True
60  elif self.get_nodeid('yes') == self.get_nodeid('missing'):
61  return True
62  else:
63  return False
64 
65 
66 def dump_tree(tree_structure):
67  """
68  dump a single decision tree to arrays to be written into the TTree
69  """
70 
71  split_values = []
72  split_features = []
73  default_left = []
74  top = XBGoostTextNode(tree_structure)
75 
76  def preorder(node):
77  # visit root
78  split_features.append(node.get_split_feature())
79  split_values.append(node.get_value())
80  default_left.append(node.get_default_left())
81 
82  # visit (yes)left
83  if node.get_left() is not None:
84  preorder(node.get_left())
85  # visit (no)right
86  if node.get_right() is not None:
87  preorder(node.get_right())
88 
89  preorder(top)
90  return split_features, split_values, default_left
91 
92 def dump2ROOT(model, output_filename, output_treename='xgboost'):
93  model.dump_model('dump_model.json', dump_format='json')
94  with open('dump_model.json', 'r') as dump_json:
95  model_dump = dump_json.read()
96  trees = json.loads(model_dump)
97  fout = ROOT.TFile.Open(output_filename, 'recreate')
98 
99  features_array = ROOT.std.vector('int')()
100  values_array = ROOT.std.vector('float')()
101  default_lefts_array = ROOT.std.vector('bool')()
102 
103  title = 'creator=xgboost'
104  root_tree = ROOT.TTree(output_treename, title)
105  root_tree.Branch('vars', 'vector<int>', ROOT.AddressOf(features_array))
106  root_tree.Branch('values', 'vector<float>', ROOT.AddressOf(values_array))
107  root_tree.Branch('default_left', 'vector<bool>', ROOT.AddressOf(default_lefts_array))
108 
109  logging.info("tree support nan: using XGBoost implementation")
110 
111  for tree in trees:
112  tree_structure = tree
113  features, values, default_lefts = dump_tree(tree_structure)
114 
115  features_array.clear()
116  values_array.clear()
117  default_lefts_array.clear()
118 
119  for value in values:
120  values_array.push_back(value)
121  for feature in features:
122  features_array.push_back(feature)
123  for default_left in default_lefts:
124  default_lefts_array.push_back(default_left)
125 
126  root_tree.Fill()
127 
128  root_tree.Write()
129  fout.Close()
130  return output_treename
131 
132 def convertXGBoostToRootTree(model, output_filename, tree_name='xgboost'):
133  """
134  Model: - a string, in this case, it is the name of the input file containing the xgboost model
135  you can get this model with xgboost with `bst.save_model('my_model.model')
136  - directly a xgboost booster object
137  """
138  if type(model) is str:
139  bst = xgb.Booster()
140  bst.load_model(model)
141  return dump2ROOT(bst, output_filename, tree_name)
142  else:
143  return dump2ROOT(model, output_filename, tree_name)
144 
145 
146 def test(model_file, tree_file, objective, tree_name='xgboost', ntests=10000, test_file=None):
147  bst = xgb.Booster()
148  bst.load_model(model_file)
149  f = ROOT.TFile.Open(tree_file)
150  tree = f.Get(tree_name)
151  try:
152  _ = ROOT.MVAUtils.BDT
153  except Exception:
154  print("cannot import MVAUtils")
155  return None
156 
157  mva_utils = ROOT.MVAUtils.BDT(tree)
158 
159  if 'binary' in objective:
160  logging.info("testing binary")
161  return test_binary(bst, mva_utils, objective, ntests, test_file)
162  elif 'multi' in objective:
163  logging.info("testing multi-class")
164  return test_multiclass(bst,mva_utils, objective, ntests, test_file)
165  else:
166  logging.info("testing regression")
167  return test_regression(bst, mva_utils, objective, ntests, test_file)
168 
169 def test_regression(booster, mva_utils, objective, ntests=10000, test_file=None):
170  import numpy as np
171  logging.info("Tesing input features with regression")
172 
173  if test_file is not None:
174  data_input = np.load(test_file)
175  logging.info("using as input %s inputs from file %s", len(data_input), test_file)
176  else:
177  logging.error("Please provide an input test file for testing")
178 
179  start = time.time()
180  dTest = xgb.DMatrix(data_input)
181  results_xgboost = booster.predict(dTest)
182  logging.info("xgboost (vectorized) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
183 
184  input_values_vector = ROOT.std.vector("float")()
185  results_MVAUtils = []
186  start = time.time()
187  for input_values in data_input:
188  input_values_vector.clear()
189  for v in input_values:
190  input_values_vector.push_back(v)
191  output_MVAUtils = mva_utils.GetResponse(input_values_vector)
192  results_MVAUtils.append(output_MVAUtils)
193  logging.info("mvautils (not vectorized+overhead) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
194 
195  for input_values, output_xgb, output_MVAUtils in zip(data_input, results_xgboost, results_MVAUtils):
196  if not np.allclose(output_xgb, output_MVAUtils, rtol=1E-4):
197  logging.info("output are different:"
198  "mvautils: %s\n"
199  "xgboost: %s\n"
200  "inputs: %s", output_MVAUtils, output_xgb, input_values)
201  return False
202  return True
203 
204 
205 def test_binary(booster, mva_utils, objective, ntests=10000, test_file=None):
206  import numpy as np
207  logging.info("Testing input features with binary classification")
208  if test_file is not None:
209  data_input = np.load(test_file)
210  logging.info("using as input %s inputs from file %s", len(data_input), test_file)
211  else:
212  logging.error("Please provide an input test file for testing")
213 
214  start = time.time()
215  dTest = xgb.DMatrix(data_input)
216  results_xgboost = booster.predict(dTest)
217  logging.info("xgboost (vectorized) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
218 
219  input_values_vector = ROOT.std.vector("float")()
220  results_MVAUtils = []
221  start = time.time()
222  for input_values in data_input:
223  input_values_vector.clear()
224  for v in input_values:
225  input_values_vector.push_back(v)
226  output_MVAUtils = mva_utils.GetClassification(input_values_vector)
227  results_MVAUtils.append(output_MVAUtils)
228  logging.info("mvautils (not vectorized+overhead) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
229 
230  for input_values, output_xgb, output_MVAUtils in zip(data_input, results_xgboost, results_MVAUtils):
231  if not np.allclose(output_xgb, output_MVAUtils):
232  logging.info("output are different:"
233  "mvautils: %s\n"
234  "xgboost: %s\n"
235  "inputs: %s", output_MVAUtils, output_xgb, input_values)
236  return False
237  return True
238 
239 def test_multiclass(booster, mva_utils, objective, ntests=10000, test_file=None):
240  import numpy as np
241  logging.info("using multiclass model")
242 
243  if test_file is not None:
244  data_input = np.load(test_file)
245  logging.info("using as input %s inputs from file %s", len(data_input), test_file)
246  else:
247  logging.error("Please provide an input test file for testing")
248 
249  start = time.time()
250  dTest = xgb.DMatrix(data_input)
251  results_xgboost = booster.predict(dTest)
252 
253  nclasses = results_xgboost.shape[1]
254  logging.info("xgboost (vectorized) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
255 
256  input_values_vector = ROOT.std.vector("float")()
257  results_MVAUtils = []
258  start = time.time()
259  for input_values in data_input:
260  input_values_vector.clear()
261  for v in input_values:
262  input_values_vector.push_back(v)
263  output_MVAUtils = mva_utils.GetMultiResponse(input_values_vector, nclasses)
264  results_MVAUtils.append(output_MVAUtils)
265 
266  logging.info("mvautils (not vectorized+overhead) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
267 
268  for input_values, output_xgb, output_MVAUtils in zip(data_input, results_xgboost, results_MVAUtils):
269  if not np.allclose(output_xgb, output_MVAUtils):
270  logging.info("output are different:"
271  "mvautils: %s\n"
272  "xgboost: %s\n"
273  "inputs: %s", output_MVAUtils, output_xgb, input_values)
274  return False
275  return True
276 
277 
278 def check_file(fn):
279  f = ROOT.TFile.Open(fn)
280  keys = f.GetListOfKeys()
281  keys = list(keys)
282  if len(keys) != 1:
283  logging.info("file %s is empty", fn)
284  return False
285  tree = f.Get(keys[0].GetName())
286  if type(tree) is not ROOT.TTree:
287  logging.info("cannot find TTree in file %s", fn)
288  return False
289  if not tree.GetEntries():
290  logging.info("tree is empty")
291  return False
292  return True
293 
294 
295 if __name__ == "__main__":
296  import argparse
297 
298  parser = argparse.ArgumentParser(description=__doc__)
299  parser.add_argument('input', help='input xgboost model')
300  parser.add_argument('output', type=str, default ='xgboost_model.root', help='Output file name, it must end with .root')
301  parser.add_argument('--tree-name', type=str, default = 'xgboost', help="tree name in Output root file")
302  parser.add_argument('--no-test', action='store_true', help="don't run test (not suggested)")
303  parser.add_argument('--ntests', type=int, default=1000, help="number of random test, default=1000")
304  parser.add_argument('--test-file', type=str, help='numpy table')
305  parser.add_argument('--objective', type=str, help='Specify the learning task and the corresponding learning objective, currently support options: binary:logistic, reg:linear(squarederror), multi:softprob')
306 
307  args = parser.parse_args()
308  logging.info("converting input file %s to root file %s", args.input, args.output)
309 
310  # 'reg:linear'is been named as 'reg:squarederror' in newer version of xgboost (> 0.90)
311  supported_objective = ['binary:logistic', 'reg:linear', 'reg:squarederror','multi:softprob']
312 
313  if args.objective not in supported_objective:
314  parser.error('''
315  Current version does NOT support this objective!!
316  Only the following objectives are supported and tested:
317  - binary:logistic
318  - reg:linear(or squarederror)
319  - multi:softprob
320  ''')
321 
322  if not args.input:
323  parser.error('Model file name not given!')
324 
325  if "root" not in args.output:
326  parser.error("The outputfile name must end with .root!!")
327 
328  output_treename = convertXGBoostToRootTree(args.input, args.output, args.tree_name)
329 
330  if args.no_test:
331  print("model has not been tested. Do not use it production!")
332  else:
333  logging.info("testing model")
334  if not args.test_file:
335  parser.error("Attempting to do test but no test file was provided, pass this with '--test-file <test_file> or use option '--no_test' ")
336  if not check_file(args.output):
337  print("problem when checking file")
338  result = test(args.input, args.output, args.objective, args.tree_name, args.ntests, args.test_file)
339  if not result:
340  print("some problems during test. Have you setup athena? Do not use this in production!")
341  else:
342  print(u"::: everything fine: XGBoost output == MVAUtils output :::")
343  objective = args.objective
344  import numpy as np
345  data = np.load(args.test_file)
346  if 'binary' in objective:
347  print('''In c++ use your BDT as:
348 #include "MVAUtils/BDT.h"
349 
350 TFile* f = TFile::Open("%s");
351 TTree* tree = nullptr;
352 f->GetObject("%s", tree);
353 MVAUtils::BDT my_bdt(tree);
354 // ...
355 // std::vector<float> input_values(%d, 0.);
356 // fill the vector using the order as in the trainig
357 // ...
358 float output = my_bdt.GetClassification(input_values);
359  ''' % (args.output, output_treename, len(data[0])))
360  elif 'reg' in objective:
361  print('''In c++ use your BDT as:
362 #include "MVAUtils/BDT.h"
363 
364 TFile* f = TFile::Open("%s");
365 TTree* tree = nullptr;
366 f->GetObject("%s", tree);
367 MVAUtils::BDT my_bdt(tree);
368 // ...
369 // std::vector<float> input_values(%d, 0.);
370 // fill the vector using the order as in the trainig
371 // ...
372 float output = my_bdt.Predict(input_values);
373  ''' % (args.output, output_treename, len(data[0])))
374  elif "multi" in objective:
375  print('''In c++ use your BDT as:
376 #include "MVAUtils/BDT.h"
377 
378 TFile* f = TFile::Open("%s");
379 TTree* tree = nullptr;
380 f->GetObject("%s", tree);
381 MVAUtils::BDT my_bdt(tree);
382 // ...
383 // std::vector<float> input_values(%d, 0.);
384 // fill the vector using the order as in the trainig
385 // ...
386 float output = my_bdt.GetMultiResponse(input_values, nclasses);
387 ''' % (args.output, output_treename, len(data[0])))
util.convertXGBoostToRootTree.dump2ROOT
def dump2ROOT(model, output_filename, output_treename='xgboost')
Definition: convertXGBoostToRootTree.py:92
util.convertXGBoostToRootTree.check_file
def check_file(fn)
Definition: convertXGBoostToRootTree.py:278
util.convertXGBoostToRootTree.XBGoostTextNode
Definition: convertXGBoostToRootTree.py:22
util.convertXGBoostToRootTree.dump_tree
def dump_tree(tree_structure)
Definition: convertXGBoostToRootTree.py:66
util.convertXGBoostToRootTree.XBGoostTextNode.get_value
def get_value(self)
Definition: convertXGBoostToRootTree.py:34
util.convertXGBoostToRootTree.XBGoostTextNode.get_split_feature
def get_split_feature(self)
Definition: convertXGBoostToRootTree.py:28
util.convertXGBoostToRootTree.XBGoostTextNode.get_right
def get_right(self)
Definition: convertXGBoostToRootTree.py:46
util.convertXGBoostToRootTree.type
type
Definition: convertXGBoostToRootTree.py:300
util.convertXGBoostToRootTree.test
def test(model_file, tree_file, objective, tree_name='xgboost', ntests=10000, test_file=None)
Definition: convertXGBoostToRootTree.py:146
util.convertXGBoostToRootTree.XBGoostTextNode.get_nodeid
def get_nodeid(self, node_type)
Definition: convertXGBoostToRootTree.py:52
histSizes.list
def list(name, path='/')
Definition: histSizes.py:38
util.convertXGBoostToRootTree.test_regression
def test_regression(booster, mva_utils, objective, ntests=10000, test_file=None)
Definition: convertXGBoostToRootTree.py:169
print
void print(char *figname, TCanvas *c1)
Definition: TRTCalib_StrawStatusPlots.cxx:25
util.convertXGBoostToRootTree.XBGoostTextNode.get_default_left
def get_default_left(self)
Definition: convertXGBoostToRootTree.py:57
Trk::open
@ open
Definition: BinningType.h:40
util.convertXGBoostToRootTree.test_binary
def test_binary(booster, mva_utils, objective, ntests=10000, test_file=None)
Definition: convertXGBoostToRootTree.py:205
util.convertXGBoostToRootTree.XBGoostTextNode.get_left
def get_left(self)
Definition: convertXGBoostToRootTree.py:40
util.convertXGBoostToRootTree.convertXGBoostToRootTree
def convertXGBoostToRootTree(model, output_filename, tree_name='xgboost')
Definition: convertXGBoostToRootTree.py:132
util.convertXGBoostToRootTree.test_multiclass
def test_multiclass(booster, mva_utils, objective, ntests=10000, test_file=None)
Definition: convertXGBoostToRootTree.py:239