5 __doc__ =
"Convert LightGBM model to TTree to be used with MVAUtils."
6 __author__ =
"Ruggero Turra"
12 """cannot load lightgbm. Try to install it with
14 or (usually on lxplus)
15 pip install numpy scipy scikit-learn
16 pip install --no-binary :all: lightgbm
24 logging.basicConfig(level=logging.DEBUG)
28 nclasses = model.num_model_per_iteration()
29 output_values = np.array(
30 [np.array([[0] * nclasses])]
32 model.predict(np.atleast_2d(my_input), raw_score=
True, num_iteration=itree)
33 for itree
in range(1, (model.num_trees() // nclasses + 1))
36 output_trees = np.diff(output_values, axis=0)
41 result = ROOT.std.vector(dtype)()
49 Adaptor from LGBM dictionary to tree
52 def __init__(self, structure, invert_as_tmva=False):
53 super(LGBMTextNode, self).
__init__(structure)
57 if "split_feature" in self:
58 return self[
"split_feature"]
63 if "threshold" in self:
64 return self[
"threshold"]
66 return self[
"leaf_value"]
69 if "left_child" not in self:
71 if self[
"decision_type"] ==
"<=":
77 if "right_child" not in self:
79 if self[
"decision_type"] ==
"<=":
97 return self.get(
"default_left",
True)
102 dump a single decision tree to arrays to be written into the TTree
113 split_features.append(node.get_split_feature())
114 split_values.append(node.get_value())
115 default_left.append(node.get_default_left())
117 if not node.get_default_left():
120 if "decision_type" in node
and node[
"decision_type"] !=
"<=":
122 "do not support categorical input BDT (decision_type = %s)" % node[
"decision_type"]
125 if "missing_type" in node:
126 if node[
"missing_type"]
not in (
"NaN",
"None"):
127 raise ValueError(
"do not support missing values different from NaN or None")
130 if node.get_left()
is not None:
131 preorder(node.get_left())
133 if node.get_right()
is not None:
134 preorder(node.get_right())
137 return split_features, split_values, default_left, simple[0]
140 def dump2ROOT(model, output_filename, output_treename="lgbm"):
141 model = model.dump_model()
142 fout = ROOT.TFile.Open(output_filename,
"recreate")
144 features_array = ROOT.std.vector(
"int")()
145 values_array = ROOT.std.vector(
"float")()
146 default_lefts_array = ROOT.std.vector(
"bool")()
149 node_type =
"node_type=lgbm_simple"
150 for tree
in model[
"tree_info"]:
151 tree_structure = tree[
"tree_structure"]
152 features, values, default_lefts, simple_tree =
dump_tree(tree_structure)
155 node_type =
"node_type=lgbm"
157 infos =
";".
join([
"%s=%s" % (k,
str(v))
for k, v
in model.items()
if type(v)
is not list])
158 title =
";".
join((
"creator=lgbm", node_type, infos))
159 root_tree = ROOT.TTree(output_treename, title)
160 root_tree.Branch(
"vars",
"vector<int>", ROOT.AddressOf(features_array))
161 root_tree.Branch(
"values",
"vector<float>", ROOT.AddressOf(values_array))
164 logging.info(
"tree support nan: using full implementation (LGBMNode)")
165 root_tree.Branch(
"default_left",
"vector<bool>", ROOT.AddressOf(default_lefts_array))
167 logging.info(
"tree do not support nan:" "using simple implementation (LGBMNodeSimple)")
169 for tree
in model[
"tree_info"]:
170 tree_structure = tree[
"tree_structure"]
171 features, values, default_lefts, simple_tree =
dump_tree(tree_structure)
173 features_array.clear()
175 default_lefts_array.clear()
178 values_array.push_back(value)
179 for feature
in features:
180 features_array.push_back(feature)
181 for default_left
in default_lefts:
182 default_lefts_array.push_back(default_left)
188 return output_treename
193 Model: - a string, in this case, it is the name of
194 the input file containing the lgbm model you
195 can get this model with lgbm with
196 `boosted.save_model('my_model.txt')
197 - directly a lgbm booster object
199 if type(model)
is str:
200 model = lgb.Booster(model_file=model)
201 return dump2ROOT(model, output_filename, tree_name)
203 return dump2ROOT(model, output_filename, tree_name)
206 def test(model_file, tree_file, tree_name="lgbm", ntests=10000, test_file=None):
207 booster = lgb.Booster(model_file=model_file)
208 f = ROOT.TFile.Open(tree_file)
209 tree = f.Get(tree_name)
211 _ = ROOT.MVAUtils.BDT
213 print(
"cannot import MVAUtils")
216 mva_utils = ROOT.MVAUtils.BDT(tree)
218 objective = booster.dump_model()[
"objective"]
222 objective = objective.replace(
"sigmoid:1",
"")
223 objective = objective.strip()
228 binary_aliases = (
"binary",
"cross_entropy",
"xentropy")
229 regression_aliases = (
233 "mean_squared_error",
236 "root_mean_squared_error",
239 + (
"regression_l1",
"l1",
"mean_absolute_error",
"mae")
242 multiclass_aliases = (
"multiclass",
"softmax")
243 if objective
in multiclass_aliases:
244 logging.info(
"assuming multiclass, testing")
246 elif objective
in binary_aliases:
247 logging.info(
"assuming binary classification, testing")
248 return test_binary(booster, mva_utils, ntests, test_file)
249 elif objective
in regression_aliases:
250 logging.info(
"assuming regression, testing")
253 print(
"cannot understand objective '%s'" % objective)
257 nvars = len(feature_names)
258 if test_file
is not None:
259 if ".root" in test_file:
260 if ":" not in test_file:
261 raise ValueError(
"when using ROOT file as test use the syntax filename:treename")
262 fn, tn = test_file.split(
":")
263 f = ROOT.TFile.Open(fn)
265 raise IOError(
"cannot find ROOT file %s" % fn)
268 raise IOError(
"cannot find TTree %s in %s" % (fn, tn))
269 branch_names = [br.GetName()
for br
in tree.GetListOfBranches()]
270 for feature
in feature_names:
271 if feature
not in branch_names:
272 raise IOError(
"required feature %s not in TTree")
273 rdf = ROOT.RDataFrame(tree, feature_names)
274 data_input = rdf.AsNumpy()
275 data_input = np.stack([data_input[k]
for k
in feature_names]).T
276 if ntests
is not None:
277 data_input = data_input[:ntests]
279 "using as input %s inputs from TTree %s from ROOT file %s", len(data_input), tn, fn
282 data_input = np.load(test_file)
283 if ntests
is not None:
284 data_input = data_input[:ntests]
285 logging.info(
"using as input %s inputs from pickle file %s", len(data_input), test_file)
289 logging.info(
"using as input %s random uniform inputs (-100,100)", ntests)
291 "using random uniform input as test: this is not safe" "provide an input test file"
293 data_input = np.random.uniform(-100, 100, size=(ntests, nvars))
296 data_input = data_input.astype(np.float32)
302 results_lgbm = booster.predict(data_input)
303 logging.info(
"lgbm (vectorized) timing = %d/s", len(data_input) / (time.time() - start))
305 input_values_vector = ROOT.std.vector(
"float")()
306 results_MVAUtils = []
308 for input_values
in data_input:
309 input_values_vector.clear()
310 for v
in input_values:
311 input_values_vector.push_back(v)
312 output_MVAUtils = mvautils_predict(input_values_vector)
313 results_MVAUtils.append(output_MVAUtils)
315 "mvautils (not vectorized+overhead) timing = %d/s", len(data_input) / (time.time() - start)
319 nevents_different = 0
320 for ievent, (input_values, output_lgbm, output_MVAUtils)
in enumerate(
321 zip(data_input, results_lgbm, results_MVAUtils), 1
324 if not np.allclose(output_lgbm, output_MVAUtils, rtol=1e-4):
325 nevents_different += 1
327 "--> output are different on input %d/%d mvautils: %s lgbm: %s",
335 logging.info(
"number of different events %d/%d", nevents_different, nevents_tested)
340 def _ff(tree, node_infos):
341 if "left_child" in tree:
342 node_infos.append((tree[
"split_feature"], tree[
"threshold"]))
343 _ff(tree[
"left_child"])
344 _ff(tree[
"right_child"])
348 logging.info(
"input values")
349 for ivar, input_value
in enumerate(input_values):
350 logging.info(
"var %d: %.15f", ivar, input_value)
351 logging.info(
"=" * 50)
353 ntrees_mva_utils = mva_utils.GetNTrees()
354 if ntrees_mva_utils != booster.num_trees():
355 logging.info(
"Number of trees are different mvautils: %s lgbm: %s", ntrees_mva_utils, booster.num_trees())
359 is_problem_found =
False
360 for itree
in range(ntrees_mva_utils):
361 tree_output_mvautils = mva_utils.GetTreeResponse(
list2stdvector(input_values), itree)
362 tree_output_lgbm = tree_outputs_lgbm[itree][0]
363 if not np.allclose(tree_output_mvautils, tree_output_lgbm):
365 is_problem_found =
True
366 logging.info(
"tree %d/%d are different", itree, ntrees_mva_utils)
367 logging.info(
"lgbm: %f", tree_output_lgbm)
368 logging.info(
"MVAUtils: %f", tree_output_mvautils)
369 logging.info(
"Tree details from MVAUtils")
370 mva_utils.PrintTree(itree)
375 booster.dump_model()[
"tree_info"][itree][
386 for node_info
in node_infos:
387 value = input_values[node_info[0]]
388 threshold = node_info[1]
389 if not np.isnan(value)
and (value <= threshold) != (
390 np.float32(value) <= np.float32(threshold)
393 "the problem could be due to double"
394 "(lgbm) -> float (mvautil) conversion"
395 " for variable %d: %.10f and threshold %.10f",
414 data_input =
get_test_data(booster.feature_name(), test_file, ntests)
415 return test_generic(booster, mva_utils.GetResponse, mva_utils, data_input)
419 data_input =
get_test_data(booster.feature_name(), test_file, ntests)
420 return test_generic(booster, mva_utils.GetClassification, mva_utils, data_input)
426 nvars = booster.num_feature()
427 nclasses = booster.num_model_per_iteration()
428 logging.info(
"using %d input features with %d classes", nvars, nclasses)
430 data_input =
get_test_data(booster.feature_name(), test_file, ntests)
433 results_lgbm = booster.predict(data_input)
435 "lgbm (vectorized) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input)
438 input_values_vector = ROOT.std.vector(
"float")()
439 results_MVAUtils = []
441 for input_values
in data_input:
442 input_values_vector.clear()
443 for v
in input_values:
444 input_values_vector.push_back(v)
445 output_MVAUtils = np.asarray(mva_utils.GetMultiResponse(input_values_vector, nclasses))
446 results_MVAUtils.append(output_MVAUtils)
448 "mvautils (not vectorized+overhead) timing = %s ms/input",
449 (time.time() - start) * 1000 / len(data_input),
452 stop_event_loop =
False
453 for ievent, (input_values, output_lgbm, output_MVAUtils)
in enumerate(
454 zip(data_input, results_lgbm, results_MVAUtils), 1
456 if not np.allclose(output_lgbm, output_MVAUtils):
457 stop_event_loop =
True
458 logging.info(
"--> output are different on input %d/%d:\n", ievent, len(data_input))
459 for ivar, input_value
in enumerate(input_values):
460 logging.info(
"var %d: %.15f", ivar, input_value)
461 logging.info(
"=" * 50)
462 logging.info(
" mvautils lgbm")
463 for ioutput, (o1, o2)
in enumerate(zip(output_MVAUtils, output_lgbm)):
464 diff_flag =
"" if np.allclose(o1, o2)
else "<---"
465 logging.info(
"output %3d %.5e %.5e %s", ioutput, o1, o2, diff_flag)
468 stop_tree_loop =
False
469 for itree, output_tree_lgbm
in enumerate(output_trees_lgbm):
470 output_tree_mva_utils = [
471 mva_utils.GetTreeResponse(
list2stdvector(input_values), itree * nclasses + c)
472 for c
in range(nclasses)
474 if not np.allclose(output_tree_mva_utils, output_tree_lgbm[0]):
475 stop_tree_loop =
True
476 logging.info(
"first tree/class with different answer (%d)", itree)
477 for isubtree, (ol, om)
in enumerate(
478 zip(output_tree_lgbm[0], output_tree_mva_utils)
480 if not np.allclose(ol, om):
481 logging.info(
"different in position %d", isubtree)
482 logging.info(
"lgbm: %f", ol)
483 logging.info(
"mvautils: %f", om)
484 logging.info(
"=" * 50)
486 "tree %d (itree) * %d (nclasses)" "+ %d (isubtree) = %d",
490 itree * nclasses + isubtree,
492 mva_utils.PrintTree(itree * nclasses + isubtree)
501 booster.dump_model()[
"tree_info"][itree * nclasses + isubtree][
506 for node_info
in node_infos:
507 value = input_values[node_info[0]]
508 threshold = node_info[1]
509 if not np.isnan(value)
and (value <= threshold) != (
510 np.float32(value) <= np.float32(threshold)
513 "the problem could be due to double"
514 "(lgbm) -> float (mvautil) conversion"
515 "for variable %d: %f and threshold %f",
520 stop_tree_loop =
False
521 stop_event_loop =
False
531 f = ROOT.TFile.Open(fn)
532 keys = f.GetListOfKeys()
535 logging.info(
"file %s is empty", fn)
537 tree = f.Get(keys[0].GetName())
538 if type(tree)
is not ROOT.TTree:
539 logging.info(
"cannot find TTree in file %s", fn)
541 if not tree.GetEntries():
542 logging.info(
"tree is empty")
547 if __name__ ==
"__main__":
550 parser = argparse.ArgumentParser(description=__doc__)
551 parser.add_argument(
"input", help=
"input text file from LGBM")
552 parser.add_argument(
"output", help=
"output ROOT filename", nargs=
"?")
553 parser.add_argument(
"--tree-name", default=
"lgbm")
554 parser.add_argument(
"--no-test", action=
"store_true", help=
"don't run test (not suggested)")
555 parser.add_argument(
"--ntests", type=int, help=
"number of test, default=1000")
557 "--test-file", help=
"numpy pickle or ROOT file (use filename.root:treename)"
560 args = parser.parse_args()
562 if args.output
is None:
565 args.output = os.path.splitext(os.path.split(args.input)[1])[0] +
".root"
567 logging.info(
"converting input file %s to root file %s", args.input, args.output)
570 print(
"model has not been tested. Do not use it production!")
572 logging.info(
"testing model")
574 print(
"problem when checking file")
575 result =
test(args.input, args.output, args.tree_name, args.ntests, args.test_file)
578 "some problems during test." " Have you setup athena? Do not use this in production!"
583 u"::: :) :) :) everything fine:" " LGBM output == MVAUtils output :) :) :) :::"
585 except UnicodeEncodeError:
586 print(
":::==> everything fine:" "LGBM output == MVAUtils output <==:::")
587 booster = lgb.Booster(model_file=args.input)
588 objective = booster.dump_model()[
"objective"]
589 if "multiclass" in objective:
591 """In c++ use your BDT as:
592 #include "MVAUtils/BDT.h"
594 TFile* f = TFile::Open("%s");
595 TTree* tree = nullptr;
596 f->GetObject("%s", tree);
597 MVAUtils::BDT my_bdt(tree);
599 // std::vector<float> input_values(%d, 0.);
600 // fill the vector using the order as in the trainig: %s
602 std::vector<float> output = my_bdt.GetMultiResponse(input_values, %d);
607 booster.num_feature(),
608 ",".
join(booster.feature_name()),
609 booster.num_model_per_iteration(),
612 elif "binary" in objective:
614 """In c++ use your BDT as:
615 #include "MVAUtils/BDT.h"
617 TFile* f = TFile::Open("%s");
618 TTree* tree = nullptr;
619 f->GetObject("%s", tree);
620 MVAUtils::BDT my_bdt(tree);
622 // std::vector<float> input_values(%d, 0.);
623 // fill the vector using the order as in the trainig: %s
625 float output = my_bdt.GetClassification(input_values);
630 booster.num_feature(),
631 ",".
join(booster.feature_name()),
634 elif "regression" in objective:
636 """In c++ use your BDT as:
637 #include "MVAUtils/BDT.h"
639 TFile* f = TFile::Open("%s");
640 TTree* tree = nullptr;
641 f->GetObject("%s", tree);
642 MVAUtils::BDT my_bdt(tree);
644 // std::vector<float> input_values(%d, 0.);
645 // fill the vector using the order as in the trainig: %s
647 float output = my_bdt.Predict(input_values);
652 booster.num_feature(),
653 ",".
join(booster.feature_name()),