ATLAS Offline Software
Loading...
Searching...
No Matches
compareFlatTrees.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4#include <ROOT/RDataFrame.hxx>
5#include <ROOT/RLogger.hxx>
6#include <TCanvas.h>
7#include <TF1.h>
8#include <TH1D.h>
9#include <TRatioPlot.h>
10#include <TROOT.h>
11#include <TStyle.h>
12
14
15#include <boost/program_options.hpp>
16
17#include <algorithm>
18#include <chrono>
19#include <iostream>
20#include <memory>
21#include <string>
22#include <unordered_set>
23
24// Intersection of two lists of vectors, needed to get the variables that are common to both samples
25std::vector<std::string> intersection(std::vector<std::string> &v1,
26 std::vector<std::string> &v2)
27{
28 std::vector<std::string> v3;
29
30 std::sort(v1.begin(), v1.end());
31 std::sort(v2.begin(), v2.end());
32
33 std::set_intersection(v1.begin(), v1.end(),
34 v2.begin(), v2.end(),
35 back_inserter(v3));
36 // Alphabetical order while we're at it
37 std::sort(v3.begin(), v3.end(), [](const std::string &a, const std::string &b) -> bool
38 { return a < b; });
39
40 return v3;
41}
42
43// List of entries in a vector that are not in another
44std::vector<std::string> remainder (const std::vector<std::string>& v1,
45 const std::vector<std::string>& v2)
46{
47 std::vector<std::string> result;
48
49 std::unordered_set<std::string> ignore {v2.begin(), v2.end()};
50 for (const auto& value : v1)
51 {
52 if (ignore.find (value) == ignore.end())
53 result.push_back (value);
54 }
55 return result;
56}
57
58int main ATLAS_NOT_THREAD_SAFE(int argc, char *argv[])
59{
60 namespace po = boost::program_options;
61 po::options_description poDescription("Common options");
62 poDescription.add_options()
63 ("help", "produce help message")
64 ("require-same-branches", "require both trees to have the same branches")
65 ("tree-name", po::value<std::string>()->required(), "tree name")
66 ("reference-file", po::value<std::string>()->required(), "reference file(s), wildcards supported")
67 ("test-file", po::value<std::string>()->required(), "test file(s), wildcards supported")
68 ("branch-name", po::value<std::string>(), "base branch name (optional)");
69
70 po::options_description poDescriptionAdvanced("Advanced options");
71 poDescriptionAdvanced.add_options()
72 ("scale", "scale histograms that both have the same event count")
73 ("rebin", "do smart rebinning")
74 ("benchmark", "benchmark the code")
75 ("verbose", "verbose logging");
76
77 po::options_description poDescriptionAll;
78 poDescriptionAll.add(poDescription).add(poDescriptionAdvanced);
79
80 po::positional_options_description poPositionalOptions;
81 poPositionalOptions.add("tree-name", 1);
82 poPositionalOptions.add("reference-file", 1);
83 poPositionalOptions.add("test-file", 1);
84 poPositionalOptions.add("branch-name", 1);
85
86 po::variables_map poVariablesMap;
87 po::store(po::command_line_parser(argc, argv)
88 .options(poDescriptionAll)
89 .positional(poPositionalOptions).run(),
90 poVariablesMap);
91
92 if (poVariablesMap.count("help"))
93 {
94 std::cout << "Usage: compareFlatTrees [OPTION] tree-name reference-file test-file [branch-name]" << std::endl;
95 std::cout << poDescriptionAll << std::endl;
96 return 0;
97 }
98
99 po::notify(poVariablesMap);
100
101 // Base name of branches to read
102 std::string treeName = poVariablesMap["tree-name"].as<std::string>();
103 std::string treeNameOut = treeName;
104 std::replace( treeNameOut.begin(), treeNameOut.end(), '/', '_');
105 std::string referenceInput = poVariablesMap["reference-file"].as<std::string>();
106 std::string testInput = poVariablesMap["test-file"].as<std::string>();
107 std::string baseBranchName;
108 std::string outputPDF;
109 if (poVariablesMap.count("branch-name") > 0)
110 {
111 baseBranchName = poVariablesMap["branch-name"].as<std::string>();
112 outputPDF = "comparison_" + treeNameOut + "_" + baseBranchName + ".pdf";
113 }
114 else
115 {
116 outputPDF = "comparison_" + treeNameOut + ".pdf";
117 }
118
119 bool scale = poVariablesMap.count("scale") > 0;
120 bool rebin = poVariablesMap.count("rebin") > 0;
121 bool benchmark = poVariablesMap.count("benchmark") > 0;
122 bool verbose = poVariablesMap.count("verbose") > 0;
123
124 // Verbose logging
125 // auto verbosity = ROOT::Experimental::RLogScopedVerbosity(ROOT::Detail::RDF::RDFLogChannel(), ROOT::Experimental::ELogLevel::kInfo);
126
127 // Run in batch mode - output is pdf
128 gROOT->SetBatch(kTRUE);
129 // Suppress logging
130 if (!verbose)
131 {
132 gErrorIgnoreLevel = kWarning;
133 }
134 // No stats box
135 gStyle->SetOptStat(0);
136 // No error bar ticks
137 gStyle->SetEndErrorSize(0);
138 // Parallel processing where possible
139 ROOT::EnableImplicitMT(4);
140
141 // Create RDataFrame
142 ROOT::RDataFrame dataFrameRefr(treeName, referenceInput);
143 ROOT::RDataFrame dataFrameTest(treeName, testInput);
144
145 // Event count and ratio
146 auto eventsRefr{dataFrameRefr.Count()};
147 auto eventsTest{dataFrameTest.Count()};
148 double eventRatio{static_cast<float>(eventsRefr.GetValue()) / static_cast<float>(eventsTest.GetValue())};
149
150 // Get column names for each file and then the intersection
151 auto colNamesRefr = dataFrameRefr.GetColumnNames();
152 auto colNamesTest = dataFrameTest.GetColumnNames();
153 auto colNames = intersection(colNamesRefr, colNamesTest);
154
155 bool checkMissColumns = false;
156 std::vector<std::string> missColNamesRefr, missColNamesTest;
157 if (poVariablesMap.count("require-same-branches"))
158 {
159 checkMissColumns = true;
160 missColNamesRefr = remainder (colNamesRefr, colNames);
161 missColNamesTest = remainder (colNamesTest, colNames);
162 }
163
164 // Loop over column names and get a list of the required columns
165 std::vector<std::string> requiredColumns;
166 std::cout << "Will attempt to plot the following columns:" << std::endl;
167 for (auto &&colName : colNames)
168 {
169 if ((baseBranchName.empty() || colName.find(baseBranchName) != std::string::npos) && // include
170 (colName.find("Trig") == std::string::npos) && // exclude, not meaningful
171 (colName.find("Link") == std::string::npos) && // exclude, elementlinks
172 (colName.find("m_persIndex") == std::string::npos) && // exclude, elementlinks
173 (colName.find("m_persKey") == std::string::npos) && // exclude, elementlinks
174 (colName.find("Parent") == std::string::npos) && // exclude, elementlinks
175 (colName.find("original") == std::string::npos) && // exclude, elementlinks
176 (colName.find("EventInfoAuxDyn.detDescrTags") == std::string::npos) && // exclude, std::pair
177 (dataFrameRefr.GetColumnType(colName).find("xAOD") == std::string::npos) && // exclude, needs ATLAS s/w
178 (dataFrameRefr.GetColumnType(colName) != "ROOT::VecOps::RVec<string>") && // exclude, needs ATLAS s/w
179 (!dataFrameRefr.GetColumnType(colName).starts_with("ROOT::VecOps::RVec<pair")) && // exclude, std::pair
180 (dataFrameRefr.GetColumnType(colName).find("vector") == std::string::npos))
181 { // exclude, needs unwrapping
182 requiredColumns.push_back(colName);
183 std::cout << " " << colName << " " << dataFrameRefr.GetColumnType(colName) << std::endl;
184 }
185 }
186
187 // Set binning
188 int nBins{128};
189 size_t nBinsU = static_cast<size_t>(nBins);
190
191 // Loop over the required columns and plot them for each sample along with the ratio
192 // Write resulting plots to a pdf file
193 bool fileOpen{};
194 size_t counter{};
195 size_t failedCount{};
196 std::chrono::seconds totalDuration{};
197 std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMinValuesRefr;
198 std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMinValuesTest;
199 std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMaxValuesRefr;
200 std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMaxValuesTest;
201 std::unordered_map<std::string, ROOT::RDF::RResultPtr<TH1D>> mapHistRefr;
202 std::unordered_map<std::string, ROOT::RDF::RResultPtr<TH1D>> mapHistTest;
203
204 std::cout << "Preparing ranges..." << std::endl;
205 for (const std::string &colName : requiredColumns)
206 {
207 mapMinValuesRefr.emplace(colName, dataFrameRefr.Min(colName));
208 mapMinValuesTest.emplace(colName, dataFrameTest.Min(colName));
209 mapMaxValuesRefr.emplace(colName, dataFrameRefr.Max(colName));
210 mapMaxValuesTest.emplace(colName, dataFrameTest.Max(colName));
211 }
212
213 std::cout << "Preparing histograms..." << std::endl;
214 auto start = std::chrono::high_resolution_clock::now();
215 for (auto it = requiredColumns.begin(); it != requiredColumns.end();)
216 {
217 const std::string &colName = *it;
218 const char* colNameCh = colName.c_str();
219
220 // Initial histogram range
221 double min = std::min(mapMinValuesRefr[colName].GetValue(), mapMinValuesTest[colName].GetValue());
222 double max = std::max(mapMaxValuesRefr[colName].GetValue(), mapMaxValuesTest[colName].GetValue()) * 1.02;
223 if (std::isinf(min) || std::isinf(max))
224 {
225 std::cout << " skipping " << colName << " ..." << std::endl;
226 it = requiredColumns.erase(it);
227 continue;
228 } else {
229 ++it;
230 }
231
232 if (max > 250e3 && min > 0.0)
233 {
234 min = 0.0;
235 }
236
237 // Initial histograms
238 mapHistRefr.emplace(colName, dataFrameRefr.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh));
239 mapHistTest.emplace(colName, dataFrameTest.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh));
240 }
241 auto stop = std::chrono::high_resolution_clock::now();
242 auto duration = std::chrono::duration_cast<std::chrono::seconds>(stop - start);
243 totalDuration += duration;
244 if (benchmark)
245 {
246 std::cout << " Time for this step: " << duration.count() << " seconds " << std::endl;
247 std::cout << " Elapsed time: " << totalDuration.count() << " seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() << " minutes)" << std::endl;
248 }
249
250 if (rebin)
251 {
252 std::cout << "Rebinning histograms..." << std::endl;
253 auto start = std::chrono::high_resolution_clock::now();
254 for (const std::string &colName : requiredColumns)
255 {
256 const char* colNameCh = colName.c_str();
257 auto &histRefr = mapHistRefr[colName];
258 auto &histTest = mapHistTest[colName];
259
260 // Initial histogram range
261 double min = std::min(mapMinValuesRefr[colName].GetValue(), mapMinValuesTest[colName].GetValue());
262 double max = std::max(mapMaxValuesRefr[colName].GetValue(), mapMaxValuesTest[colName].GetValue()) * 1.02;
263 if (max > 250e3 && min > 0.0)
264 {
265 min = 0.0;
266 }
267
268 // Check range - make sure that bins other than the first contain at least one per mille events
269 // Avoids case where max is determined by a single outlier leading to most events being in the 1st bin
270 bool rangeSatisfactory{};
271 size_t rangeItrCntr{};
272 while (!rangeSatisfactory && rangeItrCntr < 10)
273 {
274 ++rangeItrCntr;
275 if (verbose)
276 {
277 std::cout << std::endl
278 << " Range tuning... iteration number " << rangeItrCntr << std::endl;
279 }
280 double entriesFirstBin = histRefr.GetPtr()->GetBinContent(1);
281 double entriesLastBin = histRefr.GetPtr()->GetBinContent(nBins);
282 double entriesOtherBins{};
283 for (size_t i{2}; i < nBinsU; ++i)
284 {
285 entriesOtherBins += histRefr.GetPtr()->GetBinContent(i);
286 }
287 bool firstBinOK{((entriesOtherBins + entriesLastBin) / entriesFirstBin > 0.001f)};
288 bool lastBinOK{((entriesOtherBins + entriesFirstBin) / entriesLastBin > 0.001f)};
289 rangeSatisfactory = ((firstBinOK && lastBinOK) || entriesOtherBins == 0.0f);
290 if (!rangeSatisfactory)
291 {
292 if (verbose)
293 {
294 std::cout << "Min " << min << std::endl;
295 std::cout << "Max " << max << std::endl;
296 std::cout << "1st " << entriesFirstBin << std::endl;
297 std::cout << "Mid " << entriesOtherBins << std::endl;
298 std::cout << "End " << entriesLastBin << std::endl;
299 std::cout << "R/F " << (entriesOtherBins + entriesLastBin) / entriesFirstBin << std::endl;
300 std::cout << "R/L " << (entriesOtherBins + entriesFirstBin) / entriesLastBin << std::endl;
301 }
302 if (!firstBinOK)
303 {
304 max = (max - min) / static_cast<double>(nBins);
305 histRefr = dataFrameRefr.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
306 histTest = dataFrameTest.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
307 }
308 if (!lastBinOK)
309 {
310 min = max * (1.0f - (1.0f / static_cast<double>(nBins)));
311 histRefr = dataFrameRefr.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
312 histTest = dataFrameTest.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh);
313 }
314 }
315 }
316 }
317
318 auto stop = std::chrono::high_resolution_clock::now();
319 auto duration = std::chrono::duration_cast<std::chrono::seconds>(stop - start);
320 totalDuration += duration;
321 if (benchmark)
322 {
323 std::cout << " Time for this step: " << duration.count() << " seconds " << std::endl;
324 std::cout << " Elapsed time: " << totalDuration.count() << " seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() << " minutes)" << std::endl;
325 }
326 }
327
328 std::cout << "Running comparisons..." << std::endl;
329 start = std::chrono::high_resolution_clock::now();
330
331 // Store only last canvas
332 std::unique_ptr<TCanvas> lastCanvas;
333
334 for (const std::string &colName : requiredColumns)
335 {
336 ++counter;
337
338 std::cout << "Processing column " << counter << " of " << requiredColumns.size() << " : " << colName << " ... ";
339
340 auto h1 = mapHistRefr[colName].GetPtr();
341 auto h2 = mapHistTest[colName].GetPtr();
342
343 if (scale)
344 {
345 h2->Scale(eventRatio);
346 }
347 h2->SetMarkerStyle(20);
348 h2->SetMarkerSize(0.8);
349
350 if (!verbose)
351 {
352 gErrorIgnoreLevel = kError; // this is spammy due to empty bins
353 }
354 auto c1 = std::make_unique<TCanvas>();
355 auto rp = std::make_unique<TRatioPlot>(h2, h1);
356 if (!verbose)
357 {
358 gErrorIgnoreLevel = kWarning;
359 }
360
361 rp->SetH1DrawOpt("PE");
362 rp->SetH2DrawOpt("hist");
363 rp->SetGraphDrawOpt("PE");
364 rp->Draw();
365 rp->GetUpperRefXaxis()->SetTitle(colName.c_str());
366 rp->GetUpperRefYaxis()->SetTitle("Count");
367 rp->GetLowerRefYaxis()->SetTitle("Test / Ref.");
368 rp->GetLowerRefGraph()->SetMarkerStyle(20);
369 rp->GetLowerRefGraph()->SetMarkerSize(0.8);
370
371 bool valid{true};
372 for (int i{}; i < rp->GetLowerRefGraph()->GetN(); i++)
373 {
374 if (rp->GetLowerRefGraph()->GetY()[i] != 1.0)
375 {
376 valid = false;
377 break;
378 }
379 }
380 if (valid)
381 {
382 std::cout << "PASS" << std::endl;
383 continue;
384 }
385 else
386 {
387 std::cout << "FAILED" << std::endl;
388 ++failedCount;
389 }
390
391 rp->GetLowerRefGraph()->SetMinimum(0.5);
392 rp->GetLowerRefGraph()->SetMaximum(1.5);
393 rp->GetLowYaxis()->SetNdivisions(505);
394
395 if (!fileOpen)
396 {
397 // Open file
398 c1->Print((outputPDF + "[").c_str());
399 fileOpen = true;
400 }
401 // Actual plot
402 c1->Print(outputPDF.c_str());
403 c1->Clear();
404 lastCanvas = std::move(c1);
405 }
406
407 if (fileOpen)
408 {
409 // Close file
410 lastCanvas->Print((outputPDF + "]").c_str());
411 lastCanvas.reset();
412 }
413
414 stop = std::chrono::high_resolution_clock::now();
415 duration = std::chrono::duration_cast<std::chrono::seconds>(stop - start);
416 totalDuration += duration;
417 if (benchmark)
418 {
419 std::cout << " Time for this step: " << duration.count() << " seconds " << std::endl;
420 std::cout << " Elapsed time: " << totalDuration.count() << " seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() << " minutes)" << std::endl;
421 }
422
423 std::cout << "========================" << std::endl;
424 std::cout << "Reference events: " << eventsRefr.GetValue() << std::endl;
425 std::cout << "Test events: " << eventsTest.GetValue() << std::endl;
426 std::cout << "Ratio: " << eventRatio << std::endl;
427 std::cout << "========================" << std::endl;
428 std::cout << "Tested columns: " << requiredColumns.size() << std::endl;
429 std::cout << "Passed: " << requiredColumns.size() - failedCount << std::endl;
430 std::cout << "Failed: " << failedCount << std::endl;
431 std::cout << "========================" << std::endl;
432 if (checkMissColumns)
433 {
434 std::cout << "Columns only in reference: " << missColNamesRefr.size();
435 for (const auto& column : missColNamesRefr)
436 std::cout << " " << column;
437 std::cout << std::endl;
438 std::cout << "Columns only in test: " << missColNamesTest.size();
439 for (const auto& column : missColNamesTest)
440 std::cout << " " << column;
441 std::cout << std::endl;
442 failedCount += missColNamesRefr.size() + missColNamesTest.size();
443 std::cout << "========================" << std::endl;
444 }
445
446 if (failedCount)
447 {
448 return 1;
449 }
450
451 return 0;
452}
int main(int, char **)
Main class for all the CppUnit test classes.
ReadCards * rp
static Double_t a
#define min(a, b)
Definition cfImp.cxx:40
#define max(a, b)
Definition cfImp.cxx:41
Define macros for attributes used to control the static checker.
#define ATLAS_NOT_THREAD_SAFE
getNoisyStrip() Find noisy strips from hitmaps and write out into xml/db formats
std::vector< std::string > intersection(std::vector< std::string > &v1, std::vector< std::string > &v2)
std::vector< std::string > remainder(const std::vector< std::string > &v1, const std::vector< std::string > &v2)
bool verbose
Definition hcg.cxx:73
Definition run.py:1
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.