ATLAS Offline Software
convertXGBoostToRootTree.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
4 
5 """ Convert XGBoost model to TTree to be used with MVAUtils. """
6 
7 __author__ = "Yuan-Tang Chou"
8 
9 
10 try:
11  import xgboost as xgb
12 except ImportError:
13  print("""cannot load xgboost. Try to install it with
14  pip install xgboost
15  """)
16 import ROOT
17 import time
18 import json
19 import logging
20 logging.basicConfig(level=logging.DEBUG)
21 
22 
23 class XBGoostTextNode(dict):
24  """
25  Adaptor from XGBoost dictionary to tree
26  * XGboost Yes is < and No is >=
27  """
28 
29  def get_split_feature(self):
30  if 'split' in self:
31  return self['split']
32  else:
33  return -1
34 
35  def get_value(self):
36  if 'split_condition' in self:
37  return self['split_condition']
38  else:
39  return self['leaf']
40 
41  def get_left(self):
42  if 'children' not in self:
43  return None
44  else: # XGBoost "YES" is left branch in MVAUtils
45  return XBGoostTextNode(self['children'][self.get_nodeid('yes')])
46 
47  def get_right(self):
48  if 'children' not in self:
49  return None
50  else: # XGBoost "NO" is right branch in MVAUtils
51  return XBGoostTextNode(self['children'][self.get_nodeid('no')])
52 
53  def get_nodeid(self, node_type):
54  for idx, children in enumerate(self['children']):
55  if children['nodeid'] == self[node_type]:
56  return idx
57 
58  def get_default_left(self):
59  if 'children' not in self:
60  return True
61  elif self.get_nodeid('yes') == self.get_nodeid('missing'):
62  return True
63  else:
64  return False
65 
66 
67 def dump_tree(tree_structure):
68  """
69  dump a single decision tree to arrays to be written into the TTree
70  """
71 
72  split_values = []
73  split_features = []
74  default_left = []
75  top = XBGoostTextNode(tree_structure)
76 
77  def preorder(node):
78  # visit root
79  split_features.append(node.get_split_feature())
80  split_values.append(node.get_value())
81  default_left.append(node.get_default_left())
82 
83  # visit (yes)left
84  if node.get_left() is not None:
85  preorder(node.get_left())
86  # visit (no)right
87  if node.get_right() is not None:
88  preorder(node.get_right())
89 
90  preorder(top)
91  return split_features, split_values, default_left
92 
93 def dump2ROOT(model, output_filename, output_treename='xgboost'):
94  model.dump_model('dump_model.json', dump_format='json')
95  with open('dump_model.json', 'r') as dump_json:
96  model_dump = dump_json.read()
97  trees = json.loads(model_dump)
98  fout = ROOT.TFile.Open(output_filename, 'recreate')
99 
100  features_array = ROOT.std.vector('int')()
101  values_array = ROOT.std.vector('float')()
102  default_lefts_array = ROOT.std.vector('bool')()
103 
104  title = 'creator=xgboost'
105  root_tree = ROOT.TTree(output_treename, title)
106  root_tree.Branch('vars', 'vector<int>', ROOT.AddressOf(features_array))
107  root_tree.Branch('values', 'vector<float>', ROOT.AddressOf(values_array))
108  root_tree.Branch('default_left', 'vector<bool>', ROOT.AddressOf(default_lefts_array))
109 
110  logging.info("tree support nan: using XGBoost implementation")
111 
112  for tree in trees:
113  tree_structure = tree
114  features, values, default_lefts = dump_tree(tree_structure)
115 
116  features_array.clear()
117  values_array.clear()
118  default_lefts_array.clear()
119 
120  for value in values:
121  values_array.push_back(value)
122  for feature in features:
123  features_array.push_back(feature)
124  for default_left in default_lefts:
125  default_lefts_array.push_back(default_left)
126 
127  root_tree.Fill()
128 
129  root_tree.Write()
130  fout.Close()
131  return output_treename
132 
133 def convertXGBoostToRootTree(model, output_filename, tree_name='xgboost'):
134  """
135  Model: - a string, in this case, it is the name of the input file containing the xgboost model
136  you can get this model with xgboost with `bst.save_model('my_model.model')
137  - directly a xgboost booster object
138  """
139  if type(model) is str:
140  bst = xgb.Booster()
141  bst.load_model(model)
142  return dump2ROOT(bst, output_filename, tree_name)
143  else:
144  return dump2ROOT(model, output_filename, tree_name)
145 
146 
147 def test(model_file, tree_file, objective, tree_name='xgboost', ntests=10000, test_file=None):
148  bst = xgb.Booster()
149  bst.load_model(model_file)
150  f = ROOT.TFile.Open(tree_file)
151  tree = f.Get(tree_name)
152  try:
153  _ = ROOT.MVAUtils.BDT
154  except Exception:
155  print("cannot import MVAUtils")
156  return None
157 
158  mva_utils = ROOT.MVAUtils.BDT(tree)
159 
160  if 'binary' in objective:
161  logging.info("testing binary")
162  return test_binary(bst, mva_utils, objective, ntests, test_file)
163  elif 'multi' in objective:
164  logging.info("testing multi-class")
165  return test_multiclass(bst,mva_utils, objective, ntests, test_file)
166  else:
167  logging.info("testing regression")
168  return test_regression(bst, mva_utils, objective, ntests, test_file)
169 
170 def test_regression(booster, mva_utils, objective, ntests=10000, test_file=None):
171  import numpy as np
172  logging.info("Tesing input features with regression")
173 
174  if test_file is not None:
175  data_input = np.load(test_file)
176  logging.info("using as input %s inputs from file %s", len(data_input), test_file)
177  else:
178  logging.error("Please provide an input test file for testing")
179 
180  start = time.time()
181  dTest = xgb.DMatrix(data_input)
182  results_xgboost = booster.predict(dTest)
183  logging.info("xgboost (vectorized) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
184 
185  input_values_vector = ROOT.std.vector("float")()
186  results_MVAUtils = []
187  start = time.time()
188  for input_values in data_input:
189  input_values_vector.clear()
190  for v in input_values:
191  input_values_vector.push_back(v)
192  output_MVAUtils = mva_utils.GetResponse(input_values_vector)
193  results_MVAUtils.append(output_MVAUtils)
194  logging.info("mvautils (not vectorized+overhead) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
195 
196  for input_values, output_xgb, output_MVAUtils in zip(data_input, results_xgboost, results_MVAUtils):
197  if not np.allclose(output_xgb, output_MVAUtils, rtol=1E-4):
198  logging.info("output are different:"
199  "mvautils: %s\n"
200  "xgboost: %s\n"
201  "inputs: %s", output_MVAUtils, output_xgb, input_values)
202  return False
203  return True
204 
205 
206 def test_binary(booster, mva_utils, objective, ntests=10000, test_file=None):
207  import numpy as np
208  logging.info("Testing input features with binary classification")
209  if test_file is not None:
210  data_input = np.load(test_file)
211  logging.info("using as input %s inputs from file %s", len(data_input), test_file)
212  else:
213  logging.error("Please provide an input test file for testing")
214 
215  start = time.time()
216  dTest = xgb.DMatrix(data_input)
217  results_xgboost = booster.predict(dTest)
218  logging.info("xgboost (vectorized) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
219 
220  input_values_vector = ROOT.std.vector("float")()
221  results_MVAUtils = []
222  start = time.time()
223  for input_values in data_input:
224  input_values_vector.clear()
225  for v in input_values:
226  input_values_vector.push_back(v)
227  output_MVAUtils = mva_utils.GetClassification(input_values_vector)
228  results_MVAUtils.append(output_MVAUtils)
229  logging.info("mvautils (not vectorized+overhead) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
230 
231  for input_values, output_xgb, output_MVAUtils in zip(data_input, results_xgboost, results_MVAUtils):
232  if not np.allclose(output_xgb, output_MVAUtils):
233  logging.info("output are different:"
234  "mvautils: %s\n"
235  "xgboost: %s\n"
236  "inputs: %s", output_MVAUtils, output_xgb, input_values)
237  return False
238  return True
239 
240 def test_multiclass(booster, mva_utils, objective, ntests=10000, test_file=None):
241  import numpy as np
242  logging.info("using multiclass model")
243 
244  if test_file is not None:
245  data_input = np.load(test_file)
246  logging.info("using as input %s inputs from file %s", len(data_input), test_file)
247  else:
248  logging.error("Please provide an input test file for testing")
249 
250  start = time.time()
251  dTest = xgb.DMatrix(data_input)
252  results_xgboost = booster.predict(dTest)
253 
254  nclasses = results_xgboost.shape[1]
255  logging.info("xgboost (vectorized) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
256 
257  input_values_vector = ROOT.std.vector("float")()
258  results_MVAUtils = []
259  start = time.time()
260  for input_values in data_input:
261  input_values_vector.clear()
262  for v in input_values:
263  input_values_vector.push_back(v)
264  output_MVAUtils = mva_utils.GetMultiResponse(input_values_vector, nclasses)
265  results_MVAUtils.append(output_MVAUtils)
266 
267  logging.info("mvautils (not vectorized+overhead) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input))
268 
269  for input_values, output_xgb, output_MVAUtils in zip(data_input, results_xgboost, results_MVAUtils):
270  if not np.allclose(output_xgb, output_MVAUtils):
271  logging.info("output are different:"
272  "mvautils: %s\n"
273  "xgboost: %s\n"
274  "inputs: %s", output_MVAUtils, output_xgb, input_values)
275  return False
276  return True
277 
278 
279 def check_file(fn):
280  f = ROOT.TFile.Open(fn)
281  keys = f.GetListOfKeys()
282  keys = list(keys)
283  if len(keys) != 1:
284  logging.info("file %s is empty", fn)
285  return False
286  tree = f.Get(keys[0].GetName())
287  if type(tree) is not ROOT.TTree:
288  logging.info("cannot find TTree in file %s", fn)
289  return False
290  if not tree.GetEntries():
291  logging.info("tree is empty")
292  return False
293  return True
294 
295 
296 if __name__ == "__main__":
297  import argparse
298 
299  parser = argparse.ArgumentParser(description=__doc__)
300  parser.add_argument('input', help='input xgboost model')
301  parser.add_argument('output', type=str, default ='xgboost_model.root', help='Output file name, it must end with .root')
302  parser.add_argument('--tree-name', type=str, default = 'xgboost', help="tree name in Output root file")
303  parser.add_argument('--no-test', action='store_true', help="don't run test (not suggested)")
304  parser.add_argument('--ntests', type=int, default=1000, help="number of random test, default=1000")
305  parser.add_argument('--test-file', type=str, help='numpy table')
306  parser.add_argument('--objective', type=str, help='Specify the learning task and the corresponding learning objective, currently support options: binary:logistic, reg:linear(squarederror), multi:softprob')
307 
308  args = parser.parse_args()
309  logging.info("converting input file %s to root file %s", args.input, args.output)
310 
311  # 'reg:linear'is been named as 'reg:squarederror' in newer version of xgboost (> 0.90)
312  supported_objective = ['binary:logistic', 'reg:linear', 'reg:squarederror','multi:softprob']
313 
314  if args.objective not in supported_objective:
315  parser.error('''
316  Current version does NOT support this objective!!
317  Only the following objectives are supported and tested:
318  - binary:logistic
319  - reg:linear(or squarederror)
320  - multi:softprob
321  ''')
322 
323  if not args.input:
324  parser.error('Model file name not given!')
325 
326  if "root" not in args.output:
327  parser.error("The outputfile name must end with .root!!")
328 
329  output_treename = convertXGBoostToRootTree(args.input, args.output, args.tree_name)
330 
331  if args.no_test:
332  print("model has not been tested. Do not use it production!")
333  else:
334  logging.info("testing model")
335  if not args.test_file:
336  parser.error("Attempting to do test but no test file was provided, pass this with '--test-file <test_file> or use option '--no_test' ")
337  if not check_file(args.output):
338  print("problem when checking file")
339  result = test(args.input, args.output, args.objective, args.tree_name, args.ntests, args.test_file)
340  if not result:
341  print("some problems during test. Have you setup athena? Do not use this in production!")
342  else:
343  print(u"::: everything fine: XGBoost output == MVAUtils output :::")
344  objective = args.objective
345  import numpy as np
346  data = np.load(args.test_file)
347  if 'binary' in objective:
348  print('''In c++ use your BDT as:
349 #include "MVAUtils/BDT.h"
350 
351 TFile* f = TFile::Open("%s");
352 TTree* tree = nullptr;
353 f->GetObject("%s", tree);
354 MVAUtils::BDT my_bdt(tree);
355 // ...
356 // std::vector<float> input_values(%d, 0.);
357 // fill the vector using the order as in the trainig
358 // ...
359 float output = my_bdt.GetClassification(input_values);
360  ''' % (args.output, output_treename, len(data[0])))
361  elif 'reg' in objective:
362  print('''In c++ use your BDT as:
363 #include "MVAUtils/BDT.h"
364 
365 TFile* f = TFile::Open("%s");
366 TTree* tree = nullptr;
367 f->GetObject("%s", tree);
368 MVAUtils::BDT my_bdt(tree);
369 // ...
370 // std::vector<float> input_values(%d, 0.);
371 // fill the vector using the order as in the trainig
372 // ...
373 float output = my_bdt.Predict(input_values);
374  ''' % (args.output, output_treename, len(data[0])))
375  elif "multi" in objective:
376  print('''In c++ use your BDT as:
377 #include "MVAUtils/BDT.h"
378 
379 TFile* f = TFile::Open("%s");
380 TTree* tree = nullptr;
381 f->GetObject("%s", tree);
382 MVAUtils::BDT my_bdt(tree);
383 // ...
384 // std::vector<float> input_values(%d, 0.);
385 // fill the vector using the order as in the trainig
386 // ...
387 float output = my_bdt.GetMultiResponse(input_values, nclasses);
388 ''' % (args.output, output_treename, len(data[0])))
util.convertXGBoostToRootTree.dump2ROOT
def dump2ROOT(model, output_filename, output_treename='xgboost')
Definition: convertXGBoostToRootTree.py:93
util.convertXGBoostToRootTree.check_file
def check_file(fn)
Definition: convertXGBoostToRootTree.py:279
util.convertXGBoostToRootTree.XBGoostTextNode
Definition: convertXGBoostToRootTree.py:23
util.convertXGBoostToRootTree.dump_tree
def dump_tree(tree_structure)
Definition: convertXGBoostToRootTree.py:67
util.convertXGBoostToRootTree.XBGoostTextNode.get_value
def get_value(self)
Definition: convertXGBoostToRootTree.py:35
util.convertXGBoostToRootTree.XBGoostTextNode.get_split_feature
def get_split_feature(self)
Definition: convertXGBoostToRootTree.py:29
util.convertXGBoostToRootTree.XBGoostTextNode.get_right
def get_right(self)
Definition: convertXGBoostToRootTree.py:47
util.convertXGBoostToRootTree.type
type
Definition: convertXGBoostToRootTree.py:301
util.convertXGBoostToRootTree.test
def test(model_file, tree_file, objective, tree_name='xgboost', ntests=10000, test_file=None)
Definition: convertXGBoostToRootTree.py:147
util.convertXGBoostToRootTree.XBGoostTextNode.get_nodeid
def get_nodeid(self, node_type)
Definition: convertXGBoostToRootTree.py:53
histSizes.list
def list(name, path='/')
Definition: histSizes.py:38
util.convertXGBoostToRootTree.test_regression
def test_regression(booster, mva_utils, objective, ntests=10000, test_file=None)
Definition: convertXGBoostToRootTree.py:170
print
void print(char *figname, TCanvas *c1)
Definition: TRTCalib_StrawStatusPlots.cxx:25
util.convertXGBoostToRootTree.XBGoostTextNode.get_default_left
def get_default_left(self)
Definition: convertXGBoostToRootTree.py:58
Trk::open
@ open
Definition: BinningType.h:40
util.convertXGBoostToRootTree.test_binary
def test_binary(booster, mva_utils, objective, ntests=10000, test_file=None)
Definition: convertXGBoostToRootTree.py:206
util.convertXGBoostToRootTree.XBGoostTextNode.get_left
def get_left(self)
Definition: convertXGBoostToRootTree.py:41
util.convertXGBoostToRootTree.convertXGBoostToRootTree
def convertXGBoostToRootTree(model, output_filename, tree_name='xgboost')
Definition: convertXGBoostToRootTree.py:133
util.convertXGBoostToRootTree.test_multiclass
def test_multiclass(booster, mva_utils, objective, ntests=10000, test_file=None)
Definition: convertXGBoostToRootTree.py:240