55 double clusterEta = 0;
56 double cluster_SIGNIFICANCE = 0;
57 double cluster_time = 0;
58 double cluster_SECOND_TIME = 0;
59 double cluster_CENTER_LAMBDA = 0;
60 double cluster_CENTER_MAG = 0;
61 double cluster_ENG_FRAC_EM_INCL = 0;
62 double cluster_FIRST_ENG_DENS = 0;
63 double cluster_LONGITUDINAL = 0;
64 double cluster_LATERAL = 0;
65 double cluster_PTD = 0;
66 double cluster_ISOLATION = 0;
68 std::vector<float> transformedFeatures;
75 ok = cluster->retrieveMoment(xAOD::CaloCluster::MomentType::SIGNIFICANCE, cluster_SIGNIFICANCE);
76 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::SECOND_TIME, cluster_SECOND_TIME);
83 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::LONGITUDINAL, cluster_LONGITUDINAL);
84 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::LATERAL, cluster_LATERAL);
85 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::PTD, cluster_PTD);
86 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::ISOLATION, cluster_ISOLATION);
91 return StatusCode::FAILURE;
102 cluster_ENG_FRAC_EM_INCL = e_EM / cluster->rawE();
104 std::vector<float> rawValues = {
105 static_cast<float>(clusterE),
106 static_cast<float>(clusterEta),
107 static_cast<float>(cluster_SIGNIFICANCE),
108 static_cast<float>(cluster_time),
109 static_cast<float>(cluster_SECOND_TIME),
110 static_cast<float>(cluster_CENTER_LAMBDA),
111 static_cast<float>(cluster_CENTER_MAG),
112 static_cast<float>(cluster_ENG_FRAC_EM_INCL),
113 static_cast<float>(cluster_FIRST_ENG_DENS),
114 static_cast<float>(cluster_LONGITUDINAL),
115 static_cast<float>(cluster_LATERAL),
116 static_cast<float>(cluster_PTD),
117 static_cast<float>(cluster_ISOLATION),
118 static_cast<float>(nPrimVtx),
125 const float raw = rawValues.at(
i);
127 transformedFeatures.push_back(transformed);
132 std::vector<int64_t> inputShape = {numClusters, 15};
135 inputData[
"features"] = std::make_pair(
136 inputShape, std::move(transformedFeatures));
140 outputData[
"mus"] = std::make_pair(
141 std::vector<int64_t>{numClusters, 3}, std::vector<float>{});
142 outputData[
"sigmas"] = std::make_pair(
143 std::vector<int64_t>{numClusters, 3}, std::vector<float>{});
144 outputData[
"alphas"] = std::make_pair(
145 std::vector<int64_t>{numClusters, 3}, std::vector<float>{});
149 std::vector<float> &onnx_mus = std::get<std::vector<float>>(outputData[
"mus"].second);
150 std::vector<float> &onnx_sigma2s = std::get<std::vector<float>>(outputData[
"sigmas"].second);
151 std::vector<float> &onnx_alphas = std::get<std::vector<float>>(outputData[
"alphas"].second);
155 int nan_in_sigma2s = 0;
156 int nan_in_alphas = 0;
158 for (
float val : onnx_mus)
162 for (
float val : onnx_sigma2s)
166 for (
float val : onnx_alphas)
171 ATH_MSG_DEBUG(nan_in_mus <<
" NaN value found in `mus` output layer during ONNX inference");
174 ATH_MSG_DEBUG(nan_in_sigma2s <<
" NaN value found in `sigmas` output layer during ONNX inference");
177 ATH_MSG_DEBUG(nan_in_alphas <<
" NaN value found in `alphas` output layer during ONNX inference");
180 clusterE_ML_vec.clear();
181 clusterE_ML_Unc_vec.clear();
182 clusterE_ML_vec.reserve(numClusters);
183 clusterE_ML_Unc_vec.reserve(numClusters);
185 for (
int i = 0;
i < numClusters; ++
i)
187 bool calibrateCluster =
true;
188 for (
size_t j=0; j<3; ++j) {
189 if (std::isnan(onnx_mus[
i*3+j]) || std::isnan(onnx_sigma2s[
i*3+j]) || std::isnan(onnx_alphas[
i*3+j])) {
190 calibrateCluster =
false;
196 if (calibrateCluster) {
197 std::vector<float> current_mus = {onnx_mus[
i * 3], onnx_mus[
i * 3 + 1], onnx_mus[
i * 3 + 2]};
198 std::vector<float> current_sigma2s = {onnx_sigma2s[
i * 3], onnx_sigma2s[
i * 3 + 1], onnx_sigma2s[
i * 3 + 2]};
199 std::vector<float> current_alphas = {onnx_alphas[
i * 3], onnx_alphas[
i * 3 + 1], onnx_alphas[
i * 3 + 2]};
206 if (!std::isfinite(
r) || std::abs(
r) < 1
e-6) {
207 ATH_MSG_WARNING(
"ML-correction factor to cluster energy (used as denominator) is " <<
r <<
"; The ML-correction factor is reset to 1. Uncertainty is set to 0.");
215 clusterE_ML_vec.push_back(cluster_energy);
216 clusterE_ML_Unc_vec.push_back(
static_cast<double>(
s));
219 return StatusCode::SUCCESS;