ATLAS Offline Software
Public Member Functions | Static Public Member Functions | Protected Member Functions | Private Types | Private Member Functions | Private Attributes | List of all members
CaloMuonScoreTool Class Reference

#include <CaloMuonScoreTool.h>

Inheritance diagram for CaloMuonScoreTool:
Collaboration diagram for CaloMuonScoreTool:

Public Member Functions

 CaloMuonScoreTool (const std::string &type, const std::string &name, const IInterface *parent)
 
virtual ~CaloMuonScoreTool ()=default
 
virtual StatusCode initialize () override
 
float getMuonScore (const xAOD::TrackParticle *trk, const CaloCellContainer *cells=nullptr, const CaloExtensionCollection *extensionCache=nullptr) const override
 
ServiceHandle< StoreGateSvc > & evtStore ()
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc. More...
 
const ServiceHandle< StoreGateSvc > & evtStore () const
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc. More...
 
const ServiceHandle< StoreGateSvc > & detStore () const
 The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc. More...
 
virtual StatusCode sysInitialize () override
 Perform system initialization for an algorithm. More...
 
virtual StatusCode sysStart () override
 Handle START transition. More...
 
virtual std::vector< Gaudi::DataHandle * > inputHandles () const override
 Return this algorithm's input handles. More...
 
virtual std::vector< Gaudi::DataHandle * > outputHandles () const override
 Return this algorithm's output handles. More...
 
Gaudi::Details::PropertyBase & declareProperty (Gaudi::Property< T > &t)
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleKey &hndl, const std::string &doc, const SG::VarHandleKeyType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleBase &hndl, const std::string &doc, const SG::VarHandleType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleKeyArray &hndArr, const std::string &doc, const SG::VarHandleKeyArrayType &)
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, T &property, const std::string &doc, const SG::NotHandleType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, T &property, const std::string &doc="none")
 Declare a new Gaudi property. More...
 
void updateVHKA (Gaudi::Details::PropertyBase &)
 
MsgStream & msg () const
 
MsgStream & msg (const MSG::Level lvl) const
 
bool msgLvl (const MSG::Level lvl) const
 

Static Public Member Functions

static const InterfaceID & interfaceID ()
 

Protected Member Functions

void renounceArray (SG::VarHandleKeyArray &handlesArray)
 remove all handles from I/O resolution More...
 
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce (T &h)
 
void extraDeps_update_handler (Gaudi::Details::PropertyBase &ExtraDeps)
 Add StoreName to extra input/output deps as needed. More...
 

Private Types

typedef ServiceHandle< StoreGateSvcStoreGateSvc_t
 

Private Member Functions

float runOnnxInference (std::vector< float > &tensor) const
 
std::vector< float > unwrapPhiAngles (const std::vector< float > &v) const
 
void fillInputVectors (std::unique_ptr< const Rec::ParticleCellAssociation > &association, std::vector< float > &eta, std::vector< float > &phi, std::vector< float > &energy, std::vector< int > &samplingId) const
 
float getMedian (std::vector< float > v) const
 --> Copy is neccessary as the elements are reorded for the moment which would then break association to the actual cell deposit More...
 
int getBin (const float low_edge, const float up_edge, const int n_bins, float val) const
 
int channelForSamplingId (int &samplingId) const
 
std::vector< float > getInputTensor (std::vector< float > &eta, std::vector< float > &phi, std::vector< float > &energy, std::vector< int > &sampling) const
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleKeyType &)
 specialization for handling Gaudi::Property<SG::VarHandleKey> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleKeyArrayType &)
 specialization for handling Gaudi::Property<SG::VarHandleKeyArray> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &hndl, const SG::VarHandleType &)
 specialization for handling Gaudi::Property<SG::VarHandleBase> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T > &t, const SG::NotHandleType &)
 specialization for handling everything that's not a Gaudi::Property<SG::VarHandleKey> or a <SG::VarHandleKeyArray> More...
 

Private Attributes

Gaudi::Property< float > m_CaloCellAssociationConeSize
 
Gaudi::Property< int > m_etaBins {this, "etaBins", 30, "Number of bins in eta"}
 
Gaudi::Property< int > m_phiBins {this, "phiBins", 30, "Number of bins in phi"}
 
Gaudi::Property< float > m_etaCut
 
Gaudi::Property< float > m_phiCut
 
Gaudi::Property< int > m_nChannels {this, "nChannels", 7, "Number of colour channels in the convolutional neural network"}
 
ToolHandle< Rec::IParticleCaloCellAssociationToolm_caloCellAssociationTool {this, "ParticleCaloCellAssociationTool", ""}
 
ServiceHandle< AthOnnx::IOnnxRuntimeSvcm_svc {this, "ONNXRuntimeSvc", "AthOnnx::OnnxRuntimeSvc", "CaloMuonScoreTool ONNXRuntimeSvc"}
 Handle to AthOnnx::IOnnxRuntimeSvc. More...
 
std::unique_ptr< Ort::Session > m_session
 
std::vector< const char * > m_input_node_names
 
std::vector< const char * > m_output_node_names
 
std::vector< int64_t > m_input_node_dims
 
