ATLAS Offline Software
check_timing_mvautils.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #include "MVAUtils/BDT.h"
6 #include "MVAUtils/NodeImpl.h"
7 
8 #include "TFile.h"
9 #include "TTree.h"
10 
11 #include <vector>
12 #include <random>
13 #include <chrono>
14 #include <iostream>
15 
16 
17 TTree* get_tree(const std::string& filename)
18 {
19  TFile* f = TFile::Open(filename.c_str());
20  auto *keys = f->GetListOfKeys();
21  TTree* tree = nullptr;
22  for (int ikey=0; ikey != keys->GetSize(); ++ikey)
23  {
24  TObject* obj = f->Get(keys->At(ikey)->GetName());
25  if (std::string(obj->ClassName()) == "TTree") {
26  f->GetObject(obj->GetName(), tree);
27  }
28  }
29 
30  if (!tree) {
31  std::cout << "cannot find any ttree in file " << filename << std::endl;
32  }
33  return tree;
34 }
35 
36 int main(int argc, char** argv)
37 {
38  if (argc != 2)
39  {
40  std::cout << "need to provide a ROOT file with a TTree" << std::endl;
41  return 1;
42  }
43 
44  TTree* tree = get_tree(argv[1]);
45  if (!tree) { return 1; }
46 
47  MVAUtils::BDT bdt(tree);
48 
49  std::cout << "sizeof class NodeTMVA is " << sizeof(MVAUtils::NodeTMVA) << std::endl;
50  std::cout << "sizeof class NodeLGBMSimple is " << sizeof(MVAUtils::NodeLGBMSimple) << std::endl;
51  std::cout << "sizeof class NodeLGBM is " << sizeof(MVAUtils::NodeLGBM) << std::endl;
52  std::cout << "sizeof class NodeXGBoost is " << sizeof(MVAUtils::NodeXGBoost) << std::endl;
53 
54  if (sizeof(MVAUtils::NodeTMVA) != 8) { std::cout << "WARNING: NodeTMVA should be 8 bytes" << std::endl; }
55  if (sizeof(MVAUtils::NodeLGBMSimple) != 8) { std::cout << "WARNING: NodeLGBMSimple should be 8 bytes" << std::endl; }
56  if (sizeof(MVAUtils::NodeLGBM) != 8) { std::cout << "WARNING: NodeLGBM should be 8 bytes" << std::endl; }
57  if (sizeof(MVAUtils::NodeXGBoost) != 8) { std::cout << "WARNING: NodeXGBoost should be 8 bytes" << std::endl; }
58 
59  std::default_random_engine gen;
60  std::uniform_real_distribution<float> rnd_uniform(-100., 100.);
61 
62  const unsigned int NTEST = 10000;
63 
65 
66  int nvars = bdt.GetNVars();
67  std::vector<float> rnd_precomputed;
68 
69  // precompute since we don't want to impact timing
70  for (unsigned int itest = 0; itest != NTEST * nvars; ++itest)
71  {
72  rnd_precomputed.push_back(rnd_uniform(gen));
73  }
74 
75  auto rnd_it1 = rnd_precomputed.begin();
76  auto rnd_it2 = rnd_it1;
77  std::advance(rnd_it2, nvars);
78  for (unsigned int itest = 0; itest != NTEST; ++itest)
79  {
80  std::vector<float> input_values(rnd_it1, rnd_it2);
81  bdt.GetResponse(input_values);
82  std::advance(rnd_it1, nvars);
83  std::advance(rnd_it2, nvars);
84  }
86  std::cout << "timing: "
87  << std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count() / double(NTEST) / double(bdt.GetNTrees())
88  << " ns / events / trees\n";
89 
90 
91  return 0;
92 
93 }
MVAUtils::NodeTMVA
Node for TMVA implementation.
Definition: NodeImpl.h:36
python.CaloRecoConfig.f
f
Definition: CaloRecoConfig.py:127
MVAUtils::BDT::GetNTrees
unsigned int GetNTrees() const
Number of trees in the whole forest.
tree
TChain * tree
Definition: tile_monitor.h:30
ALFA_EventTPCnv_Dict::t1
std::vector< ALFA_RawDataCollection_p1 > t1
Definition: ALFA_EventTPCnvDict.h:43
MVAUtils::BDT
Simplified Boosted Regression Tree, support TMVA, lgbm, and xgboost.
Definition: BDT.h:34
LArCellConditions.argv
argv
Definition: LArCellConditions.py:112
MVAUtils::NodeXGBoost
Node for XGBoost with nan implementation.
Definition: NodeImpl.h:177
master.gen
gen
Definition: master.py:32
python.handimod.now
now
Definition: handimod.py:675
MVAUtils::NodeLGBMSimple
Node for LGBM without nan implementation.
Definition: NodeImpl.h:92
DQHistogramMergeRegExp.argc
argc
Definition: DQHistogramMergeRegExp.py:20
xAOD::double
double
Definition: CompositeParticle_v1.cxx:159
BDT.h
main
int main(int argc, char **argv)
Definition: check_timing_mvautils.cxx:36
ALFA_EventTPCnv_Dict::t2
std::vector< ALFA_RawDataContainer_p1 > t2
Definition: ALFA_EventTPCnvDict.h:44
MVAUtils::BDT::GetNVars
int GetNVars() const
Number of variables expected in the inputs.
MVAUtils::BDT::GetResponse
float GetResponse(const std::vector< float > &values) const
Get response of the forest, for regression.
CaloCellTimeCorrFiller.filename
filename
Definition: CaloCellTimeCorrFiller.py:24
python.Bindings.keys
keys
Definition: Control/AthenaPython/python/Bindings.py:790
MVAUtils::NodeLGBM
Node for LGBM with nan implementation.
Definition: NodeImpl.h:135
NodeImpl.h
python.PyAthena.obj
obj
Definition: PyAthena.py:135
get_tree
TTree * get_tree(const std::string &filename)
Definition: check_timing_mvautils.cxx:17