ATLAS Offline Software
Loading...
Searching...
No Matches
AsgForwardElectronCalibrationTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
22
23
25
27#include "xAODEgamma/Electron.h"
31#include "CaloGeoHelpers/CaloSampling.h"
32
33#include "lwtnn/parse_json.hh"
34
35#include <fstream>
36#include <cmath>
37
38
39
40
41
42//=============================================================================
43// Standard constructor
44//=============================================================================
45AsgForwardElectronCalibrationTool::AsgForwardElectronCalibrationTool(
46 const std::string& myname)
47 : AsgTool(myname)
48{
49}
50
51//=============================================================================
52// Standard destructor
53//=============================================================================
55
56// ============================================================================
57// Initialise
58// ============================================================================
60{
61 // Sanity checks
62 if (m_modelFiles.size() != 3) {
63 ATH_MSG_ERROR("Exactly 3 model files expected (one per eta bin): "
64 "[2.5,2.7], [2.7,3.2], [3.2,4.0]). Only got "
65 << m_modelFiles.size());
66 return StatusCode::FAILURE;
67 }
68
69 // Fixed variable names - must match the lwtnn JSON input
70 m_variables = {
71 // Calo eta and Phi
72 "x1_calo_eta",
73 "x2_calo_phi",
74 // ITk eta and Phi
75 "x3_track_eta",
76 "x4_track_phi",
77 // HGTD time
78 "x5_time",
79 // ITk hit counts
80 "x6_pixels",
81 "x7_strips",
82 // Shower shape moments
83 "x8_ENG_FRAC_MAX",
84 "x9_LONGITUDINAL",
85 "x10_SECOND_LAMBDA",
86 "x11_LATERAL",
87 "x12_SECOND_R",
88 "x13_CENTER_LAMBDA",
89 "x14_SECOND_ENG_DENS",
90 // Track-cluster matching variables
91 "x15_delta_eta2",
92 "x16_delta_phi2",
93 "x17_delta_phi_rescaled2",
94 "x18_delta_phi_last",
95 // Calo energy fractions in each layer
96 "x19_calo_frac_EM_1",
97 "x20_calo_frac_EM_2",
98 "x21_calo_frac_EM_3",
99 "x22_calo_frac_HAD_0",
100 "x23_calo_frac_HAD_1",
101 "x24_calo_frac_HAD_2",
102 "x25_calo_frac_HAD_3"
103 };
104 // More informations on Table 3 in Internal Note: https://cds.cern.ch/record/2922184
105
106 // Load one JSON per eta bin
107 m_graphs.reserve(3);
108 for (const auto& model : m_modelFiles) {
109 const std::string path = PathResolverFindCalibFile(model);
110 if (path.empty()) {
111 ATH_MSG_ERROR("Could not locate: " << model);
112 return StatusCode::FAILURE;
113 }
114 std::ifstream dnn_json(path);
115 auto parsed = lwt::parse_json_graph(dnn_json);
116 m_graphs.emplace_back(
117 std::make_unique<lwt::LightweightGraph>(parsed));
118 ATH_MSG_INFO("Loaded calibration model for bin "
119 << m_graphs.size() - 1 << ": " << path);
120 }
121
122 ATH_MSG_INFO("AsgForwardElectronCalibrationTool initialised");
123 return StatusCode::SUCCESS;
124}
125
126// ============================================================================
127// Calibration
128// ============================================================================
129double AsgForwardElectronCalibrationTool::calibrate(const EventContext& /*ctx*/,
130 const xAOD::Electron* eg) const
131{
132 if (!eg) {
133 ATH_MSG_ERROR("Failed, no Electron object.");
134 return -999.;
135 }
136
137 const xAOD::CaloCluster* cluster = eg->caloCluster();
138 if (!cluster) {
139 ATH_MSG_WARNING("Failed, no cluster.");
140 return -999.;
141 }
142
143 // ITk electron must have a track
144 const xAOD::TrackParticle* track = eg->trackParticle();
145 if (!track) {
146 ATH_MSG_WARNING("Failed, no track.");
147 return -999.;
148 }
149
150 const double absEta = std::abs(cluster->eta());
151 const int etaBin = getEtaBin(absEta);
152 if (etaBin < 0) {
153 ATH_MSG_WARNING("Electron |eta|=" << absEta
154 << " is outside allowed range.");
155 return -999.;
156 }
157
158
159 // Get input variables
160 std::vector<float> inputs;
161 if (!getInputs(eg, inputs)) return -999.;
162
163
164 // Compute the DNN output
165 std::map<std::string, std::map<std::string, double>> inputMap;
166 for (size_t i = 0; i < m_variables.size(); ++i)
167 inputMap["node_0"][m_variables[i]] = static_cast<float>(inputs[i]);
168
169 const auto outputs = m_graphs[etaBin]->compute(inputMap);
170 const double rawOut = outputs.begin()->second;
171 const double calibratedPt = unscalePt(softplus(rawOut));
172
173 // Round to 1 MeV precision due to lwtnn mismatch
174 double calibratedPt_rounded = std::round(calibratedPt);
175
176 return calibratedPt_rounded;
177}
178
179// ============================================================================
180// getEtaBin
181// ============================================================================
183{
184 // Convention: x1 < |eta| <= x2
185 if (absEta > 2.5 && absEta <= 2.7) return 0;
186 if (absEta > 2.7 && absEta <= 3.2) return 1;
187 if (absEta > 3.2 && absEta <= 4.0) return 2;
188 return -1;
189}
190
191// ============================================================================
192// getInputs
193// ============================================================================
195 std::vector<float>& inputs) const
196{
197 inputs.clear();
198 inputs.reserve(25);
199
200 const xAOD::CaloCluster* cluster = eg->caloCluster();
201 const xAOD::TrackParticle* track = eg->trackParticle();
202
203 if (!cluster) { ATH_MSG_ERROR("No CaloCluster."); return false; }
204 if (!track) { ATH_MSG_ERROR("No TrackParticle."); return false; }
205
206 // x1, x2 = calo eta and phi
207 inputs.push_back(static_cast<float>(cluster->eta()));
208 inputs.push_back(static_cast<float>(cluster->phi()));
209
210 // x3, x4 = track eta and phi
211 inputs.push_back(static_cast<float>(track->eta()));
212 inputs.push_back(static_cast<float>(track->phi()));
213
214 // x5 = HGTD time
215 if (track->hasValidTime()) inputs.push_back(static_cast<float>(track->time()));
216 else
217 {
218 ATH_MSG_DEBUG("No valid time for the track while doing track->time()" );
219 inputs.push_back(-99);
220 }
221
222 // x6, x7 = ITk hit counts
223 inputs.push_back(static_cast<float>(eg->trackParticleSummaryIntValue(xAOD::numberOfPixelHits)));
224 inputs.push_back(static_cast<float>(eg->trackParticleSummaryIntValue(xAOD::numberOfSCTHits)));
225
226 // x8 to x14 = 7 calorimeter shower-shape moments
227 auto getMoment = [&](xAOD::CaloCluster::MomentType type,
228 const char* name) -> float {
229 double val{0.};
230 if (!cluster->retrieveMoment(type, val))
231 ATH_MSG_WARNING("Could not retrieve calo moment: " << name);
232 return val;
233 };
234
235 inputs.push_back(getMoment(xAOD::CaloCluster::ENG_FRAC_MAX, "ENG_FRAC_MAX"));
236 inputs.push_back(getMoment(xAOD::CaloCluster::LONGITUDINAL, "LONGITUDINAL"));
237 inputs.push_back(getMoment(xAOD::CaloCluster::SECOND_LAMBDA, "SECOND_LAMBDA"));
238 inputs.push_back(getMoment(xAOD::CaloCluster::LATERAL, "LATERAL"));
239 inputs.push_back(getMoment(xAOD::CaloCluster::SECOND_R, "SECOND_R"));
240 inputs.push_back(getMoment(xAOD::CaloCluster::CENTER_LAMBDA, "CENTER_LAMBDA"));
241 inputs.push_back(getMoment(xAOD::CaloCluster::SECOND_ENG_DENS, "SECOND_ENG_DENS"));
242
243 // x15 to x18 = track-calo matching
245 const char* name) -> float {
246 float val{0.f};
247 if (!eg->trackCaloMatchValue(val, type))
248 ATH_MSG_WARNING("Could not retrieve track-calo match: " << name);
249 return val;
250 };
251
252 inputs.push_back(getMatch(xAOD::EgammaParameters::deltaEta2,
253 "delta_eta2"));
254 inputs.push_back(getMatch(xAOD::EgammaParameters::deltaPhi2,
255 "delta_phi2"));
256 inputs.push_back(getMatch(xAOD::EgammaParameters::deltaPhiRescaled2,
257 "delta_phi_rescaled2"));
259 "delta_phi_last"));
260
261 // x19 to x25 = Derived calorimeter energy fractions
262 // Prevent division by zero
263 const double caloE = cluster->e();
264 const double inv_E = (caloE != 0.) ? 1. / caloE : 0.;
265
266 using CS = CaloSampling::CaloSample;
267
268 // EM fractions:
269 // f_i^EM = energyBE(i) / caloE for i in {1,2,3}
270 inputs.push_back(static_cast<float>(cluster->energyBE(1) * inv_E));
271 inputs.push_back(static_cast<float>(cluster->energyBE(2) * inv_E));
272 inputs.push_back(static_cast<float>(cluster->energyBE(3) * inv_E));
273
274 // HAD fractions:
275 // f_0^HAD = HEC0 / caloE
276 // f_i^HAD = (HEC_i + FCAL_(i-1)) / caloE for i in {1,2,3}
277 inputs.push_back(static_cast<float>(
278 cluster->eSample(static_cast<CS>(CaloSampling::HEC0)) * inv_E));
279 inputs.push_back(static_cast<float>(
280 (cluster->eSample(static_cast<CS>(CaloSampling::HEC1)) +
281 cluster->eSample(static_cast<CS>(CaloSampling::FCAL0))) * inv_E));
282 inputs.push_back(static_cast<float>(
283 (cluster->eSample(static_cast<CS>(CaloSampling::HEC2)) +
284 cluster->eSample(static_cast<CS>(CaloSampling::FCAL1))) * inv_E));
285 inputs.push_back(static_cast<float>(
286 (cluster->eSample(static_cast<CS>(CaloSampling::HEC3)) +
287 cluster->eSample(static_cast<CS>(CaloSampling::FCAL2))) * inv_E));
288
289 return true;
290}
291
292// ============================================================================
293// Helpers
294// ============================================================================
295
297{
298 // Avoids overflow for large x
299 return (x > 20.) ? x : std::log1p(std::exp(x));
300}
301
303{
304 // Inverse MinMaxScaler with feature_range=(0,1)
305 // Trained with:
306 // pTMin = 10 GeV
307 // pTMax = 255 GeV
308 return x * (m_pTMax - m_pTMin) + m_pTMin;
309}
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
#define x
int getEtaBin(double absEta) const
Select the eta bin index for |eta|, three eta bins: Bin 0: 2.5 < |eta| <= 2.7 Bin 1: 2....
double unscalePt(double x) const
Undo MinMax pT scaling used during training.
Gaudi::Property< std::vector< std::string > > m_modelFiles
One lwtnn JSON file / DNN per eta bin.
double calibrate(const EventContext &ctx, const xAOD::Electron *eg) const override
Return the DNN-calibrated pT in MeV Returns -999 on error.
bool getInputs(const xAOD::Electron *eg, std::vector< float > &inputs) const
Get 25 input variables Returns false on failure.
virtual StatusCode initialize() override
Gaudi Service Interface method implementations.
Gaudi::Property< double > m_pTMin
pT scaling range used [MeV]
double softplus(double x) const
Sftplus: log(1 + exp(x)).
std::vector< std::string > m_variables
Input variable names.
virtual ASG_TOOL_CLASS1(AsgForwardElectronCalibrationTool, IForwardElectronCalib) public ~AsgForwardElectronCalibrationTool()
Standard constructor.
std::vector< std::unique_ptr< lwt::LightweightGraph > > m_graphs
bool retrieveMoment(MomentType type, double &value) const
Retrieve individual moment.
virtual double eta() const
The pseudorapidity ( ) of the particle.
virtual double e() const
The total energy of the particle.
float eSample(const CaloSample sampling) const
float energyBE(const unsigned layer) const
Get the energy in one layer of the EM Calo.
virtual double phi() const
The azimuthal angle ( ) of the particle.
MomentType
Enums to identify different moments.
@ SECOND_ENG_DENS
Second Moment in E/V.
@ SECOND_LAMBDA
Second Moment in .
@ LATERAL
Normalized lateral moment.
@ LONGITUDINAL
Normalized longitudinal moment.
@ ENG_FRAC_MAX
Energy fraction of hottest cell.
@ SECOND_R
Second Moment in .
@ CENTER_LAMBDA
Shower depth at Cluster Centroid.
const xAOD::CaloCluster * caloCluster(size_t index=0) const
Pointer to the xAOD::CaloCluster/s that define the electron candidate.
bool trackCaloMatchValue(float &value, const EgammaParameters::TrackCaloMatchType information) const
Accessor for Track to Calo Match Values.
const xAOD::TrackParticle * trackParticle(size_t index=0) const
Pointer to the xAOD::TrackParticle/s that match the electron candidate.
uint8_t trackParticleSummaryIntValue(const SummaryType information, int index=0) const
Accessor to the matching track(s) int information (index = 0 is the best match) Will lead to an excep...
@ deltaPhiFromLastMeasurement
difference between the cluster phi (sampling 2) and the eta of the track extrapolated from the last m...
@ deltaEta2
difference between the cluster eta (second sampling) and the eta of the track extrapolated to the sec...
@ deltaPhiRescaled2
difference between the cluster phi (second sampling) and the phi of the track extrapolated to the sec...
@ deltaPhi2
difference between the cluster phi (second sampling) and the phi of the track extrapolated to the sec...
CaloCluster_v1 CaloCluster
Define the latest version of the calorimeter cluster class.
TrackParticle_v1 TrackParticle
Reference the current persistent version:
@ numberOfSCTHits
number of hits in SCT [unit8_t].
@ numberOfPixelHits
these are the pixel hits, including the b-layer [unit8_t].
Electron_v1 Electron
Definition of the current "egamma version".