Gaudi::Property< std::string > m_modelFileName {this, "ModelFileName", "CaloTrkMuIdTools/nnBased_201022/CaloMuonCNN_1.onnx"}
 
Gaudi::Property< double > m_CaloMuonEtaCut
 
StoreGateSvc_t m_evtStore
 Pointer to StoreGate (event store by default) More...
 
StoreGateSvc_t m_detStore
 Pointer to StoreGate (detector store by default) More...
 
std::vector< SG::VarHandleKeyArray * > m_vhka
 
bool m_varHandleArraysDeclared
 

Detailed Description

Fetch the calorimeter cells around a track particle and compute the muon score.

The muon score is computed by doing inference on a 7-colour-channel convolutional neural network. The inputs to the convolutional neural network are the energy deposits in 30 eta and 30 phi bins around the track particle. Seven colour channels are considered, corresponding to the seven calorimeter layers (CaloSamplingIDs) in the low-eta region (eta < 0.1).

The convolutional neural network was trained using tensorflow. Inference on this model is done using ONNX (the tensorflow model having been converted to ONNX format).

Author
Ricardo Woelker ricar.nosp@m.do.w.nosp@m.oelke.nosp@m.r@ce.nosp@m.rn.ch

Definition at line 37 of file CaloMuonScoreTool.h.

Member Typedef Documentation

◆ StoreGateSvc_t

typedef ServiceHandle<StoreGateSvc> AthCommonDataStore< AthCommonMsg< AlgTool > >::StoreGateSvc_t
privateinherited

Definition at line 388 of file AthCommonDataStore.h.

Constructor & Destructor Documentation

◆ CaloMuonScoreTool()

CaloMuonScoreTool::CaloMuonScoreTool ( const std::string &  type,
const std::string &  name,
const IInterface *  parent 
)

Definition at line 24 of file CaloMuonScoreTool.cxx.

24  :
26  declareInterface<ICaloMuonScoreTool>(this);
27 }

◆ ~CaloMuonScoreTool()

virtual CaloMuonScoreTool::~CaloMuonScoreTool ( )
virtualdefault

Member Function Documentation

◆ channelForSamplingId()

int CaloMuonScoreTool::channelForSamplingId ( int &  samplingId) const
private

Definition at line 238 of file CaloMuonScoreTool.cxx.

238  {
239  // List of 7 central calo sampling IDs: [0,1,2,3,12,13,14]
240  switch (samplingId) {
241  case 0: return 0;
242  case 1: return 1;
243  case 2: return 2;
244  case 3: return 3;
245  case 12: return 4;
246  case 13: return 5;
247  case 14: return 6;
248  default: return -1;
249  }
250 }

◆ declareGaudiProperty() [1/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleKeyArrayType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKeyArray>

Definition at line 170 of file AthCommonDataStore.h.

172  {
173  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
174  hndl.value(),
175  hndl.documentation());
176 
177  }

◆ declareGaudiProperty() [2/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleKeyType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKey>

Definition at line 156 of file AthCommonDataStore.h.

158  {
159  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
160  hndl.value(),
161  hndl.documentation());
162 
163  }

◆ declareGaudiProperty() [3/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  hndl,
const SG::VarHandleType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleBase>

Definition at line 184 of file AthCommonDataStore.h.

186  {
187  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
188  hndl.value(),
189  hndl.documentation());
190  }

◆ declareGaudiProperty() [4/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T > &  t,
const SG::NotHandleType  
)
inlineprivateinherited

specialization for handling everything that's not a Gaudi::Property<SG::VarHandleKey> or a <SG::VarHandleKeyArray>

Definition at line 199 of file AthCommonDataStore.h.

200  {
201  return PBASE::declareProperty(t);
202  }

