ATLAS Offline Software
Loading...
Searching...
No Matches
PFEnergyPredictorTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
6#include "eflowCaloObject.h"
8#include "eflowRecCluster.h"
9
10#include "CaloGeoHelpers/CaloSampling.h"
11
12PFEnergyPredictorTool::PFEnergyPredictorTool(const std::string& type, const std::string& name, const IInterface* parent) : AthAlgTool(type, name, parent)
13{
14
15}
16
17
19{
20 ATH_MSG_DEBUG("Initializing " << name());
21 if(m_model_filepath == "////"){
22 ATH_MSG_WARNING("model not provided tool will not work");
23 return StatusCode::SUCCESS;
24 }
25 ATH_CHECK(m_svc.retrieve());
26 std::string path = m_model_filepath;//Add path resolving code
27
28 Ort::SessionOptions session_options;
29 Ort::AllocatorWithDefaultOptions allocator;
30 session_options.SetIntraOpNumThreads(1);
31 session_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
32 m_session = std::make_unique<Ort::Session>(m_svc->env(), path.c_str(), session_options);
33
34 ATH_MSG_INFO("Created ONNX runtime session with model " << path);
35
36 size_t num_input_nodes = m_session->GetInputCount();
37 m_input_node_names.resize(num_input_nodes);
38
39 for (std::size_t i = 0; i < num_input_nodes; i++) {
40 // print input node names
41 char* input_name = m_session->GetInputNameAllocated(i, allocator).release();
42 ATH_MSG_INFO("Input " << i << " : "
43 << " name= " << input_name);
44 m_input_node_names[i] = input_name;
45 // print input node types
46 Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
47 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
48 ONNXTensorElementDataType type = tensor_info.GetElementType();
49 ATH_MSG_INFO("Input " << i << " : "
50 << " type= " << type);
51
52 // print input shapes/dims
53 m_input_node_dims = tensor_info.GetShape();
54 m_input_node_dims[1] = 5430/5;
55 ATH_MSG_INFO("Input " << i << " : num_dims= " << m_input_node_dims.size());
56 for (std::size_t j = 0; j < m_input_node_dims.size(); j++) {
57 if (m_input_node_dims[j] < 0) m_input_node_dims[j] = 1;
58 ATH_MSG_INFO("Input " << i << " : dim " << j << "= " << m_input_node_dims[j]);
59 }
60 }
61
62 // output nodes
63 std::vector<int64_t> output_node_dims;
64 size_t num_output_nodes = m_session->GetOutputCount();
65 ATH_MSG_INFO("Have output nodes " << num_output_nodes);
66 m_output_node_names.resize(num_output_nodes);
67
68 for (std::size_t i = 0; i < num_output_nodes; i++) {
69 // print output node names
70 char* output_name = m_session->GetOutputNameAllocated(i, allocator).release();
71 ATH_MSG_INFO("Output " << i << " : "
72 << " name= " << output_name);
73 m_output_node_names[i] = output_name;
74
75 Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(i);
76 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
77 ONNXTensorElementDataType type = tensor_info.GetElementType();
78 ATH_MSG_INFO("Output " << i << " : "
79 << " type= " << type);
80
81 // print output shapes/dims
82 output_node_dims = tensor_info.GetShape();
83 ATH_MSG_INFO("Output " << i << " : num_dims= " << output_node_dims.size());
84 for (std::size_t j = 0; j < output_node_dims.size(); j++) {
85 if (output_node_dims[j] < 0) output_node_dims[j] = 1;
86 ATH_MSG_INFO("Output" << i << " : dim " << j << "= " << output_node_dims[j]);
87 }
88 }
89
90 return StatusCode::SUCCESS;
91}
92
93float PFEnergyPredictorTool::runOnnxInference(std::vector<float> &tensor) const {
94 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
95 auto input_tensor_size = tensor.size();
96
97 Ort::Value input_tensor =
98 Ort::Value::CreateTensor<float>(memory_info, tensor.data(), input_tensor_size,
100
101 auto output_tensors = m_session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), &input_tensor, m_input_node_names.size(),
103
104 const float *output_score_array = output_tensors.front().GetTensorData<float>();
105
106 // Binary classification - the score is just the first element of the output tensor
107 float output_score = output_score_array[0];
108
109 return output_score;
110}
111
112std::array<double,19> getEtaTrackCalo(const eflowTrackCaloPoints& trackCaloPoints) {
113 return std::array<double,19> { trackCaloPoints.getEta(eflowCalo::EMB1), trackCaloPoints.getEta(eflowCalo::EMB2), trackCaloPoints.getEta(eflowCalo::EMB3),
114 trackCaloPoints.getEta(eflowCalo::EME1), trackCaloPoints.getEta(eflowCalo::EME2), trackCaloPoints.getEta(eflowCalo::EME3),
115 trackCaloPoints.getEta(eflowCalo::HEC1), trackCaloPoints.getEta(eflowCalo::HEC2), trackCaloPoints.getEta(eflowCalo::HEC3),trackCaloPoints.getEta(eflowCalo::HEC4),
116 trackCaloPoints.getTileEta(CaloSampling::TileBar0),trackCaloPoints.getTileEta(CaloSampling::TileBar1),trackCaloPoints.getTileEta(CaloSampling::TileBar2),
117 trackCaloPoints.getTileEta(CaloSampling::TileGap1),trackCaloPoints.getTileEta(CaloSampling::TileGap2),trackCaloPoints.getTileEta(CaloSampling::TileGap3),
118 trackCaloPoints.getTileEta(CaloSampling::TileExt0),trackCaloPoints.getTileEta(CaloSampling::TileExt1),trackCaloPoints.getTileEta(CaloSampling::TileExt2)};
119}
120
121
122std::array<double,19> getPhiTrackCalo(const eflowTrackCaloPoints& trackCaloPoints) {
123 return std::array<double,19> { trackCaloPoints.getPhi(eflowCalo::EMB1), trackCaloPoints.getPhi(eflowCalo::EMB2), trackCaloPoints.getPhi(eflowCalo::EMB3),
124 trackCaloPoints.getPhi(eflowCalo::EME1), trackCaloPoints.getPhi(eflowCalo::EME2), trackCaloPoints.getPhi(eflowCalo::EME3),
125 trackCaloPoints.getPhi(eflowCalo::HEC1), trackCaloPoints.getPhi(eflowCalo::HEC2), trackCaloPoints.getPhi(eflowCalo::HEC3),trackCaloPoints.getPhi(eflowCalo::HEC4),
126 trackCaloPoints.getTilePhi(CaloSampling::TileBar0),trackCaloPoints.getTilePhi(CaloSampling::TileBar1),trackCaloPoints.getTilePhi(CaloSampling::TileBar2),
127 trackCaloPoints.getTilePhi(CaloSampling::TileGap1),trackCaloPoints.getTilePhi(CaloSampling::TileGap2),trackCaloPoints.getTilePhi(CaloSampling::TileGap3),
128 trackCaloPoints.getTilePhi(CaloSampling::TileExt0),trackCaloPoints.getTilePhi(CaloSampling::TileExt1),trackCaloPoints.getTilePhi(CaloSampling::TileExt2)};
129}
130
131
133
134 constexpr std::array<int,19> calo_numbers{1,2,3,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20};
135 constexpr std::array<int,12> fixed_r_numbers = {1,2,3,12,13,14,15,16,17,18,19,20};
136 constexpr std::array<double,12> fixed_r_vals = {1532.18, 1723.89, 1923.02, 2450.00, 2995.00, 3630.00, 3215.00,
137 3630.00, 2246.50, 2450.00, 2870.00, 3480.00
138 };
139 constexpr std::array<int, 7> fixed_z_numbers = {5,6,7,8,9,10,11};
140 constexpr std::array<double, 7> fixed_z_vals = {3790.03, 3983.68, 4195.84, 4461.25, 4869.50, 5424.50, 5905.00};
141 std::unordered_map<int, double> r_calo_dict;//change to flatmap in c++23
142 std::unordered_map<int, double> z_calo_dict;
143 for(size_t i=0; i<fixed_r_vals.size(); i++) r_calo_dict[fixed_r_numbers[i]] = fixed_r_vals[i];
144 for(size_t i=0; i<fixed_z_numbers.size(); i++) z_calo_dict[fixed_z_numbers[i]] = fixed_z_vals[i];
145
146 std::vector<float> inputnn;
147 inputnn.assign(5430, 0.0);
148 std::vector<eflowRecCluster*> matchedClusters;
149 const std::vector<eflowTrackClusterLink*>& links = ptr->getClusterMatches();
150
151 std::array<double, 19> etatotal = getEtaTrackCalo(ptr->getTrackCaloPoints());
152 std::array<double, 19> phitotal = getPhiTrackCalo(ptr->getTrackCaloPoints());
153
154 const std::array<double, 2> track{ptr->getTrack()->eta(), ptr->getTrack()->phi()};
155
156 for(auto *clink : links){
157 auto *cell = clink->getCluster()->getCluster();
158 float clusterE = cell->e()*1e-3;
159 float clusterEta = cell->eta();
160
161 if (clusterE < 0.0 || clusterE > 1e4f || std::abs(clusterEta) > 2.5) continue;
162
163 constexpr bool cutOnR = false;
164 if(cutOnR){
165 std::array<double, 2> p{clink->getCluster()->getCluster()->eta(), clink->getCluster()->getCluster()->phi()};
166 double part1 = p[0] - track[0];
167 double part2 = p[1] - track[1];
168 while(part1 > M_PI) part1 -= 2*M_PI;
169 while(part1 < -M_PI) part1 += 2*M_PI;
170 while(part2 > M_PI) part2 -= 2*M_PI;
171 while(part2 < -M_PI) part2 += 2*M_PI;
172 double R = std::sqrt(part1 * part1 + part2*part2);
173 if(R >= 1.2) continue;
174 }
175
176 matchedClusters.push_back(clink->getCluster());
177 }
178
179 std::vector<std::array<double, 5>> cells;
180
181 const eflowTrackCaloPoints& trackCaloPoints = ptr->getTrackCaloPoints();
182 bool trk_bool_em[2] = {false,false};
183 std::array<double,2> trk_em_eta = {trackCaloPoints.getEta(eflowCalo::EMB2), trackCaloPoints.getEta(eflowCalo::EME2)};
184 std::array<double,2> trk_em_phi = {trackCaloPoints.getPhi(eflowCalo::EMB2), trackCaloPoints.getPhi(eflowCalo::EME2)};
185 double eta_ctr;
186 double phi_ctr;
187 for(int i =0; i<2; i++) {
188 trk_bool_em[i] = std::abs(trk_em_eta[i]) < 2.5 && std::abs(trk_em_phi[i]) <= M_PI;
189 }
190 int nProj_em = (int)trk_bool_em[0] + (int)trk_bool_em[1];
191
192 if(nProj_em ==1) {
193 eta_ctr = trk_bool_em[0] ? trk_em_eta[0] : trk_em_eta[1];
194 phi_ctr = trk_bool_em[0] ? trk_em_phi[0] : trk_em_phi[1];
195 } else if(nProj_em==2) {
196 eta_ctr = (trk_em_eta[0] + trk_em_eta[1]) / 2.0;
197 phi_ctr = (trk_em_phi[0] + trk_em_phi[1]) / 2.0;
198 } else {
199 eta_ctr = ptr->getTrack()->eta();
200 phi_ctr = ptr->getTrack()->phi();
201 }
202
203
204
205 for(auto *cptr : matchedClusters){
206 auto *clustlink = cptr->getCluster();
207
208 for(auto it_cell = clustlink->cell_begin(); it_cell != clustlink->cell_end(); it_cell++){
209 const CaloCell* cell = (*it_cell);
210 float cellE = cell->e()*(it_cell.weight())*1e-3f;
211 if(cellE < 0.005) continue;//Cut from ntuple maker
212 const auto *theDDE=it_cell->caloDDE();
213 double cx=theDDE->x();
214 double cy=theDDE->y();
215
216 cells.emplace_back( std::array<double, 5> { cellE,
217 theDDE->eta() - eta_ctr,
218 theDDE->phi() - phi_ctr,
219 std::hypot(cx,cy), //rperp
220 0.0 } );
221 }
222 }
223
224
225 std::vector<bool> trk_bool(calo_numbers.size(), false);
226 std::vector<std::array<double,4>> trk_full(calo_numbers.size());
227 for(size_t j=0; j<phitotal.size(); j++) {
228 int cnum = calo_numbers[j];
229 double eta = etatotal[j];
230 double phi = phitotal[j];
231 if(std::abs(eta) < 2.5 && std::abs(phi) <= M_PI) {
232 trk_bool[j] = true;
233 trk_full[j][0] = eta;
234 trk_full[j][1] = phi;
235 trk_full[j][3] = cnum;
236 double rPerp =-99999;
237 if(auto itr = r_calo_dict.find(cnum); itr != r_calo_dict.end()) rPerp = itr->second;
238 else if(auto itr = z_calo_dict.find(cnum); itr != z_calo_dict.end())
239 {
240 double z = itr->second;
241 if(eta != 0.0){
242 double aeta = std::abs(eta);
243 rPerp = z*2.*std::exp(aeta)/(std::exp(2.0*aeta)-1.0);
244 }else rPerp =0.0; //Check if this makes sense
245 } else {
246 throw std::runtime_error("Calo sample num not found in dicts..");
247 }
248 trk_full[j][2] = rPerp;
249 } else {
250 trk_full[j].fill(0.0);
251 }
252 }
253 double trackP = std::abs(1. / ptr->getTrack()->qOverP()) * 1e-3;
254 int trk_proj_num = std::accumulate(trk_bool.begin(), trk_bool.end(), 0);
255 if(trk_proj_num ==0) {
256 trk_proj_num =1;
257 std::array<double,5> trk_arr{};
258
259 trk_arr[0] = trackP;
260 trk_arr[1] = ptr->getTrack()->eta() - eta_ctr;
261 trk_arr[2] = ptr->getTrack()->phi() - phi_ctr;
262 trk_arr[3] = 1532.18; // just place it in EMB1
263 trk_arr[4] = 1.;
264
265 cells.emplace_back(trk_arr);
266 } else {
267 for(size_t i =0; i<calo_numbers.size(); i++) {
268 if(!trk_bool[i]) continue;
269 std::array<double,5> trk_arr{};
270 trk_arr[0]= trackP/double(trk_proj_num);
271 trk_arr[1]= trk_full[i][0] - eta_ctr;
272 trk_arr[2]= trk_full[i][1] - phi_ctr;
273 trk_arr[3]= trk_full[i][2];
274 trk_arr[4]= 1.;
275
276 cells.emplace_back(trk_arr);
277 }
278 }
279
280 int index = 0;
281 for(auto &in : cells){
282 std::copy(in.begin(), in.end(), inputnn.begin() + index);
283 index+=5;
284 if(index >= static_cast<int>(inputnn.size()-4)) {
285 ATH_MSG_WARNING("Data exceeded tensor size");
286 break;
287 }
288 }
289
290 //Normalization prior to training
291 NormalizeTensor(inputnn, cells.size() * 5 );
292
293 float predictedEnergy = exp(runOnnxInference(inputnn)) * 1000.0;//Correct to MeV units
294 ATH_MSG_DEBUG("NN Predicted energy " << predictedEnergy);
295 return predictedEnergy;
296
297}
298
299void PFEnergyPredictorTool::NormalizeTensor(std::vector<float> &inputnn, size_t limit) const{
300 size_t i=0;
301 for(i =0;i<limit;i+=5){
302 auto &f = inputnn[i+3];
303 if(f!= 0.0f) f/= 3630.f;
304 auto &e = inputnn[i+0];
305 if(e!= 0.0f){
306 e = std::log(e);
307 e = (e - m_cellE_mean)/m_cellE_std;
308 }
309 auto &eta = inputnn[i+1];
310 if(eta!= 0.0) eta /= 0.7f;
311 auto &phi = inputnn[i+2];
312 if(phi!= 0.0) phi /= m_cellPhi_std;
313 }
314 if(i> inputnn.size()){
315 ATH_MSG_ERROR("Index exceeded tensor MEMORY CORRUPTION");
316 }
317}
318
319
320
322{
323 return StatusCode::SUCCESS;
324}
325
#define M_PI
Scalar eta() const
pseudorapidity method
Scalar phi() const
phi method
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
std::array< double, 19 > getPhiTrackCalo(const eflowTrackCaloPoints &trackCaloPoints)
std::array< double, 19 > getEtaTrackCalo(const eflowTrackCaloPoints &trackCaloPoints)
#define z
AthAlgTool(const std::string &type, const std::string &name, const IInterface *parent)
Constructor with parameters:
Data object for each calorimeter readout cell.
Definition CaloCell.h:57
void NormalizeTensor(std::vector< float > &tensor, size_t limit) const
float runOnnxInference(std::vector< float > &tensor) const
std::vector< const char * > m_input_node_names
Gaudi::Property< float > m_cellE_std
Gaudi::Property< std::string > m_model_filepath
virtual StatusCode finalize() override
std::vector< const char * > m_output_node_names
PFEnergyPredictorTool(const std::string &type, const std::string &name, const IInterface *parent)
float nnEnergyPrediction(const eflowRecTrack *ptr) const
virtual StatusCode initialize() override
Gaudi::Property< float > m_cellPhi_std
Gaudi::Property< float > m_cellE_mean
Normalization constants for the inputs to the onnx model.
std::vector< int64_t > m_input_node_dims
ServiceHandle< AthOnnx::IOnnxRuntimeSvc > m_svc
This class extends the information about a xAOD::Track.
This class stores a map of calorimeter layers and track parameters (the result of the track extrapola...
double getTileEta(CaloCell_ID::CaloSample layer) const
double getPhi(eflowCalo::LAYER layer) const
double getEta(eflowCalo::LAYER layer) const
double getTilePhi(CaloCell_ID::CaloSample layer) const
Definition index.py:1
Definition part1.py:1
Definition part2.py:1