ATLAS Offline Software
Loading...
Searching...
No Matches
check_timing_mvautils.cxx File Reference
#include "MVAUtils/BDT.h"
#include "MVAUtils/NodeImpl.h"
#include "TFile.h"
#include "TTree.h"
#include <vector>
#include <random>
#include <chrono>
#include <iostream>
Include dependency graph for check_timing_mvautils.cxx:

Go to the source code of this file.

Functions

TTree * get_tree (const std::string &filename)
int main (int argc, char **argv)

Function Documentation

◆ get_tree()

TTree * get_tree ( const std::string & filename)

Definition at line 17 of file check_timing_mvautils.cxx.

18{
19 TFile* f = TFile::Open(filename.c_str());
20 auto *keys = f->GetListOfKeys();
21 TTree* tree = nullptr;
22 for (int ikey=0; ikey != keys->GetSize(); ++ikey)
23 {
24 TObject* obj = f->Get(keys->At(ikey)->GetName());
25 if (std::string(obj->ClassName()) == "TTree") {
26 f->GetObject(obj->GetName(), tree);
27 }
28 }
29
30 if (!tree) {
31 std::cout << "cannot find any ttree in file " << filename << std::endl;
32 }
33 return tree;
34}
TChain * tree

◆ main()

int main ( int argc,
char ** argv )

Definition at line 36 of file check_timing_mvautils.cxx.

37{
38 if (argc != 2)
39 {
40 std::cout << "need to provide a ROOT file with a TTree" << std::endl;
41 return 1;
42 }
43
44 TTree* tree = get_tree(argv[1]);
45 if (!tree) { return 1; }
46
48
49 std::cout << "sizeof class NodeTMVA is " << sizeof(MVAUtils::NodeTMVA) << std::endl;
50 std::cout << "sizeof class NodeLGBMSimple is " << sizeof(MVAUtils::NodeLGBMSimple) << std::endl;
51 std::cout << "sizeof class NodeLGBM is " << sizeof(MVAUtils::NodeLGBM) << std::endl;
52 std::cout << "sizeof class NodeXGBoost is " << sizeof(MVAUtils::NodeXGBoost) << std::endl;
53
54 if (sizeof(MVAUtils::NodeTMVA) != 8) { std::cout << "WARNING: NodeTMVA should be 8 bytes" << std::endl; }
55 if (sizeof(MVAUtils::NodeLGBMSimple) != 8) { std::cout << "WARNING: NodeLGBMSimple should be 8 bytes" << std::endl; }
56 if (sizeof(MVAUtils::NodeLGBM) != 8) { std::cout << "WARNING: NodeLGBM should be 8 bytes" << std::endl; }
57 if (sizeof(MVAUtils::NodeXGBoost) != 8) { std::cout << "WARNING: NodeXGBoost should be 8 bytes" << std::endl; }
58
59 std::default_random_engine gen;
60 std::uniform_real_distribution<float> rnd_uniform(-100., 100.);
61
62 const unsigned int NTEST = 10000;
63
64 auto t1 = std::chrono::high_resolution_clock::now();
65
66 int nvars = bdt.GetNVars();
67 std::vector<float> rnd_precomputed;
68
69 // precompute since we don't want to impact timing
70 for (unsigned int itest = 0; itest != NTEST * nvars; ++itest)
71 {
72 rnd_precomputed.push_back(rnd_uniform(gen));
73 }
74
75 auto rnd_it1 = rnd_precomputed.begin();
76 auto rnd_it2 = rnd_it1;
77 std::advance(rnd_it2, nvars);
78 for (unsigned int itest = 0; itest != NTEST; ++itest)
79 {
80 std::vector<float> input_values(rnd_it1, rnd_it2);
81 bdt.GetResponse(input_values);
82 std::advance(rnd_it1, nvars);
83 std::advance(rnd_it2, nvars);
84 }
85 auto t2 = std::chrono::high_resolution_clock::now();
86 std::cout << "timing: "
87 << std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count() / double(NTEST) / double(bdt.GetNTrees())
88 << " ns / events / trees\n";
89
90
91 return 0;
92
93}
TTree * get_tree(const std::string &filename)
Simplified Boosted Regression Tree, support TMVA, lgbm, and xgboost.
Node for LGBM without nan implementation.
Definition NodeImpl.h:92
Node for LGBM with nan implementation.
Definition NodeImpl.h:135
Node for TMVA implementation.
Definition NodeImpl.h:36
Node for XGBoost with nan implementation.
Definition NodeImpl.h:177
std::vector< ALFA_RawDataContainer_p1 > t2
std::vector< ALFA_RawDataCollection_p1 > t1