5 """ Convert LightGBM model to TTree to be used with MVAUtils. """
7 __author__ =
"Ruggero Turra"
10 import lightgbm
as lgb
13 """cannot load lightgbm. Try to install it with
15 or (usually on lxplus)
16 pip install numpy scipy scikit-learn
17 pip install --no-binary :all: lightgbm
25 logging.basicConfig(level=logging.DEBUG)
29 nclasses = model.num_model_per_iteration()
30 output_values = np.array(
31 [np.array([[0] * nclasses])]
33 model.predict(np.atleast_2d(my_input), raw_score=
True, num_iteration=itree)
34 for itree
in range(1, (model.num_trees() // nclasses + 1))
37 output_trees = np.diff(output_values, axis=0)
42 result = ROOT.std.vector(dtype)()
50 Adaptor from LGBM dictionary to tree
53 def __init__(self, structure, invert_as_tmva=False):
54 super(LGBMTextNode, self).
__init__(structure)
58 if "split_feature" in self:
59 return self[
"split_feature"]
64 if "threshold" in self:
65 return self[
"threshold"]
67 return self[
"leaf_value"]
70 if "left_child" not in self:
72 if self[
"decision_type"] ==
"<=":
78 if "right_child" not in self:
80 if self[
"decision_type"] ==
"<=":
98 return self.get(
"default_left",
True)
103 dump a single decision tree to arrays to be written into the TTree
114 split_features.append(node.get_split_feature())
115 split_values.append(node.get_value())
116 default_left.append(node.get_default_left())
118 if not node.get_default_left():
121 if "decision_type" in node
and node[
"decision_type"] !=
"<=":
123 "do not support categorical input BDT (decision_type = %s)" % node[
"decision_type"]
126 if "missing_type" in node:
127 if node[
"missing_type"]
not in (
"NaN",
"None"):
128 raise ValueError(
"do not support missing values different from NaN or None")
131 if node.get_left()
is not None:
132 preorder(node.get_left())
134 if node.get_right()
is not None:
135 preorder(node.get_right())
138 return split_features, split_values, default_left, simple[0]
141 def dump2ROOT(model, output_filename, output_treename="lgbm"):
142 model = model.dump_model()
143 fout = ROOT.TFile.Open(output_filename,
"recreate")
145 features_array = ROOT.std.vector(
"int")()
146 values_array = ROOT.std.vector(
"float")()
147 default_lefts_array = ROOT.std.vector(
"bool")()
150 node_type =
"node_type=lgbm_simple"
151 for tree
in model[
"tree_info"]:
152 tree_structure = tree[
"tree_structure"]
153 features, values, default_lefts, simple_tree =
dump_tree(tree_structure)
156 node_type =
"node_type=lgbm"
158 infos =
";".
join([
"%s=%s" % (k,
str(v))
for k, v
in model.items()
if type(v)
is not list])
159 title =
";".
join((
"creator=lgbm", node_type, infos))
160 root_tree = ROOT.TTree(output_treename, title)
161 root_tree.Branch(
"vars",
"vector<int>", ROOT.AddressOf(features_array))
162 root_tree.Branch(
"values",
"vector<float>", ROOT.AddressOf(values_array))
165 logging.info(
"tree support nan: using full implementation (LGBMNode)")
166 root_tree.Branch(
"default_left",
"vector<bool>", ROOT.AddressOf(default_lefts_array))
168 logging.info(
"tree do not support nan:" "using simple implementation (LGBMNodeSimple)")
170 for tree
in model[
"tree_info"]:
171 tree_structure = tree[
"tree_structure"]
172 features, values, default_lefts, simple_tree =
dump_tree(tree_structure)
174 features_array.clear()
176 default_lefts_array.clear()
179 values_array.push_back(value)
180 for feature
in features:
181 features_array.push_back(feature)
182 for default_left
in default_lefts:
183 default_lefts_array.push_back(default_left)
189 return output_treename
194 Model: - a string, in this case, it is the name of
195 the input file containing the lgbm model you
196 can get this model with lgbm with
197 `boosted.save_model('my_model.txt')
198 - directly a lgbm booster object
200 if type(model)
is str:
201 model = lgb.Booster(model_file=model)
202 return dump2ROOT(model, output_filename, tree_name)
204 return dump2ROOT(model, output_filename, tree_name)
207 def test(model_file, tree_file, tree_name="lgbm", ntests=10000, test_file=None):
208 booster = lgb.Booster(model_file=model_file)
209 f = ROOT.TFile.Open(tree_file)
210 tree = f.Get(tree_name)
212 _ = ROOT.MVAUtils.BDT
214 print(
"cannot import MVAUtils")
217 mva_utils = ROOT.MVAUtils.BDT(tree)
219 objective = booster.dump_model()[
"objective"]
223 objective = objective.replace(
"sigmoid:1",
"")
224 objective = objective.strip()
229 binary_aliases = (
"binary",
"cross_entropy",
"xentropy")
230 regression_aliases = (
234 "mean_squared_error",
237 "root_mean_squared_error",
240 + (
"regression_l1",
"l1",
"mean_absolute_error",
"mae")
243 multiclass_aliases = (
"multiclass",
"softmax")
244 if objective
in multiclass_aliases:
245 logging.info(
"assuming multiclass, testing")
247 elif objective
in binary_aliases:
248 logging.info(
"assuming binary classification, testing")
249 return test_binary(booster, mva_utils, ntests, test_file)
250 elif objective
in regression_aliases:
251 logging.info(
"assuming regression, testing")
254 print(
"cannot understand objective '%s'" % objective)
258 nvars = len(feature_names)
259 if test_file
is not None:
260 if ".root" in test_file:
261 if ":" not in test_file:
262 raise ValueError(
"when using ROOT file as test use the syntax filename:treename")
263 fn, tn = test_file.split(
":")
264 f = ROOT.TFile.Open(fn)
266 raise IOError(
"cannot find ROOT file %s" % fn)
269 raise IOError(
"cannot find TTree %s in %s" % (fn, tn))
270 branch_names = [br.GetName()
for br
in tree.GetListOfBranches()]
271 for feature
in feature_names:
272 if feature
not in branch_names:
273 raise IOError(
"required feature %s not in TTree")
274 rdf = ROOT.RDataFrame(tree, feature_names)
275 data_input = rdf.AsNumpy()
276 data_input = np.stack([data_input[k]
for k
in feature_names]).T
277 if ntests
is not None:
278 data_input = data_input[:ntests]
280 "using as input %s inputs from TTree %s from ROOT file %s", len(data_input), tn, fn
283 data_input = np.load(test_file)
284 if ntests
is not None:
285 data_input = data_input[:ntests]
286 logging.info(
"using as input %s inputs from pickle file %s", len(data_input), test_file)
290 logging.info(
"using as input %s random uniform inputs (-100,100)", ntests)
292 "using random uniform input as test: this is not safe" "provide an input test file"
294 data_input = np.random.uniform(-100, 100, size=(ntests, nvars))
297 data_input = data_input.astype(np.float32)
303 results_lgbm = booster.predict(data_input)
304 logging.info(
"lgbm (vectorized) timing = %d/s", len(data_input) / (time.time() - start))
306 input_values_vector = ROOT.std.vector(
"float")()
307 results_MVAUtils = []
309 for input_values
in data_input:
310 input_values_vector.clear()
311 for v
in input_values:
312 input_values_vector.push_back(v)
313 output_MVAUtils = mvautils_predict(input_values_vector)
314 results_MVAUtils.append(output_MVAUtils)
316 "mvautils (not vectorized+overhead) timing = %d/s", len(data_input) / (time.time() - start)
320 nevents_different = 0
321 for ievent, (input_values, output_lgbm, output_MVAUtils)
in enumerate(
322 zip(data_input, results_lgbm, results_MVAUtils), 1
325 if not np.allclose(output_lgbm, output_MVAUtils, rtol=1e-4):
326 nevents_different += 1
328 "--> output are different on input %d/%d mvautils: %s lgbm: %s",
336 logging.info(
"number of different events %d/%d", nevents_different, nevents_tested)
341 def _ff(tree, node_infos):
342 if "left_child" in tree:
343 node_infos.append((tree[
"split_feature"], tree[
"threshold"]))
344 _ff(tree[
"left_child"])
345 _ff(tree[
"right_child"])
349 logging.info(
"input values")
350 for ivar, input_value
in enumerate(input_values):
351 logging.info(
"var %d: %.15f", ivar, input_value)
352 logging.info(
"=" * 50)
354 ntrees_mva_utils = mva_utils.GetNTrees()
355 if ntrees_mva_utils != booster.num_trees():
356 logging.info(
"Number of trees are different mvautils: %s lgbm: %s", ntrees_mva_utils, booster.num_trees())
360 is_problem_found =
False
361 for itree
in range(ntrees_mva_utils):
362 tree_output_mvautils = mva_utils.GetTreeResponse(
list2stdvector(input_values), itree)
363 tree_output_lgbm = tree_outputs_lgbm[itree][0]
364 if not np.allclose(tree_output_mvautils, tree_output_lgbm):
366 is_problem_found =
True
367 logging.info(
"tree %d/%d are different", itree, ntrees_mva_utils)
368 logging.info(
"lgbm: %f", tree_output_lgbm)
369 logging.info(
"MVAUtils: %f", tree_output_mvautils)
370 logging.info(
"Tree details from MVAUtils")
371 mva_utils.PrintTree(itree)
376 booster.dump_model()[
"tree_info"][itree][
387 for node_info
in node_infos:
388 value = input_values[node_info[0]]
389 threshold = node_info[1]
390 if not np.isnan(value)
and (value <= threshold) != (
391 np.float32(value) <= np.float32(threshold)
394 "the problem could be due to double"
395 "(lgbm) -> float (mvautil) conversion"
396 " for variable %d: %.10f and threshold %.10f",
415 data_input =
get_test_data(booster.feature_name(), test_file, ntests)
416 return test_generic(booster, mva_utils.GetResponse, mva_utils, data_input)
420 data_input =
get_test_data(booster.feature_name(), test_file, ntests)
421 return test_generic(booster, mva_utils.GetClassification, mva_utils, data_input)
427 nvars = booster.num_feature()
428 nclasses = booster.num_model_per_iteration()
429 logging.info(
"using %d input features with %d classes", nvars, nclasses)
431 data_input =
get_test_data(booster.feature_name(), test_file, ntests)
434 results_lgbm = booster.predict(data_input)
436 "lgbm (vectorized) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input)
439 input_values_vector = ROOT.std.vector(
"float")()
440 results_MVAUtils = []
442 for input_values
in data_input:
443 input_values_vector.clear()
444 for v
in input_values:
445 input_values_vector.push_back(v)
446 output_MVAUtils = np.asarray(mva_utils.GetMultiResponse(input_values_vector, nclasses))
447 results_MVAUtils.append(output_MVAUtils)
449 "mvautils (not vectorized+overhead) timing = %s ms/input",
450 (time.time() - start) * 1000 / len(data_input),
453 stop_event_loop =
False
454 for ievent, (input_values, output_lgbm, output_MVAUtils)
in enumerate(
455 zip(data_input, results_lgbm, results_MVAUtils), 1
457 if not np.allclose(output_lgbm, output_MVAUtils):
458 stop_event_loop =
True
459 logging.info(
"--> output are different on input %d/%d:\n", ievent, len(data_input))
460 for ivar, input_value
in enumerate(input_values):
461 logging.info(
"var %d: %.15f", ivar, input_value)
462 logging.info(
"=" * 50)
463 logging.info(
" mvautils lgbm")
464 for ioutput, (o1, o2)
in enumerate(zip(output_MVAUtils, output_lgbm)):
465 diff_flag =
"" if np.allclose(o1, o2)
else "<---"
466 logging.info(
"output %3d %.5e %.5e %s", ioutput, o1, o2, diff_flag)
469 stop_tree_loop =
False
470 for itree, output_tree_lgbm
in enumerate(output_trees_lgbm):
471 output_tree_mva_utils = [
472 mva_utils.GetTreeResponse(
list2stdvector(input_values), itree * nclasses + c)
473 for c
in range(nclasses)
475 if not np.allclose(output_tree_mva_utils, output_tree_lgbm[0]):
476 stop_tree_loop =
True
477 logging.info(
"first tree/class with different answer (%d)", itree)
478 for isubtree, (ol, om)
in enumerate(
479 zip(output_tree_lgbm[0], output_tree_mva_utils)
481 if not np.allclose(ol, om):
482 logging.info(
"different in position %d", isubtree)
483 logging.info(
"lgbm: %f", ol)
484 logging.info(
"mvautils: %f", om)
485 logging.info(
"=" * 50)
487 "tree %d (itree) * %d (nclasses)" "+ %d (isubtree) = %d",
491 itree * nclasses + isubtree,
493 mva_utils.PrintTree(itree * nclasses + isubtree)
502 booster.dump_model()[
"tree_info"][itree * nclasses + isubtree][
507 for node_info
in node_infos:
508 value = input_values[node_info[0]]
509 threshold = node_info[1]
510 if not np.isnan(value)
and (value <= threshold) != (
511 np.float32(value) <= np.float32(threshold)
514 "the problem could be due to double"
515 "(lgbm) -> float (mvautil) conversion"
516 "for variable %d: %f and threshold %f",
521 stop_tree_loop =
False
522 stop_event_loop =
False
532 f = ROOT.TFile.Open(fn)
533 keys = f.GetListOfKeys()
536 logging.info(
"file %s is empty", fn)
538 tree = f.Get(keys[0].GetName())
539 if type(tree)
is not ROOT.TTree:
540 logging.info(
"cannot find TTree in file %s", fn)
542 if not tree.GetEntries():
543 logging.info(
"tree is empty")
548 if __name__ ==
"__main__":
551 parser = argparse.ArgumentParser(description=__doc__)
552 parser.add_argument(
"input", help=
"input text file from LGBM")
553 parser.add_argument(
"output", help=
"output ROOT filename", nargs=
"?")
554 parser.add_argument(
"--tree-name", default=
"lgbm")
555 parser.add_argument(
"--no-test", action=
"store_true", help=
"don't run test (not suggested)")
556 parser.add_argument(
"--ntests", type=int, help=
"number of test, default=1000")
558 "--test-file", help=
"numpy pickle or ROOT file (use filename.root:treename)"
561 args = parser.parse_args()
563 if args.output
is None:
566 args.output = os.path.splitext(os.path.split(args.input)[1])[0] +
".root"
568 logging.info(
"converting input file %s to root file %s", args.input, args.output)
571 print(
"model has not been tested. Do not use it production!")
573 logging.info(
"testing model")
575 print(
"problem when checking file")
576 result =
test(args.input, args.output, args.tree_name, args.ntests, args.test_file)
579 "some problems during test." " Have you setup athena? Do not use this in production!"
584 u"::: :) :) :) everything fine:" " LGBM output == MVAUtils output :) :) :) :::"
586 except UnicodeEncodeError:
587 print(
":::==> everything fine:" "LGBM output == MVAUtils output <==:::")
588 booster = lgb.Booster(model_file=args.input)
589 objective = booster.dump_model()[
"objective"]
590 if "multiclass" in objective:
592 """In c++ use your BDT as:
593 #include "MVAUtils/BDT.h"
595 TFile* f = TFile::Open("%s");
596 TTree* tree = nullptr;
597 f->GetObject("%s", tree);
598 MVAUtils::BDT my_bdt(tree);
600 // std::vector<float> input_values(%d, 0.);
601 // fill the vector using the order as in the trainig: %s
603 std::vector<float> output = my_bdt.GetMultiResponse(input_values, %d);
608 booster.num_feature(),
609 ",".
join(booster.feature_name()),
610 booster.num_model_per_iteration(),
613 elif "binary" in objective:
615 """In c++ use your BDT as:
616 #include "MVAUtils/BDT.h"
618 TFile* f = TFile::Open("%s");
619 TTree* tree = nullptr;
620 f->GetObject("%s", tree);
621 MVAUtils::BDT my_bdt(tree);
623 // std::vector<float> input_values(%d, 0.);
624 // fill the vector using the order as in the trainig: %s
626 float output = my_bdt.GetClassification(input_values);
631 booster.num_feature(),
632 ",".
join(booster.feature_name()),
635 elif "regression" in objective:
637 """In c++ use your BDT as:
638 #include "MVAUtils/BDT.h"
640 TFile* f = TFile::Open("%s");
641 TTree* tree = nullptr;
642 f->GetObject("%s", tree);
643 MVAUtils::BDT my_bdt(tree);
645 // std::vector<float> input_values(%d, 0.);
646 // fill the vector using the order as in the trainig: %s
648 float output = my_bdt.Predict(input_values);
653 booster.num_feature(),
654 ",".
join(booster.feature_name()),