◆ declareProperty() [1/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleBase hndl,
const std::string &  doc,
const SG::VarHandleType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
hndlObject holding the property value.
docDocumentation string for the property.

This is the version for types that derive from SG::VarHandleBase. The property value object is put on the input and output lists as appropriate; then we forward to the base class.

Definition at line 245 of file AthCommonDataStore.h.

249  {
250  this->declare(hndl.vhKey());
251  hndl.vhKey().setOwner(this);
252 
253  return PBASE::declareProperty(name,hndl,doc);
254  }

◆ declareProperty() [2/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleKey hndl,
const std::string &  doc,
const SG::VarHandleKeyType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
hndlObject holding the property value.
docDocumentation string for the property.

This is the version for types that derive from SG::VarHandleKey. The property value object is put on the input and output lists as appropriate; then we forward to the base class.

Definition at line 221 of file AthCommonDataStore.h.

225  {
226  this->declare(hndl);
227  hndl.setOwner(this);
228 
229  return PBASE::declareProperty(name,hndl,doc);
230  }

◆ declareProperty() [3/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleKeyArray hndArr,
const std::string &  doc,
const SG::VarHandleKeyArrayType  
)
inlineinherited

Definition at line 259 of file AthCommonDataStore.h.

263  {
264 
265  // std::ostringstream ost;
266  // ost << Algorithm::name() << " VHKA declareProp: " << name
267  // << " size: " << hndArr.keys().size()
268  // << " mode: " << hndArr.mode()
269  // << " vhka size: " << m_vhka.size()
270  // << "\n";
271  // debug() << ost.str() << endmsg;
272 
273  hndArr.setOwner(this);
274  m_vhka.push_back(&hndArr);
275 
276  Gaudi::Details::PropertyBase* p = PBASE::declareProperty(name, hndArr, doc);
277  if (p != 0) {
278  p->declareUpdateHandler(&AthCommonDataStore<PBASE>::updateVHKA, this);
279  } else {
280  ATH_MSG_ERROR("unable to call declareProperty on VarHandleKeyArray "
281  << name);
282  }
283 
284  return p;
285 
286  }

◆ declareProperty() [4/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
T &  property,
const std::string &  doc,
const SG::NotHandleType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
propertyObject holding the property value.
docDocumentation string for the property.

This is the generic version, for types that do not derive from SG::VarHandleKey. It just forwards to the base class version of declareProperty.

Definition at line 333 of file AthCommonDataStore.h.

337  {
338  return PBASE::declareProperty(name, property, doc);
339  }

◆ declareProperty() [5/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
T &  property,
const std::string &  doc = "none" 
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
propertyObject holding the property value.
docDocumentation string for the property.

This dispatches to either the generic declareProperty or the one for VarHandle/Key/KeyArray.

Definition at line 352 of file AthCommonDataStore.h.

355  {
356  typedef typename SG::HandleClassifier<T>::type htype;
357  return declareProperty (name, property, doc, htype());
358  }

◆ declareProperty() [6/6]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( Gaudi::Property< T > &  t)
inlineinherited

Definition at line 145 of file AthCommonDataStore.h.

145  {
146  typedef typename SG::HandleClassifier<T>::type htype;
148  }

◆ detStore()

const ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::detStore ( ) const
inlineinherited

The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 95 of file AthCommonDataStore.h.

95 { return m_detStore; }

◆ evtStore() [1/2]

ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( )
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 85 of file AthCommonDataStore.h.

85 { return m_evtStore; }

◆ evtStore() [2/2]

const ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( ) const
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 90 of file AthCommonDataStore.h.

90 { return m_evtStore; }

◆ extraDeps_update_handler()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::extraDeps_update_handler ( Gaudi::Details::PropertyBase &  ExtraDeps)
protectedinherited

Add StoreName to extra input/output deps as needed.

use the logic of the VarHandleKey to parse the DataObjID keys supplied via the ExtraInputs and ExtraOuputs Properties to add the StoreName if it's not explicitly given

◆ fillInputVectors()

void CaloMuonScoreTool::fillInputVectors ( std::unique_ptr< const Rec::ParticleCellAssociation > &  association,
std::vector< float > &  eta,
std::vector< float > &  phi,
std::vector< float > &  energy,
std::vector< int > &  samplingId 
) const
private

Definition at line 134 of file CaloMuonScoreTool.cxx.

135  {
136  int cell_count = 0;
137 
138  for (auto cluster : association->data()) {
139  eta.push_back(cluster->eta());
140  phi.push_back(cluster->phi());
141  samplingId.push_back(cluster->caloDDE()->getSampling());
142  energy.push_back(cluster->energy());
143 
144  cell_count++;
145  }
146 
147  ATH_MSG_DEBUG("Iterated over " << cell_count << " calo cells");
148 
149  return;
150 }

◆ getBin()

int CaloMuonScoreTool::getBin ( const float  low_edge,
const float  up_edge,
const int  n_bins,
float  val 
) const
private

Definition at line 269 of file CaloMuonScoreTool.cxx.

269  {
270  if (val < low_edge || val >= up_edge)
271  return -1;
272  const float bin_width = (up_edge - low_edge) / (n_bins - 1);
273  float interval = val - low_edge;
274  return std::ceil(interval / bin_width);
275 
276 
277 }

◆ getInputTensor()

std::vector< float > CaloMuonScoreTool::getInputTensor ( std::vector< float > &  eta,
std::vector< float > &  phi,
std::vector< float > &  energy,
std::vector< int > &  sampling 
) const
private

Definition at line 282 of file CaloMuonScoreTool.cxx.

283  {
284  int n_cells = eta.size();
285 
286  // make sure the vector of phi values does not contain discontinuities around the
287  // boundary between pi and -pi
288  std::vector<float> unwrappedPhi = unwrapPhiAngles(phi);
289 
290  float median_eta = getMedian(eta);
291  float median_phi = getMedian(unwrappedPhi);
292 
293  // initialise output matrix of zeros
294  std::vector<float> tensor(m_etaBins * m_phiBins * m_nChannels, 0.);
295 
296  int skipped_cells = 0;
297 
298  for (int i = 0; i < n_cells; i++) {
299  // take eta and phi values, and shift them by their repsective median
300  float shifted_eta = eta[i] - median_eta;
301  float shifted_phi = unwrappedPhi[i] - median_phi;
302 
303  int eta_bin = getBin(-m_etaCut, m_etaCut, m_etaBins, shifted_eta);
304  int phi_bin = getBin(-m_phiCut, m_phiCut, m_phiBins, shifted_phi);
305  // the cell lies outside the acceptable range
306  if (eta_bin == -1 || phi_bin == -1) {
307  skipped_cells++;
308  ATH_MSG_DEBUG("Skipping cell because eta or phi bin lies outside of range. Eta bin: " << eta_bin << " phi bin: " << phi_bin);
309  continue;
310  }
311 
312  int channel = channelForSamplingId(sampling[i]);
313 
314  // this really should not happen, but let's skip this cell if it does
315  if (channel == -1) {
316  skipped_cells++;
317  ATH_MSG_DEBUG("Skipping cell because sampling ID does not correspond to low-eta layers. Sampling ID: " << sampling[i]);
318  continue;
319  }
320 
321  // 3D array flattening in row-major style: https://en.wikipedia.org/wiki/Row-_and_column-major_order#Explanation_and_example
322  int tensor_idx = eta_bin * m_phiBins * m_nChannels + phi_bin * m_nChannels + channel;
323 
324  tensor[tensor_idx] += energy[i];
325  }
326 
327  ATH_MSG_DEBUG("Skipped " << skipped_cells << " out of " << n_cells << " cells");
328 
329  return tensor;
330 }

◆ getMedian()

float CaloMuonScoreTool::getMedian ( std::vector< float >  v) const
private

--> Copy is neccessary as the elements are reorded for the moment which would then break association to the actual cell deposit

Definition at line 255 of file CaloMuonScoreTool.cxx.

255  {
256  if (v.empty()) return 0.0;
257 
258  int n = v.size() / 2;
259  std::nth_element(v.begin(), v.begin() + n, v.end());
260  float med = v[n];
261 
262  if (v.size() % 2 == 1) return med;
263 
264  auto max_it = std::max_element(v.begin(), v.begin() + n);
265 
266  return (*max_it + med) / 2.0;
267 }

◆ getMuonScore()

float CaloMuonScoreTool::getMuonScore ( const xAOD::TrackParticle trk,
const CaloCellContainer cells = nullptr,
const CaloExtensionCollection extensionCache = nullptr 
) const
overridevirtual

Implements ICaloMuonScoreTool.

Definition at line 155 of file CaloMuonScoreTool.cxx.

156  {
157  ATH_MSG_DEBUG("in CaloMuonScoreTool::getMuonScore()");
158 
159  double track_eta = trk->eta();
160 
161  // calculate muon score at all eta values
162  if (std::abs(track_eta) > m_CaloMuonEtaCut) {
163  ATH_MSG_DEBUG("Skip calculation of muon score for track particle due to failed eta cut of " << m_CaloMuonEtaCut
164  << " (eta=" << track_eta << ")");
165  return -1;
166  }
167 
168  ATH_MSG_DEBUG("Calculating muon score for track particle with eta=" << track_eta);
169 
170  ATH_MSG_DEBUG("Finding calo cell association for track particle within cone of delta R=" << m_CaloCellAssociationConeSize);
171 
172  // - associate calocells to trackparticle
173  std::unique_ptr<const Rec::ParticleCellAssociation> association =
174  m_caloCellAssociationTool->particleCellAssociation(*trk, m_CaloCellAssociationConeSize, cells, extensionCache);
175  if (!association) {
176  ATH_MSG_VERBOSE("Could not get particleCellAssociation");
177  return -1.;
178  }
179  ATH_MSG_VERBOSE(" particleCellAssociation done " << association.get());
180 
181  // create input vectors from calo cell association
182  std::vector<float> eta, phi, energy;
183  std::vector<int> sampling;
184 
185  fillInputVectors(association, eta, phi, energy, sampling);
186 
187  // if any of the vectors are empty, return.
188  // They are filled in the same loop in `fillInputVectors`, so it is enough to check one
189  if (eta.empty()) {
190  ATH_MSG_VERBOSE("Input vectors for CaloMuonScore are empty");
191  return -1.;
192  }
193 
194  // create tensor from vectors
195  std::vector<float> inputTensor = getInputTensor(eta, phi, energy, sampling);
196 
197  // run inference on input tensor
198  float outputScore = runOnnxInference(inputTensor);
199  ATH_MSG_DEBUG("Computed CaloMuonScore: " << outputScore);
200 
201  return outputScore;
202 }

◆ initialize()

StatusCode CaloMuonScoreTool::initialize ( )
overridevirtual

Definition at line 32 of file CaloMuonScoreTool.cxx.

32  {
33  ATH_MSG_INFO("Initializing " << name());
34 
35  ATH_CHECK(m_svc.retrieve());
37 
38  std::string model_file_name = PathResolverFindCalibFile(m_modelFileName);
39 
40  if (m_modelFileName.empty() || model_file_name.empty()) {
41  ATH_MSG_FATAL("Could not find the requested ONNX model file: " << m_modelFileName);
43  "Please make sure it exists in the ATLAS calibration area (https://atlas-groupdata.web.cern.ch/atlas-groupdata/), and provide "
44  "a model file name relative to the root of the calibration area.");
45 
46  return StatusCode::FAILURE;
47  }
48 
49  // initialise session
50  Ort::SessionOptions session_options;
51  Ort::AllocatorWithDefaultOptions allocator;
52  session_options.SetIntraOpNumThreads(1);
53  session_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
54 
55  m_session = std::make_unique<Ort::Session>(m_svc->env(), model_file_name.c_str(), session_options);
56 
57  ATH_MSG_INFO("Created ONNX runtime session with model " << model_file_name);
58 
59  size_t num_input_nodes = m_session->GetInputCount();
60  m_input_node_names.resize(num_input_nodes);
61 
62  for (std::size_t i = 0; i < num_input_nodes; i++) {
63  // print input node names
64  char* input_name = m_session->GetInputNameAllocated(i, allocator).release();
65  ATH_MSG_INFO("Input " << i << " : "
66  << " name= " << input_name);
67  m_input_node_names[i] = input_name;
68  // print input node types
69  Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
70  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
71  ONNXTensorElementDataType type = tensor_info.GetElementType();
72  ATH_MSG_INFO("Input " << i << " : "
73  << " type= " << type);
74 
75  // print input shapes/dims
76  m_input_node_dims = tensor_info.GetShape();
77  ATH_MSG_INFO("Input " << i << " : num_dims= " << m_input_node_dims.size());
78  for (std::size_t j = 0; j < m_input_node_dims.size(); j++) {
79  if (m_input_node_dims[j] < 0) m_input_node_dims[j] = 1;
80  ATH_MSG_INFO("Input " << i << " : dim " << j << "= " << m_input_node_dims[j]);
81  }
82  }
83 
84  // output nodes
85  std::vector<int64_t> output_node_dims;
86  size_t num_output_nodes = m_session->GetOutputCount();
87  ATH_MSG_INFO("Have output nodes " << num_output_nodes);
88  m_output_node_names.resize(num_output_nodes);
89 
90  for (std::size_t i = 0; i < num_output_nodes; i++) {
91  // print output node names
92  char* output_name = m_session->GetOutputNameAllocated(i, allocator).release();
93  ATH_MSG_INFO("Output " << i << " : "
94  << " name= " << output_name);
95  m_output_node_names[i] = output_name;
96 
97  Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(i);
98  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
99  ONNXTensorElementDataType type = tensor_info.GetElementType();
100  ATH_MSG_INFO("Output " << i << " : "
101  << " type= " << type);
102 
103  // print output shapes/dims
104  output_node_dims = tensor_info.GetShape();
105  ATH_MSG_INFO("Output " << i << " : num_dims= " << output_node_dims.size());
106  for (std::size_t j = 0; j < output_node_dims.size(); j++) {
107  if (output_node_dims[j] < 0) output_node_dims[j] = 1;
108  ATH_MSG_INFO("Output" << i << " : dim " << j << "= " << output_node_dims[j]);
109  }
110  }
111 
112  return StatusCode::SUCCESS;
113 }

◆ inputHandles()

virtual std::vector<Gaudi::DataHandle*> AthCommonDataStore< AthCommonMsg< AlgTool > >::inputHandles ( ) const
overridevirtualinherited

Return this algorithm's input handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ interfaceID()

static const InterfaceID& ICaloMuonScoreTool::interfaceID ( )
inlinestaticinherited

Definition at line 22 of file ICaloMuonScoreTool.h.

22 {return IID_ICaloMuonScoreTool;}

◆ msg() [1/2]

MsgStream& AthCommonMsg< AlgTool >::msg ( ) const
inlineinherited

Definition at line 24 of file AthCommonMsg.h.

24  {
25  return this->msgStream();
26  }

◆ msg() [2/2]

MsgStream& AthCommonMsg< AlgTool >::msg ( const MSG::Level  lvl) const
inlineinherited

Definition at line 27 of file AthCommonMsg.h.

27  {
28  return this->msgStream(lvl);
29  }

◆ msgLvl()

bool AthCommonMsg< AlgTool >::msgLvl ( const MSG::Level  lvl) const
inlineinherited

Definition at line 30 of file AthCommonMsg.h.

30  {
31  return this->msgLevel(lvl);
32  }

◆ outputHandles()

virtual std::vector<Gaudi::DataHandle*> AthCommonDataStore< AthCommonMsg< AlgTool > >::outputHandles ( ) const
overridevirtualinherited

Return this algorithm's output handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ renounce()

std::enable_if_t<std::is_void_v<std::result_of_t<decltype(&T::renounce)(T)> > && !std::is_base_of_v<SG::VarHandleKeyArray, T> && std::is_base_of_v<Gaudi::DataHandle, T>, void> AthCommonDataStore< AthCommonMsg< AlgTool > >::renounce ( T &  h)
inlineprotectedinherited

Definition at line 380 of file AthCommonDataStore.h.

381  {
382  h.renounce();
383  PBASE::renounce (h);
384  }

◆ renounceArray()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::renounceArray ( SG::VarHandleKeyArray handlesArray)
inlineprotectedinherited

remove all handles from I/O resolution

Definition at line 364 of file AthCommonDataStore.h.

364  {
365  handlesArray.renounce();
366  }

◆ runOnnxInference()

float CaloMuonScoreTool::runOnnxInference ( std::vector< float > &  tensor) const
private

Definition at line 207 of file CaloMuonScoreTool.cxx.

207  {
208  // create input tensor object from data values
209  ATH_MSG_DEBUG("in CaloMuonScoreTool::runOnnxInference()");
210 
211  auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
212  int input_tensor_size(m_etaBins * m_phiBins * m_nChannels);
213  Ort::Value input_tensor =
214  Ort::Value::CreateTensor<float>(memory_info, tensor.data(), input_tensor_size, m_input_node_dims.data(), m_input_node_dims.size());
215 
216  // score model & input tensor, get back output tensor
217 
218  // Ort::Session::Run is non-const.
219  // However, the onxx authors claim that it is safe to call
220  // from multiple threads:
221  // https://github.com/Microsoft/onnxruntime/issues/114
222  Ort::Session* session ATLAS_THREAD_SAFE = m_session.get();
223  auto output_tensors = session->Run(Ort::RunOptions{nullptr}, m_input_node_names.data(), &input_tensor, m_input_node_names.size(),
225 
226  // Get pointer to output tensor float values
227  float *output_score_array = output_tensors.front().GetTensorMutableData<float>();
228 
229  // Binary classification - the score is just the first element of the output tensor
230  float output_score = output_score_array[0];
231 
232  return output_score;
233 }

◆ sysInitialize()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysInitialize ( )
overridevirtualinherited

Perform system initialization for an algorithm.

We override this to declare all the elements of handle key arrays at the end of initialization. See comments on updateVHKA.

Reimplemented in DerivationFramework::CfAthAlgTool, AthCheckedComponent< AthAlgTool >, AthCheckedComponent<::AthAlgTool >, and asg::AsgMetadataTool.

◆ sysStart()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysStart ( )
overridevirtualinherited

Handle START transition.

We override this in order to make sure that conditions handle keys can cache a pointer to the conditions container.

◆ unwrapPhiAngles()

std::vector< float > CaloMuonScoreTool::unwrapPhiAngles ( const std::vector< float > &  v) const
private

Definition at line 118 of file CaloMuonScoreTool.cxx.

118  {
119  std::vector<float> out(in.size());
120 
121  out[0] = in[0];
122 
123  for (unsigned int i = 1; i < out.size(); i++) {
124  float d = xAOD::P4Helpers::deltaPhi(in[i], in[i - 1]);
125  out[i] = out[i - 1] + d;
126  }
127 
128  return out;
129 }

◆ updateVHKA()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::updateVHKA ( Gaudi::Details::PropertyBase &  )
inlineinherited

Definition at line 308 of file AthCommonDataStore.h.

308  {
309  // debug() << "updateVHKA for property " << p.name() << " " << p.toString()
310  // << " size: " << m_vhka.size() << endmsg;
311  for (auto &a : m_vhka) {
312  std::vector<SG::VarHandleKey*> keys = a->keys();
313  for (auto k : keys) {
314  k->setOwner(this);
315  }
316  }
317  }

Member Data Documentation

◆ m_CaloCellAssociationConeSize

Gaudi::Property<float> CaloMuonScoreTool::m_CaloCellAssociationConeSize
private
Initial value:
{this, "CaloCellAssociationConeSize", 0.2,
"Size of the cone within which calo cells are associated with a track particle"}

Definition at line 75 of file CaloMuonScoreTool.h.

◆ m_caloCellAssociationTool

ToolHandle<Rec::IParticleCaloCellAssociationTool> CaloMuonScoreTool::m_caloCellAssociationTool {this, "ParticleCaloCellAssociationTool", ""}
private

Definition at line 87 of file CaloMuonScoreTool.h.

◆ m_CaloMuonEtaCut

Gaudi::Property<double> CaloMuonScoreTool::m_CaloMuonEtaCut
private
Initial value:
{this, "CaloMuonEtaCut", 1.0,
"Eta cut (absolute value) up to which a track particle's muon score will be calculated"}

Definition at line 104 of file CaloMuonScoreTool.h.

◆ m_detStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
privateinherited

Pointer to StoreGate (detector store by default)

Definition at line 393 of file AthCommonDataStore.h.

◆ m_etaBins

Gaudi::Property<int> CaloMuonScoreTool::m_etaBins {this, "etaBins", 30, "Number of bins in eta"}
private

Definition at line 77 of file CaloMuonScoreTool.h.

◆ m_etaCut

Gaudi::Property<float> CaloMuonScoreTool::m_etaCut
private
Initial value:
{
this, "etaCut", 0.25,
"Eta cut on the calorimeter cells associated with the track particle after centering of the calorimeter image"}

Definition at line 79 of file CaloMuonScoreTool.h.

◆ m_evtStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
privateinherited

Pointer to StoreGate (event store by default)

Definition at line 390 of file AthCommonDataStore.h.

◆ m_input_node_dims

std::vector<int64_t> CaloMuonScoreTool::m_input_node_dims
private

Definition at line 98 of file CaloMuonScoreTool.h.

◆ m_input_node_names

std::vector<const char *> CaloMuonScoreTool::m_input_node_names
private

Definition at line 94 of file CaloMuonScoreTool.h.

◆ m_modelFileName

Gaudi::Property<std::string> CaloMuonScoreTool::m_modelFileName {this, "ModelFileName", "CaloTrkMuIdTools/nnBased_201022/CaloMuonCNN_1.onnx"}
private

Definition at line 102 of file CaloMuonScoreTool.h.

◆ m_nChannels

Gaudi::Property<int> CaloMuonScoreTool::m_nChannels {this, "nChannels", 7, "Number of colour channels in the convolutional neural network"}
private

Definition at line 85 of file CaloMuonScoreTool.h.

◆ m_output_node_names

std::vector<const char *> CaloMuonScoreTool::m_output_node_names
private

Definition at line 96 of file CaloMuonScoreTool.h.

◆ m_phiBins

Gaudi::Property<int> CaloMuonScoreTool::m_phiBins {this, "phiBins", 30, "Number of bins in phi"}
private

Definition at line 78 of file CaloMuonScoreTool.h.

◆ m_phiCut

Gaudi::Property<float> CaloMuonScoreTool::m_phiCut
private
Initial value:
{
this, "phiCut", 0.25,
"Phi cut on the calorimeter cells associated with the track particle after centering of the calorimeter image"}

Definition at line 82 of file CaloMuonScoreTool.h.

◆ m_session

std::unique_ptr<Ort::Session> CaloMuonScoreTool::m_session
private

Definition at line 92 of file CaloMuonScoreTool.h.

◆ m_svc

ServiceHandle<AthOnnx::IOnnxRuntimeSvc> CaloMuonScoreTool::m_svc {this, "ONNXRuntimeSvc", "AthOnnx::OnnxRuntimeSvc", "CaloMuonScoreTool ONNXRuntimeSvc"}
private

Handle to AthOnnx::IOnnxRuntimeSvc.

Definition at line 90 of file CaloMuonScoreTool.h.

◆ m_varHandleArraysDeclared

bool AthCommonDataStore< AthCommonMsg< AlgTool > >::m_varHandleArraysDeclared
privateinherited

Definition at line 399 of file AthCommonDataStore.h.

◆ m_vhka

std::vector<SG::VarHandleKeyArray*> AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
privateinherited

Definition at line 398 of file AthCommonDataStore.h.


The documentation for this class was generated from the following files:
CaloMuonScoreTool::getInputTensor
std::vector< float > getInputTensor(std::vector< float > &eta, std::vector< float > &phi, std::vector< float > &energy, std::vector< int > &sampling) const
Definition: CaloMuonScoreTool.cxx:282
RunTileCalibRec.cells
cells
Definition: RunTileCalibRec.py:271
CaloMuonScoreTool::runOnnxInference
float runOnnxInference(std::vector< float > &tensor) const
Definition: CaloMuonScoreTool.cxx:207
CaloMuonScoreTool::m_etaCut
Gaudi::Property< float > m_etaCut
Definition: CaloMuonScoreTool.h:79
ATH_MSG_FATAL
#define ATH_MSG_FATAL(x)
Definition: AthMsgStreamMacros.h:34
CaloMuonScoreTool::channelForSamplingId
int channelForSamplingId(int &samplingId) const
Definition: CaloMuonScoreTool.cxx:238
plotting.yearwise_efficiency.channel
channel
Definition: yearwise_efficiency.py:24
phi
Scalar phi() const
phi method
Definition: AmgMatrixBasePlugin.h:67
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
CaloMuonScoreTool::unwrapPhiAngles
std::vector< float > unwrapPhiAngles(const std::vector< float > &v) const
Definition: CaloMuonScoreTool.cxx:118
CaloMuonScoreTool::m_input_node_dims
std::vector< int64_t > m_input_node_dims
Definition: CaloMuonScoreTool.h:98
eta
Scalar eta() const
pseudorapidity method
Definition: AmgMatrixBasePlugin.h:83
AthCommonDataStore::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
Definition: AthCommonDataStore.h:145
hist_file_dump.d
d
Definition: hist_file_dump.py:137
xAOD::TrackParticle_v1::eta
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
Definition: TrackParticle_v1.cxx:77
CheckAppliedSFs.bin_width
bin_width
Definition: CheckAppliedSFs.py:242
CaloMuonScoreTool::m_caloCellAssociationTool
ToolHandle< Rec::IParticleCaloCellAssociationTool > m_caloCellAssociationTool
Definition: CaloMuonScoreTool.h:87
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:71
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
StoreGateSvc_t m_evtStore
Pointer to StoreGate (event store by default)
Definition: AthCommonDataStore.h:390
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
std::vector< SG::VarHandleKeyArray * > m_vhka
Definition: AthCommonDataStore.h:398
xAOD::P4Helpers::deltaPhi
double deltaPhi(double phiA, double phiB)
delta Phi in range [-pi,pi[
Definition: xAODP4Helpers.h:69
read_hist_ntuple.t
t
Definition: read_hist_ntuple.py:5
ATH_MSG_VERBOSE
#define ATH_MSG_VERBOSE(x)
Definition: AthMsgStreamMacros.h:28
CaloMuonScoreTool::m_session
std::unique_ptr< Ort::Session > m_session
Definition: CaloMuonScoreTool.h:92
python.oracle.Session
Session
Definition: oracle.py:78
SG::VarHandleKeyArray::setOwner
virtual void setOwner(IDataHandleHolder *o)=0
IDTPMcnv.htype
htype
Definition: IDTPMcnv.py:27
CaloMuonScoreTool::m_CaloCellAssociationConeSize
Gaudi::Property< float > m_CaloCellAssociationConeSize
Definition: CaloMuonScoreTool.h:75
CaloMuonScoreTool::m_phiBins
Gaudi::Property< int > m_phiBins
Definition: CaloMuonScoreTool.h:78
CaloMuonScoreTool::getBin
int getBin(const float low_edge, const float up_edge, const int n_bins, float val) const
Definition: CaloMuonScoreTool.cxx:269
python.utils.AtlRunQueryDQUtils.p
p
Definition: AtlRunQueryDQUtils.py:210
AthCommonDataStore
Definition: AthCommonDataStore.h:52
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
ParticleGun_FastCalo_ChargeFlip_Config.energy
energy
Definition: ParticleGun_FastCalo_ChargeFlip_Config.py:78
lumiFormat.i
int i
Definition: lumiFormat.py:85
CaloMuonScoreTool::m_nChannels
Gaudi::Property< int > m_nChannels
Definition: CaloMuonScoreTool.h:85
beamspotman.n
n
Definition: beamspotman.py:731
CaloMuonScoreTool::m_etaBins
Gaudi::Property< int > m_etaBins
Definition: CaloMuonScoreTool.h:77
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
CaloMuonScoreTool::m_output_node_names
std::vector< const char * > m_output_node_names
Definition: CaloMuonScoreTool.h:96
getLatestRuns.interval
interval
Definition: getLatestRuns.py:24
Trk::association
@ association
Definition: TrackingGeometry.h:46
CaloMuonScoreTool::m_CaloMuonEtaCut
Gaudi::Property< double > m_CaloMuonEtaCut
Definition: CaloMuonScoreTool.h:104
test_pyathena.parent
parent
Definition: test_pyathena.py:15
CaloMuonScoreTool::getMedian
float getMedian(std::vector< float > v) const
--> Copy is neccessary as the elements are reorded for the moment which would then break association ...
Definition: CaloMuonScoreTool.cxx:255
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
StoreGateSvc_t m_detStore
Pointer to StoreGate (detector store by default)
Definition: AthCommonDataStore.h:393
AthAlgTool::AthAlgTool
AthAlgTool()
Default constructor:
CaloMuonScoreTool::fillInputVectors
void fillInputVectors(std::unique_ptr< const Rec::ParticleCellAssociation > &association, std::vector< float > &eta, std::vector< float > &phi, std::vector< float > &energy, std::vector< int > &samplingId) const
Definition: CaloMuonScoreTool.cxx:134
SG::VarHandleKeyArray::renounce
virtual void renounce()=0
SG::HandleClassifier::type
std::conditional< std::is_base_of< SG::VarHandleKeyArray, T >::value, VarHandleKeyArrayType, type2 >::type type
Definition: HandleClassifier.h:54
doL1CaloHVCorrections.eta_bin
eta_bin
Definition: doL1CaloHVCorrections.py:368
merge_scale_histograms.doc
string doc
Definition: merge_scale_histograms.py:9
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:221
CaloMuonScoreTool::m_svc
ServiceHandle< AthOnnx::IOnnxRuntimeSvc > m_svc
Handle to AthOnnx::IOnnxRuntimeSvc.
Definition: CaloMuonScoreTool.h:90
CaloMuonScoreTool::m_modelFileName
Gaudi::Property< std::string > m_modelFileName
Definition: CaloMuonScoreTool.h:102
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
CaloMuonScoreTool::m_input_node_names
std::vector< const char * > m_input_node_names
Definition: CaloMuonScoreTool.h:94
python.PyAthena.v
v
Definition: PyAthena.py:154
a
TList * a
Definition: liststreamerinfos.cxx:10
h
Pythia8_RapidityOrderMPI.val
val
Definition: Pythia8_RapidityOrderMPI.py:14
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
SG::VarHandleBase::vhKey
SG::VarHandleKey & vhKey()
Return a non-const reference to the HandleKey.
Definition: StoreGate/src/VarHandleBase.cxx:623
doL1CaloHVCorrections.phi_bin
phi_bin
Definition: doL1CaloHVCorrections.py:369
python.Bindings.keys
keys
Definition: Control/AthenaPython/python/Bindings.py:798
ATLAS_THREAD_SAFE
#define ATLAS_THREAD_SAFE
Definition: checker_macros.h:211
AthCommonDataStore::declareGaudiProperty
Gaudi::Details::PropertyBase & declareGaudiProperty(Gaudi::Property< T > &hndl, const SG::VarHandleKeyType &)
specialization for handling Gaudi::Property<SG::VarHandleKey>
Definition: AthCommonDataStore.h:156
fitman.k
k
Definition: fitman.py:528
CaloMuonScoreTool::m_phiCut
Gaudi::Property< float > m_phiCut
Definition: CaloMuonScoreTool.h:82