ATLAS Offline Software
convertLGBMToRootTree.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
4 
5 __doc__ = "Convert LightGBM model to TTree to be used with MVAUtils."
6 __author__ = "Ruggero Turra"
7 
8 try:
9  import lightgbm as lgb
10 except ImportError:
11  print(
12  """cannot load lightgbm. Try to install it with
13  pip install lightgbm
14 or (usually on lxplus)
15  pip install numpy scipy scikit-learn
16  pip install --no-binary :all: lightgbm
17 """
18  )
19 import ROOT
20 import time
21 import logging
22 import numpy as np
23 
24 logging.basicConfig(level=logging.DEBUG)
25 
26 
27 def lgbm_rawresponse_each_tree(model, my_input):
28  nclasses = model.num_model_per_iteration()
29  output_values = np.array(
30  [np.array([[0] * nclasses])]
31  + [
32  model.predict(np.atleast_2d(my_input), raw_score=True, num_iteration=itree)
33  for itree in range(1, (model.num_trees() // nclasses + 1))
34  ]
35  )
36  output_trees = np.diff(output_values, axis=0)
37  return output_trees
38 
39 
40 def list2stdvector(values, dtype="float"):
41  result = ROOT.std.vector(dtype)()
42  for v in values:
43  result.push_back(v)
44  return result
45 
46 
47 class LGBMTextNode(dict):
48  """
49  Adaptor from LGBM dictionary to tree
50  """
51 
52  def __init__(self, structure, invert_as_tmva=False):
53  super(LGBMTextNode, self).__init__(structure)
54  self.invert_as_tmva = invert_as_tmva
55 
56  def get_split_feature(self):
57  if "split_feature" in self:
58  return self["split_feature"]
59  else:
60  return -1
61 
62  def get_value(self):
63  if "threshold" in self:
64  return self["threshold"]
65  else:
66  return self["leaf_value"]
67 
68  def _get_left(self):
69  if "left_child" not in self:
70  return None
71  if self["decision_type"] == "<=":
72  return LGBMTextNode(self["left_child"])
73  else:
74  return LGBMTextNode(self["right_child"])
75 
76  def _get_right(self):
77  if "right_child" not in self:
78  return None
79  if self["decision_type"] == "<=":
80  return LGBMTextNode(self["right_child"])
81  else:
82  return LGBMTextNode(self["left_child"])
83 
84  def get_left(self):
85  if not self.invert_as_tmva:
86  return self._get_left()
87  else:
88  return self._get_right()
89 
90  def get_right(self):
91  if not self.invert_as_tmva:
92  return self._get_right()
93  else:
94  return self._get_left()
95 
96  def get_default_left(self):
97  return self.get("default_left", True)
98 
99 
100 def dump_tree(tree_structure):
101  """
102  dump a single decision tree to arrays to be written into the TTree
103  """
104 
105  split_values = []
106  split_features = []
107  default_left = []
108  top = LGBMTextNode(tree_structure)
109  simple = [True] # python2 lack of nonlocal variables
110 
111  def preorder(node):
112  # visit root
113  split_features.append(node.get_split_feature())
114  split_values.append(node.get_value())
115  default_left.append(node.get_default_left())
116 
117  if not node.get_default_left():
118  simple[0] = False
119 
120  if "decision_type" in node and node["decision_type"] != "<=":
121  raise ValueError(
122  "do not support categorical input BDT (decision_type = %s)" % node["decision_type"]
123  )
124 
125  if "missing_type" in node:
126  if node["missing_type"] not in ("NaN", "None"):
127  raise ValueError("do not support missing values different from NaN or None")
128 
129  # visit left
130  if node.get_left() is not None:
131  preorder(node.get_left())
132  # visit right
133  if node.get_right() is not None:
134  preorder(node.get_right())
135 
136  preorder(top)
137  return split_features, split_values, default_left, simple[0]
138 
139 
140 def dump2ROOT(model, output_filename, output_treename="lgbm"):
141  model = model.dump_model()
142  fout = ROOT.TFile.Open(output_filename, "recreate")
143 
144  features_array = ROOT.std.vector("int")()
145  values_array = ROOT.std.vector("float")()
146  default_lefts_array = ROOT.std.vector("bool")()
147 
148  simple = True
149  node_type = "node_type=lgbm_simple"
150  for tree in model["tree_info"]:
151  tree_structure = tree["tree_structure"]
152  features, values, default_lefts, simple_tree = dump_tree(tree_structure)
153  if not simple_tree:
154  simple = False
155  node_type = "node_type=lgbm"
156 
157  infos = ";".join(["%s=%s" % (k, str(v)) for k, v in model.items() if type(v) is not list])
158  title = ";".join(("creator=lgbm", node_type, infos))
159  root_tree = ROOT.TTree(output_treename, title)
160  root_tree.Branch("vars", "vector<int>", ROOT.AddressOf(features_array))
161  root_tree.Branch("values", "vector<float>", ROOT.AddressOf(values_array))
162 
163  if not simple:
164  logging.info("tree support nan: using full implementation (LGBMNode)")
165  root_tree.Branch("default_left", "vector<bool>", ROOT.AddressOf(default_lefts_array))
166  if simple:
167  logging.info("tree do not support nan:" "using simple implementation (LGBMNodeSimple)")
168 
169  for tree in model["tree_info"]:
170  tree_structure = tree["tree_structure"]
171  features, values, default_lefts, simple_tree = dump_tree(tree_structure)
172 
173  features_array.clear()
174  values_array.clear()
175  default_lefts_array.clear()
176 
177  for value in values:
178  values_array.push_back(value)
179  for feature in features:
180  features_array.push_back(feature)
181  for default_left in default_lefts:
182  default_lefts_array.push_back(default_left)
183 
184  root_tree.Fill()
185 
186  root_tree.Write()
187  fout.Close()
188  return output_treename
189 
190 
191 def convertLGBMToRootTree(model, output_filename, tree_name="lgbm"):
192  """
193  Model: - a string, in this case, it is the name of
194  the input file containing the lgbm model you
195  can get this model with lgbm with
196  `boosted.save_model('my_model.txt')
197  - directly a lgbm booster object
198  """
199  if type(model) is str:
200  model = lgb.Booster(model_file=model)
201  return dump2ROOT(model, output_filename, tree_name)
202  else:
203  return dump2ROOT(model, output_filename, tree_name)
204 
205 
206 def test(model_file, tree_file, tree_name="lgbm", ntests=10000, test_file=None):
207  booster = lgb.Booster(model_file=model_file)
208  f = ROOT.TFile.Open(tree_file)
209  tree = f.Get(tree_name)
210  try:
211  _ = ROOT.MVAUtils.BDT
212  except Exception:
213  print("cannot import MVAUtils")
214  return None
215 
216  mva_utils = ROOT.MVAUtils.BDT(tree)
217 
218  objective = booster.dump_model()["objective"]
219 
220  # sometimes options are inlined with objective
221  # we don't support non-default options
222  objective = objective.replace("sigmoid:1", "")
223  objective = objective.strip()
224 
225  # binary and xentropy are not the exact same thing when training but the output value is the same
226  # same for l1/l2/huber/... regression
227  # (https://lightgbm.readthedocs.io/en/latest/Parameters.html)
228  binary_aliases = ("binary", "cross_entropy", "xentropy")
229  regression_aliases = (
230  (
231  "regression_l2",
232  "l2",
233  "mean_squared_error",
234  "mse",
235  "l2_root",
236  "root_mean_squared_error",
237  "rmse",
238  )
239  + ("regression_l1", "l1", "mean_absolute_error", "mae")
240  + ("huber",)
241  )
242  multiclass_aliases = ("multiclass", "softmax")
243  if objective in multiclass_aliases:
244  logging.info("assuming multiclass, testing")
245  return test_multiclass(booster, mva_utils, ntests, test_file)
246  elif objective in binary_aliases:
247  logging.info("assuming binary classification, testing")
248  return test_binary(booster, mva_utils, ntests, test_file)
249  elif objective in regression_aliases:
250  logging.info("assuming regression, testing")
251  return test_regression(booster, mva_utils, ntests, test_file)
252  else:
253  print("cannot understand objective '%s'" % objective)
254 
255 
256 def get_test_data(feature_names, test_file=None, ntests=None):
257  nvars = len(feature_names)
258  if test_file is not None:
259  if ".root" in test_file:
260  if ":" not in test_file:
261  raise ValueError("when using ROOT file as test use the syntax filename:treename")
262  fn, tn = test_file.split(":")
263  f = ROOT.TFile.Open(fn)
264  if not f:
265  raise IOError("cannot find ROOT file %s" % fn)
266  tree = f.Get(tn)
267  if not tree:
268  raise IOError("cannot find TTree %s in %s" % (fn, tn))
269  branch_names = [br.GetName() for br in tree.GetListOfBranches()]
270  for feature in feature_names:
271  if feature not in branch_names:
272  raise IOError("required feature %s not in TTree")
273  rdf = ROOT.RDataFrame(tree, feature_names)
274  data_input = rdf.AsNumpy()
275  data_input = np.stack([data_input[k] for k in feature_names]).T
276  if ntests is not None:
277  data_input = data_input[:ntests]
278  logging.info(
279  "using as input %s inputs from TTree %s from ROOT file %s", len(data_input), tn, fn
280  )
281  else:
282  data_input = np.load(test_file)
283  if ntests is not None:
284  data_input = data_input[:ntests]
285  logging.info("using as input %s inputs from pickle file %s", len(data_input), test_file)
286  else:
287  if ntests is None:
288  ntests = 10000
289  logging.info("using as input %s random uniform inputs (-100,100)", ntests)
290  logging.warning(
291  "using random uniform input as test: this is not safe" "provide an input test file"
292  )
293  data_input = np.random.uniform(-100, 100, size=(ntests, nvars))
294 
295  # to match what mvautils is doing (using c-float)
296  data_input = data_input.astype(np.float32)
297  return data_input
298 
299 
300 def test_generic(booster, mvautils_predict, mva_utils, data_input):
301  start = time.time()
302  results_lgbm = booster.predict(data_input)
303  logging.info("lgbm (vectorized) timing = %d/s", len(data_input) / (time.time() - start))
304 
305  input_values_vector = ROOT.std.vector("float")()
306  results_MVAUtils = []
307  start = time.time()
308  for input_values in data_input:
309  input_values_vector.clear()
310  for v in input_values:
311  input_values_vector.push_back(v)
312  output_MVAUtils = mvautils_predict(input_values_vector)
313  results_MVAUtils.append(output_MVAUtils)
314  logging.info(
315  "mvautils (not vectorized+overhead) timing = %d/s", len(data_input) / (time.time() - start)
316  )
317 
318  nevents_tested = 0
319  nevents_different = 0
320  for ievent, (input_values, output_lgbm, output_MVAUtils) in enumerate(
321  zip(data_input, results_lgbm, results_MVAUtils), 1
322  ):
323  nevents_tested += 1
324  if not np.allclose(output_lgbm, output_MVAUtils, rtol=1e-4):
325  nevents_different += 1
326  logging.info(
327  "--> output are different on input %d/%d mvautils: %s lgbm: %s",
328  ievent,
329  len(data_input),
330  output_MVAUtils,
331  output_lgbm,
332  )
333  if not test_detail_event(booster, mva_utils, input_values):
334  return False
335  logging.info("number of different events %d/%d", nevents_different, nevents_tested)
336  return True
337 
338 
339 # helper for tree traversal
340 def _ff(tree, node_infos):
341  if "left_child" in tree:
342  node_infos.append((tree["split_feature"], tree["threshold"]))
343  _ff(tree["left_child"])
344  _ff(tree["right_child"])
345 
346 
347 def test_detail_event(booster, mva_utils, input_values):
348  logging.info("input values")
349  for ivar, input_value in enumerate(input_values):
350  logging.info("var %d: %.15f", ivar, input_value)
351  logging.info("=" * 50)
352 
353  ntrees_mva_utils = mva_utils.GetNTrees()
354  if ntrees_mva_utils != booster.num_trees():
355  logging.info("Number of trees are different mvautils: %s lgbm: %s", ntrees_mva_utils, booster.num_trees())
356  tree_outputs_lgbm = lgbm_rawresponse_each_tree(booster, input_values)
357 
358  # loop over the trees
359  is_problem_found = False
360  for itree in range(ntrees_mva_utils):
361  tree_output_mvautils = mva_utils.GetTreeResponse(list2stdvector(input_values), itree)
362  tree_output_lgbm = tree_outputs_lgbm[itree][0]
363  if not np.allclose(tree_output_mvautils, tree_output_lgbm):
364  is_tree_ok = False
365  is_problem_found = True
366  logging.info("tree %d/%d are different", itree, ntrees_mva_utils)
367  logging.info("lgbm: %f", tree_output_lgbm)
368  logging.info("MVAUtils: %f", tree_output_mvautils)
369  logging.info("Tree details from MVAUtils")
370  mva_utils.PrintTree(itree)
371 
372  # dump the tree from lightgbm
373  node_infos = []
374  _ff(
375  booster.dump_model()["tree_info"][itree][
376  "tree_structure"
377  ],
378  node_infos
379  )
380 
381  # we now which tree is failing, check if this is
382  # due to input values very close to the threshold
383  # the problem is that lgbm is using double,
384  # while mva_utils is using float
385 
386  for node_info in node_infos:
387  value = input_values[node_info[0]]
388  threshold = node_info[1]
389  if not np.isnan(value) and (value <= threshold) != (
390  np.float32(value) <= np.float32(threshold)
391  ):
392  logging.info(
393  "the problem could be due to double"
394  "(lgbm) -> float (mvautil) conversion"
395  " for variable %d: %.10f and threshold %.10f",
396  node_info[0],
397  value,
398  threshold,
399  )
400  # we consider this ok
401  is_tree_ok = True
402  break
403  if not is_tree_ok:
404  return False
405 
406  if is_problem_found:
407  # if we have found the problem, but we arrive here
408  # it means that we found the problematic tree,
409  # but it is ok
410  return True
411 
412 
413 def test_regression(booster, mva_utils, ntests=None, test_file=None):
414  data_input = get_test_data(booster.feature_name(), test_file, ntests)
415  return test_generic(booster, mva_utils.GetResponse, mva_utils, data_input)
416 
417 
418 def test_binary(booster, mva_utils, ntests=None, test_file=None):
419  data_input = get_test_data(booster.feature_name(), test_file, ntests)
420  return test_generic(booster, mva_utils.GetClassification, mva_utils, data_input)
421 
422 
423 def test_multiclass(booster, mva_utils, ntests=10000, test_file=None):
424  import numpy as np
425 
426  nvars = booster.num_feature()
427  nclasses = booster.num_model_per_iteration()
428  logging.info("using %d input features with %d classes", nvars, nclasses)
429 
430  data_input = get_test_data(booster.feature_name(), test_file, ntests)
431 
432  start = time.time()
433  results_lgbm = booster.predict(data_input)
434  logging.info(
435  "lgbm (vectorized) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input)
436  )
437 
438  input_values_vector = ROOT.std.vector("float")()
439  results_MVAUtils = []
440  start = time.time()
441  for input_values in data_input:
442  input_values_vector.clear()
443  for v in input_values:
444  input_values_vector.push_back(v)
445  output_MVAUtils = np.asarray(mva_utils.GetMultiResponse(input_values_vector, nclasses))
446  results_MVAUtils.append(output_MVAUtils)
447  logging.info(
448  "mvautils (not vectorized+overhead) timing = %s ms/input",
449  (time.time() - start) * 1000 / len(data_input),
450  )
451 
452  stop_event_loop = False
453  for ievent, (input_values, output_lgbm, output_MVAUtils) in enumerate(
454  zip(data_input, results_lgbm, results_MVAUtils), 1
455  ):
456  if not np.allclose(output_lgbm, output_MVAUtils):
457  stop_event_loop = True
458  logging.info("--> output are different on input %d/%d:\n", ievent, len(data_input))
459  for ivar, input_value in enumerate(input_values):
460  logging.info("var %d: %.15f", ivar, input_value)
461  logging.info("=" * 50)
462  logging.info(" mvautils lgbm")
463  for ioutput, (o1, o2) in enumerate(zip(output_MVAUtils, output_lgbm)):
464  diff_flag = "" if np.allclose(o1, o2) else "<---"
465  logging.info("output %3d %.5e %.5e %s", ioutput, o1, o2, diff_flag)
466  output_trees_lgbm = lgbm_rawresponse_each_tree(booster, [input_values])
467 
468  stop_tree_loop = False
469  for itree, output_tree_lgbm in enumerate(output_trees_lgbm):
470  output_tree_mva_utils = [
471  mva_utils.GetTreeResponse(list2stdvector(input_values), itree * nclasses + c)
472  for c in range(nclasses)
473  ]
474  if not np.allclose(output_tree_mva_utils, output_tree_lgbm[0]):
475  stop_tree_loop = True
476  logging.info("first tree/class with different answer (%d)", itree)
477  for isubtree, (ol, om) in enumerate(
478  zip(output_tree_lgbm[0], output_tree_mva_utils)
479  ):
480  if not np.allclose(ol, om):
481  logging.info("different in position %d", isubtree)
482  logging.info("lgbm: %f", ol)
483  logging.info("mvautils: %f", om)
484  logging.info("=" * 50)
485  logging.info(
486  "tree %d (itree) * %d (nclasses)" "+ %d (isubtree) = %d",
487  itree,
488  nclasses,
489  isubtree,
490  itree * nclasses + isubtree,
491  )
492  mva_utils.PrintTree(itree * nclasses + isubtree)
493 
494  node_infos = []
495 
496  # we now which tree is failing, check if this is
497  # due to input values very close to the threshold
498  # the problem is that lgbm is using double,
499  # while mva_utils is using float
500  _ff(
501  booster.dump_model()["tree_info"][itree * nclasses + isubtree][
502  "tree_structure"
503  ],
504  node_infos
505  )
506  for node_info in node_infos:
507  value = input_values[node_info[0]]
508  threshold = node_info[1]
509  if not np.isnan(value) and (value <= threshold) != (
510  np.float32(value) <= np.float32(threshold)
511  ):
512  logging.info(
513  "the problem could be due to double"
514  "(lgbm) -> float (mvautil) conversion"
515  "for variable %d: %f and threshold %f",
516  node_info[0],
517  value,
518  threshold,
519  )
520  stop_tree_loop = False
521  stop_event_loop = False
522 
523  if stop_tree_loop:
524  break
525  if stop_event_loop:
526  return False
527  return True
528 
529 
530 def check_file(fn):
531  f = ROOT.TFile.Open(fn)
532  keys = f.GetListOfKeys()
533  keys = list(keys)
534  if len(keys) != 1:
535  logging.info("file %s is empty", fn)
536  return False
537  tree = f.Get(keys[0].GetName())
538  if type(tree) is not ROOT.TTree:
539  logging.info("cannot find TTree in file %s", fn)
540  return False
541  if not tree.GetEntries():
542  logging.info("tree is empty")
543  return False
544  return True
545 
546 
547 if __name__ == "__main__":
548  import argparse
549 
550  parser = argparse.ArgumentParser(description=__doc__)
551  parser.add_argument("input", help="input text file from LGBM")
552  parser.add_argument("output", help="output ROOT filename", nargs="?")
553  parser.add_argument("--tree-name", default="lgbm")
554  parser.add_argument("--no-test", action="store_true", help="don't run test (not suggested)")
555  parser.add_argument("--ntests", type=int, help="number of test, default=1000")
556  parser.add_argument(
557  "--test-file", help="numpy pickle or ROOT file (use filename.root:treename)"
558  )
559 
560  args = parser.parse_args()
561 
562  if args.output is None:
563  import os
564 
565  args.output = os.path.splitext(os.path.split(args.input)[1])[0] + ".root"
566 
567  logging.info("converting input file %s to root file %s", args.input, args.output)
568  output_treename = convertLGBMToRootTree(args.input, args.output, args.tree_name)
569  if args.no_test:
570  print("model has not been tested. Do not use it production!")
571  else:
572  logging.info("testing model")
573  if not check_file(args.output):
574  print("problem when checking file")
575  result = test(args.input, args.output, args.tree_name, args.ntests, args.test_file)
576  if not result:
577  print(
578  "some problems during test." " Have you setup athena? Do not use this in production!"
579  )
580  else:
581  try:
582  print(
583  u"::: :) :) :) everything fine:" " LGBM output == MVAUtils output :) :) :) :::"
584  )
585  except UnicodeEncodeError:
586  print(":::==> everything fine:" "LGBM output == MVAUtils output <==:::")
587  booster = lgb.Booster(model_file=args.input)
588  objective = booster.dump_model()["objective"]
589  if "multiclass" in objective:
590  print(
591  """In c++ use your BDT as:
592 #include "MVAUtils/BDT.h"
593 
594 TFile* f = TFile::Open("%s");
595 TTree* tree = nullptr;
596 f->GetObject("%s", tree);
597 MVAUtils::BDT my_bdt(tree);
598 // ...
599 // std::vector<float> input_values(%d, 0.);
600 // fill the vector using the order as in the trainig: %s
601 // ...
602 std::vector<float> output = my_bdt.GetMultiResponse(input_values, %d);
603  """
604  % (
605  args.output,
606  output_treename,
607  booster.num_feature(),
608  ",".join(booster.feature_name()),
609  booster.num_model_per_iteration(),
610  )
611  )
612  elif "binary" in objective:
613  print(
614  """In c++ use your BDT as:
615 #include "MVAUtils/BDT.h"
616 
617 TFile* f = TFile::Open("%s");
618 TTree* tree = nullptr;
619 f->GetObject("%s", tree);
620 MVAUtils::BDT my_bdt(tree);
621 // ...
622 // std::vector<float> input_values(%d, 0.);
623 // fill the vector using the order as in the trainig: %s
624 // ...
625 float output = my_bdt.GetClassification(input_values);
626  """
627  % (
628  args.output,
629  output_treename,
630  booster.num_feature(),
631  ",".join(booster.feature_name()),
632  )
633  )
634  elif "regression" in objective:
635  print(
636  """In c++ use your BDT as:
637 #include "MVAUtils/BDT.h"
638 
639 TFile* f = TFile::Open("%s");
640 TTree* tree = nullptr;
641 f->GetObject("%s", tree);
642 MVAUtils::BDT my_bdt(tree);
643 // ...
644 // std::vector<float> input_values(%d, 0.);
645 // fill the vector using the order as in the trainig: %s
646 // ...
647 float output = my_bdt.Predict(input_values);
648  """
649  % (
650  args.output,
651  output_treename,
652  booster.num_feature(),
653  ",".join(booster.feature_name()),
654  )
655  )
util.convertLGBMToRootTree.LGBMTextNode.get_right
def get_right(self)
Definition: convertLGBMToRootTree.py:90
util.convertLGBMToRootTree.LGBMTextNode._get_left
def _get_left(self)
Definition: convertLGBMToRootTree.py:68
util.convertLGBMToRootTree.test
def test(model_file, tree_file, tree_name="lgbm", ntests=10000, test_file=None)
Definition: convertLGBMToRootTree.py:206
util.convertLGBMToRootTree.dump2ROOT
def dump2ROOT(model, output_filename, output_treename="lgbm")
Definition: convertLGBMToRootTree.py:140
util.convertLGBMToRootTree.list2stdvector
def list2stdvector(values, dtype="float")
Definition: convertLGBMToRootTree.py:40
util.convertLGBMToRootTree.LGBMTextNode._get_right
def _get_right(self)
Definition: convertLGBMToRootTree.py:76
util.convertLGBMToRootTree._ff
def _ff(tree, node_infos)
Definition: convertLGBMToRootTree.py:340
util.convertLGBMToRootTree.type
type
Definition: convertLGBMToRootTree.py:555
util.convertLGBMToRootTree.get_test_data
def get_test_data(feature_names, test_file=None, ntests=None)
Definition: convertLGBMToRootTree.py:256
util.convertLGBMToRootTree.LGBMTextNode.get_value
def get_value(self)
Definition: convertLGBMToRootTree.py:62
util.convertLGBMToRootTree.LGBMTextNode.get_default_left
def get_default_left(self)
Definition: convertLGBMToRootTree.py:96
util.convertLGBMToRootTree.check_file
def check_file(fn)
Definition: convertLGBMToRootTree.py:530
plotBeamSpotVxVal.range
range
Definition: plotBeamSpotVxVal.py:195
util.convertLGBMToRootTree.LGBMTextNode.__init__
def __init__(self, structure, invert_as_tmva=False)
Definition: convertLGBMToRootTree.py:52
histSizes.list
def list(name, path='/')
Definition: histSizes.py:38
util.convertLGBMToRootTree.lgbm_rawresponse_each_tree
def lgbm_rawresponse_each_tree(model, my_input)
Definition: convertLGBMToRootTree.py:27
util.convertLGBMToRootTree.test_detail_event
def test_detail_event(booster, mva_utils, input_values)
Definition: convertLGBMToRootTree.py:347
TCS::join
std::string join(const std::vector< std::string > &v, const char c=',')
Definition: Trigger/TrigT1/L1Topo/L1TopoCommon/Root/StringUtils.cxx:10
util.convertLGBMToRootTree.dump_tree
def dump_tree(tree_structure)
Definition: convertLGBMToRootTree.py:100
util.convertLGBMToRootTree.test_generic
def test_generic(booster, mvautils_predict, mva_utils, data_input)
Definition: convertLGBMToRootTree.py:300
util.convertLGBMToRootTree.LGBMTextNode.get_left
def get_left(self)
Definition: convertLGBMToRootTree.py:84
util.convertLGBMToRootTree.test_binary
def test_binary(booster, mva_utils, ntests=None, test_file=None)
Definition: convertLGBMToRootTree.py:418
util.convertLGBMToRootTree.LGBMTextNode
Definition: convertLGBMToRootTree.py:47
util.convertLGBMToRootTree.test_multiclass
def test_multiclass(booster, mva_utils, ntests=10000, test_file=None)
Definition: convertLGBMToRootTree.py:423
util.convertLGBMToRootTree.LGBMTextNode.get_split_feature
def get_split_feature(self)
Definition: convertLGBMToRootTree.py:56
util.convertLGBMToRootTree.convertLGBMToRootTree
def convertLGBMToRootTree(model, output_filename, tree_name="lgbm")
Definition: convertLGBMToRootTree.py:191
str
Definition: BTagTrackIpAccessor.cxx:11
dbg::print
void print(std::FILE *stream, std::format_string< Args... > fmt, Args &&... args)
Definition: SGImplSvc.cxx:70
util.convertLGBMToRootTree.test_regression
def test_regression(booster, mva_utils, ntests=None, test_file=None)
Definition: convertLGBMToRootTree.py:413
util.convertLGBMToRootTree.LGBMTextNode.invert_as_tmva
invert_as_tmva
Definition: convertLGBMToRootTree.py:54