ATLAS Offline Software
Loading...
Searching...
No Matches
CaloClusterMLCalibToolLite.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
9
10#include "GaudiKernel/SystemOfUnits.h"
11
12CaloClusterMLCalibToolLite::CaloClusterMLCalibToolLite(const std::string &type, const std::string &name, const IInterface *parent) : base_class(type, name, parent) {}
13
15
17{
18 ATH_MSG_DEBUG("Initializing " << name() << "...");
20 for (int i = 0; i < m_numFeatures; i++)
21 {
22 PreprocessTransform transform;
24 if (funcIt != CaloClusterMLCalib::TRANSFORMATIONS.end())
25 {
26 transform.processor = funcIt->second;
27 std::vector<float> floatParams;
28 for (double param : m_preprocessingTransformParams[i])
29 {
30 floatParams.push_back(static_cast<float>(param));
31 }
32 transform.parameters = std::move(floatParams);
33 m_featurePreprocessingTransforms.push_back(std::move(transform));
34 }
35 else
36 {
37 ATH_MSG_WARNING("Undefined transformation " << m_preprocessingTransformNames[i]);
38 return StatusCode::FAILURE;
39 }
40 }
41
42 ATH_CHECK(m_onnxTool.retrieve());
43 return StatusCode::SUCCESS;
44}
45
47 int nPrimVtx,
48 float avgMu,
49 std::vector<double> &clusterE_ML_vec,
50 std::vector<double> &clusterE_ML_Unc_vec) const
51{
52 ATH_MSG_DEBUG("Executing " << name() << "...");
53
54 double clusterE = 0;
55 double clusterEta = 0;
56 double cluster_SIGNIFICANCE = 0;
57 double cluster_time = 0;
58 double cluster_SECOND_TIME = 0;
59 double cluster_CENTER_LAMBDA = 0;
60 double cluster_CENTER_MAG = 0;
61 double cluster_ENG_FRAC_EM_INCL = 0;
62 double cluster_FIRST_ENG_DENS = 0;
63 double cluster_LONGITUDINAL = 0;
64 double cluster_LATERAL = 0;
65 double cluster_PTD = 0;
66 double cluster_ISOLATION = 0;
67
68 std::vector<float> transformedFeatures;
69 bool ok{}; // for checking return value of cluster->retrieveMoment
70
71 for (const xAOD::CaloCluster *cluster : clusters)
72 {
73 clusterE = cluster->e(xAOD::CaloCluster::UNCALIBRATED) / Gaudi::Units::GeV;
74 clusterEta = cluster->eta(xAOD::CaloCluster::UNCALIBRATED);
75 ok = cluster->retrieveMoment(xAOD::CaloCluster::MomentType::SIGNIFICANCE, cluster_SIGNIFICANCE);
76 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::SECOND_TIME, cluster_SECOND_TIME);
77 cluster_SECOND_TIME /= (Gaudi::Units::nanosecond * Gaudi::Units::nanosecond);
78 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::CENTER_LAMBDA, cluster_CENTER_LAMBDA);
79 cluster_CENTER_LAMBDA /= Gaudi::Units::millimeter;
80 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::CENTER_MAG, cluster_CENTER_MAG);
81 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::FIRST_ENG_DENS, cluster_FIRST_ENG_DENS);
82 cluster_FIRST_ENG_DENS /= (Gaudi::Units::GeV / Gaudi::Units::millimeter3);
83 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::LONGITUDINAL, cluster_LONGITUDINAL);
84 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::LATERAL, cluster_LATERAL);
85 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::PTD, cluster_PTD);
86 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::ISOLATION, cluster_ISOLATION);
87 cluster_time = cluster->time() / Gaudi::Units::nanosecond;
88
89 if (!ok) {
90 ATH_MSG_ERROR("retrieveMoment() failed for " << cluster);
91 return StatusCode::FAILURE;
92 }
93
94 float e_EM = 0.0;
95 for (size_t s = CaloSampling::PreSamplerB; s < CaloSampling::Unknown; s++)
96 {
97 if (s == CaloSampling::EMB1 || s == CaloSampling::EMB2 || s == CaloSampling::EMB3 || s == CaloSampling::EME1 || s == CaloSampling::EME2 || s == CaloSampling::EME3 || s == CaloSampling::FCAL0)
98 {
99 e_EM += cluster->eSample(static_cast<xAOD::CaloCluster::CaloSample>(s));
100 }
101 }
102 cluster_ENG_FRAC_EM_INCL = e_EM / cluster->rawE();
103
104 std::vector<float> rawValues = {
105 static_cast<float>(clusterE),
106 static_cast<float>(clusterEta),
107 static_cast<float>(cluster_SIGNIFICANCE),
108 static_cast<float>(cluster_time),
109 static_cast<float>(cluster_SECOND_TIME),
110 static_cast<float>(cluster_CENTER_LAMBDA),
111 static_cast<float>(cluster_CENTER_MAG),
112 static_cast<float>(cluster_ENG_FRAC_EM_INCL),
113 static_cast<float>(cluster_FIRST_ENG_DENS),
114 static_cast<float>(cluster_LONGITUDINAL),
115 static_cast<float>(cluster_LATERAL),
116 static_cast<float>(cluster_PTD),
117 static_cast<float>(cluster_ISOLATION),
118 static_cast<float>(nPrimVtx),
119 avgMu
120 };
121
122 for (int i = 0; i < m_numFeatures; i++)
123 {
125 const float raw = rawValues.at(i);
126 float transformed = transform.processor(raw, transform.parameters);
127 transformedFeatures.push_back(transformed);
128 }
129 }
130
131 int numClusters = clusters.size();
132 std::vector<int64_t> inputShape = {numClusters, 15};
133
134 AthInfer::InputDataMap inputData;
135 inputData["features"] = std::make_pair(
136 inputShape, std::move(transformedFeatures));
137
138 AthInfer::OutputDataMap outputData;
139
140 outputData["mus"] = std::make_pair(
141 std::vector<int64_t>{numClusters, 3}, std::vector<float>{});
142 outputData["sigmas"] = std::make_pair(
143 std::vector<int64_t>{numClusters, 3}, std::vector<float>{});
144 outputData["alphas"] = std::make_pair(
145 std::vector<int64_t>{numClusters, 3}, std::vector<float>{});
146
147 ATH_CHECK(m_onnxTool->inference(inputData, outputData));
148
149 std::vector<float> &onnx_mus = std::get<std::vector<float>>(outputData["mus"].second);
150 std::vector<float> &onnx_sigma2s = std::get<std::vector<float>>(outputData["sigmas"].second);
151 std::vector<float> &onnx_alphas = std::get<std::vector<float>>(outputData["alphas"].second);
152
153 if (msgLvl(MSG::DEBUG)) {
154 int nan_in_mus = 0;
155 int nan_in_sigma2s = 0;
156 int nan_in_alphas = 0;
157
158 for (float val : onnx_mus)
159 if (std::isnan(val))
160 nan_in_mus++;
161
162 for (float val : onnx_sigma2s)
163 if (std::isnan(val))
164 nan_in_sigma2s++;
165
166 for (float val : onnx_alphas)
167 if (std::isnan(val))
168 nan_in_alphas++;
169
170 if (nan_in_mus > 0)
171 ATH_MSG_DEBUG(nan_in_mus << " NaN value found in `mus` output layer during ONNX inference");
172
173 if (nan_in_sigma2s)
174 ATH_MSG_DEBUG(nan_in_sigma2s << " NaN value found in `sigmas` output layer during ONNX inference");
175
176 if (nan_in_alphas)
177 ATH_MSG_DEBUG(nan_in_alphas << " NaN value found in `alphas` output layer during ONNX inference");
178 }
179
180 clusterE_ML_vec.clear();
181 clusterE_ML_Unc_vec.clear();
182 clusterE_ML_vec.reserve(numClusters);
183 clusterE_ML_Unc_vec.reserve(numClusters);
184
185 for (int i = 0; i < numClusters; ++i)
186 {
187 bool calibrateCluster = true;
188 for (size_t j=0; j<3; ++j) {
189 if (std::isnan(onnx_mus[i*3+j]) || std::isnan(onnx_sigma2s[i*3+j]) || std::isnan(onnx_alphas[i*3+j])) {
190 calibrateCluster = false;
191 break;
192 }
193 }
194 float r = 1.;
195 float s = 0.;
196 if (calibrateCluster) {
197 std::vector<float> current_mus = {onnx_mus[i * 3], onnx_mus[i * 3 + 1], onnx_mus[i * 3 + 2]};
198 std::vector<float> current_sigma2s = {onnx_sigma2s[i * 3], onnx_sigma2s[i * 3 + 1], onnx_sigma2s[i * 3 + 2]};
199 std::vector<float> current_alphas = {onnx_alphas[i * 3], onnx_alphas[i * 3 + 1], onnx_alphas[i * 3 + 2]};
200
201 float mode = CaloClusterMLCalib::modes(current_mus, current_sigma2s, current_alphas);
202 r = std::pow(10, mode);
203 float onnx_s = CaloClusterMLCalib::sigma_stoch(current_mus, current_sigma2s, current_alphas);
204 s = std::abs(std::log(10) * r) * onnx_s;
205
206 if (!std::isfinite(r) || std::abs(r) < 1e-6) {
207 ATH_MSG_WARNING("ML-correction factor to cluster energy (used as denominator) is " << r << "; The ML-correction factor is reset to 1. Uncertainty is set to 0.");
208 r = 1.0;
209 s = 0.0;
210 }
211 }
212
213 const double cluster_energy = clusters[i]->e(xAOD::CaloCluster::UNCALIBRATED) / static_cast<double>(r);
214
215 clusterE_ML_vec.push_back(cluster_energy);
216 clusterE_ML_Unc_vec.push_back(static_cast<double>(s));
217 }
218
219 return StatusCode::SUCCESS;
220}
221
223{
224 ATH_MSG_DEBUG("Finalizing " << name() << "...");
225 return StatusCode::SUCCESS;
226}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
Handle class for adding a decoration to an object.
virtual StatusCode inference(const xAOD::CaloClusterContainer &clusters, int nPrimVtx, float avgMu, std::vector< double > &clusterE_ML_vec, std::vector< double > &clusterE_ML_Unc_vec) const override
Gaudi::Property< std::vector< std::vector< double > > > m_preprocessingTransformParams
ToolHandle< AthInfer::IAthInferenceTool > m_onnxTool
virtual StatusCode initialize() override
CaloClusterMLCalibToolLite(const std::string &type, const std::string &name, const IInterface *parent)
Gaudi::Property< std::vector< std::string > > m_preprocessingTransformNames
virtual StatusCode finalize() override
std::vector< PreprocessTransform > m_featurePreprocessingTransforms
@ PTD
relative spread of pT of constiuent cells = sqrt(n)*RMS/Mean
@ SECOND_TIME
Second moment of cell time distribution in cluster.
@ LATERAL
Normalized lateral moment.
@ LONGITUDINAL
Normalized longitudinal moment.
@ FIRST_ENG_DENS
First Moment in E/V.
@ CENTER_LAMBDA
Shower depth at Cluster Centroid.
@ SIGNIFICANCE
Cluster significance.
@ CENTER_MAG
Cluster Centroid ( )
@ ISOLATION
Energy weighted fraction of non-clustered perimeter cells.
CaloSampling::CaloSample CaloSample
int r
Definition globals.cxx:22
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap
const std::map< std::string, TransformFunc > TRANSFORMATIONS
float modes(const std::vector< float > &mus, const std::vector< float > &log_sigma2s, const std::vector< float > &alphas)
float sigma_stoch(const std::vector< float > &mus, const std::vector< float > &log_sigma2s, const std::vector< float > &alphas)
CaloCluster_v1 CaloCluster
Define the latest version of the calorimeter cluster class.
CaloClusterContainer_v1 CaloClusterContainer
Define the latest version of the calorimeter cluster container.