ATLAS Offline Software
Loading...
Searching...
No Matches
egammaTransformerCalibTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
6
8
9#include "xAODEgamma/Egamma.h"
10#include "xAODEgamma/Photon.h"
11#include "xAODEgamma/Electron.h"
14
17
18#include "TFile.h"
19#include "TMath.h"
20#include "TObjString.h"
21#include "TTree.h"
22#include "TClass.h"
23
24#include <cmath>
25#include <format>
26
27#ifndef XAOD_ANALYSIS
28#include "GaudiKernel/SystemOfUnits.h"
29using Gaudi::Units::GeV;
30#else
31#define GeV 1000
32#endif
33
35 asg::AsgTool(name)
36{
37}
38
39// Need to declare this out-of-line since the full type of m_funcs
40// isn't available in the header.
44
45
47{
49 ATH_MSG_FATAL("Particle type not set: you have to set property ParticleType to a valid value");
50 return StatusCode::FAILURE;
51 }
52 ATH_MSG_DEBUG("Initializing with particle " << m_particleType);
53
54 if (m_isMC) {
55 ATH_MSG_DEBUG("Input is MC");
56 } else {
57 ATH_MSG_DEBUG("Input is data");
58 }
59
61 ATH_MSG_DEBUG("Using layer-corrected energies as input to Transformer");
62 //
63 m_layerRecalibTool = std::make_unique<egammaLayerRecalibTool>(m_layerCalibTune, m_useSaccCorrection);
65 m_layerRecalibTool->disable_LayerclEdecoration();
66 // by default it will not apply timing cut fix, we apply the timing cut fix here by default
67 } else {
68 ATH_MSG_DEBUG("Not using layer tool, Using raw layer energies as input to Transformer");
69 }
70
71 // get the Transformer models and initialize functions
72 ATH_MSG_DEBUG("get Transformer ONNX models in folder: " << m_folder);
73 switch (m_particleType) {
75 {
79 }
80 break;
82 {
86 }
87 break;
89 {
93 }
94 break;
96 {
99 // Forward electron is not implemented, will use model for electron
100 ATH_MSG_WARNING("Forward electron Transformer model is not implemented, will use electron model instead");
102 }
103 break;
104
105 default:
106 ATH_MSG_FATAL("Particle type not set properly: " << m_particleType);
107 return StatusCode::FAILURE;
108 }
109
110 return StatusCode::SUCCESS;
111}
112
113
114StatusCode egammaTransformerCalibTool::setupTransformerModel(const std::string& fileName)
115{
116 ATH_MSG_DEBUG("initialize() initialize salt model...");
117
118 m_saltModel = std::make_unique<FlavorTagInference::SaltModel>(fileName);
119
120 // return StatusCode::SUCCESS;
121
122 // set up decorators using a dummy query of the onnx model
123 std::map<std::string, FlavorTagInference::Inputs> gnn_input;
124
125 ATH_MSG_DEBUG("initialize() initialize cluster-level features...");
126 std::vector<float> cluster_feat(m_num_cluster_features, 0.);
127 std::vector<int64_t> cluster_feat_dim = {1, static_cast<int64_t>(cluster_feat.size())};
128 FlavorTagInference::Inputs elec_info(cluster_feat, cluster_feat_dim);
129 gnn_input.insert({"cluster_features", elec_info}); // need to use the "jet_features" keyword as we are borrowing flavour tagging code
130
131 ATH_MSG_DEBUG("initialize() initialize cell-level features...");
132 std::vector<float> cell_feat(m_num_cell_features, 0.);
133 std::vector<int64_t> cell_feat_dim = {1, m_num_cell_features};
134 FlavorTagInference::Inputs track_info(cell_feat, cell_feat_dim);
135 gnn_input.insert({"cell_features", track_info});
136
137 ATH_MSG_DEBUG("initialize() initialize dummy evaluation...");
138 auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input); // the dummy evaluation
139
140 ATH_MSG_DEBUG("initialize() finished dummy evaluation...");
141 ATH_MSG_DEBUG("initialize() Output Float(s):");
142 for (auto &singlefloat : out_f)
143 {
144 ATH_MSG_DEBUG("initialize() " << singlefloat.first << " = " << singlefloat.second);
145 }
146 ATH_MSG_DEBUG("initialize() Output vector char(s):");
147 for (auto &vecchar : out_vc)
148 {
149 ATH_MSG_DEBUG("initialize() " << vecchar.first << " = ");
150 for (auto &cc : vecchar.second)
151 {
152 ATH_MSG_DEBUG("initialize() " << cc);
153 }
154 }
155
156 ATH_MSG_DEBUG("initialize() Output vector float(s):");
157 for (auto &vecfloat : out_vf)
158 {
159 ATH_MSG_DEBUG("initialize() " << vecfloat.first << " = ");
160 for (auto &ff : vecfloat.second)
161 {
162 ATH_MSG_DEBUG("initialize() " << ff);
163 }
164 }
165
166 return StatusCode::SUCCESS;
167}
168
170 const xAOD::Egamma* eg,
171 const egammaMVACalib::GlobalEventInfo& gei) const
172{
173 // 0. Safety Checks
174 if (!m_saltModel || !eg) {
175 if (m_clusterEif0) {
176 ATH_MSG_WARNING("Model not loaded or Egamma pointer is null, returning cluster energy");
177 return clus.e();
178 } else {
179 ATH_MSG_FATAL("Model not loaded or Egamma pointer is null, and useClusterIf0 is false, cannot proceed");
180 return 0.0f;
181 }
182 }
183
184 // --- 1. Cell Recovery (Timing Cut Fix) ---
186 bool recoverySucceeded = false;
187 if (!m_egammaCellRecoveryTool.empty()) {
188 if (m_egammaCellRecoveryTool->execute(clus, recoveryInfo).isFailure()) {
189 ATH_MSG_WARNING("Cell Recovery Tool failed. Proceeding without recovered cells.");
190 } else {
191 recoverySucceeded = true;
192 }
193 }
194
195 // --- 2. Apply Layer Calibration if needed ---
196 const xAOD::CaloCluster* clusterForTransformer = eg->caloCluster();
197 std::unique_ptr<xAOD::Egamma> temp_eg; // Manages the lifetime of the temporary object
198
200 auto array_layer_scales = std::array<double, 4>{1.0, 1.0, 1.0, 1.0}; // default scales
201
202 if (m_layerRecalibTool && !m_isMC && !isForward) {
203 ATH_MSG_DEBUG("Applying layer recalibration for GNN on data.");
204
205 // A. Create a new object of the correct concrete type (Electron or Photon)
206 // We use a switch based on the configured particle type.
207 switch (m_particleType) {
209 temp_eg = std::make_unique<xAOD::Electron>();
210 temp_eg->makePrivateStore(*eg); // B. Copy data from the original object
211 break;
212 }
215 temp_eg = std::make_unique<xAOD::Photon>();
216 temp_eg->makePrivateStore(*eg); // B. Copy data from the original object
217 break;
218 }
219 default:
220 ATH_MSG_WARNING("Unknown particle type set in tool for layer calibration: " << m_particleType);
221 temp_eg = nullptr;
222 break;
223 }
224
225 // D. Apply correction to the new, non-const object
226 if (temp_eg) {
227 const xAOD::EventInfo* eventInfo = gei.eventInfo;
228 array_layer_scales = m_layerRecalibTool->getLayerCorrections(*temp_eg, *eventInfo);
229 // E. Get the calibrated cluster from the temporary object
230 clusterForTransformer = temp_eg->caloCluster();
231 }
232 }
233 else
234 {
235 // Get the cluster from the (possibly calibrated) local object, if not using layer corrections, this will just be the original cluster
236 clusterForTransformer = eg->caloCluster();
237 ATH_MSG_DEBUG("Not using layer tool, Using raw layer energies as input to Transformer");
238 }
239
240 if ( m_useExtraLayerScales ) {
241 ATH_MSG_DEBUG("Applying extra layer scales for systematic studies, normally this is for MC events.");
242 if ( !m_isMC ) {
243 ATH_MSG_WARNING("You are applying extra layer scales but the input is not MC! Are you sure this is intended?");
244 }
245 // extract scales from global event info
246 for (std::size_t i = 0; i < 4; ++i)
247 array_layer_scales[i] *= gei.scaleEs[i];
248 }
249
250 // --- 3. Calculate Scale Factors ---
251
252 // Raw energies + Recovered Energy (Timing Fix)
253 // double raw_Es0 = clus.energyBE(0); // LG: not sure if this is still needed but keeping it here for consistency
254 double raw_Es1 = clus.energyBE(1);
255 double raw_Es2 = clus.energyBE(2) + (recoverySucceeded && m_useFixForMissingCells ? recoveryInfo.eCells[0] : 0.0);
256 double raw_Es3 = clus.energyBE(3) + (recoverySucceeded && m_useFixForMissingCells ? recoveryInfo.eCells[1] : 0.0);
257
258 // --- 4. Cell Gathering ---
259 std::vector<float> cells_E, cells_eta, cells_phi, cells_x, cells_y, cells_z;
260 std::vector<int> cells_layer;
261 std::vector<Identifier> included_cells; // Track cells to avoid duplicates
262
263 // Layer sums
264 double sum_cell_E_L0 = 0.0, sum_cell_E_L1 = 0.0, sum_cell_E_L2 = 0.0, sum_cell_E_L3 = 0.0, sum_cell_E_Gap = 0.0;
265
266 // A. Iterate over Standard Cluster Cells
267 const CaloClusterCellLink* cellLinks = clus.getCellLinks();
268 if (cellLinks) {
269 for (const CaloCell* cell : *cellLinks) {
270 if (!cell) continue;
271
272 int sampling = cell->caloDDE()->getSampling();
273 double scale_factor = 1.0;
274 int layer_idx = -1;
275
276 switch (sampling) {
277 case CaloCell_ID::PreSamplerB: case CaloCell_ID::PreSamplerE:
278 scale_factor = array_layer_scales[0]; layer_idx = 0; break;
279 case CaloCell_ID::EMB1: case CaloCell_ID::EME1:
280 scale_factor = array_layer_scales[1]; layer_idx = 1; break;
281 case CaloCell_ID::EMB2: case CaloCell_ID::EME2:
282 scale_factor = array_layer_scales[2]; layer_idx = 2;
283 // Track cells that might be already recovered (those with time > timing cut)
284 if (cell->time() > m_timeCut) { // Use your actual timing cut threshold
285 included_cells.push_back(cell->ID());
286 }
287 break;
288 case CaloCell_ID::EMB3: case CaloCell_ID::EME3:
289 scale_factor = array_layer_scales[3]; layer_idx = 3;
290 if (cell->time() > m_timeCut) {
291 included_cells.push_back(cell->ID());
292 }
293 break;
294 case CaloCell_ID::TileGap3:
295 scale_factor = 1.0; layer_idx = 4; break;
296 default: continue;
297 }
298
299 double final_E = cell->e() * scale_factor;
300
301 cells_E.push_back(final_E);
302 cells_eta.push_back(cell->eta());
303 cells_phi.push_back(cell->phi());
304 cells_x.push_back(cell->x());
305 cells_y.push_back(cell->y());
306 cells_z.push_back(cell->z());
307 cells_layer.push_back(layer_idx);
308
309 // Accumulate Sums
310 switch(layer_idx) {
311 case 0: sum_cell_E_L0 += final_E; break;
312 case 1: sum_cell_E_L1 += final_E; break;
313 case 2: sum_cell_E_L2 += final_E; break;
314 case 3: sum_cell_E_L3 += final_E; break;
315 case 4: sum_cell_E_Gap += final_E; break;
316 }
317 }
318 }
319
320 // B. Iterate over Recovered Cells (from Tool) - Skip Duplicates
321 // Added cells are only expected in layers 2 and 3, so the dedup list only tracks those layers.
322 for (const CaloCell* cell : recoveryInfo.addedCells) {
323 if (!cell || !cell->caloDDE()) continue;
324
325 // Skip if this cell is already in the cluster
326 if (std::find(included_cells.begin(), included_cells.end(), cell->ID()) != included_cells.end()) {
327 ATH_MSG_DEBUG("Recovered cell " << cell->ID() << " already included in cluster. Skipping to avoid double counting.");
328 continue;
329 }
330 else {
331 ATH_MSG_DEBUG("Adding recovered cell " << cell->ID() << " to cluster inputs.");
332 }
333
334 int sampling = cell->caloDDE()->getSampling();
335 double scale_factor = 1.0;
336 int layer_idx = -1;
337
338 if (sampling == CaloCell_ID::EMB2 || sampling == CaloCell_ID::EME2) {
339 scale_factor = array_layer_scales[2]; layer_idx = 2;
340 } else if (sampling == CaloCell_ID::EMB3 || sampling == CaloCell_ID::EME3) {
341 scale_factor = array_layer_scales[3]; layer_idx = 3;
342 } else {
343 // Fallback
344 if (sampling == CaloCell_ID::PreSamplerB || sampling == CaloCell_ID::PreSamplerE) {
345 scale_factor = array_layer_scales[0]; layer_idx = 0;
346 } else if (sampling == CaloCell_ID::EMB1 || sampling == CaloCell_ID::EME1) {
347 scale_factor = array_layer_scales[1]; layer_idx = 1;
348 } else {
349 continue;
350 }
351 }
352
353 double final_E = cell->e() * scale_factor;
354
355 cells_E.push_back(final_E);
356 cells_eta.push_back(cell->eta());
357 cells_phi.push_back(cell->phi());
358 cells_x.push_back(cell->x());
359 cells_y.push_back(cell->y());
360 cells_z.push_back(cell->z());
361 cells_layer.push_back(layer_idx);
362
363 switch(layer_idx) {
364 case 0: sum_cell_E_L0 += final_E; break;
365 case 1: sum_cell_E_L1 += final_E; break;
366 case 2: sum_cell_E_L2 += final_E; break;
367 case 3: sum_cell_E_L3 += final_E; break;
368 case 4: sum_cell_E_Gap += final_E; break;
369 }
370 }
371
372 // --- 5. Calculate Derived Features (Post-Loop) ---
373 const size_t nCells = cells_E.size();
374 if (nCells == 0) return 0.0f;
375
376 double sum_cell_E_total = sum_cell_E_L0 + sum_cell_E_L1 + sum_cell_E_L2 + sum_cell_E_L3;
377 const double cluster_eta = clus.eta();
378 const double cluster_phi = clus.phi();
379
380 std::vector<float> cells_deta, cells_dphi, cells_eFrac;
381 cells_deta.reserve(nCells);
382 cells_dphi.reserve(nCells);
383 cells_eFrac.reserve(nCells);
384
385 for (size_t i = 0; i < nCells; ++i) {
386 float deta = cells_eta[i] - cluster_eta;
387 float dphi = cells_phi[i] - cluster_phi;
388 dphi = std::fmod(dphi + 3.0f * M_PI, 2.0f * M_PI) - M_PI;
389
390 cells_deta.push_back(deta);
391 cells_dphi.push_back(dphi);
392
393 float eFrac_layer = 0.0f;
394 switch (cells_layer[i]) {
395 case 0: eFrac_layer = (sum_cell_E_L0 != 0) ? (cells_E[i] / sum_cell_E_L0) : 0.0f; break;
396 case 1: eFrac_layer = (sum_cell_E_L1 != 0) ? (cells_E[i] / sum_cell_E_L1) : 0.0f; break;
397 case 2: eFrac_layer = (sum_cell_E_L2 != 0) ? (cells_E[i] / sum_cell_E_L2) : 0.0f; break;
398 case 3: eFrac_layer = (sum_cell_E_L3 != 0) ? (cells_E[i] / sum_cell_E_L3) : 0.0f; break;
399 case 4: eFrac_layer = (sum_cell_E_Gap != 0) ? (cells_E[i] / sum_cell_E_Gap) : 0.0f; break;
400 }
401 cells_eFrac.push_back(eFrac_layer);
402 }
403
404 // --- 6. Prepare GNN Inputs and Run Inference ---
405 double ratio_L1_L2 = (sum_cell_E_L2 != 0) ? (sum_cell_E_L1 / sum_cell_E_L2) : 0.0;
406 double main_layers_sum = sum_cell_E_L1 + sum_cell_E_L2 + sum_cell_E_L3;
407 double ratio_L0_total = (main_layers_sum != 0) ? (sum_cell_E_L0 / main_layers_sum) : 0.0;
408 double ratio_Tile_total = (main_layers_sum != 0) ? (sum_cell_E_Gap / main_layers_sum) : 0.0;
409
410 std::map<std::string, FlavorTagInference::Inputs> gnn_input;
411
412 // Cluster Features
413 std::vector<float> cluster_feats = {
414 static_cast<float>(sum_cell_E_total),
415 static_cast<float>(sum_cell_E_L0),
416 static_cast<float>(sum_cell_E_L1),
417 static_cast<float>(sum_cell_E_L2),
418 static_cast<float>(sum_cell_E_L3),
419 static_cast<float>(sum_cell_E_Gap),
420 static_cast<float>(cluster_eta),
421 static_cast<float>(cluster_phi),
422 static_cast<float>(ratio_L1_L2),
423 static_cast<float>(ratio_L0_total),
424 static_cast<float>(ratio_Tile_total)
425 };
426
427 // For converted photons, append conversion-specific features in order: convR, convEtOverPt, convPtRatio, conversionType.
429 const xAOD::Photon* photon = dynamic_cast<const xAOD::Photon*>(eg);
430 if (photon) {
431 // - convR
432 float convR = 799.0f;
433 if (egammaMVAFunctions::compute_ptconv(photon) > 3 * GeV) {
435 }
436
437 // - convEtOverPt
438 float convEtOverPt = 0.0f;
439 float ptconv = egammaMVAFunctions::compute_ptconv(photon);
440 if (xAOD::EgammaHelpers::numberOfSiTracks(photon) == 2 && ptconv > 0.0f) {
441 float eacc = (m_useLayerCorrected ?
442 (raw_Es1 * array_layer_scales[1] + raw_Es2 * array_layer_scales[2] + raw_Es3 * array_layer_scales[3]) :
443 (raw_Es1 + raw_Es2 + raw_Es3));
444 float cl_eta = egammaMVAFunctions::compute_cl_eta(*clusterForTransformer);
445 convEtOverPt = std::max(0.0f, eacc / (std::cosh(cl_eta) * ptconv));
446 }
447 convEtOverPt = std::min(convEtOverPt, 2.0f);
448
449 // - convPtRatio
450 float convPtRatio = 1.0f;
451 if (xAOD::EgammaHelpers::numberOfSiTracks(photon) == 2) {
452 float pt1 = egammaMVAFunctions::compute_pt1conv(photon);
453 float pt2 = egammaMVAFunctions::compute_pt2conv(photon);
454 if ((pt1 + pt2) > 0.0f) {
455 convPtRatio = std::max(pt1, pt2) / (pt1 + pt2);
456 }
457 }
458
459 // - conversionType
460 float conversionType = static_cast<float>(photon->conversionType());
461 // must push back in this order as the model expects features in this order
462 cluster_feats.push_back(convR);
463 cluster_feats.push_back(convEtOverPt);
464 cluster_feats.push_back(convPtRatio);
465 cluster_feats.push_back(conversionType);
466 } else {
467 cluster_feats.push_back(0.0f);
468 cluster_feats.push_back(0.0f);
469 cluster_feats.push_back(0.0f);
470 cluster_feats.push_back(0.0f);
471 }
472 }
473
474
475 gnn_input["cluster_features"] = FlavorTagInference::Inputs(cluster_feats, {1, (int64_t)cluster_feats.size()});
476
477 // Cell Features
478 std::vector<float> cell_feats_flat;
479 cell_feats_flat.reserve(nCells * m_num_cell_features);
480 for (size_t i = 0; i < nCells; ++i) {
481 cell_feats_flat.push_back(cells_eFrac[i]);
482 cell_feats_flat.push_back(cells_deta[i]);
483 cell_feats_flat.push_back(cells_dphi[i]);
484 cell_feats_flat.push_back(cells_x[i]);
485 cell_feats_flat.push_back(cells_y[i]);
486 cell_feats_flat.push_back(cells_z[i]);
487 cell_feats_flat.push_back(static_cast<float>(cells_layer[i]));
488 }
489 gnn_input["cell_features"] = FlavorTagInference::Inputs(cell_feats_flat, {(int64_t)nCells, m_num_cell_features});
490
491 // Run Inference
492 auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input);
493
494 float el_gnn_score = 0.0f;
495 if (out_vf.empty() || out_vf.begin()->second.empty()) {
496 ATH_MSG_DEBUG("GNN inference output is empty!");
497 } else {
498 el_gnn_score = out_vf.begin()->second.front();
499 }
500
501 // what to do if the Transformer response is 0;
502 if (el_gnn_score == 0.0f) {
503 return m_clusterEif0 ? clus.e() : 0.0f;
504 }
505
506 return el_gnn_score * static_cast<float>(sum_cell_E_total);
507}
#define M_PI
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_FATAL(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Define macros for attributes used to control the static checker.
Data object for each calorimeter readout cell.
Definition CaloCell.h:57
std::vector< const CaloCell * > addedCells
AsgTool(const std::string &name)
Constructor specifying the tool instance's name.
Definition AsgTool.cxx:58
std::unique_ptr< const FlavorTagInference::SaltModel > m_saltModel
Gaudi::Property< bool > m_isMC
Layer calibration related properties.
Gaudi::Property< std::string > m_unconvertedPhotonModelFile
StatusCode setupTransformerModel(const std::string &fileName)
a utility to set up the transformer model, separated from initialize for better readability
Gaudi::Property< std::string > m_layerCalibTune
Gaudi::Property< bool > m_useLayerCorrected
Gaudi::Property< bool > m_useFixForMissingCells
Gaudi::Property< std::string > m_convertedPhotonModelFile
virtual StatusCode initialize() override
Dummy implementation of the initialisation function.
Gaudi::Property< std::string > m_folder
string with folder for weight files
Gaudi::Property< bool > m_useExtraLayerScales
egammaTransformerCalibTool(const std::string &type)
Gaudi::Property< std::string > m_electronModelFile
Gaudi::Property< bool > m_useSaccCorrection
std::unique_ptr< egammaLayerRecalibTool > m_layerRecalibTool
float getEnergy(const xAOD::CaloCluster &clus, const xAOD::Egamma *eg, const egammaMVACalib::GlobalEventInfo &gei=egammaMVACalib::GlobalEventInfo()) const override final
returns the calibrated energy
ToolHandle< IegammaCellRecoveryTool > m_egammaCellRecoveryTool
Pointer to the egammaCellRecoveryTool.
Gaudi::Property< bool > m_clusterEif0
const CaloClusterCellLink * getCellLinks() const
Get a pointer to the CaloClusterCellLink object (const version)
virtual double eta() const
The pseudorapidity ( ) of the particle.
virtual double e() const
The total energy of the particle.
float energyBE(const unsigned layer) const
Get the energy in one layer of the EM Calo.
virtual double phi() const
The azimuthal angle ( ) of the particle.
const xAOD::CaloCluster * caloCluster(size_t index=0) const
Pointer to the xAOD::CaloCluster/s that define the electron candidate.
std::pair< std::vector< float >, std::vector< int64_t > > Inputs
float compute_ptconv(const xAOD::Photon *ph)
This ptconv is the old one used by MVACalib.
float compute_pt2conv(const xAOD::Photon *ph)
float compute_cl_eta(const xAOD::CaloCluster &cluster)
float compute_pt1conv(const xAOD::Photon *ph)
std::size_t numberOfSiTracks(const xAOD::Photon *eg)
return the number of Si tracks in the conversion
float conversionRadius(const xAOD::Vertex *vx)
return the conversion radius or 9999.
EventInfo_v1 EventInfo
Definition of the latest event info version.
CaloCluster_v1 CaloCluster
Define the latest version of the calorimeter cluster class.
Egamma_v1 Egamma
Definition of the current "egamma version".
Definition Egamma.h:17
Photon_v1 Photon
Definition of the current "egamma version".
A structure holding some global event information.
std::array< float, 4 > scaleEs
const xAOD::EventInfo * eventInfo