ATLAS Offline Software
Loading...
Searching...
No Matches
egammaTransformerCalibTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4#ifndef XAOD_ANALYSIS
7#include "Identifier/Identifier.h"
8
10
11#include "xAODEgamma/Egamma.h"
12#include "xAODEgamma/Photon.h"
13#include "xAODEgamma/Electron.h"
16
19
20#include "TFile.h"
21#include "TMath.h"
22#include "TObjString.h"
23#include "TTree.h"
24#include "TClass.h"
25
26#include <cmath>
27#include <format>
28
29#include "GaudiKernel/SystemOfUnits.h"
30using Gaudi::Units::GeV;
31
33 asg::AsgTool(name)
34{
35}
36
37// Need to declare this out-of-line since the full type of m_funcs
38// isn't available in the header.
42
43
45{
47 ATH_MSG_FATAL("Particle type not set: you have to set property ParticleType to a valid value");
48 return StatusCode::FAILURE;
49 }
50 ATH_MSG_DEBUG("Initializing with particle " << m_particleType);
51
52 if (m_isMC) {
53 ATH_MSG_DEBUG("Input is MC");
54 } else {
55 ATH_MSG_DEBUG("Input is data");
56 }
57
58 if (!m_egammaCellRecoveryTool.empty()) {
59 ATH_MSG_DEBUG("Retrieving cell recovery tool");
61 } else {
62 ATH_MSG_DEBUG("Disabling cell recovery tool");
64 }
65
67 ATH_MSG_DEBUG("Using layer-corrected energies as input to Transformer");
68 //
69 m_layerRecalibTool = std::make_unique<egammaLayerRecalibTool>(m_layerCalibTune, m_useSaccCorrection);
71 m_layerRecalibTool->disable_LayerclEdecoration();
72 // by default it will not apply timing cut fix, we apply the timing cut fix here by default
73 } else {
74 ATH_MSG_DEBUG("Not using layer tool, Using raw layer energies as input to Transformer");
75 }
76
77 // get the Transformer models and initialize functions
78 ATH_MSG_DEBUG("get Transformer ONNX models in folder: " << m_folder);
79 switch (m_particleType) {
81 {
85 }
86 break;
88 {
92 }
93 break;
95 {
99 }
100 break;
101
102 default:
103 ATH_MSG_FATAL("Particle type not set properly: " << m_particleType);
104 return StatusCode::FAILURE;
105 }
106
107 return StatusCode::SUCCESS;
108}
109
110
111StatusCode egammaTransformerCalibTool::setupTransformerModel(const std::string& fileName)
112{
113 ATH_MSG_DEBUG("initialize() initialize salt model...");
114
115 m_saltModel = std::make_unique<FlavorTagInference::SaltModel>(fileName);
116
117 // return StatusCode::SUCCESS;
118
119 // set up decorators using a dummy query of the onnx model
121
122 ATH_MSG_DEBUG("initialize() initialize cluster-level features...");
123 std::vector<float> cluster_feat(m_num_cluster_features, 0.);
124 std::vector<int64_t> cluster_feat_dim = {1, static_cast<int64_t>(cluster_feat.size())};
125 FlavorTagInference::Inputs elec_info(cluster_feat, cluster_feat_dim);
126 gnn_input.insert({"cluster_features", elec_info}); // need to use the "jet_features" keyword as we are borrowing flavour tagging code
127
128 ATH_MSG_DEBUG("initialize() initialize cell-level features...");
129 std::vector<float> cell_feat(m_num_cell_features, 0.);
130 std::vector<int64_t> cell_feat_dim = {1, m_num_cell_features};
131 FlavorTagInference::Inputs track_info(cell_feat, cell_feat_dim);
132 gnn_input.insert({"cell_features", track_info});
133
134 ATH_MSG_DEBUG("initialize() initialize dummy evaluation...");
135 auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input); // the dummy evaluation
136
137 ATH_MSG_DEBUG("initialize() finished dummy evaluation...");
138 ATH_MSG_DEBUG("initialize() Output Float(s):");
139 for (auto &singlefloat : out_f)
140 {
141 ATH_MSG_DEBUG("initialize() " << singlefloat.first << " = " << singlefloat.second);
142 }
143 ATH_MSG_DEBUG("initialize() Output vector char(s):");
144 for (auto &vecchar : out_vc)
145 {
146 ATH_MSG_DEBUG("initialize() " << vecchar.first << " = ");
147 for (auto &cc : vecchar.second)
148 {
149 ATH_MSG_DEBUG("initialize() " << cc);
150 }
151 }
152
153 ATH_MSG_DEBUG("initialize() Output vector float(s):");
154 for (auto &vecfloat : out_vf)
155 {
156 ATH_MSG_DEBUG("initialize() " << vecfloat.first << " = ");
157 for (auto &ff : vecfloat.second)
158 {
159 ATH_MSG_DEBUG("initialize() " << ff);
160 }
161 }
162
163 return StatusCode::SUCCESS;
164}
165
167 const xAOD::Egamma* eg,
168 const egammaMVACalib::GlobalEventInfo& gei) const
169{
170 // 0. Safety Checks
171 if (!m_saltModel || !eg) {
172 if (m_clusterEif0) {
173 ATH_MSG_WARNING("Model not loaded or Egamma pointer is null, returning cluster energy");
174 return clus.e();
175 } else {
176 ATH_MSG_FATAL("Model not loaded or Egamma pointer is null, and useClusterIf0 is false, cannot proceed");
177 return 0.0f;
178 }
179 }
180
181 // --- 1. Cell Recovery (Timing Cut Fix) ---
183 bool recoverySucceeded = true;
184 if (!m_egammaCellRecoveryTool.empty()) {
185 egammaCellUtils::MaxECell maxECell(&clus);
186 if (maxECell.sc == StatusCode::FAILURE) {
187 ATH_MSG_WARNING("Issues in finding maximum energy cell.");
188 recoverySucceeded = false;
189 } else {
190 recoveryInfo.etamax = maxECell.etaCell;
191 recoveryInfo.phimax = maxECell.phiCell;
192 if (m_egammaCellRecoveryTool->execute(clus, recoveryInfo).isFailure()) {
193 ATH_MSG_WARNING("Cell Recovery Tool failed. Proceeding without recovered cells.");
194 recoverySucceeded = false;
195 }
196 }
197 }
198
199 // --- 2. Apply Layer Calibration if needed ---
201 auto array_layer_scales = std::array<double, 4>{1.0, 1.0, 1.0, 1.0}; // default scales
202
203 if (m_layerRecalibTool && !m_isMC && !isForward) {
204 ATH_MSG_DEBUG("Applying layer recalibration for GNN on data.");
205
206 // Apply correction to the new, non-const object
207 const xAOD::EventInfo* eventInfo = gei.eventInfo;
208 array_layer_scales = m_layerRecalibTool->getLayerCorrections(*eg, *eventInfo);
209 }
210
211 if ( m_useExtraLayerScales ) {
212 ATH_MSG_DEBUG("Applying extra layer scales for systematic studies, normally this is for MC events.");
213 if ( !m_isMC ) {
214 ATH_MSG_WARNING("You are applying extra layer scales but the input is not MC! Are you sure this is intended?");
215 }
216 // extract scales from global event info
217 for (std::size_t i = 0; i < 4; ++i)
218 array_layer_scales[i] *= gei.scaleEs[i];
219 }
220
221 // --- 3. Calculate Scale Factors ---
222
223 // Raw energies + Recovered Energy (Timing Fix)
224 // double raw_Es0 = clus.energyBE(0); // LG: not sure if this is still needed but keeping it here for consistency
225 double raw_Es1 = clus.energyBE(1);
226 double raw_Es2 = clus.energyBE(2) + (recoverySucceeded && m_useFixForMissingCells ? recoveryInfo.eCells[0] : 0.0);
227 double raw_Es3 = clus.energyBE(3) + (recoverySucceeded && m_useFixForMissingCells ? recoveryInfo.eCells[1] : 0.0);
228
229 // --- 4. Cell Gathering ---
230 std::vector<float> cells_E, cells_eta, cells_phi, cells_x, cells_y, cells_z;
231 std::vector<int> cells_layer;
232 std::vector<Identifier> included_cells; // Track cells to avoid duplicates
233
234 // Layer sums
235 double sum_cell_E_L0 = 0.0, sum_cell_E_L1 = 0.0, sum_cell_E_L2 = 0.0, sum_cell_E_L3 = 0.0, sum_cell_E_Gap = 0.0;
236
237 // A. Iterate over Standard Cluster Cells
238 const CaloClusterCellLink* cellLinks = clus.getCellLinks();
239 if (cellLinks) {
240 for (const CaloCell* cell : *cellLinks) {
241 if (!cell) continue;
242
243 int sampling = cell->caloDDE()->getSampling();
244 double scale_factor = 1.0;
245 int layer_idx = -1;
246
247 switch (sampling) {
248 case CaloCell_ID::PreSamplerB: case CaloCell_ID::PreSamplerE:
249 scale_factor = array_layer_scales[0]; layer_idx = 0; break;
250 case CaloCell_ID::EMB1: case CaloCell_ID::EME1:
251 scale_factor = array_layer_scales[1]; layer_idx = 1; break;
252 case CaloCell_ID::EMB2: case CaloCell_ID::EME2:
253 scale_factor = array_layer_scales[2]; layer_idx = 2;
254 // Track cells that might be already recovered (those with time > timing cut)
255 if (cell->time() > m_timeCut) { // Use your actual timing cut threshold
256 included_cells.push_back(cell->ID());
257 }
258 break;
259 case CaloCell_ID::EMB3: case CaloCell_ID::EME3:
260 scale_factor = array_layer_scales[3]; layer_idx = 3;
261 if (cell->time() > m_timeCut) {
262 included_cells.push_back(cell->ID());
263 }
264 break;
265 case CaloCell_ID::TileGap3:
266 scale_factor = 1.0; layer_idx = 4; break;
267 default: continue;
268 }
269
270 double final_E = cell->e() * scale_factor;
271
272 cells_E.push_back(final_E);
273 cells_eta.push_back(cell->eta());
274 cells_phi.push_back(cell->phi());
275 cells_x.push_back(cell->x());
276 cells_y.push_back(cell->y());
277 cells_z.push_back(cell->z());
278 cells_layer.push_back(layer_idx);
279
280 // Accumulate Sums
281 switch(layer_idx) {
282 case 0: sum_cell_E_L0 += final_E; break;
283 case 1: sum_cell_E_L1 += final_E; break;
284 case 2: sum_cell_E_L2 += final_E; break;
285 case 3: sum_cell_E_L3 += final_E; break;
286 case 4: sum_cell_E_Gap += final_E; break;
287 }
288 }
289 }
290
291 // B. Iterate over Recovered Cells (from Tool) - Skip Duplicates
292 // Added cells are only expected in layers 2 and 3, so the dedup list only tracks those layers.
293 for (const CaloCell* cell : recoveryInfo.addedCells) {
294 if (!cell || !cell->caloDDE()) continue;
295
296 // Skip if this cell is already in the cluster
297 if (std::find(included_cells.begin(), included_cells.end(), cell->ID()) != included_cells.end()) {
298 ATH_MSG_WARNING("Recovered cell " << cell->ID() << " already included in cluster. Skipping to avoid double counting.");
299 continue;
300 }
301 else {
302 ATH_MSG_DEBUG("Adding recovered cell " << cell->ID() << " to cluster inputs.");
303 }
304
305 int sampling = cell->caloDDE()->getSampling();
306 double scale_factor = 1.0;
307 int layer_idx = -1;
308
309 if (sampling == CaloCell_ID::EMB2 || sampling == CaloCell_ID::EME2) {
310 scale_factor = array_layer_scales[2]; layer_idx = 2;
311 } else if (sampling == CaloCell_ID::EMB3 || sampling == CaloCell_ID::EME3) {
312 scale_factor = array_layer_scales[3]; layer_idx = 3;
313 } else {
314 // Fallback
315 if (sampling == CaloCell_ID::PreSamplerB || sampling == CaloCell_ID::PreSamplerE) {
316 scale_factor = array_layer_scales[0]; layer_idx = 0;
317 } else if (sampling == CaloCell_ID::EMB1 || sampling == CaloCell_ID::EME1) {
318 scale_factor = array_layer_scales[1]; layer_idx = 1;
319 } else {
320 continue;
321 }
322 }
323
324 double final_E = cell->e() * scale_factor;
325
326 cells_E.push_back(final_E);
327 cells_eta.push_back(cell->eta());
328 cells_phi.push_back(cell->phi());
329 cells_x.push_back(cell->x());
330 cells_y.push_back(cell->y());
331 cells_z.push_back(cell->z());
332 cells_layer.push_back(layer_idx);
333
334 switch(layer_idx) {
335 case 0: sum_cell_E_L0 += final_E; break;
336 case 1: sum_cell_E_L1 += final_E; break;
337 case 2: sum_cell_E_L2 += final_E; break;
338 case 3: sum_cell_E_L3 += final_E; break;
339 case 4: sum_cell_E_Gap += final_E; break;
340 }
341 }
342
343 // --- 5. Calculate Derived Features (Post-Loop) ---
344 const size_t nCells = cells_E.size();
345 if (nCells == 0) return 0.0f;
346
347 double sum_cell_E_total = sum_cell_E_L0 + sum_cell_E_L1 + sum_cell_E_L2 + sum_cell_E_L3;
348 const double cluster_eta = clus.eta();
349 const double cluster_phi = clus.phi();
350
351 std::vector<float> cells_deta, cells_dphi, cells_eFrac;
352 cells_deta.reserve(nCells);
353 cells_dphi.reserve(nCells);
354 cells_eFrac.reserve(nCells);
355
356 for (size_t i = 0; i < nCells; ++i) {
357 float deta = cells_eta[i] - cluster_eta;
358 float dphi = cells_phi[i] - cluster_phi;
359 dphi = std::fmod(dphi + 3.0f * M_PI, 2.0f * M_PI) - M_PI;
360
361 cells_deta.push_back(deta);
362 cells_dphi.push_back(dphi);
363
364 float eFrac_layer = 0.0f;
365 switch (cells_layer[i]) {
366 case 0: eFrac_layer = (sum_cell_E_L0 != 0) ? (cells_E[i] / sum_cell_E_L0) : 0.0f; break;
367 case 1: eFrac_layer = (sum_cell_E_L1 != 0) ? (cells_E[i] / sum_cell_E_L1) : 0.0f; break;
368 case 2: eFrac_layer = (sum_cell_E_L2 != 0) ? (cells_E[i] / sum_cell_E_L2) : 0.0f; break;
369 case 3: eFrac_layer = (sum_cell_E_L3 != 0) ? (cells_E[i] / sum_cell_E_L3) : 0.0f; break;
370 case 4: eFrac_layer = (sum_cell_E_Gap != 0) ? (cells_E[i] / sum_cell_E_Gap) : 0.0f; break;
371 }
372 cells_eFrac.push_back(eFrac_layer);
373 }
374
375 // --- 6. Prepare GNN Inputs and Run Inference ---
376 double ratio_L1_L2 = (sum_cell_E_L2 != 0) ? (sum_cell_E_L1 / sum_cell_E_L2) : 0.0;
377 double main_layers_sum = sum_cell_E_L1 + sum_cell_E_L2 + sum_cell_E_L3;
378 double ratio_L0_total = (main_layers_sum != 0) ? (sum_cell_E_L0 / main_layers_sum) : 0.0;
379 double ratio_Tile_total = (main_layers_sum != 0) ? (sum_cell_E_Gap / main_layers_sum) : 0.0;
380
382
383 // Cluster Features
384 std::vector<float> cluster_feats = {
385 static_cast<float>(sum_cell_E_total),
386 static_cast<float>(sum_cell_E_L0),
387 static_cast<float>(sum_cell_E_L1),
388 static_cast<float>(sum_cell_E_L2),
389 static_cast<float>(sum_cell_E_L3),
390 static_cast<float>(sum_cell_E_Gap),
391 static_cast<float>(cluster_eta),
392 static_cast<float>(cluster_phi),
393 static_cast<float>(ratio_L1_L2),
394 static_cast<float>(ratio_L0_total),
395 static_cast<float>(ratio_Tile_total)
396 };
397
398 // For converted photons, append conversion-specific features in order: convR, convEtOverPt, convPtRatio, conversionType.
400 const xAOD::Photon* photon = dynamic_cast<const xAOD::Photon*>(eg);
401 if (photon) {
402 // - convR
403 float convR = 799.0f;
404 if (egammaMVAFunctions::compute_ptconv(photon) > 3 * GeV) {
406 }
407
408 // - convEtOverPt
409 float convEtOverPt = 0.0f;
410 float ptconv = egammaMVAFunctions::compute_ptconv(photon);
411 if (xAOD::EgammaHelpers::numberOfSiTracks(photon) == 2 && ptconv > 0.0f) {
412 float eacc = (m_useLayerCorrected ?
413 (raw_Es1 * array_layer_scales[1] + raw_Es2 * array_layer_scales[2] + raw_Es3 * array_layer_scales[3]) :
414 (raw_Es1 + raw_Es2 + raw_Es3));
415 float cl_eta = eg->caloCluster()->eta();
416 convEtOverPt = std::max(0.0f, eacc / (std::cosh(cl_eta) * ptconv));
417 }
418 convEtOverPt = std::min(convEtOverPt, 2.0f);
419
420 // - convPtRatio
421 float convPtRatio = 1.0f;
422 if (xAOD::EgammaHelpers::numberOfSiTracks(photon) == 2) {
423 float pt1 = egammaMVAFunctions::compute_pt1conv(photon);
424 float pt2 = egammaMVAFunctions::compute_pt2conv(photon);
425 if ((pt1 + pt2) > 0.0f) {
426 convPtRatio = std::max(pt1, pt2) / (pt1 + pt2);
427 }
428 }
429
430 // - conversionType
431 float conversionType = static_cast<float>(photon->conversionType());
432 // must push back in this order as the model expects features in this order
433 cluster_feats.push_back(convR);
434 cluster_feats.push_back(convEtOverPt);
435 cluster_feats.push_back(convPtRatio);
436 cluster_feats.push_back(conversionType);
437 } else {
438 cluster_feats.push_back(0.0f);
439 cluster_feats.push_back(0.0f);
440 cluster_feats.push_back(0.0f);
441 cluster_feats.push_back(0.0f);
442 }
443 }
444
445
446 gnn_input["cluster_features"] = FlavorTagInference::Inputs(cluster_feats, {1, (int64_t)cluster_feats.size()});
447
448 // Cell Features
449 std::vector<float> cell_feats_flat;
450 cell_feats_flat.reserve(nCells * m_num_cell_features);
451 for (size_t i = 0; i < nCells; ++i) {
452 cell_feats_flat.push_back(cells_eFrac[i]);
453 cell_feats_flat.push_back(cells_deta[i]);
454 cell_feats_flat.push_back(cells_dphi[i]);
455 cell_feats_flat.push_back(cells_x[i]);
456 cell_feats_flat.push_back(cells_y[i]);
457 cell_feats_flat.push_back(cells_z[i]);
458 cell_feats_flat.push_back(static_cast<float>(cells_layer[i]));
459 }
460 gnn_input["cell_features"] = FlavorTagInference::Inputs(cell_feats_flat, {(int64_t)nCells, m_num_cell_features});
461
462 // Run Inference
463 auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input);
464
465 float el_gnn_score = 0.0f;
466 if (out_vf.empty() || out_vf.begin()->second.empty()) {
467 ATH_MSG_DEBUG("GNN inference output is empty!");
468 } else {
469 el_gnn_score = out_vf.begin()->second.front();
470 }
471
472 // what to do if the Transformer response is 0;
473 if (el_gnn_score == 0.0f) {
474 return m_clusterEif0 ? clus.e() : 0.0f;
475 }
476
477 return el_gnn_score * static_cast<float>(sum_cell_E_total);
478}
479#endif
#define M_PI
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_FATAL(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Define macros for attributes used to control the static checker.
Data object for each calorimeter readout cell.
Definition CaloCell.h:57
std::vector< const CaloCell * > addedCells
AsgTool(const std::string &name)
Constructor specifying the tool instance's name.
Definition AsgTool.cxx:58
std::unique_ptr< const FlavorTagInference::SaltModel > m_saltModel
Gaudi::Property< bool > m_isMC
Layer calibration related properties.
Gaudi::Property< std::string > m_unconvertedPhotonModelFile
StatusCode setupTransformerModel(const std::string &fileName)
a utility to set up the transformer model, separated from initialize for better readability
Gaudi::Property< std::string > m_layerCalibTune
Gaudi::Property< bool > m_useLayerCorrected
Gaudi::Property< bool > m_useFixForMissingCells
Gaudi::Property< std::string > m_convertedPhotonModelFile
virtual StatusCode initialize() override
Dummy implementation of the initialisation function.
Gaudi::Property< std::string > m_folder
string with folder for weight files
Gaudi::Property< bool > m_useExtraLayerScales
egammaTransformerCalibTool(const std::string &type)
Gaudi::Property< std::string > m_electronModelFile
Gaudi::Property< bool > m_useSaccCorrection
std::unique_ptr< egammaLayerRecalibTool > m_layerRecalibTool
float getEnergy(const xAOD::CaloCluster &clus, const xAOD::Egamma *eg, const egammaMVACalib::GlobalEventInfo &gei=egammaMVACalib::GlobalEventInfo()) const override final
returns the calibrated energy
ToolHandle< IegammaCellRecoveryTool > m_egammaCellRecoveryTool
Pointer to the egammaCellRecoveryTool.
Gaudi::Property< bool > m_clusterEif0
const CaloClusterCellLink * getCellLinks() const
Get a pointer to the CaloClusterCellLink object (const version).
virtual double eta() const
The pseudorapidity ( ) of the particle.
virtual double e() const
The total energy of the particle.
float energyBE(const unsigned layer) const
Get the energy in one layer of the EM Calo.
virtual double phi() const
The azimuthal angle ( ) of the particle.
const xAOD::CaloCluster * caloCluster(size_t index=0) const
Pointer to the xAOD::CaloCluster/s that define the electron candidate.
std::map< std::string, Inputs, std::less<> > InputMap
Definition ISaltModel.h:37
float compute_ptconv(const xAOD::Photon *ph)
This ptconv is the old one used by MVACalib.
float compute_pt2conv(const xAOD::Photon *ph)
float compute_pt1conv(const xAOD::Photon *ph)
std::size_t numberOfSiTracks(const xAOD::Photon *eg)
return the number of Si tracks in the conversion
float conversionRadius(const xAOD::Vertex *vx)
return the conversion radius or 9999.
EventInfo_v1 EventInfo
Definition of the latest event info version.
CaloCluster_v1 CaloCluster
Define the latest version of the calorimeter cluster class.
Egamma_v1 Egamma
Definition of the current "egamma version".
Definition Egamma.h:17
Photon_v1 Photon
Definition of the current "egamma version".
A structure holding some global event information.
std::array< float, 4 > scaleEs
const xAOD::EventInfo * eventInfo