ATLAS Offline Software
Loading...
Searching...
No Matches
CaloMuonScoreTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3*/
4
5#include "CaloMuonScoreTool.h"
6
7#include <cmath>
8#include <iostream>
9#include <map>
10#include <string>
11
14#include "GaudiKernel/SystemOfUnits.h"
20
22// CaloMuonScoreTool constructor
24CaloMuonScoreTool::CaloMuonScoreTool(const std::string &type, const std::string &name, const IInterface *parent) :
25 AthAlgTool(type, name, parent) {
26 declareInterface<ICaloMuonScoreTool>(this);
27}
28
30// CaloMuonScoreTool::initialize
33 ATH_MSG_INFO("Initializing " << name());
34
35 ATH_CHECK(m_svc.retrieve());
37
38 std::string model_file_name = PathResolverFindCalibFile(m_modelFileName);
39
40 if (m_modelFileName.empty() || model_file_name.empty()) {
41 ATH_MSG_FATAL("Could not find the requested ONNX model file: " << m_modelFileName);
43 "Please make sure it exists in the ATLAS calibration area (https://atlas-groupdata.web.cern.ch/atlas-groupdata/), and provide "
44 "a model file name relative to the root of the calibration area.");
45
46 return StatusCode::FAILURE;
47 }
48
49 // initialise session
50 Ort::SessionOptions session_options;
51 Ort::AllocatorWithDefaultOptions allocator;
52 session_options.SetIntraOpNumThreads(1);
53 session_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
54
55 m_session = std::make_unique<Ort::Session>(m_svc->env(), model_file_name.c_str(), session_options);
56
57 ATH_MSG_INFO("Created ONNX runtime session with model " << model_file_name);
58
59 size_t num_input_nodes = m_session->GetInputCount();
60 m_input_node_names.resize(num_input_nodes);
61
62 for (std::size_t i = 0; i < num_input_nodes; i++) {
63 // print input node names
64 char* input_name = m_session->GetInputNameAllocated(i, allocator).release();
65 ATH_MSG_INFO("Input " << i << " : "
66 << " name= " << input_name);
67 m_input_node_names[i] = input_name;
68 // print input node types
69 Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
70 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
71 ONNXTensorElementDataType type = tensor_info.GetElementType();
72 ATH_MSG_INFO("Input " << i << " : "
73 << " type= " << type);
74
75 // print input shapes/dims
76 m_input_node_dims = tensor_info.GetShape();
77 ATH_MSG_INFO("Input " << i << " : num_dims= " << m_input_node_dims.size());
78 for (std::size_t j = 0; j < m_input_node_dims.size(); j++) {
79 if (m_input_node_dims[j] < 0) m_input_node_dims[j] = 1;
80 ATH_MSG_INFO("Input " << i << " : dim " << j << "= " << m_input_node_dims[j]);
81 }
82 }
83
84 // output nodes
85 std::vector<int64_t> output_node_dims;
86 size_t num_output_nodes = m_session->GetOutputCount();
87 ATH_MSG_INFO("Have output nodes " << num_output_nodes);
88 m_output_node_names.resize(num_output_nodes);
89
90 for (std::size_t i = 0; i < num_output_nodes; i++) {
91 // print output node names
92 char* output_name = m_session->GetOutputNameAllocated(i, allocator).release();
93 ATH_MSG_INFO("Output " << i << " : "
94 << " name= " << output_name);
95 m_output_node_names[i] = output_name;
96
97 Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(i);
98 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
99 ONNXTensorElementDataType type = tensor_info.GetElementType();
100 ATH_MSG_INFO("Output " << i << " : "
101 << " type= " << type);
102
103 // print output shapes/dims
104 output_node_dims = tensor_info.GetShape();
105 ATH_MSG_INFO("Output " << i << " : num_dims= " << output_node_dims.size());
106 for (std::size_t j = 0; j < output_node_dims.size(); j++) {
107 if (output_node_dims[j] < 0) output_node_dims[j] = 1;
108 ATH_MSG_INFO("Output" << i << " : dim " << j << "= " << output_node_dims[j]);
109 }
110 }
111
112 return StatusCode::SUCCESS;
113}
114
116// CaloMuonScoreTool::unwrapPhiAngles
118std::vector<float> CaloMuonScoreTool::unwrapPhiAngles(const std::vector<float> &in) const {
119 std::vector<float> out(in.size());
120
121 out[0] = in[0];
122
123 for (unsigned int i = 1; i < out.size(); i++) {
124 float d = xAOD::P4Helpers::deltaPhi(in[i], in[i - 1]);
125 out[i] = out[i - 1] + d;
126 }
127
128 return out;
129}
130
132// CaloMuonScoreTool::fillInputVectors
134void CaloMuonScoreTool::fillInputVectors(std::unique_ptr<const Rec::ParticleCellAssociation> &association, std::vector<float> &eta,
135 std::vector<float> &phi, std::vector<float> &energy, std::vector<int> &samplingId) const {
136 int cell_count = 0;
137
138 for (auto cluster : association->data()) {
139 eta.push_back(cluster->eta());
140 phi.push_back(cluster->phi());
141 samplingId.push_back(cluster->caloDDE()->getSampling());
142 energy.push_back(cluster->energy());
143
144 cell_count++;
145 }
146
147 ATH_MSG_DEBUG("Iterated over " << cell_count << " calo cells");
148
149 return;
150}
151
153// CaloMuonScoreTool::getMuonScore
156 const CaloExtensionCollection *extensionCache) const {
157 ATH_MSG_DEBUG("in CaloMuonScoreTool::getMuonScore()");
158
159 double track_eta = trk->eta();
160
161 // calculate muon score at all eta values
162 if (std::abs(track_eta) > m_CaloMuonEtaCut) {
163 ATH_MSG_DEBUG("Skip calculation of muon score for track particle due to failed eta cut of " << m_CaloMuonEtaCut
164 << " (eta=" << track_eta << ")");
165 return -1;
166 }
167
168 ATH_MSG_DEBUG("Calculating muon score for track particle with eta=" << track_eta);
169
170 ATH_MSG_DEBUG("Finding calo cell association for track particle within cone of delta R=" << m_CaloCellAssociationConeSize);
171
172 // - associate calocells to trackparticle
173 std::unique_ptr<const Rec::ParticleCellAssociation> association =
174 m_caloCellAssociationTool->particleCellAssociation(*trk, m_CaloCellAssociationConeSize, cells, extensionCache);
175 if (!association) {
176 ATH_MSG_VERBOSE("Could not get particleCellAssociation");
177 return -1.;
178 }
179 ATH_MSG_VERBOSE(" particleCellAssociation done " << association.get());
180
181 // create input vectors from calo cell association
182 std::vector<float> eta, phi, energy;
183 std::vector<int> sampling;
184
185 fillInputVectors(association, eta, phi, energy, sampling);
186
187 // if any of the vectors are empty, return.
188 // They are filled in the same loop in `fillInputVectors`, so it is enough to check one
189 if (eta.empty()) {
190 ATH_MSG_VERBOSE("Input vectors for CaloMuonScore are empty");
191 return -1.;
192 }
193
194 // create tensor from vectors
195 std::vector<float> inputTensor = getInputTensor(eta, phi, energy, sampling);
196
197 // run inference on input tensor
198 float outputScore = runOnnxInference(inputTensor);
199 ATH_MSG_DEBUG("Computed CaloMuonScore: " << outputScore);
200
201 return outputScore;
202}
203
205// CaloMuonScoreTool::runOnnxInference
207float CaloMuonScoreTool::runOnnxInference(std::vector<float> &tensor) const {
208 // create input tensor object from data values
209 ATH_MSG_DEBUG("in CaloMuonScoreTool::runOnnxInference()");
210
211 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
212 int input_tensor_size(m_etaBins * m_phiBins * m_nChannels);
213 Ort::Value input_tensor =
214 Ort::Value::CreateTensor<float>(memory_info, tensor.data(), input_tensor_size, m_input_node_dims.data(), m_input_node_dims.size());
215
216 // score model & input tensor, get back output tensor
217
218 // Ort::Session::Run is non-const.
219 // However, the onxx authors claim that it is safe to call
220 // from multiple threads:
221 // https://github.com/Microsoft/onnxruntime/issues/114
222 Ort::Session* session ATLAS_THREAD_SAFE = m_session.get();
223 auto output_tensors = session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), &input_tensor, m_input_node_names.size(),
225
226 // Get pointer to output tensor float values
227 float *output_score_array = output_tensors.front().GetTensorMutableData<float>();
228
229 // Binary classification - the score is just the first element of the output tensor
230 float output_score = output_score_array[0];
231
232 return output_score;
233}
234
236// CaloMuonScoreTool::channelForSamplingId
238int CaloMuonScoreTool::channelForSamplingId(int &samplingId) const {
239 // List of 7 central calo sampling IDs: [0,1,2,3,12,13,14]
240 switch (samplingId) {
241 case 0: return 0;
242 case 1: return 1;
243 case 2: return 2;
244 case 3: return 3;
245 case 12: return 4;
246 case 13: return 5;
247 case 14: return 6;
248 default: return -1;
249 }
250}
251
253// CaloMuonScoreTool::getMedian
255float CaloMuonScoreTool::getMedian(std::vector<float> v) const {
256 if (v.empty()) return 0.0;
257
258 int n = v.size() / 2;
259 std::nth_element(v.begin(), v.begin() + n, v.end());
260 float med = v[n];
261
262 if (v.size() % 2 == 1) return med;
263
264 auto max_it = std::max_element(v.begin(), v.begin() + n);
265
266 return (*max_it + med) / 2.0;
267}
268
269int CaloMuonScoreTool::getBin(const float low_edge, const float up_edge, const int n_bins, float val) const {
270 if (val < low_edge || val >= up_edge)
271 return -1;
272 const float bin_width = (up_edge - low_edge) / (n_bins - 1);
273 float interval = val - low_edge;
274 return std::ceil(interval / bin_width);
275
276
277}
278
280// CaloMuonScoreTool::getInputTensor
282std::vector<float> CaloMuonScoreTool::getInputTensor(std::vector<float> &eta, std::vector<float> &phi, std::vector<float> &energy,
283 std::vector<int> &sampling) const {
284 int n_cells = eta.size();
285
286 // make sure the vector of phi values does not contain discontinuities around the
287 // boundary between pi and -pi
288 std::vector<float> unwrappedPhi = unwrapPhiAngles(phi);
289
290 float median_eta = getMedian(eta);
291 float median_phi = getMedian(unwrappedPhi);
292
293 // initialise output matrix of zeros
294 std::vector<float> tensor(m_etaBins * m_phiBins * m_nChannels, 0.);
295
296 int skipped_cells = 0;
297
298 for (int i = 0; i < n_cells; i++) {
299 // take eta and phi values, and shift them by their repsective median
300 float shifted_eta = eta[i] - median_eta;
301 float shifted_phi = unwrappedPhi[i] - median_phi;
302
303 int eta_bin = getBin(-m_etaCut, m_etaCut, m_etaBins, shifted_eta);
304 int phi_bin = getBin(-m_phiCut, m_phiCut, m_phiBins, shifted_phi);
305 // the cell lies outside the acceptable range
306 if (eta_bin == -1 || phi_bin == -1) {
307 skipped_cells++;
308 ATH_MSG_DEBUG("Skipping cell because eta or phi bin lies outside of range. Eta bin: " << eta_bin << " phi bin: " << phi_bin);
309 continue;
310 }
311
312 int channel = channelForSamplingId(sampling[i]);
313
314 // this really should not happen, but let's skip this cell if it does
315 if (channel == -1) {
316 skipped_cells++;
317 ATH_MSG_DEBUG("Skipping cell because sampling ID does not correspond to low-eta layers. Sampling ID: " << sampling[i]);
318 continue;
319 }
320
321 // 3D array flattening in row-major style: https://en.wikipedia.org/wiki/Row-_and_column-major_order#Explanation_and_example
322 int tensor_idx = eta_bin * m_phiBins * m_nChannels + phi_bin * m_nChannels + channel;
323
324 tensor[tensor_idx] += energy[i];
325 }
326
327 ATH_MSG_DEBUG("Skipped " << skipped_cells << " out of " << n_cells << " cells");
328
329 return tensor;
330}
Scalar eta() const
pseudorapidity method
Scalar phi() const
phi method
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_FATAL(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_VERBOSE(x)
#define ATH_MSG_DEBUG(x)
DataVector< Trk::CaloExtension > CaloExtensionCollection
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
#define ATLAS_THREAD_SAFE
AthAlgTool(const std::string &type, const std::string &name, const IInterface *parent)
Constructor with parameters:
Container class for CaloCell.
Gaudi::Property< float > m_phiCut
ServiceHandle< AthOnnx::IOnnxRuntimeSvc > m_svc
Handle to AthOnnx::IOnnxRuntimeSvc.
float runOnnxInference(std::vector< float > &tensor) const
std::unique_ptr< Ort::Session > m_session
virtual StatusCode initialize() override
int getBin(const float low_edge, const float up_edge, const int n_bins, float val) const
std::vector< const char * > m_input_node_names
float getMedian(std::vector< float > v) const
--> Copy is neccessary as the elements are reorded for the moment which would then break association ...
Gaudi::Property< int > m_nChannels
float getMuonScore(const xAOD::TrackParticle *trk, const CaloCellContainer *cells=nullptr, const CaloExtensionCollection *extensionCache=nullptr) const override
Gaudi::Property< float > m_etaCut
Gaudi::Property< std::string > m_modelFileName
int channelForSamplingId(int &samplingId) const
std::vector< float > getInputTensor(std::vector< float > &eta, std::vector< float > &phi, std::vector< float > &energy, std::vector< int > &sampling) const
Gaudi::Property< int > m_phiBins
std::vector< const char * > m_output_node_names
Gaudi::Property< float > m_CaloCellAssociationConeSize
std::vector< float > unwrapPhiAngles(const std::vector< float > &v) const
ToolHandle< Rec::IParticleCaloCellAssociationTool > m_caloCellAssociationTool
void fillInputVectors(std::unique_ptr< const Rec::ParticleCellAssociation > &association, std::vector< float > &eta, std::vector< float > &phi, std::vector< float > &energy, std::vector< int > &samplingId) const
Gaudi::Property< int > m_etaBins
std::vector< int64_t > m_input_node_dims
CaloMuonScoreTool(const std::string &type, const std::string &name, const IInterface *parent)
Gaudi::Property< double > m_CaloMuonEtaCut
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
double deltaPhi(double phiA, double phiB)
delta Phi in range [-pi,pi[
TrackParticle_v1 TrackParticle
Reference the current persistent version: