ATLAS Offline Software
Loading...
Searching...
No Matches
ElectronPhotonVariableNFCorrectionTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
6
9
11
12#include "TEnv.h"
13#include "TString.h"
14
15#include <cmath>
16#include <algorithm>
17
18
19
20// Ordered list of shower shapes used by the tool
21// The order must match the ONNX model inputs and outputs
22const std::vector<std::string> ElectronPhotonVariableNFCorrectionTool::s_ssVarNames = {
23 "weta2", "weta1", "Rphi", "Reta", "wtots1", "Rhad", "Rhad1", "f1", "fracs1", "DeltaE", "Eratio"
24};
25
26// Mapping of shower shapes to xAOD enums (same order as s_ssVarNames)
27const std::vector<xAOD::EgammaParameters::ShowerShapeType> ElectronPhotonVariableNFCorrectionTool::s_ssEnums = {
39};
40
41
42// Constructor, declares properties
46
47// Select fold index based on event number (and optionally pT)
48int ElectronPhotonVariableNFCorrectionTool::selectFold(unsigned long long eventNumber, float phi) const
49{
50 if(m_forceOneFold) return 0;
51 if (m_nFolds <= 1) return 0;
52
53 unsigned long long key = eventNumber;
54
56 const long long phiBin = static_cast<long long>(std::floor((phi + static_cast<float>(M_PI)) * 100.0f));
57 key = eventNumber + static_cast<unsigned long long>(phiBin);
58 }
59
60 return static_cast<int>(key % m_nFolds);
61}
62
63// Convert string from config to fold strategy
66{
67 if (s == "eventNumber") return FoldStrategy::EventNumber;
68 if (s == "eventNumber_phi") return FoldStrategy::EventNumberPhi;
69 ATH_MSG_WARNING("Unknown FoldStrategy '" << s << "'");
71}
72
74 const xAOD::Photon& photon,
75 const std::vector<float>& ss) const
76{
77 // pT cut
78 if (photon.pt() < m_pTcutMeV) return false;
79
80 // TruthType cut
82 static const SG::AuxElement::Accessor<int> acc_truthType("truthType");
83 if (!acc_truthType.isAvailable(photon)) {
84 ATH_MSG_WARNING("ApplyTo = TruthPhotons but truthType not available — skipping photon");
85 return false;
86 }
87 int truthType = acc_truthType(photon);
88 if (truthType < 13 || truthType > 15) return false;
89 }
90
91 // Shower shape cuts
93 // weta2
94 if (ss[0] <= -10.f || ss[0] >= 10.f) return false;
95 // weta1
96 if (ss[1] <= -10.f || ss[1] >= 10.f) return false;
97 // Rphi
98 if (ss[2] <= -10.f || ss[2] >= 10.f) return false;
99 // Reta
100 if (ss[3] <= -10.f || ss[3] >= 10.f) return false;
101 // wtots1
102 if (ss[4] < -2.f || ss[4] >= 10.f) return false;
103 // Rhad
104 if (ss[5] < -2.f || ss[5] > 2.f) return false;
105 // Rhad1
106 if (ss[6] < -2.f || ss[6] > 2.f) return false;
107 // f1
108 if (ss[7] <= -2.f || ss[7] >= 2.f) return false;
109 // fracs1
110 if (ss[8] <= -2.f || ss[8] >= 5.f) return false;
111 // DeltaE
112 if (ss[9] < 0.f || ss[9] >= 5000.f) return false;
113 // Eratio
114 if (ss[10] < 0.f || ss[10] > 1.f) return false;
115 }
116
117 return true;
118}
119
120
121
122// Initialize tool: read config, setup ONNX tools and accessors
124{
125 if (m_configFile.empty()) {
126 ATH_MSG_ERROR("ConfigFile property is empty. Please provide a config file to the tool.");
127 return StatusCode::FAILURE;
128 }
129
130 std::string resolvedConfig = PathResolverFindCalibFile(m_configFile);
131 if (resolvedConfig.empty()) {
132 ATH_MSG_ERROR("Failed to resolve config file \"" << m_configFile << "\"");
133 return StatusCode::FAILURE;
134 }
135 ATH_MSG_DEBUG("Use configuration file " << m_configFile);
136
137 TEnv env;
138 env.ReadFile(resolvedConfig.c_str(), kEnvLocal);
139 env.IgnoreDuplicates(false);
140
141 m_nFolds = (m_forceOneFold)?1:env.GetValue("NFolds", 0);
142 if (m_nFolds <= 0) {
143 ATH_MSG_ERROR("NFolds not set or invalid in config: " << resolvedConfig);
144 return StatusCode::FAILURE;
145 }
146
147 TString pattern = env.GetValue("ONNXnamePattern", "");
148 if (pattern.IsNull()) {
149 ATH_MSG_ERROR("ONNXnamePattern not set in config: " << resolvedConfig);
150 return StatusCode::FAILURE;
151 }
152 m_onnxPattern = pattern.Data();
153
154
155 TString fs = env.GetValue("FoldStrategy", "eventNumber");
156 std::string fsStr = fs.Data();
157
159
161 ATH_MSG_ERROR("FoldStrategy must be 'eventNumber' or 'eventNumber_phi', but got '" << fsStr << "' in config: " << resolvedConfig);
162 return StatusCode::FAILURE;
163 }
164
165
166 ATH_MSG_VERBOSE("NFolds = " << m_nFolds << ", pattern = " << m_onnxPattern << ", FoldStrategy = " << fsStr);
167
168 if (static_cast<int>(m_onnxToolsForward.size()) != m_nFolds ||
169 static_cast<int>(m_onnxToolsBackward.size()) != m_nFolds) {
170 ATH_MSG_ERROR("Expected "<<m_nFolds<<" forward/backward tools, "<< "but got "<<m_onnxToolsForward.size()<<" / "<< m_onnxToolsBackward.size());
171 return StatusCode::FAILURE;
172 }
173
174
176 else if (m_applyToStr == "All") m_applyToMode = ApplyToMode::All;
177 else {
178 ATH_MSG_ERROR("ApplyTo must be TruthPhotons or All, but got '" << m_applyToStr << "'");
179 return StatusCode::FAILURE;
180 }
181
182 // Cuts on SS vars to remove default values
183 m_applyShowerShapeCuts = (env.GetValue("ApplyShowerShapeCuts", 1) == 1);
184
185 ATH_MSG_INFO("ApplyTo = " << m_applyToStr << ", pTcut=" << m_pTcutMeV << " MeV, ApplyShowerShapeCuts=" << m_applyShowerShapeCuts);
186
187
188 ATH_CHECK(m_onnxToolsForward.retrieve());
189 ATH_CHECK(m_onnxToolsBackward.retrieve());
190
191 if (msgLvl(MSG::DEBUG)) {
192 for (int i = 0; i < m_nFolds; ++i) {
193 ATH_MSG_VERBOSE("Fold " << i << " forward model info:");
194 m_onnxToolsForward[i]->printModelInfo();
195 ATH_MSG_VERBOSE("Fold " << i << " backward model info:");
196 m_onnxToolsBackward[i]->printModelInfo();
197 }
198 }
199
200 // Prepare decorations for each shower shape
201 m_accessors.resize(s_ssVarNames.size());
202 for (size_t i = 0; i < s_ssVarNames.size(); ++i) {
203 const std::string& var = s_ssVarNames[i];
204 m_accessors[i].original = std::make_unique<SG::AuxElement::Accessor<float>>(var + "_original");
205 }
206
207 ATH_CHECK(m_eventInfoKey.initialize());
208
209 ATH_MSG_INFO("NF correction tool initialized with " << m_nFolds << " folds. ");
210
211 return StatusCode::SUCCESS;
212}
213
214
215// Apply NF correction to photon shower shapes.
217{
218
219 const size_t nSS = s_ssEnums.size();
220 std::vector<float> ss(nSS);
221
222 // Read shower shapes, then store original values
223 for (size_t i = 0; i < nSS; ++i) {
224 ss[i] = photon.showerShapeValue(s_ssEnums[i]);
225 (*m_accessors[i].original)(photon) = ss[i];
226 }
227
228
229 static const SG::AuxElement::Decorator<char> dec_pass("NFCorrectedShowerShapes");
230
231 // Photon selection
232 bool pass = passSelectionCuts(photon, ss);
233
234 dec_pass(photon) = pass ? 1 : 0;
235
236 if (!pass) {
237 // If selection is not passed, then SS value will be same to original
239 }
240
241
242 // Get event info and select fold
244 if (!h.isValid()) {
245 ATH_MSG_ERROR("Failed to read EventInfo via key " << m_eventInfoKey.key());
247 }
248
249 const unsigned long long eventNumber = h->eventNumber();
250 float ptGeV = photon.pt() / 1000.0f;
251 const float phi = static_cast<float>(photon.phi());
252 const int fold = selectFold(eventNumber, phi);
253
254
255 if (fold < 0 || fold >= m_nFolds) {
256 ATH_MSG_ERROR("Selected fold " << fold << " out of range [0," << (m_nFolds-1) << "]");
258 }
259
260
261 // Kinematic inputs
262 const bool isConv = photon.conversionType() != xAOD::EgammaParameters::unconverted;
263 std::vector<float> kinematic = {
264 ptGeV,
265 static_cast<float>(photon.eta()),
266 static_cast<float>(photon.phi()),
267 static_cast<float>(isConv)
268 };
269
270 // Forward inference
271 std::vector<Ort::Value> inputTensors;
272
273 const auto& onnxToolForward = m_onnxToolsForward[fold];
274
275 // index 0 is for kinematics
276 int64_t batchSizeKin = onnxToolForward->getBatchSize(
277 static_cast<int64_t>(kinematic.size()), 0);
278 if (onnxToolForward->addInput(inputTensors, kinematic, 0, batchSizeKin).isFailure()) {
279 ATH_MSG_ERROR("Fold " << fold << ": failed to add kinematic input tensor");
281 }
282
283 // index 1 is for shower shape varibales
284 int64_t batchSizeSS = onnxToolForward->getBatchSize(
285 static_cast<int64_t>(ss.size()), 1);
286 if (onnxToolForward->addInput(inputTensors, ss, 1, batchSizeSS).isFailure()) {
287 ATH_MSG_ERROR("Fold " << fold << ": failed to add shower shape input tensor");
289 }
290
291 std::vector<Ort::Value> outputTensors;
292 std::vector<float> outputData;
293 if (onnxToolForward->addOutput(outputTensors, outputData, 0, batchSizeKin).isFailure()) {
294 ATH_MSG_ERROR("Fold " << fold << ": failed to add forward output tensor");
296 }
297
298 if (onnxToolForward->inference(inputTensors, outputTensors).isFailure()) {
299 ATH_MSG_ERROR("Fold " << fold << ": forward inference failed");
301 }
302
303 float* zPtr = outputTensors[0].GetTensorMutableData<float>();
304 std::vector<float> zVec(zPtr, zPtr + nSS);
305
306
307 // Backward inference
308 std::vector<Ort::Value> inputTensorsBack;
309 std::vector<Ort::Value> outputTensorsBack;
310 std::vector<float> outputDataBack;
311
312 const auto& onnxToolBackward = m_onnxToolsBackward[fold];
313
314 // index 0 is for kinematics
315 int64_t batchSizeKinBack = onnxToolBackward->getBatchSize(
316 static_cast<int64_t>(kinematic.size()), 0);
317 if (onnxToolBackward->addInput(inputTensorsBack, kinematic, 0, batchSizeKinBack).isFailure()) {
318 ATH_MSG_ERROR("Fold " << fold << ": failed to add kinematic input tensor for backward model");
320 }
321
322 // index 1 is for shower shapes in latent space
323 int64_t batchSizeZBack = onnxToolBackward->getBatchSize(static_cast<int64_t>(zVec.size()), 1);
324
325 if (onnxToolBackward->addInput(inputTensorsBack, zVec, 1, batchSizeZBack).isFailure()) {
326 ATH_MSG_ERROR("Fold " << fold << ": failed to add z input tensor for backward model");
328 }
329
330 // index 2 is for original shower shapes (models use them to cut on std values [-5, 5])
331 if (onnxToolBackward->addInput(inputTensorsBack, ss, 2, batchSizeKinBack).isFailure()) {
332 ATH_MSG_ERROR("Fold " << fold << ": failed to add original SS input tensor for backward model");
334 }
335
336 if (onnxToolBackward->addOutput(outputTensorsBack, outputDataBack, 0, batchSizeZBack).isFailure()) {
337 ATH_MSG_ERROR("Fold " << fold << ": failed to add backward output tensor");
339 }
340
341 if (onnxToolBackward->inference(inputTensorsBack, outputTensorsBack).isFailure()) {
342 ATH_MSG_ERROR("Fold " << fold << ": backward inference failed");
344 }
345
346 const auto infoB = outputTensorsBack[0].GetTensorTypeAndShapeInfo();
347 const auto nElB = infoB.GetElementCount();
348 if (nElB != nSS) {
349 ATH_MSG_ERROR("Fold "<<fold <<": backward output has " <<nElB<< " elements, expected "<<nSS);
351 }
352
353 // Write corrected shower shapes
354 float* corrPtr = outputTensorsBack[0].GetTensorMutableData<float>();
355 for (size_t i = 0; i < nSS; ++i) {
356 photon.setShowerShapeValue(corrPtr[i], s_ssEnums[i]);
357 }
358
359 ATH_MSG_DEBUG("NF correction applied successfully");
360
361
363}
364
365// Electrons are not supported.
367{
368 ATH_MSG_ERROR("ElectronPhotonVariableNFCorrectionTool does not support electrons.");
370}
371
372// Create corrected copy of photon.
374 xAOD::Photon*& out_photon) const
375{
376
377 out_photon = new xAOD::Photon(in_photon);
378 return applyCorrection(*out_photon);
379}
380
381// Create copy of electron (no correction).
383 xAOD::Electron*& out_electron) const
384{
385 ATH_MSG_ERROR("ElectronPhotonVariableNFCorrectionTool cannot correct electrons.");
386 out_electron = new xAOD::Electron(in_electron);
388}
#define M_PI
Scalar phi() const
phi method
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
static Double_t fs
static Double_t ss
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
bool msgLvl(const MSG::Level lvl) const
Header file for AthHistogramAlgorithm.
Return value from object correction CP tools.
@ Error
Some error happened during the object correction.
@ Ok
The correction was done successfully.
std::vector< SSAccessors > m_accessors
Per-variable accessors aligned with s_ssVarNames.
ToolHandleArray< AthOnnx::IOnnxRuntimeInferenceTool > m_onnxToolsForward
ToolHandleArray for forward ONNX models (one tool per fold).
static const std::vector< std::string > s_ssVarNames
List of shower shape variable names (order must match model I/O).
FoldStrategy parseFoldStrategy(const std::string &s) const
Parse fold strategy string from config.
SG::ReadHandleKey< xAOD::EventInfo > m_eventInfoKey
ReadHandleKey for EventInfo used for fold selection.
int selectFold(unsigned long long eventNumber, float phi) const
Select fold index for the current event/photon.
virtual const CP::CorrectionCode correctedCopy(const xAOD::Photon &in_photon, xAOD::Photon *&out_photon) const override
Make a corrected copy of the passed photon.
std::string m_onnxPattern
Models path pattern string from config.
ToolHandleArray< AthOnnx::IOnnxRuntimeInferenceTool > m_onnxToolsBackward
ToolHandleArray for backward ONNX models (one tool per fold).
bool passSelectionCuts(const xAOD::Photon &photon, const std::vector< float > &ss) const
Returns true if NF correction should be applied to this photon.
bool m_applyShowerShapeCuts
Cuts applied to remove default values of shower shapes.
ElectronPhotonVariableNFCorrectionTool(const std::string &name)
Standard constructor.
static const std::vector< xAOD::EgammaParameters::ShowerShapeType > s_ssEnums
Egamma shower shape enum mapping for reading/writing values (order matches s_ssVarNames).
int m_nFolds
Number of model folds configured (must match tool handle array sizes).
virtual const CP::CorrectionCode applyCorrection(xAOD::Photon &photon) const override
Apply the Normalizing Flow correction to the passed photon.
virtual StatusCode initialize() override
Initialize the class instance.
Gaudi::Property< std::string > m_configFile
The configuration file for the tool, application mode and minimum photon pT cut in MeV.
FoldStrategy m_foldStrategy
Selected fold strategy (configured via FoldStrategy in the config).
AsgTool(const std::string &name)
Constructor specifying the tool instance's name.
Definition AsgTool.cxx:58
@ unconverted
unconverted photon
@ wtots1
shower width is determined in a window detaxdphi = 0,0625 ×~0,2, corresponding typically to 20 strips...
@ f1
E1/E = fraction of energy reconstructed in the first sampling, where E1 is energy in all strips belon...
Definition EgammaEnums.h:53
@ Eratio
(emaxs1-e2tsts1)/(emaxs1+e2tsts1)
@ DeltaE
e2tsts1-emins1
@ fracs1
shower shape in the shower core : [E(+/-3)-E(+/-1)]/E(+/-1), where E(+/-n) is the energy in ± n strip...
@ weta2
the lateral width is calculated with a window of 3x5 cells using the energy weighted sum over all cel...
@ weta1
shower width using +/-3 strips around the one with the maximal energy deposit: w3 strips = sqrt{sum(E...
Definition EgammaEnums.h:98
Photon_v1 Photon
Definition of the current "egamma version".
Electron_v1 Electron
Definition of the current "egamma version".