ATLAS Offline Software
Loading...
Searching...
No Matches
egammaMVACalibTool.cxx
Go to the documentation of this file.
1 /*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6
7#include "xAODEgamma/Egamma.h"
9
12
13#include "TFile.h"
14#include "TMath.h"
15#include "TObjString.h"
16#include "TTree.h"
17#include "TClass.h"
18
19#include <cmath>
20#include <format>
21
22#ifndef XAOD_ANALYSIS
23#include "GaudiKernel/SystemOfUnits.h"
24using Gaudi::Units::GeV;
25#else
26#define GeV 1000
27#endif
28
30
31struct Funcs
32{
33 std::vector<std::vector<std::function<float(const xAOD::Egamma*, const xAOD::CaloCluster*)> > > funcs;
34};
35
36} // namespace egammaMVACalibTool_detail
37
39 asg::AsgTool(name),
40 m_funcs (std::make_unique<egammaMVACalibTool_detail::Funcs>())
41{
42}
43
44// Need to declare this out-of-line since the full type of m_funcs
45// isn't available in the header.
49
50
52{
54 ATH_MSG_FATAL("Particle type not set: you have to set property ParticleType to a valid value");
55 return StatusCode::FAILURE;
56 }
57 ATH_MSG_DEBUG("Initializing with particle " << m_particleType);
58
60 ATH_MSG_DEBUG("Using Mean10 shift");
61 } else if (m_shiftType == NOSHIFT) {
62 ATH_MSG_DEBUG("Not using a shift");
63 } else {
64 ATH_MSG_FATAL("Unsupported shift: " << m_shiftType);
65 return StatusCode::FAILURE;
66 }
67
68 // get the BDTs and initialize functions
69 ATH_MSG_DEBUG("get BDTs in folder: " << m_folder);
70 switch (m_particleType) {
72 {
73 std::unique_ptr<egammaMVAFunctions::funcMap_t> funcLibraryPtr =
75 ATH_CHECK(setupBDT(*funcLibraryPtr,
76 PathResolverFindCalibFile(m_folder + "/MVACalib_electron.weights.root")));
77 }
78 break;
80 {
81 std::unique_ptr<egammaMVAFunctions::funcMap_t> funcLibraryPtr =
83 ATH_CHECK(setupBDT(*funcLibraryPtr,
84 PathResolverFindCalibFile(m_folder + "/MVACalib_unconvertedPhoton.weights.root")));
85 }
86 break;
88 {
89 std::unique_ptr<egammaMVAFunctions::funcMap_t> funcLibraryPtr =
91 ATH_CHECK(setupBDT(*funcLibraryPtr,
92 PathResolverFindCalibFile(m_folder + "/MVACalib_convertedPhoton.weights.root")));
93 }
94 break;
96 {
97 std::unique_ptr<egammaMVAFunctions::funcMap_t> funcLibraryPtr =
99 ATH_CHECK(setupBDT(*funcLibraryPtr,
100 PathResolverFindCalibFile("egammaMVACalib/MVACalib_fwdelectron.weights.root")));
101 }
102 break;
103
104 default:
105 ATH_MSG_FATAL("Particle type not set properly: " << m_particleType);
106 return StatusCode::FAILURE;
107 }
108
109 // Load these dictionaries now, so we don't need to try to do so
110 // while multiple threads are running.
111 TClass::GetClass ("TH2Poly");
112 TClass::GetClass ("TMultiGraph");
113
114 return StatusCode::SUCCESS;
115}
116
118 const std::string& fileName)
119{
120 ATH_MSG_DEBUG("Trying to open " << fileName);
121
122 std::unique_ptr<TFile> f(TFile::Open(fileName.c_str()));
123 if (!f || f->IsZombie()) {
124 ATH_MSG_FATAL("Could not open file: " << fileName);
125 return StatusCode::FAILURE;
126 }
127
128 // Load hPoly
129 TH2Poly *hPoly = nullptr;
130 f->GetObject("hPoly", hPoly);
131 if (!hPoly) {
132 ATH_MSG_FATAL("Could not find hPoly");
133 return StatusCode::FAILURE;
134 }
135 //pass ownership to class variable
136 m_hPoly.reset(static_cast<TH2Poly*>(hPoly));
137 m_hPoly->SetDirectory(nullptr);
138
139 // Load variables
140 TObjArray *variablesTmp = nullptr;
141 f->GetObject("variables", variablesTmp);
142 if (!variablesTmp) {
143 ATH_MSG_FATAL("Could not find variables");
144 return StatusCode::FAILURE;
145 }
146 auto variables = std::unique_ptr<TObjArray>(variablesTmp);
147 variables->SetOwner(); // to delete the objects when d-tor is called
148
149 // Load shifts
150 TObjArray *shiftsTmp = nullptr;
151 f->GetObject("shifts", shiftsTmp);
152 if (!shiftsTmp) {
153 ATH_MSG_FATAL("Could not find shifts");
154 return StatusCode::FAILURE;
155 }
156 auto shifts = std::unique_ptr<TObjArray>(shiftsTmp);
157 shifts->SetOwner(); // to delete the objects when d-tor is called
158
159 // Load trees
160 TObjArray *treesTmp = nullptr;
161 std::unique_ptr<TObjArray> trees;
162 f->GetObject("trees", treesTmp);
163 if (treesTmp) {
164 trees = std::unique_ptr<TObjArray>(treesTmp);
165 trees->SetOwner(); // to delete the objects when d-tor is called
166 ATH_MSG_DEBUG("setupBDT " << "BDTs read from TObjArray");
167 } else {
168 ATH_MSG_DEBUG("setupBDT " << "Reading trees individually");
169 trees = std::make_unique<TObjArray>();
170 trees->SetOwner(); // to delete the objects when d-tor is called
171 for (int i = 0; i < variables->GetEntries(); ++i)
172 {
173 TTree *tree = nullptr;
174 f->GetObject(Form("BDT%d", i), tree);
175 if (tree) tree->SetCacheSize(0);
176 trees->AddAtAndExpand(tree, i);
177 }
178 }
179
180 // Ensure the objects have (the same number of) entries
181 if (!trees->GetEntries() || !(trees->GetEntries() == variables->GetEntries())) {
182 ATH_MSG_FATAL("Tree has size " << trees->GetEntries()
183 << " while variables has size " << variables->GetEntries());
184 return StatusCode::FAILURE;
185 }
186
187 // Loop simultaneously over trees, variables and shifts
188 // Define the BDTs, the list of variables and the shift for each BDT
189 TObjString *str2;
190 TObjString *shift;
191 TTree *tree;
192 TIter nextTree(trees.get());
193 TIter nextVariables(variables.get());
194 TIter nextShift(shifts.get());
195 for (int i=0; (tree = (TTree*) nextTree()) && ((TObjString*) nextVariables()); ++i)
196 {
197 m_BDTs.emplace_back(tree);
198
199 std::vector<std::function<float(const xAOD::Egamma*, const xAOD::CaloCluster*)> > funcs;
200 // Loop over variables, which are separated by comma
201 char separator_var = ';';
202 if (getString(variables->At(i)).Index(";") < 1) separator_var = ','; // old versions
203 std::unique_ptr<TObjArray> tokens(getString(variables->At(i)).Tokenize(separator_var));
204 TIter nextVar(tokens.get());
205 while ((str2 = (TObjString*) nextVar()))
206 {
207 const TString& varName = getString(str2);
208 if (varName.Contains("npv") || varName.Contains("actualIntPerXing"))
209 continue;
210 if (!varName.Length()) {
211 ATH_MSG_FATAL("There was an empty variable name!");
212 return StatusCode::FAILURE;
213 }
214 try {
215 funcs.push_back(funcLibrary.at(varName.Data()));
216 } catch(const std::out_of_range& e) {
217 ATH_MSG_FATAL("Could not find formula for variable " << varName << ", error: " << e.what());
218 return StatusCode::FAILURE;
219 }
220 }
221 m_funcs->funcs.push_back(std::move(funcs));
222
223 if (m_shiftType == MEAN10TOTRUE) {
224 shift = (TObjString*) nextShift();
225 const TString& shiftFormula = getString(shift);
226 m_shifts.emplace_back("", shiftFormula);
227 }
228 }
229 return StatusCode::SUCCESS;
230
231}
232
233const TString& egammaMVACalibTool::getString(TObject* obj)
234{
235 TObjString *objS = dynamic_cast<TObjString*>(obj);
236 if (!objS) {
237 throw std::runtime_error("egammaMVACalibTool::getString was passed something that was not a string object");
238 }
239 return objS->GetString();
240}
241
243 const xAOD::Egamma* eg,
244 const egammaMVACalib::GlobalEventInfo& gei) const
245{
246
247 ATH_MSG_DEBUG("calling getEnergy with cluster index " << clus.index());
248
249 // find the bin of BDT and the shift
250 const auto initEnergy =
252 float(clus.e()) : (m_useLayerCorrected ?
255
256 const auto etVarGeV = (initEnergy / std::cosh(clus.eta())) / GeV;
257 const auto etaVar = std::abs(clus.eta());
258
259 ATH_MSG_DEBUG("Looking at object with initEnergy = " << initEnergy
260 << ", etVarGeV = " << etVarGeV
261 << ", etaVar = " << etaVar
262 << ", clus->e() = " << clus.e());
263
264 // Normally, we'd just use FindFixBin here. But TH2Poly overrides FindBin
265 // to handle its special bin definitions, but it doesn't also override
266 // FindFixBin. But TH2Poly::FindBin (unlike TH1::FindBin) doesn't actually
267 // do anything non-const, so just suppress the warning here.
268 TH2Poly* hPoly ATLAS_THREAD_SAFE = m_hPoly.get();
269 const int bin = hPoly->FindBin(etaVar, etVarGeV) - 1; // poly bins are shifted by one
270
271 ATH_MSG_DEBUG("Using bin: " << bin);
272
273 if (bin < 0) {
274 ATH_MSG_DEBUG("The bin is under/overflow; just return the energy");
275 return clus.e();
276 }
277
278 if (bin >= static_cast<int>(m_BDTs.size())) {
279 ATH_MSG_WARNING("The bin is outside the range, so just return the energy");
280 return clus.e();
281 }
282
283
284 // select the bdt and functions. (shifts are done later if needed)
285 // if there is only one BDT just use that
286 const int bin_BDT = m_BDTs.size() != 1 ? bin : 0;
287 const auto& bdt = m_BDTs[bin_BDT];
288 const auto& funcs = m_funcs->funcs[bin_BDT];
289
290 const size_t sz = funcs.size();
291
292 // could consider adding std::array support to the BDTs
293 std::vector<float> vars(sz);
294
295 for (size_t i = 0; i < sz; ++i) {
296 vars[i] = funcs[i](eg, &clus);
297 ATH_MSG_DEBUG("Variable " << i << " = " << std::format("{:10.4f}", vars[i]));
298 }
299
300 // Retrieve nPV and mu in case of forward electron;
301 // they are the last two variables in the input vars
303 vars.insert(vars.end(),{float(gei.nPV), gei.acmu});
304 }
305
306 // evaluate the BDT response
307 const float mvaOutput = bdt.GetResponse(vars);
308 ATH_MSG_DEBUG("BDT response = " << std::format("{:10.6f}", mvaOutput));
309
310 // what to do if the MVA response is 0;
311 if (mvaOutput == 0.) {
312 if (m_clusterEif0) {
313 return clus.e();
314 }
315 return 0.;
316 }
317
318 // calculate the unshifted energy
319 const auto energy = (m_calibrationType == fullCalibration) ?
320 mvaOutput : (initEnergy * mvaOutput);
321
322 ATH_MSG_DEBUG("energy after MVA = " << std::format("{:.3f}", energy));
323
324 if (m_shiftType == NOSHIFT) {
325 // if no shift, just return the unshifted energy
326 return energy;
327 }
328
329 // have to do a shift if here. It's based on the corrected Et in GeV
330 const auto etGeV = (energy / std::cosh(clus.eta())) / GeV;
331
332 // evaluate the TFormula associated with the bin
333 const auto shift = m_shifts[bin].Eval(etGeV);
334 ATH_MSG_DEBUG("shift = " << shift);
335 if (shift > 0.5) {
336 ATH_MSG_DEBUG("energy after MVA and shift = " << std::format("{:.3f}", energy/shift));
337 return energy / shift;
338 }
339 ATH_MSG_WARNING("Shift value too small: " << shift << "; not applying shift");
340 return energy;
341
342}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_FATAL(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
static Double_t sz
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Define macros for attributes used to control the static checker.
#define ATLAS_THREAD_SAFE
size_t index() const
Return the index of this element within its container.
AsgTool(const std::string &name)
Constructor specifying the tool instance's name.
Definition AsgTool.cxx:58
Gaudi::Property< std::string > m_folder
string with folder for weight files
float getEnergy(const xAOD::CaloCluster &clus, const xAOD::Egamma *eg, const egammaMVACalib::GlobalEventInfo &gei=egammaMVACalib::GlobalEventInfo()) const override final
returns the calibrated energy
Gaudi::Property< bool > m_clusterEif0
virtual StatusCode initialize() override
Dummy implementation of the initialisation function.
std::vector< TFormula > m_shifts
shifts formulas
egammaMVACalibTool(const std::string &type)
Gaudi::Property< bool > m_useLayerCorrected
StatusCode setupBDT(const egammaMVAFunctions::funcMap_t &funcLibrary, const std::string &fileName)
a function called by initialize to setup the BDT, funcs, and shifts.
std::unique_ptr< egammaMVACalibTool_detail::Funcs > m_funcs
where the pointers to the funcs to calculate the vars per BDT
Gaudi::Property< int > m_calibrationType
Gaudi::Property< int > m_shiftType
virtual ~egammaMVACalibTool() override
std::unique_ptr< TH2Poly > m_hPoly
A TH2Poly used to extract bin numbers. Note there is an offset of 1.
Gaudi::Property< int > m_particleType
std::vector< MVAUtils::BDT > m_BDTs
Where the BDTs are stored.
static const TString & getString(TObject *obj)
a utility to get a TString out of an TObjString pointer
virtual double eta() const
The pseudorapidity ( ) of the particle.
virtual double e() const
The total energy of the particle.
std::unique_ptr< funcMap_t > initializeUnconvertedPhotonFuncs(bool useLayerCorrected)
A function to build the map for uncoverted photons.
float compute_correctedcl_Eacc(const xAOD::CaloCluster &cl)
std::unique_ptr< funcMap_t > initializeElectronFuncs(bool useLayerCorrected)
A function to build the map for electrons.
float compute_rawcl_Eacc(const xAOD::CaloCluster &cl)
std::unique_ptr< funcMap_t > initializeForwardElectronFuncs(bool useLayerCorrected)
NEW: A function to build the map for forward electrons.
std::unique_ptr< funcMap_t > initializeConvertedPhotonFuncs(bool useLayerCorrected)
A function to build the map for converted photons.
std::unordered_map< std::string, std::function< float(const xAOD::Egamma *, const xAOD::CaloCluster *)> > funcMap_t
Define the map type since it's long.
STL namespace.
CaloCluster_v1 CaloCluster
Define the latest version of the calorimeter cluster class.
Egamma_v1 Egamma
Definition of the current "egamma version".
Definition Egamma.h:17
std::vector< std::vector< std::function< float(const xAOD::Egamma *, const xAOD::CaloCluster *)> > > funcs
A structure holding some global event information.
TChain * tree