ATLAS Offline Software
compareFlatTrees.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 #include <ROOT/RDataFrame.hxx>
5 #include <ROOT/RLogger.hxx>
6 #include <TCanvas.h>
7 #include <TF1.h>
8 #include <TH1D.h>
9 #include <TRatioPlot.h>
10 #include <TROOT.h>
11 #include <TStyle.h>
12 
14 
15 #include <boost/program_options.hpp>
16 
17 #include <algorithm>
18 #include <chrono>
19 #include <iostream>
20 #include <memory>
21 #include <string>
22 #include <unordered_set>
23 
24 // Intersection of two lists of vectors, needed to get the variables that are common to both samples
25 std::vector<std::string> intersection(std::vector<std::string> &v1,
26  std::vector<std::string> &v2)
27 {
28  std::vector<std::string> v3;
29 
30  std::sort(v1.begin(), v1.end());
31  std::sort(v2.begin(), v2.end());
32 
33  std::set_intersection(v1.begin(), v1.end(),
34  v2.begin(), v2.end(),
35  back_inserter(v3));
36  // Alphabetical order while we're at it
37  std::sort(v3.begin(), v3.end(), [](const std::string &a, const std::string &b) -> bool
38  { return a < b; });
39 
40  return v3;
41 }
42 
43 // List of entries in a vector that are not in another
44 std::vector<std::string> remainder (const std::vector<std::string>& v1,
45  const std::vector<std::string>& v2)
46 {
47  std::vector<std::string> result;
48 
49  std::unordered_set<std::string> ignore {v2.begin(), v2.end()};
50  for (const auto& value : v1)
51  {
52  if (ignore.find (value) == ignore.end())
53  result.push_back (value);
54  }
55  return result;
56 }
57 
59 {
60  namespace po = boost::program_options;
61  po::options_description poDescription("Common options");
62  poDescription.add_options()
63  ("help", "produce help message")
64  ("require-same-branches", "require both trees to have the same branches")
65  ("tree-name", po::value<std::string>()->required(), "tree name")
66  ("reference-file", po::value<std::string>()->required(), "reference file(s), wildcards supported")
67  ("test-file", po::value<std::string>()->required(), "test file(s), wildcards supported")
68  ("branch-name", po::value<std::string>(), "base branch name (optional)");
69 
70  po::options_description poDescriptionAdvanced("Advanced options");
71  poDescriptionAdvanced.add_options()
72  ("scale", "scale histograms that both have the same event count")
73  ("rebin", "do smart rebinning")
74  ("benchmark", "benchmark the code")
75  ("verbose", "verbose logging");
76 
77  po::options_description poDescriptionAll;
78  poDescriptionAll.add(poDescription).add(poDescriptionAdvanced);
79 
80  po::positional_options_description poPositionalOptions;
81  poPositionalOptions.add("tree-name", 1);
82  poPositionalOptions.add("reference-file", 1);
83  poPositionalOptions.add("test-file", 1);
84  poPositionalOptions.add("branch-name", 1);
85 
86  po::variables_map poVariablesMap;
87  po::store(po::command_line_parser(argc, argv)
88  .options(poDescriptionAll)
89  .positional(poPositionalOptions).run(),
90  poVariablesMap);
91 
92  if (poVariablesMap.count("help"))
93  {
94  std::cout << "Usage: compareFlatTrees [OPTION] tree-name reference-file test-file [branch-name]" << std::endl;
95  std::cout << poDescriptionAll << std::endl;
96  return 0;
97  }
98 
99  po::notify(poVariablesMap);
100 
101  // Base name of branches to read
102  std::string treeName = poVariablesMap["tree-name"].as<std::string>();
103  std::string treeNameOut = treeName;
104  std::replace( treeNameOut.begin(), treeNameOut.end(), '/', '_');
105  std::string referenceInput = poVariablesMap["reference-file"].as<std::string>();
106  std::string testInput = poVariablesMap["test-file"].as<std::string>();
107  std::string baseBranchName;
108  std::string outputPDF;
109  if (poVariablesMap.count("branch-name") > 0)
110  {
111  baseBranchName = poVariablesMap["branch-name"].as<std::string>();
112  outputPDF = "comparison_" + treeNameOut + "_" + baseBranchName + ".pdf";
113  }
114  else
115  {
116  outputPDF = "comparison_" + treeNameOut + ".pdf";
117  }
118 
119  bool scale = poVariablesMap.count("scale") > 0;
120  bool rebin = poVariablesMap.count("rebin") > 0;
121  bool benchmark = poVariablesMap.count("benchmark") > 0;
122  bool verbose = poVariablesMap.count("verbose") > 0;
123 
124  // Verbose logging
125  // auto verbosity = ROOT::Experimental::RLogScopedVerbosity(ROOT::Detail::RDF::RDFLogChannel(), ROOT::Experimental::ELogLevel::kInfo);
126 
127  // Run in batch mode - output is pdf
128  gROOT->SetBatch(kTRUE);
129  // Suppress logging
130  if (!verbose)
131  {
132  gErrorIgnoreLevel = kWarning;
133  }
134  // No stats box
135  gStyle->SetOptStat(0);
136  // No error bar ticks
137  gStyle->SetEndErrorSize(0);
138  // Parallel processing where possible
139  ROOT::EnableImplicitMT(4);
140 
141  // Create RDataFrame
142  ROOT::RDataFrame dataFrameRefr(treeName, referenceInput);
143  ROOT::RDataFrame dataFrameTest(treeName, testInput);
144 
145  // Event count and ratio
146  auto eventsRefr{dataFrameRefr.Count()};
147  auto eventsTest{dataFrameTest.Count()};
148  double eventRatio{static_cast<float>(eventsRefr.GetValue()) / static_cast<float>(eventsTest.GetValue())};
149 
150  // Get column names for each file and then the intersection
151  auto colNamesRefr = dataFrameRefr.GetColumnNames();
152  auto colNamesTest = dataFrameTest.GetColumnNames();
153  auto colNames = intersection(colNamesRefr, colNamesTest);
154 
155  bool checkMissColumns = false;
156  std::vector<std::string> missColNamesRefr, missColNamesTest;
157  if (poVariablesMap.count("require-same-branches"))
158  {
159  checkMissColumns = true;
160  missColNamesRefr = remainder (colNamesRefr, colNames);
161  missColNamesTest = remainder (colNamesTest, colNames);
162  }
163 
164  // Loop over column names and get a list of the required columns
165  std::vector<std::string> requiredColumns;
166  std::cout << "Will attempt to plot the following columns:" << std::endl;
167  for (auto &&colName : colNames)
168  {
169  if ((baseBranchName.empty() || colName.find(baseBranchName) != std::string::npos) && // include
170  (colName.find("Trig") == std::string::npos) && // exclude, not meaningful
171  (colName.find("Link") == std::string::npos) && // exclude, elementlinks
172  (colName.find("m_persIndex") == std::string::npos) && // exclude, elementlinks
173  (colName.find("m_persKey") == std::string::npos) && // exclude, elementlinks
174  (colName.find("Parent") == std::string::npos) && // exclude, elementlinks
175  (colName.find("original") == std::string::npos) && // exclude, elementlinks
176  (colName.find("EventInfoAuxDyn.detDescrTags") == std::string::npos) && // exclude, std::pair
177  (dataFrameRefr.GetColumnType(colName).find("xAOD") == std::string::npos) && // exclude, needs ATLAS s/w
178  (dataFrameRefr.GetColumnType(colName) != "ROOT::VecOps::RVec<string>") && // exclude, needs ATLAS s/w
179  (!dataFrameRefr.GetColumnType(colName).starts_with("ROOT::VecOps::RVec<pair")) && // exclude, std::pair
180  (dataFrameRefr.GetColumnType(colName).find("vector") == std::string::npos))
181  { // exclude, needs unwrapping
182  requiredColumns.push_back(colName);
183  std::cout << " " << colName << " " << dataFrameRefr.GetColumnType(colName) << std::endl;
184  }
185  }
186 
187  // Set binning
188  int nBins{128};
189  size_t nBinsU = static_cast<size_t>(nBins);
190 
191  // Loop over the required columns and plot them for each sample along with the ratio
192  // Write resulting plots to a pdf file
193  bool fileOpen{};
194  size_t counter{};
195  size_t failedCount{};
196  std::chrono::seconds totalDuration{};
197  std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMinValuesRefr;
198  std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMinValuesTest;
199  std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMaxValuesRefr;
200  std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMaxValuesTest;
201  std::unordered_map<std::string, ROOT::RDF::RResultPtr<TH1D>> mapHistRefr;
202  std::unordered_map<std::string, ROOT::RDF::RResultPtr<TH1D>> mapHistTest;
203 
204  std::cout << "Preparing ranges..." << std::endl;
205  for (const std::string &colName : requiredColumns)
206  {
207  mapMinValuesRefr.emplace(colName, dataFrameRefr.Min(colName));
208  mapMinValuesTest.emplace(colName, dataFrameTest.Min(colName));
209  mapMaxValuesRefr.emplace(colName, dataFrameRefr.Max(colName));
210  mapMaxValuesTest.emplace(colName, dataFrameTest.Max(colName));
211  }
212 
213  std::cout << "Preparing histograms..." << std::endl;
215  for (auto it = requiredColumns.begin(); it != requiredColumns.end();)
216  {
217  const std::string &colName = *it;
218  const char* colNameCh = colName.c_str();
219 
220  // Initial histogram range
221  double min = std::min(mapMinValuesRefr[colName].GetValue(), mapMinValuesTest[colName].GetValue());
222  double max = std::max(mapMaxValuesRefr[colName].GetValue(), mapMaxValuesTest[colName].GetValue()) * 1.02;
223  if (std::isinf(min) || std::isinf(max))
224  {
225  std::cout << " skipping " << colName << " ..." << std::endl;
226  it = requiredColumns.erase(it);
227  continue;
228  } else {
229  ++it;
230  }
231 
232  if (max > 250e3 && min > 0.0)
233  {
234  min = 0.0;
235  }
236 
237  // Initial histograms
238  mapHistRefr.emplace(colName, dataFrameRefr.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh));
239  mapHistTest.emplace(colName, dataFrameTest.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh));
240  }
242  auto duration = std::chrono::duration_cast<std::chrono::seconds>(stop - start);
243  totalDuration += duration;
244  if (benchmark)
245  {
246  std::cout << " Time for this step: " << duration.count() << " seconds " << std::endl;
247  std::cout << " Elapsed time: " << totalDuration.count() << " seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() << " minutes)" << std::endl;
248  }
249 
250  if (rebin)
251  {
252  std::cout << "Rebinning histograms..." << std::endl;
254  for (const std::string &colName : requiredColumns)
255  {
256  const char* colNameCh = colName.c_str();
257  auto &histRefr = mapHistRefr[colName];
258  auto &histTest = mapHistTest[colName];
259 
260  // Initial histogram range
261  double min = std::min(mapMinValuesRefr[colName].GetValue(), mapMinValuesTest[colName].GetValue());
262  double max = std::max(mapMaxValuesRefr[colName].GetValue(), mapMaxValuesTest[colName].GetValue()) * 1.02;
263  if (max > 250e3 && min > 0.0)
264  {
265  min = 0.0;
266  }
267 
268  // Check range - make sure that bins other than the first contain at least one per mille events
269  // Avoids case where max is determined by a single outlier leading to most events being in the 1st bin
270  bool rangeSatisfactory{};
271  size_t rangeItrCntr{};
272  while (!rangeSatisfactory && rangeItrCntr < 10)
273  {
274  ++rangeItrCntr;
275  if (verbose)
276  {
277  std::cout << std::endl
278  << " Range tuning... iteration number " << rangeItrCntr << std::endl;
279  }
280  double entriesFirstBin = histRefr.GetPtr()->GetBinContent(1);
281  double entriesLastBin = histRefr.GetPtr()->GetBinContent(nBins);
282  double entriesOtherBins{};
283  for (size_t i{2}; i < nBinsU; ++i)
284  {
285  entriesOtherBins += histRefr.GetPtr()->GetBinContent(i);
286  }
287  bool firstBinOK{((entriesOtherBins + entriesLastBin) / entriesFirstBin > 0.001f)};
288  bool lastBinOK{((entriesOtherBins + entriesFirstBin) / entriesLastBin > 0.001f)};
289  rangeSatisfactory = ((firstBinOK && lastBinOK) || entriesOtherBins == 0.0f);
290  if (!rangeSatisfactory)
291  {
292  if (verbose)
293  {
294  std::cout << "Min " << min << std::endl;
295  std::cout << "Max " << max << std::endl;
296  std::cout << "1st " << entriesFirstBin << std::endl;
297  std::cout << "Mid " << entriesOtherBins << std::endl;
298  std::cout << "End " << entriesLastBin << std::endl;
299  std::cout << "R/F " << (entriesOtherBins + entriesLastBin) / entriesFirstBin << std::endl;
300  std::cout << "R/L " << (entriesOtherBins + entriesFirstBin) / entriesLastBin << std::endl;
301  }
302  if (!firstBinOK)
303  {
304  max = (max - min) / static_cast<double>(nBins);
305  histRefr = dataFrameRefr.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
306  histTest = dataFrameTest.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
307  }
308  if (!lastBinOK)
309  {
310  min = max * (1.0f - (1.0f / static_cast<double>(nBins)));
311  histRefr = dataFrameRefr.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
312  histTest = dataFrameTest.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
313  }
314  }
315  }
316  }
317 
319  auto duration = std::chrono::duration_cast<std::chrono::seconds>(stop - start);
320  totalDuration += duration;
321  if (benchmark)
322  {
323  std::cout << " Time for this step: " << duration.count() << " seconds " << std::endl;
324  std::cout << " Elapsed time: " << totalDuration.count() << " seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() << " minutes)" << std::endl;
325  }
326  }
327 
328  std::cout << "Running comparisons..." << std::endl;
330 
331  // Store only last canvas
332  std::unique_ptr<TCanvas> lastCanvas;
333 
334  for (const std::string &colName : requiredColumns)
335  {
336  ++counter;
337 
338  std::cout << "Processing column " << counter << " of " << requiredColumns.size() << " : " << colName << " ... ";
339 
340  auto h1 = mapHistRefr[colName].GetPtr();
341  auto h2 = mapHistTest[colName].GetPtr();
342 
343  if (scale)
344  {
345  h2->Scale(eventRatio);
346  }
347  h2->SetMarkerStyle(20);
348  h2->SetMarkerSize(0.8);
349 
350  if (!verbose)
351  {
352  gErrorIgnoreLevel = kError; // this is spammy due to empty bins
353  }
354  auto c1 = std::make_unique<TCanvas>();
355  auto rp = std::make_unique<TRatioPlot>(h2, h1);
356  if (!verbose)
357  {
358  gErrorIgnoreLevel = kWarning;
359  }
360 
361  rp->SetH1DrawOpt("PE");
362  rp->SetH2DrawOpt("hist");
363  rp->SetGraphDrawOpt("PE");
364  rp->Draw();
365  rp->GetUpperRefXaxis()->SetTitle(colName.c_str());
366  rp->GetUpperRefYaxis()->SetTitle("Count");
367  rp->GetLowerRefYaxis()->SetTitle("Test / Ref.");
368  rp->GetLowerRefGraph()->SetMarkerStyle(20);
369  rp->GetLowerRefGraph()->SetMarkerSize(0.8);
370 
371  bool valid{true};
372  for (int i{}; i < rp->GetLowerRefGraph()->GetN(); i++)
373  {
374  if (rp->GetLowerRefGraph()->GetY()[i] != 1.0)
375  {
376  valid = false;
377  break;
378  }
379  }
380  if (valid)
381  {
382  std::cout << "PASS" << std::endl;
383  continue;
384  }
385  else
386  {
387  std::cout << "FAILED" << std::endl;
388  ++failedCount;
389  }
390 
391  rp->GetLowerRefGraph()->SetMinimum(0.5);
392  rp->GetLowerRefGraph()->SetMaximum(1.5);
393  rp->GetLowYaxis()->SetNdivisions(505);
394 
395  if (!fileOpen)
396  {
397  // Open file
398  c1->Print((outputPDF + "[").c_str());
399  fileOpen = true;
400  }
401  // Actual plot
402  c1->Print(outputPDF.c_str());
403  c1->Clear();
404  lastCanvas = std::move(c1);
405  }
406 
407  if (fileOpen)
408  {
409  // Close file
410  lastCanvas->Print((outputPDF + "]").c_str());
411  lastCanvas.reset();
412  }
413 
415  duration = std::chrono::duration_cast<std::chrono::seconds>(stop - start);
416  totalDuration += duration;
417  if (benchmark)
418  {
419  std::cout << " Time for this step: " << duration.count() << " seconds " << std::endl;
420  std::cout << " Elapsed time: " << totalDuration.count() << " seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() << " minutes)" << std::endl;
421  }
422 
423  std::cout << "========================" << std::endl;
424  std::cout << "Reference events: " << eventsRefr.GetValue() << std::endl;
425  std::cout << "Test events: " << eventsTest.GetValue() << std::endl;
426  std::cout << "Ratio: " << eventRatio << std::endl;
427  std::cout << "========================" << std::endl;
428  std::cout << "Tested columns: " << requiredColumns.size() << std::endl;
429  std::cout << "Passed: " << requiredColumns.size() - failedCount << std::endl;
430  std::cout << "Failed: " << failedCount << std::endl;
431  std::cout << "========================" << std::endl;
432  if (checkMissColumns)
433  {
434  std::cout << "Columns only in reference: " << missColNamesRefr.size();
435  for (const auto& column : missColNamesRefr)
436  std::cout << " " << column;
437  std::cout << std::endl;
438  std::cout << "Columns only in test: " << missColNamesTest.size();
439  for (const auto& column : missColNamesTest)
440  std::cout << " " << column;
441  std::cout << std::endl;
442  failedCount += missColNamesRefr.size() + missColNamesTest.size();
443  std::cout << "========================" << std::endl;
444  }
445 
446  if (failedCount)
447  {
448  return 1;
449  }
450 
451  return 0;
452 }
replace
std::string replace(std::string s, const std::string &s2, const std::string &s3)
Definition: hcg.cxx:307
SGTest::store
TestStore store
Definition: TestStore.cxx:23
get_generator_info.result
result
Definition: get_generator_info.py:21
max
constexpr double max()
Definition: ap_fixedTest.cxx:33
covarianceToolsLibrary.gErrorIgnoreLevel
gErrorIgnoreLevel
Definition: covarianceToolsLibrary.py:21
min
constexpr double min()
Definition: ap_fixedTest.cxx:26
mergePhysValFiles.start
start
Definition: DataQuality/DataQualityUtils/scripts/mergePhysValFiles.py:13
extractSporadic.c1
c1
Definition: extractSporadic.py:133
skel.it
it
Definition: skel.GENtoEVGEN.py:407
run
int run(int argc, char *argv[])
Definition: ttree2hdf5.cxx:28
DeMoUpdate.column
dictionary column
Definition: DeMoUpdate.py:1110
athena.value
value
Definition: athena.py:124
PixelModuleFeMask_create_db.stop
int stop
Definition: PixelModuleFeMask_create_db.py:76
read_hist_ntuple.h1
h1
Definition: read_hist_ntuple.py:21
yodamerge_tmp.scale
scale
Definition: yodamerge_tmp.py:138
intersection
std::vector< std::string > intersection(std::vector< std::string > &v1, std::vector< std::string > &v2)
Definition: compareFlatTrees.cxx:25
DiTauMassTools::ignore
void ignore(T &&)
Definition: PhysicsAnalysis/TauID/DiTauMassTools/DiTauMassTools/HelperFunctions.h:58
main
int main(int, char **)
Main class for all the CppUnit test classes
Definition: CppUnit_SGtestdriver.cxx:141
calibdata.valid
list valid
Definition: calibdata.py:44
python.handimod.now
now
Definition: handimod.py:674
CheckAppliedSFs.e3
e3
Definition: CheckAppliedSFs.py:264
lumiFormat.i
int i
Definition: lumiFormat.py:85
LArCellNtuple.argv
argv
Definition: LArCellNtuple.py:152
Analysis::kError
@ kError
Definition: CalibrationDataVariables.h:60
dumpFileToPlots.treeName
string treeName
Definition: dumpFileToPlots.py:19
hist_file_dump.f
f
Definition: hist_file_dump.py:140
python.AtlRunQueryLib.options
options
Definition: AtlRunQueryLib.py:378
DQHistogramMergeRegExp.argc
argc
Definition: DQHistogramMergeRegExp.py:19
python.LArCalib_HVCorrConfig.seconds
seconds
Definition: LArCalib_HVCorrConfig.py:100
PixelAthHitMonAlgCfg.duration
duration
Definition: PixelAthHitMonAlgCfg.py:152
dumpTgcDigiJitter.nBins
list nBins
Definition: dumpTgcDigiJitter.py:29
remainder
std::vector< std::string > remainder(const std::vector< std::string > &v1, const std::vector< std::string > &v2)
Definition: compareFlatTrees.cxx:44
makeegammaturnon.rebin
def rebin(binning, data)
Definition: makeegammaturnon.py:15
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:76
ReadCellNoiseFromCoolCompare.v2
v2
Definition: ReadCellNoiseFromCoolCompare.py:364
a
TList * a
Definition: liststreamerinfos.cxx:10
python.TriggerHandler.verbose
verbose
Definition: TriggerHandler.py:296
rp
ReadCards * rp
Definition: IReadCards.cxx:26
test_pyathena.counter
counter
Definition: test_pyathena.py:15
set_intersection
Set * set_intersection(Set *set1, Set *set2)
Perform an intersection of two sets.
checker_macros.h
Define macros for attributes used to control the static checker.
ATLAS_NOT_THREAD_SAFE
int main ATLAS_NOT_THREAD_SAFE(int argc, char *argv[])
Definition: compareFlatTrees.cxx:58