51 double clusterEta = 0;
52 double cluster_SIGNIFICANCE = 0;
53 double cluster_time = 0;
54 double cluster_SECOND_TIME = 0;
55 double cluster_CENTER_LAMBDA = 0;
56 double cluster_CENTER_MAG = 0;
57 double cluster_ENG_FRAC_EM_INCL = 0;
58 double cluster_FIRST_ENG_DENS = 0;
59 double cluster_LONGITUDINAL = 0;
60 double cluster_LATERAL = 0;
61 double cluster_PTD = 0;
62 double cluster_ISOLATION = 0;
63 double clusterE_TRUTH = 0;
65 std::vector<float> transformedFeatures;
72 ok = cluster->retrieveMoment(xAOD::CaloCluster::MomentType::ENG_CALIB_TOT, clusterE_TRUTH);
74 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::SIGNIFICANCE, cluster_SIGNIFICANCE);
75 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::SECOND_TIME, cluster_SECOND_TIME);
82 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::LONGITUDINAL, cluster_LONGITUDINAL);
83 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::LATERAL, cluster_LATERAL);
84 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::PTD, cluster_PTD);
85 ok &= cluster->retrieveMoment(xAOD::CaloCluster::MomentType::ISOLATION, cluster_ISOLATION);
88 ATH_MSG_WARNING(
"CaloClusterMLCalibToolLite: retrieveMoment failed for "<<cluster);
98 cluster_ENG_FRAC_EM_INCL = e_EM / cluster->rawE();
100 std::vector<float> rawValues;
102 rawValues.push_back(clusterE);
103 rawValues.push_back(clusterEta);
104 rawValues.push_back(cluster_SIGNIFICANCE);
105 rawValues.push_back(cluster_time);
106 rawValues.push_back(cluster_SECOND_TIME);
107 rawValues.push_back(cluster_CENTER_LAMBDA);
108 rawValues.push_back(cluster_CENTER_MAG);
109 rawValues.push_back(cluster_ENG_FRAC_EM_INCL);
110 rawValues.push_back(cluster_FIRST_ENG_DENS);
111 rawValues.push_back(cluster_LONGITUDINAL);
112 rawValues.push_back(cluster_LATERAL);
113 rawValues.push_back(cluster_PTD);
114 rawValues.push_back(cluster_ISOLATION);
115 rawValues.push_back(nPrimVtx);
116 rawValues.push_back(avgMu);
121 const float &raw = rawValues.at(
i);
123 transformedFeatures.push_back(transformed);
128 std::vector<int64_t> inputShape = {numClusters, 15};
131 inputData[
"features"] = std::make_pair(
132 inputShape, std::move(transformedFeatures));
136 outputData[
"0"] = std::make_pair(
137 std::vector<int64_t>{numClusters, 3}, std::vector<float>{});
138 outputData[
"1"] = std::make_pair(
139 std::vector<int64_t>{numClusters, 3}, std::vector<float>{});
140 outputData[
"2"] = std::make_pair(
141 std::vector<int64_t>{numClusters, 3}, std::vector<float>{});
145 std::vector<float> &onnx_mus = std::get<std::vector<float>>(outputData[
"0"].second);
146 std::vector<float> &onnx_sigma2s = std::get<std::vector<float>>(outputData[
"1"].second);
147 std::vector<float> &onnx_alphas = std::get<std::vector<float>>(outputData[
"2"].second);
149 for (
float val : onnx_mus)
153 ATH_MSG_WARNING(
"NaN value found in `mus` output layer during ONNX inference");
156 for (
float val : onnx_sigma2s)
160 ATH_MSG_WARNING(
"NaN value found in `sigma2s` output layer during ONNX inference");
163 for (
float val : onnx_alphas)
167 ATH_MSG_WARNING(
"NaN value found in `alphas` output layer during ONNX inference");
171 clusterE_ML_vec.clear();
172 clusterE_ML_Unc_vec.clear();
173 clusterE_ML_vec.reserve(numClusters);
174 clusterE_ML_Unc_vec.reserve(numClusters);
176 for (
int i = 0;
i < numClusters; ++
i)
178 std::vector<float> current_mus = {onnx_mus[
i * 3], onnx_mus[
i * 3 + 1], onnx_mus[
i * 3 + 2]};
179 std::vector<float> current_sigma2s = {onnx_sigma2s[
i * 3], onnx_sigma2s[
i * 3 + 1], onnx_sigma2s[
i * 3 + 2]};
180 std::vector<float> current_alphas = {onnx_alphas[
i * 3], onnx_alphas[
i * 3 + 1], onnx_alphas[
i * 3 + 2]};
185 float s = std::abs(
std::log(10) *
r) * onnx_s;
188 clusterE_ML_Unc_vec.push_back(
static_cast<double>(
s));
191 return StatusCode::SUCCESS;