#include "MVAUtils/BDT.h"
#include "MVAUtils/NodeImpl.h"
#include "TFile.h"
#include "TTree.h"
#include <vector>
#include <random>
#include <chrono>
#include <iostream>
Go to the source code of this file.
◆ get_tree()
TTree* get_tree |
( |
const std::string & |
filename | ) |
|
Definition at line 17 of file check_timing_mvautils.cxx.
20 auto *
keys =
f->GetListOfKeys();
21 TTree*
tree =
nullptr;
22 for (
int ikey=0; ikey !=
keys->GetSize(); ++ikey)
24 TObject*
obj =
f->Get(
keys->At(ikey)->GetName());
25 if (std::string(
obj->ClassName()) ==
"TTree") {
31 std::cout <<
"cannot find any ttree in file " <<
filename << std::endl;
◆ main()
int main |
( |
int |
argc, |
|
|
char ** |
argv |
|
) |
| |
Definition at line 36 of file check_timing_mvautils.cxx.
40 std::cout <<
"need to provide a ROOT file with a TTree" << std::endl;
45 if (!
tree) {
return 1; }
54 if (
sizeof(
MVAUtils::NodeTMVA) != 8) { std::cout <<
"WARNING: NodeTMVA should be 8 bytes" << std::endl; }
55 if (
sizeof(
MVAUtils::NodeLGBMSimple) != 8) { std::cout <<
"WARNING: NodeLGBMSimple should be 8 bytes" << std::endl; }
56 if (
sizeof(
MVAUtils::NodeLGBM) != 8) { std::cout <<
"WARNING: NodeLGBM should be 8 bytes" << std::endl; }
57 if (
sizeof(
MVAUtils::NodeXGBoost) != 8) { std::cout <<
"WARNING: NodeXGBoost should be 8 bytes" << std::endl; }
59 std::default_random_engine
gen;
60 std::uniform_real_distribution<float> rnd_uniform(-100., 100.);
62 const unsigned int NTEST = 10000;
66 int nvars = bdt.GetNVars();
67 std::vector<float> rnd_precomputed;
70 for (
unsigned int itest = 0; itest != NTEST * nvars; ++itest)
72 rnd_precomputed.push_back(rnd_uniform(
gen));
75 auto rnd_it1 = rnd_precomputed.begin();
76 auto rnd_it2 = rnd_it1;
77 std::advance(rnd_it2, nvars);
78 for (
unsigned int itest = 0; itest != NTEST; ++itest)
80 std::vector<float> input_values(rnd_it1, rnd_it2);
81 bdt.GetResponse(input_values);
82 std::advance(rnd_it1, nvars);
83 std::advance(rnd_it2, nvars);
86 std::cout <<
"timing: "
87 << std::chrono::duration_cast<std::chrono::nanoseconds>(
t2-
t1).count() /
double(NTEST) /
double(bdt.GetNTrees())
88 <<
" ns / events / trees\n";