ATLAS Offline Software
Loading...
Searching...
No Matches
CaloMuonScoreTool.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef CALOTRKMUIDTOOLS_CALOMUONSCORETOOL_H
6#define CALOTRKMUIDTOOLS_CALOMUONSCORETOOL_H
7
8#include <memory>
9#include <vector>
10
13#include "GaudiKernel/ServiceHandle.h"
14#include "GaudiKernel/ToolHandle.h"
17
36
37class CaloMuonScoreTool : public AthAlgTool, virtual public ICaloMuonScoreTool {
38public:
39 CaloMuonScoreTool(const std::string &type, const std::string &name, const IInterface *parent);
40 virtual ~CaloMuonScoreTool() = default;
41
42 virtual StatusCode initialize() override;
43
44 // Compute the muon score given a track particle
45 float getMuonScore(const xAOD::TrackParticle *trk, const CaloCellContainer *cells = nullptr,
46 const CaloExtensionCollection *extensionCache = nullptr) const override;
47
48private:
49 // run the ONNX inference on the input tensor
50 float runOnnxInference(std::vector<float> &tensor) const;
51
52 // unwrap phi values by mapping phi vectors that span the boundary
53 // between pi and -pi to their 2*pi complement
54 std::vector<float> unwrapPhiAngles(const std::vector<float> &v) const;
55
56 // fill vectors from the particle cell association
57 void fillInputVectors(std::unique_ptr<const Rec::ParticleCellAssociation> &association, std::vector<float> &eta,
58 std::vector<float> &phi, std::vector<float> &energy, std::vector<int> &samplingId) const;
59
60 // Compute the median of a vector of floats (can be even or odd in length)
63 float getMedian(std::vector<float> v) const;
64
65 // Given a vector of bins, return the index of the matching bin
66 int getBin(const float low_edge, const float up_edge, const int n_bins, float val) const;
67
68 // Given a calo sampling ID (as integer), return the corresponding "RGB"-like channel ID (0,1,2,3,4,5,6)
69 int channelForSamplingId(int &samplingId) const;
70
71 // for a given particle, consume vectors for eta, phi, energy, sampling ID, and return the input tensor to be used in ONNX
72 std::vector<float> getInputTensor(std::vector<float> &eta, std::vector<float> &phi, std::vector<float> &energy,
73 std::vector<int> &sampling) const;
74
75 Gaudi::Property<float> m_CaloCellAssociationConeSize{this, "CaloCellAssociationConeSize", 0.2,
76 "Size of the cone within which calo cells are associated with a track particle"};
77 Gaudi::Property<int> m_etaBins{this, "etaBins", 30, "Number of bins in eta"};
78 Gaudi::Property<int> m_phiBins{this, "phiBins", 30, "Number of bins in phi"};
79 Gaudi::Property<float> m_etaCut{
80 this, "etaCut", 0.25,
81 "Eta cut on the calorimeter cells associated with the track particle after centering of the calorimeter image"};
82 Gaudi::Property<float> m_phiCut{
83 this, "phiCut", 0.25,
84 "Phi cut on the calorimeter cells associated with the track particle after centering of the calorimeter image"};
85 Gaudi::Property<int> m_nChannels{this, "nChannels", 7, "Number of colour channels in the convolutional neural network"};
86
87 ToolHandle<Rec::IParticleCaloCellAssociationTool> m_caloCellAssociationTool{this, "ParticleCaloCellAssociationTool", ""};
88
90 ServiceHandle<AthOnnx::IOnnxRuntimeSvc> m_svc{this, "ONNXRuntimeSvc", "AthOnnx::OnnxRuntimeSvc", "CaloMuonScoreTool ONNXRuntimeSvc"};
91
92 std::unique_ptr<Ort::Session> m_session;
93
94 std::vector<const char *> m_input_node_names;
95
96 std::vector<const char *> m_output_node_names;
97
98 std::vector<int64_t> m_input_node_dims;
99
100 // This path needs to point to the ATLAS calibration area (https://atlas-groupdata.web.cern.ch/atlas-groupdata/)
101 // It needs to be a full path relative to the root of the calibration area, e.g. `CaloTrkMuIdTools/nnBased_201022/CaloMuonCNN_1.onnx`
102 Gaudi::Property<std::string> m_modelFileName{this, "ModelFileName", "CaloTrkMuIdTools/nnBased_201022/CaloMuonCNN_1.onnx"};
103
104 Gaudi::Property<double> m_CaloMuonEtaCut{this, "CaloMuonEtaCut", 1.0,
105 "Eta cut (absolute value) up to which a track particle's muon score will be calculated"};
106};
107
108#endif
Scalar eta() const
pseudorapidity method
Scalar phi() const
phi method
DataVector< Trk::CaloExtension > CaloExtensionCollection
AthAlgTool(const std::string &type, const std::string &name, const IInterface *parent)
Constructor with parameters:
Container class for CaloCell.
Gaudi::Property< float > m_phiCut
ServiceHandle< AthOnnx::IOnnxRuntimeSvc > m_svc
Handle to AthOnnx::IOnnxRuntimeSvc.
float runOnnxInference(std::vector< float > &tensor) const
std::unique_ptr< Ort::Session > m_session
virtual StatusCode initialize() override
int getBin(const float low_edge, const float up_edge, const int n_bins, float val) const
std::vector< const char * > m_input_node_names
float getMedian(std::vector< float > v) const
--> Copy is neccessary as the elements are reorded for the moment which would then break association ...
Gaudi::Property< int > m_nChannels
float getMuonScore(const xAOD::TrackParticle *trk, const CaloCellContainer *cells=nullptr, const CaloExtensionCollection *extensionCache=nullptr) const override
Gaudi::Property< float > m_etaCut
Gaudi::Property< std::string > m_modelFileName
int channelForSamplingId(int &samplingId) const
std::vector< float > getInputTensor(std::vector< float > &eta, std::vector< float > &phi, std::vector< float > &energy, std::vector< int > &sampling) const
Gaudi::Property< int > m_phiBins
virtual ~CaloMuonScoreTool()=default
std::vector< const char * > m_output_node_names
Gaudi::Property< float > m_CaloCellAssociationConeSize
std::vector< float > unwrapPhiAngles(const std::vector< float > &v) const
ToolHandle< Rec::IParticleCaloCellAssociationTool > m_caloCellAssociationTool
void fillInputVectors(std::unique_ptr< const Rec::ParticleCellAssociation > &association, std::vector< float > &eta, std::vector< float > &phi, std::vector< float > &energy, std::vector< int > &samplingId) const
Gaudi::Property< int > m_etaBins
std::vector< int64_t > m_input_node_dims
CaloMuonScoreTool(const std::string &type, const std::string &name, const IInterface *parent)
Gaudi::Property< double > m_CaloMuonEtaCut
TrackParticle_v1 TrackParticle
Reference the current persistent version: