ATLAS Offline Software
compareFlatTrees.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2022 CERN for the benefit of the ATLAS collaboration
3 */
4 #include <ROOT/RDataFrame.hxx>
5 #include <ROOT/RLogger.hxx>
6 #include <TCanvas.h>
7 #include <TF1.h>
8 #include <TH1D.h>
9 #include <TRatioPlot.h>
10 #include <TROOT.h>
11 #include <TStyle.h>
12 
14 
15 #include <boost/program_options.hpp>
16 
17 #include <algorithm>
18 #include <chrono>
19 #include <iostream>
20 #include <memory>
21 #include <string>
22 #include <unordered_set>
23 
24 // Intersection of two lists of vectors, needed to get the variables that are common to both samples
25 std::vector<std::string> intersection(std::vector<std::string> &v1,
26  std::vector<std::string> &v2)
27 {
28  std::vector<std::string> v3;
29 
30  std::sort(v1.begin(), v1.end());
31  std::sort(v2.begin(), v2.end());
32 
33  std::set_intersection(v1.begin(), v1.end(),
34  v2.begin(), v2.end(),
35  back_inserter(v3));
36  // Alphabetical order while we're at it
37  std::sort(v3.begin(), v3.end(), [](const std::string &a, const std::string &b) -> bool
38  { return a < b; });
39 
40  return v3;
41 }
42 
44 std::vector<std::string> remainder (const std::vector<std::string>& v1,
45  const std::vector<std::string>& v2)
46 {
47  std::vector<std::string> result;
48 
49  std::unordered_set<std::string> ignore {v2.begin(), v2.end()};
50  for (const auto& value : v1)
51  {
52  if (ignore.find (value) == ignore.end())
53  result.push_back (value);
54  }
55  return result;
56 }
57 
59 {
60  namespace po = boost::program_options;
61  po::options_description poDescription("Common options");
62  poDescription.add_options()
63  ("help", "produce help message")
64  ("require-same-branches", "require both trees to have the same branches")
65  ("tree-name", po::value<std::string>()->required(), "tree name")
66  ("reference-file", po::value<std::string>()->required(), "reference file(s), wildcards supported")
67  ("test-file", po::value<std::string>()->required(), "test file(s), wildcards supported")
68  ("branch-name", po::value<std::string>(), "base branch name (optional)");
69 
70  po::options_description poDescriptionAdvanced("Advanced options");
71  poDescriptionAdvanced.add_options()
72  ("scale", "scale histograms that both have the same event count")
73  ("rebin", "do smart rebinning")
74  ("benchmark", "benchmark the code")
75  ("verbose", "verbose logging");
76 
77  po::options_description poDescriptionAll;
78  poDescriptionAll.add(poDescription).add(poDescriptionAdvanced);
79 
80  po::positional_options_description poPositionalOptions;
81  poPositionalOptions.add("tree-name", 1);
82  poPositionalOptions.add("reference-file", 1);
83  poPositionalOptions.add("test-file", 1);
84  poPositionalOptions.add("branch-name", 1);
85 
86  po::variables_map poVariablesMap;
87  po::store(po::command_line_parser(argc, argv)
88  .options(poDescriptionAll)
89  .positional(poPositionalOptions).run(),
90  poVariablesMap);
91 
92  if (poVariablesMap.count("help"))
93  {
94  std::cout << "Usage: compareFlatTrees [OPTION] tree-name reference-file test-file [branch-name]" << std::endl;
95  std::cout << poDescriptionAll << std::endl;
96  return 0;
97  }
98 
99  po::notify(poVariablesMap);
100 
101  // Base name of branches to read
102  std::string treeName = poVariablesMap["tree-name"].as<std::string>();
103  std::string treeNameOut = treeName;
104  std::replace( treeNameOut.begin(), treeNameOut.end(), '/', '_');
105  std::string referenceInput = poVariablesMap["reference-file"].as<std::string>();
106  std::string testInput = poVariablesMap["test-file"].as<std::string>();
107  std::string baseBranchName;
108  std::string outputPDF;
109  if (poVariablesMap.count("branch-name") > 0)
110  {
111  baseBranchName = poVariablesMap["branch-name"].as<std::string>();
112  outputPDF = "comparison_" + treeNameOut + "_" + baseBranchName + ".pdf";
113  }
114  else
115  {
116  outputPDF = "comparison_" + treeNameOut + ".pdf";
117  }
118 
119  bool scale = poVariablesMap.count("scale") > 0;
120  bool rebin = poVariablesMap.count("rebin") > 0;
121  bool benchmark = poVariablesMap.count("benchmark") > 0;
122  bool verbose = poVariablesMap.count("verbose") > 0;
123 
124  // Verbose logging
125  // auto verbosity = ROOT::Experimental::RLogScopedVerbosity(ROOT::Detail::RDF::RDFLogChannel(), ROOT::Experimental::ELogLevel::kInfo);
126 
127  // Run in batch mode - output is pdf
128  gROOT->SetBatch(kTRUE);
129  // Suppress logging
130  if (!verbose)
131  {
132  gErrorIgnoreLevel = kWarning;
133  }
134  // No stats box
135  gStyle->SetOptStat(0);
136  // No error bar ticks
137  gStyle->SetEndErrorSize(0);
138  // Parallel processing where possible
139  ROOT::EnableImplicitMT(4);
140 
141  // Create RDataFrame
142  ROOT::RDataFrame dataFrameRefr(treeName, referenceInput);
143  ROOT::RDataFrame dataFrameTest(treeName, testInput);
144 
145  // Event count and ratio
146  auto eventsRefr{dataFrameRefr.Count()};
147  auto eventsTest{dataFrameTest.Count()};
148  double eventRatio{static_cast<float>(eventsRefr.GetValue()) / static_cast<float>(eventsTest.GetValue())};
149 
150  // Get column names for each file and then the intersection
151  auto colNamesRefr = dataFrameRefr.GetColumnNames();
152  auto colNamesTest = dataFrameTest.GetColumnNames();
153  auto colNames = intersection(colNamesRefr, colNamesTest);
154 
155  bool checkMissColumns = false;
156  std::vector<std::string> missColNamesRefr, missColNamesTest;
157  if (poVariablesMap.count("require-same-branches"))
158  {
159  checkMissColumns = true;
160  missColNamesRefr = remainder (colNamesRefr, colNames);
161  missColNamesTest = remainder (colNamesTest, colNames);
162  }
163 
164  // Loop over column names and get a list of the required columns
165  std::vector<std::string> requiredColumns;
166  std::cout << "Will attempt to plot the following columns:" << std::endl;
167  for (auto &&colName : colNames)
168  {
169  if ((baseBranchName.empty() || colName.find(baseBranchName) != std::string::npos) && // include
170  (colName.find("Trig") == std::string::npos) && // exclude, not meaningful
171  (colName.find("Link") == std::string::npos) && // exclude, elementlinks
172  (colName.find("m_persIndex") == std::string::npos) && // exclude, elementlinks
173  (colName.find("m_persKey") == std::string::npos) && // exclude, elementlinks
174  (colName.find("Parent") == std::string::npos) && // exclude, elementlinks
175  (colName.find("original") == std::string::npos) && // exclude, elementlinks
176  (colName.find("EventInfoAuxDyn.detDescrTags") == std::string::npos) && // exclude, std::pair
177  (dataFrameRefr.GetColumnType(colName).find("xAOD") == std::string::npos) && // exclude, needs ATLAS s/w
178  (dataFrameRefr.GetColumnType(colName) != "ROOT::VecOps::RVec<string>") && // exclude, needs ATLAS s/w
179  (dataFrameRefr.GetColumnType(colName).find("vector") == std::string::npos))
180  { // exclude, needs unwrapping
181  requiredColumns.push_back(colName);
182  std::cout << " " << colName << " " << dataFrameRefr.GetColumnType(colName) << std::endl;
183  }
184  }
185 
186  // Create canvas
187  auto c1 = new TCanvas("c1", "Tree comparison");
188 
189  // Set binning
190  int nBins{128};
191  size_t nBinsU = static_cast<size_t>(nBins);
192 
193  // Loop over the required columns and plot them for each sample along with the ratio
194  // Write resulting plots to a pdf file
195  bool fileOpen{};
196  size_t counter{};
197  size_t failedCount{};
198  std::chrono::seconds totalDuration{};
199  std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMinValuesRefr;
200  std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMinValuesTest;
201  std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMaxValuesRefr;
202  std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMaxValuesTest;
203  std::unordered_map<std::string, ROOT::RDF::RResultPtr<TH1D>> mapHistRefr;
204  std::unordered_map<std::string, ROOT::RDF::RResultPtr<TH1D>> mapHistTest;
205 
206  std::cout << "Preparing ranges..." << std::endl;
207  for (const std::string &colName : requiredColumns)
208  {
209  mapMinValuesRefr.emplace(colName, dataFrameRefr.Min(colName));
210  mapMinValuesTest.emplace(colName, dataFrameTest.Min(colName));
211  mapMaxValuesRefr.emplace(colName, dataFrameRefr.Max(colName));
212  mapMaxValuesTest.emplace(colName, dataFrameTest.Max(colName));
213  }
214 
215  std::cout << "Preparing histograms..." << std::endl;
217  for (auto it = requiredColumns.begin(); it != requiredColumns.end();)
218  {
219  const std::string &colName = *it;
220  const char* colNameCh = colName.c_str();
221 
222  // Initial histogram range
223  double min = std::min(mapMinValuesRefr[colName].GetValue(), mapMinValuesTest[colName].GetValue());
224  double max = std::max(mapMaxValuesRefr[colName].GetValue(), mapMaxValuesTest[colName].GetValue()) * 1.02;
225  if (std::isinf(min) || std::isinf(max))
226  {
227  std::cout << " skipping " << colName << " ..." << std::endl;
228  it = requiredColumns.erase(it);
229  continue;
230  } else {
231  ++it;
232  }
233 
234  if (max > 250e3 && min > 0.0)
235  {
236  min = 0.0;
237  }
238 
239  // Initial histograms
240  mapHistRefr.emplace(colName, dataFrameRefr.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh));
241  mapHistTest.emplace(colName, dataFrameTest.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh));
242  }
244  auto duration = std::chrono::duration_cast<std::chrono::seconds>(stop - start);
245  totalDuration += duration;
246  if (benchmark)
247  {
248  std::cout << " Time for this step: " << duration.count() << " seconds " << std::endl;
249  std::cout << " Elapsed time: " << totalDuration.count() << " seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() << " minutes)" << std::endl;
250  }
251 
252  if (rebin)
253  {
254  std::cout << "Rebinning histograms..." << std::endl;
256  for (const std::string &colName : requiredColumns)
257  {
258  const char* colNameCh = colName.c_str();
259  auto &histRefr = mapHistRefr[colName];
260  auto &histTest = mapHistTest[colName];
261 
262  // Initial histogram range
263  double min = std::min(mapMinValuesRefr[colName].GetValue(), mapMinValuesTest[colName].GetValue());
264  double max = std::max(mapMaxValuesRefr[colName].GetValue(), mapMaxValuesTest[colName].GetValue()) * 1.02;
265  if (max > 250e3 && min > 0.0)
266  {
267  min = 0.0;
268  }
269 
270  // Check range - make sure that bins other than the first contain at least one per mille events
271  // Avoids case where max is determined by a single outlier leading to most events being in the 1st bin
272  bool rangeSatisfactory{};
273  size_t rangeItrCntr{};
274  while (!rangeSatisfactory && rangeItrCntr < 10)
275  {
276  ++rangeItrCntr;
277  if (verbose)
278  {
279  std::cout << std::endl
280  << " Range tuning... iteration number " << rangeItrCntr << std::endl;
281  }
282  double entriesFirstBin = histRefr.GetPtr()->GetBinContent(1);
283  double entriesLastBin = histRefr.GetPtr()->GetBinContent(nBins);
284  double entriesOtherBins{};
285  for (size_t i{2}; i < nBinsU; ++i)
286  {
287  entriesOtherBins += histRefr.GetPtr()->GetBinContent(i);
288  }
289  bool firstBinOK{((entriesOtherBins + entriesLastBin) / entriesFirstBin > 0.001f)};
290  bool lastBinOK{((entriesOtherBins + entriesFirstBin) / entriesLastBin > 0.001f)};
291  rangeSatisfactory = ((firstBinOK && lastBinOK) || entriesOtherBins == 0.0f);
292  if (!rangeSatisfactory)
293  {
294  if (verbose)
295  {
296  std::cout << "Min " << min << std::endl;
297  std::cout << "Max " << max << std::endl;
298  std::cout << "1st " << entriesFirstBin << std::endl;
299  std::cout << "Mid " << entriesOtherBins << std::endl;
300  std::cout << "End " << entriesLastBin << std::endl;
301  std::cout << "R/F " << (entriesOtherBins + entriesLastBin) / entriesFirstBin << std::endl;
302  std::cout << "R/L " << (entriesOtherBins + entriesFirstBin) / entriesLastBin << std::endl;
303  }
304  if (!firstBinOK)
305  {
306  max = (max - min) / static_cast<double>(nBins);
307  histRefr = dataFrameRefr.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
308  histTest = dataFrameTest.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
309  }
310  if (!lastBinOK)
311  {
312  min = max * (1.0f - (1.0f / static_cast<double>(nBins)));
313  histRefr = dataFrameRefr.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
314  histTest = dataFrameTest.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
315  }
316  }
317  }
318  }
319 
321  auto duration = std::chrono::duration_cast<std::chrono::seconds>(stop - start);
322  totalDuration += duration;
323  if (benchmark)
324  {
325  std::cout << " Time for this step: " << duration.count() << " seconds " << std::endl;
326  std::cout << " Elapsed time: " << totalDuration.count() << " seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() << " minutes)" << std::endl;
327  }
328  }
329 
330  std::cout << "Running comparisons..." << std::endl;
332  for (const std::string &colName : requiredColumns)
333  {
334  ++counter;
335 
336  std::cout << "Processing column " << counter << " of " << requiredColumns.size() << " : " << colName << " ... ";
337 
338  auto h1 = mapHistRefr[colName].GetPtr();
339  auto h2 = mapHistTest[colName].GetPtr();
340 
341  if (scale)
342  {
343  h2->Scale(eventRatio);
344  }
345  h2->SetMarkerStyle(20);
346  h2->SetMarkerSize(0.8);
347 
348  if (!verbose)
349  {
350  gErrorIgnoreLevel = kError; // this is spammy due to empty bins
351  }
352  auto rp = std::unique_ptr<TRatioPlot>(new TRatioPlot(h2, h1));
353  if (!verbose)
354  {
355  gErrorIgnoreLevel = kWarning;
356  }
357 
358  rp->SetH1DrawOpt("PE");
359  rp->SetH2DrawOpt("hist");
360  rp->SetGraphDrawOpt("PE");
361  rp->Draw();
362  rp->GetUpperRefXaxis()->SetTitle(colName.c_str());
363  rp->GetUpperRefYaxis()->SetTitle("Count");
364  rp->GetLowerRefYaxis()->SetTitle("Test / Ref.");
365  rp->GetLowerRefGraph()->SetMarkerStyle(20);
366  rp->GetLowerRefGraph()->SetMarkerSize(0.8);
367 
368  bool valid{true};
369  for (int i{}; i < rp->GetLowerRefGraph()->GetN(); i++)
370  {
371  if (rp->GetLowerRefGraph()->GetY()[i] != 1.0)
372  {
373  valid = false;
374  break;
375  }
376  }
377  if (valid)
378  {
379  std::cout << "PASS" << std::endl;
380  continue;
381  }
382  else
383  {
384  std::cout << "FAILED" << std::endl;
385  ++failedCount;
386  }
387 
388  rp->GetLowerRefGraph()->SetMinimum(0.5);
389  rp->GetLowerRefGraph()->SetMaximum(1.5);
390  rp->GetLowYaxis()->SetNdivisions(505);
391 
392  if (valid) {
393  c1->SetTicks(0, 1);
394  c1->Update();
395  }
396 
397  if (!fileOpen)
398  {
399  // Open file
400  c1->Print((outputPDF + "[").c_str());
401  fileOpen = true;
402  }
403  // Actual plot
404  c1->Print(outputPDF.c_str());
405  c1->Clear();
406  }
407 
408  if (fileOpen)
409  {
410  // Close file
411  c1->Print((outputPDF + "]").c_str());
412  }
413 
415  duration = std::chrono::duration_cast<std::chrono::seconds>(stop - start);
416  totalDuration += duration;
417  if (benchmark)
418  {
419  std::cout << " Time for this step: " << duration.count() << " seconds " << std::endl;
420  std::cout << " Elapsed time: " << totalDuration.count() << " seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() << " minutes)" << std::endl;
421  }
422 
423  std::cout << "========================" << std::endl;
424  std::cout << "Reference events: " << eventsRefr.GetValue() << std::endl;
425  std::cout << "Test events: " << eventsTest.GetValue() << std::endl;
426  std::cout << "Ratio: " << eventRatio << std::endl;
427  std::cout << "========================" << std::endl;
428  std::cout << "Tested columns: " << requiredColumns.size() << std::endl;
429  std::cout << "Passed: " << requiredColumns.size() - failedCount << std::endl;
430  std::cout << "Failed: " << failedCount << std::endl;
431  std::cout << "========================" << std::endl;
432  if (checkMissColumns)
433  {
434  std::cout << "Columns only in reference: " << missColNamesRefr.size();
435  for (const auto& column : missColNamesRefr)
436  std::cout << " " << column;
437  std::cout << std::endl;
438  std::cout << "Columns only in test: " << missColNamesTest.size();
439  for (const auto& column : missColNamesTest)
440  std::cout << " " << column;
441  std::cout << std::endl;
442  failedCount += missColNamesRefr.size() + missColNamesTest.size();
443  std::cout << "========================" << std::endl;
444  }
445 
446  if (failedCount)
447  {
448  return 1;
449  }
450 
451  return 0;
452 }
replace
std::string replace(std::string s, const std::string &s2, const std::string &s3)
Definition: hcg.cxx:307
SGTest::store
TestStore store
Definition: TestStore.cxx:23
get_generator_info.result
result
Definition: get_generator_info.py:21
max
constexpr double max()
Definition: ap_fixedTest.cxx:33
min
constexpr double min()
Definition: ap_fixedTest.cxx:26
mergePhysValFiles.start
start
Definition: DataQuality/DataQualityUtils/scripts/mergePhysValFiles.py:14
extractSporadic.c1
c1
Definition: extractSporadic.py:134
skel.it
it
Definition: skel.GENtoEVGEN.py:396
run
int run(int argc, char *argv[])
Definition: ttree2hdf5.cxx:28
DeMoUpdate.column
dictionary column
Definition: DeMoUpdate.py:1110
athena.value
value
Definition: athena.py:124
PixelModuleFeMask_create_db.stop
int stop
Definition: PixelModuleFeMask_create_db.py:76
read_hist_ntuple.h1
h1
Definition: read_hist_ntuple.py:21
yodamerge_tmp.scale
scale
Definition: yodamerge_tmp.py:138
intersection
std::vector< std::string > intersection(std::vector< std::string > &v1, std::vector< std::string > &v2)
Definition: compareFlatTrees.cxx:25
DiTauMassTools::ignore
void ignore(T &&)
Definition: PhysicsAnalysis/TauID/DiTauMassTools/DiTauMassTools/HelperFunctions.h:58
main
int main(int, char **)
Main class for all the CppUnit test classes
Definition: CppUnit_SGtestdriver.cxx:141
calibdata.valid
list valid
Definition: calibdata.py:45
python.handimod.now
now
Definition: handimod.py:675
CheckAppliedSFs.e3
e3
Definition: CheckAppliedSFs.py:264
lumiFormat.i
int i
Definition: lumiFormat.py:85
LArCellNtuple.argv
argv
Definition: LArCellNtuple.py:152
Analysis::kError
@ kError
Definition: CalibrationDataVariables.h:60
dumpFileToPlots.treeName
string treeName
Definition: dumpFileToPlots.py:20
hist_file_dump.f
f
Definition: hist_file_dump.py:135
python.AtlRunQueryLib.options
options
Definition: AtlRunQueryLib.py:379
DQHistogramMergeRegExp.argc
argc
Definition: DQHistogramMergeRegExp.py:20
python.LArCalib_HVCorrConfig.seconds
seconds
Definition: LArCalib_HVCorrConfig.py:86
PixelAthHitMonAlgCfg.duration
duration
Definition: PixelAthHitMonAlgCfg.py:152
dumpTgcDigiJitter.nBins
list nBins
Definition: dumpTgcDigiJitter.py:29
remainder
std::vector< std::string > remainder(const std::vector< std::string > &v1, const std::vector< std::string > &v2)
list of entries in a vector that are not in another
Definition: compareFlatTrees.cxx:44
makeegammaturnon.rebin
def rebin(binning, data)
Definition: makeegammaturnon.py:17
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:77
gErrorIgnoreLevel
int gErrorIgnoreLevel
ReadCellNoiseFromCoolCompare.v2
v2
Definition: ReadCellNoiseFromCoolCompare.py:364
a
TList * a
Definition: liststreamerinfos.cxx:10
python.TriggerHandler.verbose
verbose
Definition: TriggerHandler.py:297
rp
ReadCards * rp
Definition: IReadCards.cxx:26
test_pyathena.counter
counter
Definition: test_pyathena.py:15
set_intersection
Set * set_intersection(Set *set1, Set *set2)
Perform an intersection of two sets.
checker_macros.h
Define macros for attributes used to control the static checker.
ATLAS_NOT_THREAD_SAFE
int main ATLAS_NOT_THREAD_SAFE(int argc, char *argv[])
Definition: compareFlatTrees.cxx:58