49 std::vector<double> &clusterE_ML_vec,
50 std::vector<double> &clusterE_ML_Unc_vec)
const
55 double clusterEta = 0;
56 double cluster_SIGNIFICANCE = 0;
57 double cluster_time = 0;
58 double cluster_SECOND_TIME = 0;
59 double cluster_CENTER_LAMBDA = 0;
60 double cluster_CENTER_MAG = 0;
61 double cluster_ENG_FRAC_EM_INCL = 0;
62 double cluster_FIRST_ENG_DENS = 0;
63 double cluster_LONGITUDINAL = 0;
64 double cluster_LATERAL = 0;
65 double cluster_PTD = 0;
66 double cluster_ISOLATION = 0;
68 std::vector<float> transformedFeatures;
77 cluster_SECOND_TIME /= (Gaudi::Units::nanosecond * Gaudi::Units::nanosecond);
79 cluster_CENTER_LAMBDA /= Gaudi::Units::millimeter;
82 cluster_FIRST_ENG_DENS /= (Gaudi::Units::GeV / Gaudi::Units::millimeter3);
87 cluster_time = cluster->time() / Gaudi::Units::nanosecond;
91 return StatusCode::FAILURE;
95 for (
size_t s = CaloSampling::PreSamplerB; s < CaloSampling::Unknown; s++)
97 if (s == CaloSampling::EMB1 || s == CaloSampling::EMB2 || s == CaloSampling::EMB3 || s == CaloSampling::EME1 || s == CaloSampling::EME2 || s == CaloSampling::EME3 || s == CaloSampling::FCAL0)
102 cluster_ENG_FRAC_EM_INCL = e_EM / cluster->rawE();
104 std::vector<float> rawValues = {
105 static_cast<float>(clusterE),
106 static_cast<float>(clusterEta),
107 static_cast<float>(cluster_SIGNIFICANCE),
108 static_cast<float>(cluster_time),
109 static_cast<float>(cluster_SECOND_TIME),
110 static_cast<float>(cluster_CENTER_LAMBDA),
111 static_cast<float>(cluster_CENTER_MAG),
112 static_cast<float>(cluster_ENG_FRAC_EM_INCL),
113 static_cast<float>(cluster_FIRST_ENG_DENS),
114 static_cast<float>(cluster_LONGITUDINAL),
115 static_cast<float>(cluster_LATERAL),
116 static_cast<float>(cluster_PTD),
117 static_cast<float>(cluster_ISOLATION),
118 static_cast<float>(nPrimVtx),
125 const float raw = rawValues.at(i);
126 float transformed = transform.processor(raw, transform.parameters);
127 transformedFeatures.push_back(transformed);
131 int numClusters = clusters.size();
132 std::vector<int64_t> inputShape = {numClusters, 15};
135 inputData[
"features"] = std::make_pair(
136 inputShape, std::move(transformedFeatures));
140 outputData[
"mus"] = std::make_pair(
141 std::vector<int64_t>{numClusters, 3}, std::vector<float>{});
142 outputData[
"sigmas"] = std::make_pair(
143 std::vector<int64_t>{numClusters, 3}, std::vector<float>{});
144 outputData[
"alphas"] = std::make_pair(
145 std::vector<int64_t>{numClusters, 3}, std::vector<float>{});
149 std::vector<float> &onnx_mus = std::get<std::vector<float>>(outputData[
"mus"].second);
150 std::vector<float> &onnx_sigma2s = std::get<std::vector<float>>(outputData[
"sigmas"].second);
151 std::vector<float> &onnx_alphas = std::get<std::vector<float>>(outputData[
"alphas"].second);
153 if (msgLvl(MSG::DEBUG)) {
155 int nan_in_sigma2s = 0;
156 int nan_in_alphas = 0;
158 for (
float val : onnx_mus)
162 for (
float val : onnx_sigma2s)
166 for (
float val : onnx_alphas)
171 ATH_MSG_DEBUG(nan_in_mus <<
" NaN value found in `mus` output layer during ONNX inference");
174 ATH_MSG_DEBUG(nan_in_sigma2s <<
" NaN value found in `sigmas` output layer during ONNX inference");
177 ATH_MSG_DEBUG(nan_in_alphas <<
" NaN value found in `alphas` output layer during ONNX inference");
180 clusterE_ML_vec.clear();
181 clusterE_ML_Unc_vec.clear();
182 clusterE_ML_vec.reserve(numClusters);
183 clusterE_ML_Unc_vec.reserve(numClusters);
185 for (
int i = 0; i < numClusters; ++i)
187 bool calibrateCluster =
true;
188 for (
size_t j=0; j<3; ++j) {
189 if (std::isnan(onnx_mus[i*3+j]) || std::isnan(onnx_sigma2s[i*3+j]) || std::isnan(onnx_alphas[i*3+j])) {
190 calibrateCluster =
false;
196 if (calibrateCluster) {
197 std::vector<float> current_mus = {onnx_mus[i * 3], onnx_mus[i * 3 + 1], onnx_mus[i * 3 + 2]};
198 std::vector<float> current_sigma2s = {onnx_sigma2s[i * 3], onnx_sigma2s[i * 3 + 1], onnx_sigma2s[i * 3 + 2]};
199 std::vector<float> current_alphas = {onnx_alphas[i * 3], onnx_alphas[i * 3 + 1], onnx_alphas[i * 3 + 2]};
202 r = std::pow(10, mode);
204 s = std::abs(std::log(10) *
r) * onnx_s;
206 if (!std::isfinite(
r) || std::abs(
r) < 1e-6) {
207 ATH_MSG_WARNING(
"ML-correction factor to cluster energy (used as denominator) is " <<
r <<
"; The ML-correction factor is reset to 1. Uncertainty is set to 0.");
215 clusterE_ML_vec.push_back(cluster_energy);
216 clusterE_ML_Unc_vec.push_back(
static_cast<double>(s));
219 return StatusCode::SUCCESS;