60 namespace po = boost::program_options;
61 po::options_description poDescription(
"Common options");
62 poDescription.add_options()
63 (
"help",
"produce help message")
64 (
"require-same-branches",
"require both trees to have the same branches")
65 (
"tree-name", po::value<std::string>()->required(),
"tree name")
66 (
"reference-file", po::value<std::string>()->required(),
"reference file(s), wildcards supported")
67 (
"test-file", po::value<std::string>()->required(),
"test file(s), wildcards supported")
68 (
"branch-name", po::value<std::string>(),
"base branch name (optional)");
70 po::options_description poDescriptionAdvanced(
"Advanced options");
71 poDescriptionAdvanced.add_options()
72 (
"scale",
"scale histograms that both have the same event count")
73 (
"rebin",
"do smart rebinning")
74 (
"benchmark",
"benchmark the code")
75 (
"verbose",
"verbose logging");
77 po::options_description poDescriptionAll;
78 poDescriptionAll.add(poDescription).add(poDescriptionAdvanced);
80 po::positional_options_description poPositionalOptions;
81 poPositionalOptions.add(
"tree-name", 1);
82 poPositionalOptions.add(
"reference-file", 1);
83 poPositionalOptions.add(
"test-file", 1);
84 poPositionalOptions.add(
"branch-name", 1);
86 po::variables_map poVariablesMap;
89 .positional(poPositionalOptions).
run(),
92 if (poVariablesMap.count(
"help"))
94 std::cout <<
"Usage: compareFlatTrees [OPTION] tree-name reference-file test-file [branch-name]" << std::endl;
95 std::cout << poDescriptionAll << std::endl;
99 po::notify(poVariablesMap);
102 std::string
treeName = poVariablesMap[
"tree-name"].as<std::string>();
104 std::replace( treeNameOut.begin(), treeNameOut.end(),
'/',
'_');
105 std::string referenceInput = poVariablesMap[
"reference-file"].as<std::string>();
106 std::string testInput = poVariablesMap[
"test-file"].as<std::string>();
107 std::string baseBranchName;
108 std::string outputPDF;
109 if (poVariablesMap.count(
"branch-name") > 0)
111 baseBranchName = poVariablesMap[
"branch-name"].as<std::string>();
112 outputPDF =
"comparison_" + treeNameOut +
"_" + baseBranchName +
".pdf";
116 outputPDF =
"comparison_" + treeNameOut +
".pdf";
119 bool scale = poVariablesMap.count(
"scale") > 0;
120 bool rebin = poVariablesMap.count(
"rebin") > 0;
121 bool benchmark = poVariablesMap.count(
"benchmark") > 0;
122 bool verbose = poVariablesMap.count(
"verbose") > 0;
128 gROOT->SetBatch(kTRUE);
135 gStyle->SetOptStat(0);
137 gStyle->SetEndErrorSize(0);
139 ROOT::EnableImplicitMT(4);
142 ROOT::RDataFrame dataFrameRefr(
treeName, referenceInput);
143 ROOT::RDataFrame dataFrameTest(
treeName, testInput);
146 auto eventsRefr{dataFrameRefr.Count()};
147 auto eventsTest{dataFrameTest.Count()};
148 double eventRatio{
static_cast<float>(eventsRefr.GetValue()) /
static_cast<float>(eventsTest.GetValue())};
151 auto colNamesRefr = dataFrameRefr.GetColumnNames();
152 auto colNamesTest = dataFrameTest.GetColumnNames();
153 auto colNames =
intersection(colNamesRefr, colNamesTest);
155 bool checkMissColumns =
false;
156 std::vector<std::string> missColNamesRefr, missColNamesTest;
157 if (poVariablesMap.count(
"require-same-branches"))
159 checkMissColumns =
true;
160 missColNamesRefr =
remainder (colNamesRefr, colNames);
161 missColNamesTest =
remainder (colNamesTest, colNames);
165 std::vector<std::string> requiredColumns;
166 std::cout <<
"Will attempt to plot the following columns:" << std::endl;
167 for (
auto &&colName : colNames)
169 if ((baseBranchName.empty() || colName.find(baseBranchName) != std::string::npos) &&
170 (colName.find(
"Trig") == std::string::npos) &&
171 (colName.find(
"Link") == std::string::npos) &&
172 (colName.find(
"m_persIndex") == std::string::npos) &&
173 (colName.find(
"m_persKey") == std::string::npos) &&
174 (colName.find(
"Parent") == std::string::npos) &&
175 (colName.find(
"original") == std::string::npos) &&
176 (colName.find(
"EventInfoAuxDyn.detDescrTags") == std::string::npos) &&
177 (dataFrameRefr.GetColumnType(colName).find(
"xAOD") == std::string::npos) &&
178 (dataFrameRefr.GetColumnType(colName) !=
"ROOT::VecOps::RVec<string>") &&
179 (!dataFrameRefr.GetColumnType(colName).starts_with(
"ROOT::VecOps::RVec<pair")) &&
180 (dataFrameRefr.GetColumnType(colName).find(
"vector") == std::string::npos))
182 requiredColumns.push_back(colName);
183 std::cout <<
" " << colName <<
" " << dataFrameRefr.GetColumnType(colName) << std::endl;
188 auto c1 =
new TCanvas(
"c1",
"Tree comparison");
192 size_t nBinsU =
static_cast<size_t>(
nBins);
198 size_t failedCount{};
200 std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMinValuesRefr;
201 std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMinValuesTest;
202 std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMaxValuesRefr;
203 std::unordered_map<std::string, ROOT::RDF::RResultPtr<double>> mapMaxValuesTest;
204 std::unordered_map<std::string, ROOT::RDF::RResultPtr<TH1D>> mapHistRefr;
205 std::unordered_map<std::string, ROOT::RDF::RResultPtr<TH1D>> mapHistTest;
207 std::cout <<
"Preparing ranges..." << std::endl;
208 for (
const std::string &colName : requiredColumns)
210 mapMinValuesRefr.emplace(colName, dataFrameRefr.Min(colName));
211 mapMinValuesTest.emplace(colName, dataFrameTest.Min(colName));
212 mapMaxValuesRefr.emplace(colName, dataFrameRefr.Max(colName));
213 mapMaxValuesTest.emplace(colName, dataFrameTest.Max(colName));
216 std::cout <<
"Preparing histograms..." << std::endl;
218 for (
auto it = requiredColumns.begin();
it != requiredColumns.end();)
220 const std::string &colName = *
it;
221 const char* colNameCh = colName.c_str();
224 double min =
std::min(mapMinValuesRefr[colName].GetValue(), mapMinValuesTest[colName].GetValue());
225 double max =
std::max(mapMaxValuesRefr[colName].GetValue(), mapMaxValuesTest[colName].GetValue()) * 1.02;
226 if (std::isinf(
min) || std::isinf(
max))
228 std::cout <<
" skipping " << colName <<
" ..." << std::endl;
229 it = requiredColumns.erase(
it);
241 mapHistRefr.emplace(colName, dataFrameRefr.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh));
242 mapHistTest.emplace(colName, dataFrameTest.Histo1D({colNameCh, colNameCh, nBins, min, max}, colNameCh));
249 std::cout <<
" Time for this step: " <<
duration.count() <<
" seconds " << std::endl;
250 std::cout <<
" Elapsed time: " << totalDuration.count() <<
" seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() <<
" minutes)" << std::endl;
255 std::cout <<
"Rebinning histograms..." << std::endl;
257 for (
const std::string &colName : requiredColumns)
259 const char* colNameCh = colName.c_str();
260 auto &histRefr = mapHistRefr[colName];
261 auto &histTest = mapHistTest[colName];
264 double min =
std::min(mapMinValuesRefr[colName].GetValue(), mapMinValuesTest[colName].GetValue());
265 double max =
std::max(mapMaxValuesRefr[colName].GetValue(), mapMaxValuesTest[colName].GetValue()) * 1.02;
273 bool rangeSatisfactory{};
274 size_t rangeItrCntr{};
275 while (!rangeSatisfactory && rangeItrCntr < 10)
280 std::cout << std::endl
281 <<
" Range tuning... iteration number " << rangeItrCntr << std::endl;
283 double entriesFirstBin = histRefr.GetPtr()->GetBinContent(1);
284 double entriesLastBin = histRefr.GetPtr()->GetBinContent(
nBins);
285 double entriesOtherBins{};
286 for (
size_t i{2};
i < nBinsU; ++
i)
288 entriesOtherBins += histRefr.GetPtr()->GetBinContent(
i);
290 bool firstBinOK{((entriesOtherBins + entriesLastBin) / entriesFirstBin > 0.001
f)};
291 bool lastBinOK{((entriesOtherBins + entriesFirstBin) / entriesLastBin > 0.001
f)};
292 rangeSatisfactory = ((firstBinOK && lastBinOK) || entriesOtherBins == 0.0
f);
293 if (!rangeSatisfactory)
297 std::cout <<
"Min " <<
min << std::endl;
298 std::cout <<
"Max " <<
max << std::endl;
299 std::cout <<
"1st " << entriesFirstBin << std::endl;
300 std::cout <<
"Mid " << entriesOtherBins << std::endl;
301 std::cout <<
"End " << entriesLastBin << std::endl;
302 std::cout <<
"R/F " << (entriesOtherBins + entriesLastBin) / entriesFirstBin << std::endl;
303 std::cout <<
"R/L " << (entriesOtherBins + entriesFirstBin) / entriesLastBin << std::endl;
308 histRefr = dataFrameRefr.Histo1D({colNameCh, colNameCh,
nBins,
min,
max}, colNameCh);
309 histTest = dataFrameTest.Histo1D({colNameCh, colNameCh,
nBins,
min,
max}, colNameCh);
313 min =
max * (1.0f - (1.0f /
static_cast<double>(
nBins)));
314 histRefr = dataFrameRefr.Histo1D({colNameCh, colNameCh,
nBins,
min,
max}, colNameCh);
315 histTest = dataFrameTest.Histo1D({colNameCh, colNameCh,
nBins,
min,
max}, colNameCh);
326 std::cout <<
" Time for this step: " <<
duration.count() <<
" seconds " << std::endl;
327 std::cout <<
" Elapsed time: " << totalDuration.count() <<
" seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() <<
" minutes)" << std::endl;
331 std::cout <<
"Running comparisons..." << std::endl;
333 for (
const std::string &colName : requiredColumns)
337 std::cout <<
"Processing column " <<
counter <<
" of " << requiredColumns.size() <<
" : " << colName <<
" ... ";
339 auto h1 = mapHistRefr[colName].GetPtr();
340 auto h2 = mapHistTest[colName].GetPtr();
344 h2->Scale(eventRatio);
346 h2->SetMarkerStyle(20);
347 h2->SetMarkerSize(0.8);
353 auto rp = std::unique_ptr<TRatioPlot>(
new TRatioPlot(h2,
h1));
359 rp->SetH1DrawOpt(
"PE");
360 rp->SetH2DrawOpt(
"hist");
361 rp->SetGraphDrawOpt(
"PE");
363 rp->GetUpperRefXaxis()->SetTitle(colName.c_str());
364 rp->GetUpperRefYaxis()->SetTitle(
"Count");
365 rp->GetLowerRefYaxis()->SetTitle(
"Test / Ref.");
366 rp->GetLowerRefGraph()->SetMarkerStyle(20);
367 rp->GetLowerRefGraph()->SetMarkerSize(0.8);
370 for (
int i{};
i <
rp->GetLowerRefGraph()->GetN();
i++)
372 if (
rp->GetLowerRefGraph()->GetY()[
i] != 1.0)
380 std::cout <<
"PASS" << std::endl;
385 std::cout <<
"FAILED" << std::endl;
389 rp->GetLowerRefGraph()->SetMinimum(0.5);
390 rp->GetLowerRefGraph()->SetMaximum(1.5);
391 rp->GetLowYaxis()->SetNdivisions(505);
401 c1->Print((outputPDF +
"[").c_str());
405 c1->Print(outputPDF.c_str());
412 c1->Print((outputPDF +
"]").c_str());
420 std::cout <<
" Time for this step: " <<
duration.count() <<
" seconds " << std::endl;
421 std::cout <<
" Elapsed time: " << totalDuration.count() <<
" seconds (" << std::chrono::duration_cast<std::chrono::minutes>(totalDuration).count() <<
" minutes)" << std::endl;
424 std::cout <<
"========================" << std::endl;
425 std::cout <<
"Reference events: " << eventsRefr.GetValue() << std::endl;
426 std::cout <<
"Test events: " << eventsTest.GetValue() << std::endl;
427 std::cout <<
"Ratio: " << eventRatio << std::endl;
428 std::cout <<
"========================" << std::endl;
429 std::cout <<
"Tested columns: " << requiredColumns.size() << std::endl;
430 std::cout <<
"Passed: " << requiredColumns.size() - failedCount << std::endl;
431 std::cout <<
"Failed: " << failedCount << std::endl;
432 std::cout <<
"========================" << std::endl;
433 if (checkMissColumns)
435 std::cout <<
"Columns only in reference: " << missColNamesRefr.size();
436 for (
const auto&
column : missColNamesRefr)
437 std::cout <<
" " <<
column;
438 std::cout << std::endl;
439 std::cout <<
"Columns only in test: " << missColNamesTest.size();
440 for (
const auto&
column : missColNamesTest)
441 std::cout <<
" " <<
column;
442 std::cout << std::endl;
443 failedCount += missColNamesRefr.size() + missColNamesTest.size();
444 std::cout <<
"========================" << std::endl;