424def test_multiclass(booster, mva_utils, ntests=10000, test_file=None):
425 import numpy as np
426
427 nvars = booster.num_feature()
428 nclasses = booster.num_model_per_iteration()
429 logging.info("using %d input features with %d classes", nvars, nclasses)
430
431 data_input = get_test_data(booster.feature_name(), test_file, ntests)
432
433 start = time.time()
434 results_lgbm = booster.predict(data_input)
435 logging.info(
436 "lgbm (vectorized) timing = %s ms/input", (time.time() - start) * 1000 / len(data_input)
437 )
438
439 input_values_vector = ROOT.std.vector("float")()
440 results_MVAUtils = []
441 start = time.time()
442 for input_values in data_input:
443 input_values_vector.clear()
444 for v in input_values:
445 input_values_vector.push_back(v)
446 output_MVAUtils = np.asarray(mva_utils.GetMultiResponse(input_values_vector, nclasses))
447 results_MVAUtils.append(output_MVAUtils)
448 logging.info(
449 "mvautils (not vectorized+overhead) timing = %s ms/input",
450 (time.time() - start) * 1000 / len(data_input),
451 )
452
453 stop_event_loop = False
454 for ievent, (input_values, output_lgbm, output_MVAUtils) in enumerate(
455 zip(data_input, results_lgbm, results_MVAUtils), 1
456 ):
457 if not np.allclose(output_lgbm, output_MVAUtils):
458 stop_event_loop = True
459 logging.info("--> output are different on input %d/%d:\n", ievent, len(data_input))
460 for ivar, input_value in enumerate(input_values):
461 logging.info("var %d: %.15f", ivar, input_value)
462 logging.info("=" * 50)
463 logging.info(" mvautils lgbm")
464 for ioutput, (o1, o2) in enumerate(zip(output_MVAUtils, output_lgbm)):
465 diff_flag = "" if np.allclose(o1, o2) else "<---"
466 logging.info("output %3d %.5e %.5e %s", ioutput, o1, o2, diff_flag)
467 output_trees_lgbm = lgbm_rawresponse_each_tree(booster, [input_values])
468
469 stop_tree_loop = False
470 for itree, output_tree_lgbm in enumerate(output_trees_lgbm):
471 output_tree_mva_utils = [
472 mva_utils.GetTreeResponse(list2stdvector(input_values), itree * nclasses + c)
473 for c in range(nclasses)
474 ]
475 if not np.allclose(output_tree_mva_utils, output_tree_lgbm[0]):
476 stop_tree_loop = True
477 logging.info("first tree/class with different answer (%d)", itree)
478 for isubtree, (ol, om) in enumerate(
479 zip(output_tree_lgbm[0], output_tree_mva_utils)
480 ):
481 if not np.allclose(ol, om):
482 logging.info("different in position %d", isubtree)
483 logging.info("lgbm: %f", ol)
484 logging.info("mvautils: %f", om)
485 logging.info("=" * 50)
486 logging.info(
487 "tree %d (itree) * %d (nclasses)" "+ %d (isubtree) = %d",
488 itree,
489 nclasses,
490 isubtree,
491 itree * nclasses + isubtree,
492 )
493 mva_utils.PrintTree(itree * nclasses + isubtree)
494
495 node_infos = []
496
497
498
499
500
501 _ff(
502 booster.dump_model()["tree_info"][itree * nclasses + isubtree][
503 "tree_structure"
504 ],
505 node_infos
506 )
507 for node_info in node_infos:
508 value = input_values[node_info[0]]
509 threshold = node_info[1]
510 if not np.isnan(value) and (value <= threshold) != (
511 np.float32(value) <= np.float32(threshold)
512 ):
513 logging.info(
514 "the problem could be due to double"
515 "(lgbm) -> float (mvautil) conversion"
516 "for variable %d: %f and threshold %f",
517 node_info[0],
518 value,
519 threshold,
520 )
521 stop_tree_loop = False
522 stop_event_loop = False
523
524 if stop_tree_loop:
525 break
526 if stop_event_loop:
527 return False
528 return True
529
530