ATLAS Offline Software
Loading...
Searching...
No Matches
ElectronPhotonVariableNFCorrectionTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
6
9
11
12#include "TEnv.h"
13#include "TString.h"
14
15#include <cmath>
16#include <algorithm>
17
18
19
20// Ordered list of shower shapes used by the tool
21// The order must match the ONNX model inputs and outputs
22const std::vector<std::string> ElectronPhotonVariableNFCorrectionTool::s_ssVarNames = {
23 "weta2", "weta1", "Rphi", "Reta", "wtots1", "Rhad", "Rhad1", "f1", "fracs1", "DeltaE", "Eratio"
24};
25
26// Mapping of shower shapes to xAOD enums (same order as s_ssVarNames)
27const std::vector<xAOD::EgammaParameters::ShowerShapeType> ElectronPhotonVariableNFCorrectionTool::s_ssEnums = {
39};
40
41
42// Constructor, declares properties
46
47// Select fold index based on event number (and optionally pT)
48int ElectronPhotonVariableNFCorrectionTool::selectFold(unsigned long long eventNumber, float phi) const
49{
50 if (m_nFolds <= 1) return 0;
51
52 unsigned long long key = eventNumber;
53
55 const long long phiBin = static_cast<long long>(std::floor((phi + static_cast<float>(M_PI)) * 100.0f));
56 key = eventNumber + static_cast<unsigned long long>(phiBin);
57 }
58
59 return static_cast<int>(key % m_nFolds);
60}
61
62// Convert string from config to fold strategy
65{
66 if (s == "eventNumber") return FoldStrategy::EventNumber;
67 if (s == "eventNumber_phi") return FoldStrategy::EventNumberPhi;
68 ATH_MSG_WARNING("Unknown FoldStrategy '" << s << "'");
70}
71
73 const xAOD::Photon& photon,
74 const std::vector<float>& ss) const
75{
76 // pT cut
77 if (photon.pt() < m_pTcutMeV) return false;
78
79 // TruthType cut
81 static const SG::AuxElement::Accessor<int> acc_truthType("truthType");
82 if (!acc_truthType.isAvailable(photon)) {
83 ATH_MSG_WARNING("ApplyTo = TruthPhotons but truthType not available — skipping photon");
84 return false;
85 }
86 int truthType = acc_truthType(photon);
87 if (truthType < 13 || truthType > 15) return false;
88 }
89
90 // Shower shape cuts
92 // weta2
93 if (ss[0] <= -10.f || ss[0] >= 10.f) return false;
94 // weta1
95 if (ss[1] <= -10.f || ss[1] >= 10.f) return false;
96 // Rphi
97 if (ss[2] <= -10.f || ss[2] >= 10.f) return false;
98 // Reta
99 if (ss[3] <= -10.f || ss[3] >= 10.f) return false;
100 // wtots1
101 if (ss[4] < -2.f || ss[4] >= 10.f) return false;
102 // Rhad
103 if (ss[5] < -2.f || ss[5] > 2.f) return false;
104 // Rhad1
105 if (ss[6] < -2.f || ss[6] > 2.f) return false;
106 // f1
107 if (ss[7] <= -2.f || ss[7] >= 2.f) return false;
108 // fracs1
109 if (ss[8] <= -2.f || ss[8] >= 5.f) return false;
110 // DeltaE
111 if (ss[9] < 0.f || ss[9] >= 5000.f) return false;
112 // Eratio
113 if (ss[10] < 0.f || ss[10] > 1.f) return false;
114 }
115
116 return true;
117}
118
119
120
121// Initialize tool: read config, setup ONNX tools and accessors
123{
124 if (m_configFile.empty()) {
125 ATH_MSG_ERROR("ConfigFile property is empty. Please provide a config file to the tool.");
126 return StatusCode::FAILURE;
127 }
128
129 std::string resolvedConfig = PathResolverFindCalibFile(m_configFile);
130 if (resolvedConfig.empty()) {
131 ATH_MSG_ERROR("Failed to resolve config file \"" << m_configFile << "\"");
132 return StatusCode::FAILURE;
133 }
134 ATH_MSG_DEBUG("Use configuration file " << m_configFile);
135
136 TEnv env;
137 env.ReadFile(resolvedConfig.c_str(), kEnvLocal);
138 env.IgnoreDuplicates(false);
139
140 m_nFolds = env.GetValue("NFolds", 0);
141 if (m_nFolds <= 0) {
142 ATH_MSG_ERROR("NFolds not set or invalid in config: " << resolvedConfig);
143 return StatusCode::FAILURE;
144 }
145
146 TString pattern = env.GetValue("ONNXnamePattern", "");
147 if (pattern.IsNull()) {
148 ATH_MSG_ERROR("ONNXnamePattern not set in config: " << resolvedConfig);
149 return StatusCode::FAILURE;
150 }
151 m_onnxPattern = pattern.Data();
152
153
154 TString fs = env.GetValue("FoldStrategy", "eventNumber");
155 std::string fsStr = fs.Data();
156
158
160 ATH_MSG_ERROR("FoldStrategy must be 'eventNumber' or 'eventNumber_phi', but got '" << fsStr << "' in config: " << resolvedConfig);
161 return StatusCode::FAILURE;
162 }
163
164
165 ATH_MSG_VERBOSE("NFolds = " << m_nFolds << ", pattern = " << m_onnxPattern << ", FoldStrategy = " << fsStr);
166
167 if (static_cast<int>(m_onnxToolsForward.size()) != m_nFolds ||
168 static_cast<int>(m_onnxToolsBackward.size()) != m_nFolds) {
169 ATH_MSG_ERROR("Expected "<<m_nFolds<<" forward/backward tools, "<< "but got "<<m_onnxToolsForward.size()<<" / "<< m_onnxToolsBackward.size());
170 return StatusCode::FAILURE;
171 }
172
173
175 else if (m_applyToStr == "All") m_applyToMode = ApplyToMode::All;
176 else {
177 ATH_MSG_ERROR("ApplyTo must be TruthPhotons or All, but got '" << m_applyToStr << "'");
178 return StatusCode::FAILURE;
179 }
180
181 // Cuts on SS vars to remove default values
182 m_applyShowerShapeCuts = (env.GetValue("ApplyShowerShapeCuts", 1) == 1);
183
184 ATH_MSG_INFO("ApplyTo = " << m_applyToStr << ", pTcut=" << m_pTcutMeV << " MeV, ApplyShowerShapeCuts=" << m_applyShowerShapeCuts);
185
186
187 ATH_CHECK(m_onnxToolsForward.retrieve());
188 ATH_CHECK(m_onnxToolsBackward.retrieve());
189
190 if (msgLvl(MSG::DEBUG)) {
191 for (int i = 0; i < m_nFolds; ++i) {
192 ATH_MSG_VERBOSE("Fold " << i << " forward model info:");
193 m_onnxToolsForward[i]->printModelInfo();
194 ATH_MSG_VERBOSE("Fold " << i << " backward model info:");
195 m_onnxToolsBackward[i]->printModelInfo();
196 }
197 }
198
199 // Prepare decorations for each shower shape
200 m_accessors.resize(s_ssVarNames.size());
201 for (size_t i = 0; i < s_ssVarNames.size(); ++i) {
202 const std::string& var = s_ssVarNames[i];
203 m_accessors[i].original = std::make_unique<SG::AuxElement::Accessor<float>>(var + "_original");
204 }
205
206 ATH_CHECK(m_eventInfoKey.initialize());
207
208 ATH_MSG_INFO("NF correction tool initialized with " << m_nFolds << " folds. ");
209
210 return StatusCode::SUCCESS;
211}
212
213
214// Apply NF correction to photon shower shapes.
216{
217
218 const size_t nSS = s_ssEnums.size();
219 std::vector<float> ss(nSS);
220
221 // Read shower shapes, then store original values
222 for (size_t i = 0; i < nSS; ++i) {
223 ss[i] = photon.showerShapeValue(s_ssEnums[i]);
224 (*m_accessors[i].original)(photon) = ss[i];
225 }
226
227
228 static const SG::AuxElement::Decorator<char> dec_pass("NFCorrectedShowerShapes");
229
230 // Photon selection
231 bool pass = passSelectionCuts(photon, ss);
232
233 dec_pass(photon) = pass ? 1 : 0;
234
235 if (!pass) {
236 // If selection is not passed, then SS value will be same to original
238 }
239
240
241 // Get event info and select fold
243 if (!h.isValid()) {
244 ATH_MSG_ERROR("Failed to read EventInfo via key " << m_eventInfoKey.key());
246 }
247
248 const unsigned long long eventNumber = h->eventNumber();
249 float ptGeV = photon.pt() / 1000.0f;
250 const float phi = static_cast<float>(photon.phi());
251 const int fold = selectFold(eventNumber, phi);
252
253
254 if (fold < 0 || fold >= m_nFolds) {
255 ATH_MSG_ERROR("Selected fold " << fold << " out of range [0," << (m_nFolds-1) << "]");
257 }
258
259
260 // Kinematic inputs
261 const bool isConv = photon.conversionType() != xAOD::EgammaParameters::unconverted;
262 std::vector<float> kinematic = {
263 ptGeV,
264 static_cast<float>(photon.eta()),
265 static_cast<float>(photon.phi()),
266 static_cast<float>(isConv)
267 };
268
269 // Forward inference
270 std::vector<Ort::Value> inputTensors;
271
272 const auto& onnxToolForward = m_onnxToolsForward[fold];
273
274 // index 0 is for kinematics
275 int64_t batchSizeKin = onnxToolForward->getBatchSize(
276 static_cast<int64_t>(kinematic.size()), 0);
277 if (onnxToolForward->addInput(inputTensors, kinematic, 0, batchSizeKin).isFailure()) {
278 ATH_MSG_ERROR("Fold " << fold << ": failed to add kinematic input tensor");
280 }
281
282 // index 1 is for shower shape varibales
283 int64_t batchSizeSS = onnxToolForward->getBatchSize(
284 static_cast<int64_t>(ss.size()), 1);
285 if (onnxToolForward->addInput(inputTensors, ss, 1, batchSizeSS).isFailure()) {
286 ATH_MSG_ERROR("Fold " << fold << ": failed to add shower shape input tensor");
288 }
289
290 std::vector<Ort::Value> outputTensors;
291 std::vector<float> outputData;
292 if (onnxToolForward->addOutput(outputTensors, outputData, 0, batchSizeKin).isFailure()) {
293 ATH_MSG_ERROR("Fold " << fold << ": failed to add forward output tensor");
295 }
296
297 if (onnxToolForward->inference(inputTensors, outputTensors).isFailure()) {
298 ATH_MSG_ERROR("Fold " << fold << ": forward inference failed");
300 }
301
302 float* zPtr = outputTensors[0].GetTensorMutableData<float>();
303 std::vector<float> zVec(zPtr, zPtr + nSS);
304
305
306 // Backward inference
307 std::vector<Ort::Value> inputTensorsBack;
308 std::vector<Ort::Value> outputTensorsBack;
309 std::vector<float> outputDataBack;
310
311 const auto& onnxToolBackward = m_onnxToolsBackward[fold];
312
313 // index 0 is for kinematics
314 int64_t batchSizeKinBack = onnxToolBackward->getBatchSize(
315 static_cast<int64_t>(kinematic.size()), 0);
316 if (onnxToolBackward->addInput(inputTensorsBack, kinematic, 0, batchSizeKinBack).isFailure()) {
317 ATH_MSG_ERROR("Fold " << fold << ": failed to add kinematic input tensor for backward model");
319 }
320
321 // index 1 is for shower shapes in latent space
322 int64_t batchSizeZBack = onnxToolBackward->getBatchSize(static_cast<int64_t>(zVec.size()), 1);
323
324 if (onnxToolBackward->addInput(inputTensorsBack, zVec, 1, batchSizeZBack).isFailure()) {
325 ATH_MSG_ERROR("Fold " << fold << ": failed to add z input tensor for backward model");
327 }
328
329 // index 2 is for original shower shapes (models use them to cut on std values [-5, 5])
330 if (onnxToolBackward->addInput(inputTensorsBack, ss, 2, batchSizeKinBack).isFailure()) {
331 ATH_MSG_ERROR("Fold " << fold << ": failed to add original SS input tensor for backward model");
333 }
334
335 if (onnxToolBackward->addOutput(outputTensorsBack, outputDataBack, 0, batchSizeZBack).isFailure()) {
336 ATH_MSG_ERROR("Fold " << fold << ": failed to add backward output tensor");
338 }
339
340 if (onnxToolBackward->inference(inputTensorsBack, outputTensorsBack).isFailure()) {
341 ATH_MSG_ERROR("Fold " << fold << ": backward inference failed");
343 }
344
345 const auto infoB = outputTensorsBack[0].GetTensorTypeAndShapeInfo();
346 const auto nElB = infoB.GetElementCount();
347 if (nElB != nSS) {
348 ATH_MSG_ERROR("Fold "<<fold <<": backward output has " <<nElB<< " elements, expected "<<nSS);
350 }
351
352 // Write corrected shower shapes
353 float* corrPtr = outputTensorsBack[0].GetTensorMutableData<float>();
354 for (size_t i = 0; i < nSS; ++i) {
355 photon.setShowerShapeValue(corrPtr[i], s_ssEnums[i]);
356 }
357
358 ATH_MSG_DEBUG("NF correction applied successfully");
359
360
362}
363
364// Electrons are not supported.
366{
367 ATH_MSG_ERROR("ElectronPhotonVariableNFCorrectionTool does not support electrons.");
369}
370
371// Create corrected copy of photon.
373 xAOD::Photon*& out_photon) const
374{
375
376 out_photon = new xAOD::Photon(in_photon);
377 return applyCorrection(*out_photon);
378}
379
380// Create copy of electron (no correction).
382 xAOD::Electron*& out_electron) const
383{
384 ATH_MSG_ERROR("ElectronPhotonVariableNFCorrectionTool cannot correct electrons.");
385 out_electron = new xAOD::Electron(in_electron);
387}
#define M_PI
Scalar phi() const
phi method
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
static Double_t fs
static Double_t ss
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
bool msgLvl(const MSG::Level lvl) const
Header file for AthHistogramAlgorithm.
Return value from object correction CP tools.
@ Error
Some error happened during the object correction.
@ Ok
The correction was done successfully.
std::vector< SSAccessors > m_accessors
Per-variable accessors aligned with s_ssVarNames.
ToolHandleArray< AthOnnx::IOnnxRuntimeInferenceTool > m_onnxToolsForward
ToolHandleArray for forward ONNX models (one tool per fold)
static const std::vector< std::string > s_ssVarNames
List of shower shape variable names (order must match model I/O)
FoldStrategy parseFoldStrategy(const std::string &s) const
Parse fold strategy string from config.
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfoKey
ReadHandleKey for EventInfo used for fold selection.
int selectFold(unsigned long long eventNumber, float phi) const
Select fold index for the current event/photon.
virtual const CP::CorrectionCode correctedCopy(const xAOD::Photon &in_photon, xAOD::Photon *&out_photon) const override
Make a corrected copy of the passed photon.
std::string m_onnxPattern
Models path pattern string from config.
ToolHandleArray< AthOnnx::IOnnxRuntimeInferenceTool > m_onnxToolsBackward
ToolHandleArray for backward ONNX models (one tool per fold)
bool passSelectionCuts(const xAOD::Photon &photon, const std::vector< float > &ss) const
Returns true if NF correction should be applied to this photon.
bool m_applyShowerShapeCuts
Cuts applied to remove default values of shower shapes.
ElectronPhotonVariableNFCorrectionTool(const std::string &name)
Standard constructor.
static const std::vector< xAOD::EgammaParameters::ShowerShapeType > s_ssEnums
Egamma shower shape enum mapping for reading/writing values (order matches s_ssVarNames)
int m_nFolds
Number of model folds configured (must match tool handle array sizes)
virtual const CP::CorrectionCode applyCorrection(xAOD::Photon &photon) const override
Apply the Normalizing Flow correction to the passed photon.
virtual StatusCode initialize() override
Initialize the class instance.
Gaudi::Property< std::string > m_configFile
The configuration file for the tool, application mode and minimum photon pT cut in MeV.
FoldStrategy m_foldStrategy
Selected fold strategy (configured via FoldStrategy in the config)
SG::Decorator< T, ALLOC > Decorator
Definition AuxElement.h:576
SG::Accessor< T, ALLOC > Accessor
Definition AuxElement.h:573
bool isAvailable(const ELT &e) const
Test to see if this variable exists in the store.
AsgTool(const std::string &name)
Constructor specifying the tool instance's name.
Definition AsgTool.cxx:58
@ unconverted
unconverted photon
@ wtots1
shower width is determined in a window detaxdphi = 0,0625 ×~0,2, corresponding typically to 20 strips...
@ f1
E1/E = fraction of energy reconstructed in the first sampling, where E1 is energy in all strips belon...
Definition EgammaEnums.h:53
@ Eratio
(emaxs1-e2tsts1)/(emaxs1+e2tsts1)
@ DeltaE
e2tsts1-emins1
@ fracs1
shower shape in the shower core : [E(+/-3)-E(+/-1)]/E(+/-1), where E(+/-n) is the energy in ± n strip...
@ weta2
the lateral width is calculated with a window of 3x5 cells using the energy weighted sum over all cel...
@ weta1
shower width using +/-3 strips around the one with the maximal energy deposit: w3 strips = sqrt{sum(E...
Definition EgammaEnums.h:98
Photon_v1 Photon
Definition of the current "egamma version".
Electron_v1 Electron
Definition of the current "egamma version".