ATLAS Offline Software
Loading...
Searching...
No Matches
check_timing_mvautils.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3*/
4
5#include "MVAUtils/BDT.h"
6#include "MVAUtils/NodeImpl.h"
7
8#include "TFile.h"
9#include "TTree.h"
10
11#include <vector>
12#include <random>
13#include <chrono>
14#include <iostream>
15
16
17TTree* get_tree(const std::string& filename)
18{
19 TFile* f = TFile::Open(filename.c_str());
20 auto *keys = f->GetListOfKeys();
21 TTree* tree = nullptr;
22 for (int ikey=0; ikey != keys->GetSize(); ++ikey)
23 {
24 TObject* obj = f->Get(keys->At(ikey)->GetName());
25 if (std::string(obj->ClassName()) == "TTree") {
26 f->GetObject(obj->GetName(), tree);
27 }
28 }
29
30 if (!tree) {
31 std::cout << "cannot find any ttree in file " << filename << std::endl;
32 }
33 return tree;
34}
35
36int main(int argc, char** argv)
37{
38 if (argc != 2)
39 {
40 std::cout << "need to provide a ROOT file with a TTree" << std::endl;
41 return 1;
42 }
43
44 TTree* tree = get_tree(argv[1]);
45 if (!tree) { return 1; }
46
48
49 std::cout << "sizeof class NodeTMVA is " << sizeof(MVAUtils::NodeTMVA) << std::endl;
50 std::cout << "sizeof class NodeLGBMSimple is " << sizeof(MVAUtils::NodeLGBMSimple) << std::endl;
51 std::cout << "sizeof class NodeLGBM is " << sizeof(MVAUtils::NodeLGBM) << std::endl;
52 std::cout << "sizeof class NodeXGBoost is " << sizeof(MVAUtils::NodeXGBoost) << std::endl;
53
54 if (sizeof(MVAUtils::NodeTMVA) != 8) { std::cout << "WARNING: NodeTMVA should be 8 bytes" << std::endl; }
55 if (sizeof(MVAUtils::NodeLGBMSimple) != 8) { std::cout << "WARNING: NodeLGBMSimple should be 8 bytes" << std::endl; }
56 if (sizeof(MVAUtils::NodeLGBM) != 8) { std::cout << "WARNING: NodeLGBM should be 8 bytes" << std::endl; }
57 if (sizeof(MVAUtils::NodeXGBoost) != 8) { std::cout << "WARNING: NodeXGBoost should be 8 bytes" << std::endl; }
58
59 std::default_random_engine gen;
60 std::uniform_real_distribution<float> rnd_uniform(-100., 100.);
61
62 const unsigned int NTEST = 10000;
63
64 auto t1 = std::chrono::high_resolution_clock::now();
65
66 int nvars = bdt.GetNVars();
67 std::vector<float> rnd_precomputed;
68
69 // precompute since we don't want to impact timing
70 for (unsigned int itest = 0; itest != NTEST * nvars; ++itest)
71 {
72 rnd_precomputed.push_back(rnd_uniform(gen));
73 }
74
75 auto rnd_it1 = rnd_precomputed.begin();
76 auto rnd_it2 = rnd_it1;
77 std::advance(rnd_it2, nvars);
78 for (unsigned int itest = 0; itest != NTEST; ++itest)
79 {
80 std::vector<float> input_values(rnd_it1, rnd_it2);
81 bdt.GetResponse(input_values);
82 std::advance(rnd_it1, nvars);
83 std::advance(rnd_it2, nvars);
84 }
85 auto t2 = std::chrono::high_resolution_clock::now();
86 std::cout << "timing: "
87 << std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count() / double(NTEST) / double(bdt.GetNTrees())
88 << " ns / events / trees\n";
89
90
91 return 0;
92
93}
TTree * get_tree(const std::string &filename)
Simplified Boosted Regression Tree, support TMVA, lgbm, and xgboost.
int GetNVars() const
Number of variables expected in the inputs.
unsigned int GetNTrees() const
Number of trees in the whole forest.
float GetResponse(const std::vector< float > &values) const
Get response of the forest, for regression.
Node for LGBM without nan implementation.
Definition NodeImpl.h:92
Node for LGBM with nan implementation.
Definition NodeImpl.h:135
Node for TMVA implementation.
Definition NodeImpl.h:36
Node for XGBoost with nan implementation.
Definition NodeImpl.h:177
int main()
Definition hello.cxx:18
TChain * tree