Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
compareFlatTrees.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 #include <ROOT/RDataFrame.hxx>
5 #include <ROOT/RLogger.hxx>
6 #include <TCanvas.h>
7 #include <TF1.h>
8 #include <TH1D.h>
9 #include <TRatioPlot.h>
10 #include <TROOT.h>
11 #include <TStyle.h>
12 
14 
15 #include <boost/program_options.hpp>
16 
17 #include <algorithm>
18 #include <chrono>
19 #include <iostream>
20 #include <memory>
21 #include <string>
22 #include <unordered_set>
23 
24 // Intersection of two lists of vectors, needed to get the variables that are common to both samples
25 std::vector<std::string> intersection(std::vector<std::string> &v1,
26  std::vector<std::string> &v2)
27 {
28  std::vector<std::string> v3;
29 
30  std::sort(v1.begin(), v1.end());
31  std::sort(v2.begin(), v2.end());
32 
33  std::set_intersection(v1.begin(), v1.end(),
34  v2.begin(), v2.end(),
35  back_inserter(v3));
36  // Alphabetical order while we're at it
37  std::sort(v3.begin(), v3.end(), [](const std::string &a, const std::string &b) -> bool
38  { return a < b; });
39 
40  return v3;
41 }
42 
43 // List of entries in a vector that are not in another
44 std::vector<std::string> remainder (const std::vector<std::string>& v1,
45  const std::vector<std::string>& v2)
46 {
47  std::vector<std::string> result;
48 
49  std::unordered_set<std::string> ignore {v2.begin(), v2.end()};
50  for (const auto& value : v1)
51  {
52  if (ignore.find (value) == ignore.end())
53  result.push_back (value);
54  }
55  return result;
56 }
57 
59 {
60  namespace po = boost::program_options;
61  po::options_description poDescription("Common options");
62  poDescription.add_options()
63  ("help", "produce help message")
64  ("require-same-branches", "require both trees to have the same branches")
65  ("tree-name", po::value<std::string>()->required(), "tree name")
66  ("reference-file", po::value<std::string>()->required(), "reference file(s), wildcards supported")
67  ("test-file", po::value<std::string>()->required(), "test file(s), wildcards supported")
68  ("branch-name", po::value<std::string>(), "base branch name (optional)");
69 
70  po::options_description poDescriptionAdvanced("Advanced options");
71  poDescriptionAdvanced.add_options()
72  ("scale", "scale histograms that both have the same event count")
73  ("rebin", "do smart rebinning")
74  ("benchmark", "benchmark the code")
75  ("verbose", "verbose logging");
76 
77  po::options_description poDescriptionAll;
78  poDescriptionAll.add(poDescription).add(poDescriptionAdvanced);
79 
80  po::positional_options_description poPositionalOptions;
81  poPositionalOptions.add("tree-name", 1);
82  poPositionalOptions.add("reference-file", 1);
83  poPositionalOptions.add("test-file", 1);
84  poPositionalOptions.add("branch-name", 1);
85 
86  po::variables_map poVariablesMap;
87  po::store(po::command_line_parser(argc, argv)
88  .options(poDescriptionAll)
89  .positional(poPositionalOptions).run(),
90  poVariablesMap);
91 
92  if (poVariablesMap.count("help"))
93  {
94  std::cout << "Usage: compareFlatTrees [OPTION] tree-name reference-file test-file [branch-name]" << std::endl;
95  std::cout << poDescriptionAll << std::endl;
96  return 0;
97  }
98 
99  po::notify(poVariablesMap);
100 
101  // Base name of branches to read
102  std::string treeName = poVariablesMap["tree-name"].as<std::string>();
103  std::string treeNameOut = treeName;
104  std::replace( treeNameOut.begin(), treeNameOut.end(), '/', '_');
105  std::string referenceInput = poVariablesMap["reference-file"].as<std::string>();
106  std::string testInput = poVariablesMap["test-file"].as<std::string>();
107  std::string baseBranchName;
108  std::string outputPDF;
109  if (poVariablesMap.count("branch-name") > 0)
110  {
111  baseBranchName = poVariablesMap["branch-name"].as<std::string>();
112  outputPDF = "comparison_" + treeNameOut + "_" + baseBranchName + ".pdf";
113  }
114  else
115  {
116  outputPDF = "comparison_" + treeNameOut + ".pdf";
117  }
118 
119  bool scale = poVariablesMap.count("scale") > 0;
120  bool rebin = poVariablesMap.count("rebin") > 0;
121  bool benchmark = poVariablesMap.count("benchmark") > 0;
122  bool verbose = poVariablesMap.count("verbose") > 0;
123 
124  // Verbose logging
125  // auto verbosity = ROOT::Experimental::RLogScopedVerbosity(ROOT::Detail::RDF::RDFLogChannel(), ROOT::Experimental::ELogLevel::kInfo);
126 
127  // Run in batch mode - output is pdf
128  gROOT->SetBatch(kTRUE);
129  // Suppress logging
130  if (!verbose)
131  {
132  gErrorIgnoreLevel = kWarning;
133  }
134  // No stats box
135  gStyle->SetOptStat(0);
136  // No error bar ticks
137  gStyle->SetEndErrorSize(0);
138  // Parallel processing where possible
139  ROOT::EnableImplicitMT(4);
140 
141  // Create RDataFrame
142  ROOT::RDataFrame dataFrameRefr(treeName, referenceInput);
143  ROOT::RDataFrame dataFrameTest(treeName, testInput);
144 
145  // Event count and ratio
146  auto eventsRefr{dataFrameRefr.Count()};
147  auto eventsTest{dataFrameTest.Count()};
148  double eventRatio{static_cast<float>(eventsRefr.GetValue()) / static_cast<float>(eventsTest.GetValue())};
149 
150  // Get column names for each file and then the intersection
151  auto colNamesRefr = dataFrameRefr.GetColumnNames();
152  auto colNamesTest = dataFrameTest.GetColumnNames();
153  auto colNames = intersection(colNamesRefr, colNamesTest);
154 
155  bool checkMissColumns = false;
156  std::vector<std::string> missColNamesRefr, missColNamesTest;
157  if (poVariablesMap.count("require-same-branches"))
158  {
159  checkMissColumns = true;
160  missColNamesRefr = remainder (colNamesRefr, colNames);
161  missColNamesTest = remainder (colNamesTest, colNames);
162  }
163 
164  // Loop over column names and get a list of the required columns
165  std::vector<std::string> requiredColumns;
166  std::cout << "Will attempt to plot the following columns:" << std::endl;
167  for (auto &&colName : colNames)
168  {
169  if ((baseBranchName.empty() || colName.find(baseBranchName) != std::string::npos) && // include
170  (colName.find("Trig") == std::string::npos) && // exclude, not meaningful
171  (colName.find("Link") == std::string::npos) && // exclude, elementlinks
172  (colName.find("m_persIndex") == std::string::npos) && // exclude, elementlinks
173  (colName.find("m_persKey") == std::string::npos) && // exclude, elementlinks
174  (colName.find("Parent") == std::string::npos) && // exclude, elementlinks
175  (colName.find("original") == std::string::npos) && // exclude, elementlinks
176  (colName.find("EventInfoAuxDyn.detDescrTags") == std::string::npos) && // exclude, std::pair
177  (dataFrameRefr.GetColumnType(colName).find("xAOD") == std::string::npos) && // exclude, needs ATLAS s/w
178  (dataFrameRefr.GetColumnType(colName) != "ROOT::VecOps::RVec<string>") && // exclude, needs ATLAS s/w
179  (!dataFrameRefr.GetColumnType(colName).starts_with("ROOT::VecOps::RVec<pair")) && // exclude, std::pair
180  (dataFrameRefr.GetColumnType(colName).find("vector") == std::string::npos))
181  { // exclude, needs unwrapping
182  requiredColumns.push_back(colName);
183  std::cout << " " << colName << " " << dataFrameRefr.GetColumnType(colName) << std::endl;
184  }
185  }
186 
187  // Create canvas
188  auto c1 = new TCanvas("c1", "Tree comparison");
189 
190  // Set binning
191  int nBins{128};
192  size_t nBinsU = static_cast<size_t>(nBins);
193 
194  // Loop over the required columns and plot them for each sample along with the ratio
195  // Write resulting plots to a pdf file
196  bool fileOpen{};
197  size_t counter{};
198  size_t failedCount{};
199  std::chrono::seconds totalDuration{};
200  std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMinValuesRefr;
201  std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMinValuesTest;
202  std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMaxValuesRefr;
203  std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMaxValuesTest;
204  std::unordered_map<std::string, ROOT::RDF::RResultPtr<TH1D>> mapHistRefr;
205  std::unordered_map<std::string, ROOT::RDF::RResultPtr<TH1D>> mapHistTest;
206 
207  std::cout << "Preparing ranges..." << std::endl;
208  for (const std::string &colName : requiredColumns)
209  {
210  mapMinValuesRefr.emplace(colName, dataFrameRefr.Min(colName));
211  mapMinValuesTest.emplace(colName, dataFrameTest.Min(colName));
212  mapMaxValuesRefr.emplace(colName, dataFrameRefr.Max(colName));
213  mapMaxValuesTest.emplace(colName, dataFrameTest.Max(colName));
214  }
215 
216  std::cout << "Preparing histograms..." << std::endl;
218  for (auto it = requiredColumns.begin(); it != requiredColumns.end();)
219  {
220  const std::string &colName = *it;
221  const char* colNameCh = colName.c_str();
222 
223  // Initial histogram range
224  double min = std::min(mapMinValuesRefr[colName].GetValue(), mapMinValuesTest[colName].GetValue());
225  double max = std::max(mapMaxValuesRefr[colName].GetValue(), mapMaxValuesTest[colName].GetValue()) * 1.02;
226  if (std::isinf(min) || std::isinf(max))
227  {
228  std::cout << " skipping " << colName << " ..." << std::endl;
229  it = requiredColumns.erase(it);
230  continue;
231  } else {
232  ++it;
233  }
234 
235  if (max > 250e3 && min > 0.0)
236  {
237  min = 0.0;
238  }
239 
240  // Initial histograms
241  mapHistRefr.emplace(colName, dataFrameRefr.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh));
242  mapHistTest.emplace(colName, dataFrameTest.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh));
243  }
245  auto duration = std::chrono::duration_cast<std::chrono::seconds>(stop - start);
246  totalDuration += duration;
247  if (benchmark)
248  {
249  std::cout << " Time for this step: " << duration.count() << " seconds " << std::endl;
250  std::cout << " Elapsed time: " << totalDuration.count() << " seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() << " minutes)" << std::endl;
251  }
252 
253  if (rebin)
254  {
255  std::cout << "Rebinning histograms..." << std::endl;
257  for (const std::string &colName : requiredColumns)
258  {
259  const char* colNameCh = colName.c_str();
260  auto &histRefr = mapHistRefr[colName];
261  auto &histTest = mapHistTest[colName];
262 
263  // Initial histogram range
264  double min = std::min(mapMinValuesRefr[colName].GetValue(), mapMinValuesTest[colName].GetValue());
265  double max = std::max(mapMaxValuesRefr[colName].GetValue(), mapMaxValuesTest[colName].GetValue()) * 1.02;
266  if (max > 250e3 && min > 0.0)
267  {
268  min = 0.0;
269  }
270 
271  // Check range - make sure that bins other than the first contain at least one per mille events
272  // Avoids case where max is determined by a single outlier leading to most events being in the 1st bin
273  bool rangeSatisfactory{};
274  size_t rangeItrCntr{};
275  while (!rangeSatisfactory && rangeItrCntr < 10)
276  {
277  ++rangeItrCntr;
278  if (verbose)
279  {
280  std::cout << std::endl
281  << " Range tuning... iteration number " << rangeItrCntr << std::endl;
282  }
283  double entriesFirstBin = histRefr.GetPtr()->GetBinContent(1);
284  double entriesLastBin = histRefr.GetPtr()->GetBinContent(nBins);
285  double entriesOtherBins{};
286  for (size_t i{2}; i < nBinsU; ++i)
287  {
288  entriesOtherBins += histRefr.GetPtr()->GetBinContent(i);
289  }
290  bool firstBinOK{((entriesOtherBins + entriesLastBin) / entriesFirstBin > 0.001f)};
291  bool lastBinOK{((entriesOtherBins + entriesFirstBin) / entriesLastBin > 0.001f)};
292  rangeSatisfactory = ((firstBinOK && lastBinOK) || entriesOtherBins == 0.0f);
293  if (!rangeSatisfactory)
294  {
295  if (verbose)
296  {
297  std::cout << "Min " << min << std::endl;
298  std::cout << "Max " << max << std::endl;
299  std::cout << "1st " << entriesFirstBin << std::endl;
300  std::cout << "Mid " << entriesOtherBins << std::endl;
301  std::cout << "End " << entriesLastBin << std::endl;
302  std::cout << "R/F " << (entriesOtherBins + entriesLastBin) / entriesFirstBin << std::endl;
303  std::cout << "R/L " << (entriesOtherBins + entriesFirstBin) / entriesLastBin << std::endl;
304  }
305  if (!firstBinOK)
306  {
307  max = (max - min) / static_cast<double>(nBins);
308  histRefr = dataFrameRefr.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
309  histTest = dataFrameTest.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
310  }
311  if (!lastBinOK)
312  {
313  min = max * (1.0f - (1.0f / static_cast<double>(nBins)));
314  histRefr = dataFrameRefr.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
315  histTest = dataFrameTest.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
316  }
317  }
318  }
319  }
320 
322  auto duration = std::chrono::duration_cast<std::chrono::seconds>(stop - start);
323  totalDuration += duration;
324  if (benchmark)
325  {
326  std::cout << " Time for this step: " << duration.count() << " seconds " << std::endl;
327  std::cout << " Elapsed time: " << totalDuration.count() << " seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() << " minutes)" << std::endl;
328  }
329  }
330 
331  std::cout << "Running comparisons..." << std::endl;
333  for (const std::string &colName : requiredColumns)
334  {
335  ++counter;
336 
337  std::cout << "Processing column " << counter << " of " << requiredColumns.size() << " : " << colName << " ... ";
338 
339  auto h1 = mapHistRefr[colName].GetPtr();
340  auto h2 = mapHistTest[colName].GetPtr();
341 
342  if (scale)
343  {
344  h2->Scale(eventRatio);
345  }
346  h2->SetMarkerStyle(20);
347  h2->SetMarkerSize(0.8);
348 
349  if (!verbose)
350  {
351  gErrorIgnoreLevel = kError; // this is spammy due to empty bins
352  }
353  auto rp = std::unique_ptr<TRatioPlot>(new TRatioPlot(h2, h1));
354  if (!verbose)
355  {
356  gErrorIgnoreLevel = kWarning;
357  }
358 
359  rp->SetH1DrawOpt("PE");
360  rp->SetH2DrawOpt("hist");
361  rp->SetGraphDrawOpt("PE");
362  rp->Draw();
363  rp->GetUpperRefXaxis()->SetTitle(colName.c_str());
364  rp->GetUpperRefYaxis()->SetTitle("Count");
365  rp->GetLowerRefYaxis()->SetTitle("Test / Ref.");
366  rp->GetLowerRefGraph()->SetMarkerStyle(20);
367  rp->GetLowerRefGraph()->SetMarkerSize(0.8);
368 
369  bool valid{true};
370  for (int i{}; i < rp->GetLowerRefGraph()->GetN(); i++)
371  {
372  if (rp->GetLowerRefGraph()->GetY()[i] != 1.0)
373  {
374  valid = false;
375  break;
376  }
377  }
378  if (valid)
379  {
380  std::cout << "PASS" << std::endl;
381  continue;
382  }
383  else
384  {
385  std::cout << "FAILED" << std::endl;
386  ++failedCount;
387  }
388 
389  rp->GetLowerRefGraph()->SetMinimum(0.5);
390  rp->GetLowerRefGraph()->SetMaximum(1.5);
391  rp->GetLowYaxis()->SetNdivisions(505);
392 
393  if (valid) {
394  c1->SetTicks(0, 1);
395  c1->Update();
396  }
397 
398  if (!fileOpen)
399  {
400  // Open file
401  c1->Print((outputPDF + "[").c_str());
402  fileOpen = true;
403  }
404  // Actual plot
405  c1->Print(outputPDF.c_str());
406  c1->Clear();
407  }
408 
409  if (fileOpen)
410  {
411  // Close file
412  c1->Print((outputPDF + "]").c_str());
413  }
414 
416  duration = std::chrono::duration_cast<std::chrono::seconds>(stop - start);
417  totalDuration += duration;
418  if (benchmark)
419  {
420  std::cout << " Time for this step: " << duration.count() << " seconds " << std::endl;
421  std::cout << " Elapsed time: " << totalDuration.count() << " seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() << " minutes)" << std::endl;
422  }
423 
424  std::cout << "========================" << std::endl;
425  std::cout << "Reference events: " << eventsRefr.GetValue() << std::endl;
426  std::cout << "Test events: " << eventsTest.GetValue() << std::endl;
427  std::cout << "Ratio: " << eventRatio << std::endl;
428  std::cout << "========================" << std::endl;
429  std::cout << "Tested columns: " << requiredColumns.size() << std::endl;
430  std::cout << "Passed: " << requiredColumns.size() - failedCount << std::endl;
431  std::cout << "Failed: " << failedCount << std::endl;
432  std::cout << "========================" << std::endl;
433  if (checkMissColumns)
434  {
435  std::cout << "Columns only in reference: " << missColNamesRefr.size();
436  for (const auto& column : missColNamesRefr)
437  std::cout << " " << column;
438  std::cout << std::endl;
439  std::cout << "Columns only in test: " << missColNamesTest.size();
440  for (const auto& column : missColNamesTest)
441  std::cout << " " << column;
442  std::cout << std::endl;
443  failedCount += missColNamesRefr.size() + missColNamesTest.size();
444  std::cout << "========================" << std::endl;
445  }
446 
447  if (failedCount)
448  {
449  return 1;
450  }
451 
452  return 0;
453 }
replace
std::string replace(std::string s, const std::string &s2, const std::string &s3)
Definition: hcg.cxx:307
SGTest::store
TestStore store
Definition: TestStore.cxx:23
get_generator_info.result
result
Definition: get_generator_info.py:21
max
constexpr double max()
Definition: ap_fixedTest.cxx:33
covarianceToolsLibrary.gErrorIgnoreLevel
gErrorIgnoreLevel
Definition: covarianceToolsLibrary.py:21
min
constexpr double min()
Definition: ap_fixedTest.cxx:26
mergePhysValFiles.start
start
Definition: DataQuality/DataQualityUtils/scripts/mergePhysValFiles.py:14
extractSporadic.c1
c1
Definition: extractSporadic.py:134
skel.it
it
Definition: skel.GENtoEVGEN.py:407
run
int run(int argc, char *argv[])
Definition: ttree2hdf5.cxx:28
DeMoUpdate.column
dictionary column
Definition: DeMoUpdate.py:1110
athena.value
value
Definition: athena.py:124
PixelModuleFeMask_create_db.stop
int stop
Definition: PixelModuleFeMask_create_db.py:76
read_hist_ntuple.h1
h1
Definition: read_hist_ntuple.py:21
yodamerge_tmp.scale
scale
Definition: yodamerge_tmp.py:138
intersection
std::vector< std::string > intersection(std::vector< std::string > &v1, std::vector< std::string > &v2)
Definition: compareFlatTrees.cxx:25
DiTauMassTools::ignore
void ignore(T &&)
Definition: PhysicsAnalysis/TauID/DiTauMassTools/DiTauMassTools/HelperFunctions.h:58
main
int main(int, char **)
Main class for all the CppUnit test classes
Definition: CppUnit_SGtestdriver.cxx:141
calibdata.valid
list valid
Definition: calibdata.py:45
python.handimod.now
now
Definition: handimod.py:675
CheckAppliedSFs.e3
e3
Definition: CheckAppliedSFs.py:264
lumiFormat.i
int i
Definition: lumiFormat.py:85
LArCellNtuple.argv
argv
Definition: LArCellNtuple.py:152
Analysis::kError
@ kError
Definition: CalibrationDataVariables.h:60
dumpFileToPlots.treeName
string treeName
Definition: dumpFileToPlots.py:20
hist_file_dump.f
f
Definition: hist_file_dump.py:141
python.AtlRunQueryLib.options
options
Definition: AtlRunQueryLib.py:379
DQHistogramMergeRegExp.argc
argc
Definition: DQHistogramMergeRegExp.py:20
python.LArCalib_HVCorrConfig.seconds
seconds
Definition: LArCalib_HVCorrConfig.py:98
PixelAthHitMonAlgCfg.duration
duration
Definition: PixelAthHitMonAlgCfg.py:152
dumpTgcDigiJitter.nBins
list nBins
Definition: dumpTgcDigiJitter.py:29
remainder
std::vector< std::string > remainder(const std::vector< std::string > &v1, const std::vector< std::string > &v2)
Definition: compareFlatTrees.cxx:44
makeegammaturnon.rebin
def rebin(binning, data)
Definition: makeegammaturnon.py:17
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:77
ReadCellNoiseFromCoolCompare.v2
v2
Definition: ReadCellNoiseFromCoolCompare.py:364
a
TList * a
Definition: liststreamerinfos.cxx:10
python.TriggerHandler.verbose
verbose
Definition: TriggerHandler.py:297
rp
ReadCards * rp
Definition: IReadCards.cxx:26
test_pyathena.counter
counter
Definition: test_pyathena.py:15
set_intersection
Set * set_intersection(Set *set1, Set *set2)
Perform an intersection of two sets.
checker_macros.h
Define macros for attributes used to control the static checker.
ATLAS_NOT_THREAD_SAFE
int main ATLAS_NOT_THREAD_SAFE(int argc, char *argv[])
Definition: compareFlatTrees.cxx:58