ATLAS Offline Software
src/DiTauOnnxDiscriminantTool.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 
7 // Core include(s):
8 #include "AthLinks/ElementLink.h"
9 
10 
11 // EDM include(s):
12 
13 
14 
15 
16 using TrackParticleLinks_t = std::vector<ElementLink<xAOD::TrackParticleContainer>>;
17 
18 //=================================PUBLIC-PART==================================
19 //______________________________________________________________________________
20 DiTauOnnxDiscriminantTool::DiTauOnnxDiscriminantTool( const std::string& type, const std::string& name, const IInterface * parent) :
22 {
23  declareInterface<DiTauToolBase > (this);
24 }
25 
26 //______________________________________________________________________________
28 
29 //______________________________________________________________________________
31 {
32  ATH_MSG_INFO( "Initializing DiTauOnnxDiscriminantTool" );
33  ATH_MSG_INFO( "onnxModelPath: " << m_onnxModelPath );
34  ATH_CHECK( m_ditau_pt_DecorKey.initialize() );
35  ATH_CHECK( m_f_core_lead_DecorKey.initialize() );
36  ATH_CHECK( m_f_core_sublead_DecorKey.initialize() );
37  ATH_CHECK( m_f_subjet_subl_DecorKey.initialize() );
38  ATH_CHECK( m_f_subjets_DecorKey.initialize() );
39  ATH_CHECK( m_R_max_lead_DecorKey.initialize() );
40  ATH_CHECK( m_R_max_sublead_DecorKey.initialize() );
41  ATH_CHECK( m_n_track_DecorKey.initialize() );
42  ATH_CHECK( m_R_track_all_DecorKey.initialize() );
43  ATH_CHECK( m_R_isotrack_DecorKey.initialize() );
44  ATH_CHECK( m_R_track_sublead_DecorKey.initialize() );
45  ATH_CHECK( m_M_core_lead_DecorKey.initialize() );
46  ATH_CHECK( m_M_core_sublead_DecorKey.initialize() );
47  ATH_CHECK( m_M_track_lead_DecorKey.initialize() );
50  ATH_CHECK( m_f_isotracks_DecorKey.initialize() );
51  auto model_path = PathResolverFindCalibFile (m_onnxModelPath);
52  if (model_path.empty()) {
53  ATH_MSG_ERROR("Could not find model file: " << m_onnxModelPath);
54  return StatusCode::FAILURE;
55  }
56  m_ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "OnnxUtil");
57  Ort::SessionOptions session_options;
58  session_options.SetIntraOpNumThreads(1);
59  session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
60  m_ort_session = std::make_unique<Ort::Session>(*m_ort_env, model_path.c_str(), session_options);
61  return StatusCode::SUCCESS;
62 }
63 
65 {
66  ATH_MSG_INFO( "Finalizing DiTauOnnxDiscriminantTool" );
67  m_ort_session.reset();
68  m_ort_env.reset();
69  return StatusCode::SUCCESS;
70 }
71 
73 {
74  static const SG::Accessor<float> omni_scoreDec("omni_score");
75  xAOD::DiTauJet* xDitau = data->xAODDiTau;
76  ATH_MSG_DEBUG("Inferencing omni DiTau ID score...");
77  float score = GetDiTauObjOnnxScore(*xDitau);
78  ATH_MSG_DEBUG("DiTau ID score: " << score);
79  omni_scoreDec(*xDitau) = score;
80  return StatusCode::SUCCESS;
81 }
82 
83 std::vector<float> DiTauOnnxDiscriminantTool::flatten(const std::vector<std::vector<float>> &vec_2d) const{
84  std::vector<float> flattened;
85  flattened.reserve(vec_2d.size() * (vec_2d.empty() ? 0 : vec_2d[0].size()));
86  for (const auto &inner : vec_2d) {
87  flattened.insert(flattened.end(), inner.begin(), inner.end());
88  }
89  return flattened;
90 }
91 
92 std::vector<float> DiTauOnnxDiscriminantTool::extract_points(const std::vector<std::vector<float>> &track_features) const{
93  std::vector<float> points;
94  points.reserve(track_features.size() * 2);
95  for (const auto &track : track_features) {
96  points.push_back(track[0]); // delta_eta
97  points.push_back(track[1]); // delta_phi
98  }
99  return points;
100 }
101 
102 std::vector<float> DiTauOnnxDiscriminantTool::create_mask(const std::vector<std::vector<float>> &track_features) const{
103  std::vector<float> mask;
104  mask.reserve(track_features.size());
105  std::transform(track_features.begin(), track_features.end(), std::back_inserter(mask), [](const auto &track) {
106  return std::abs(track[2]) > 1e-6 ? 1.0f : 0.0f;
107  });
108  return mask;
109 }
110 
111 Ort::Value DiTauOnnxDiscriminantTool::create_tensor(std::vector<float> &data, const std::vector<int64_t> &shape) const{
112  Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
113  return Ort::Value::CreateTensor<float>(memory_info, data.data(), data.size(),shape.data(), shape.size());
114 }
115 
117  std::vector<Ort::Value> input_tensors;
118  input_tensors.reserve(m_input_node_names.size());
119  input_tensors.emplace_back(create_tensor(inputs.input_features, inputs.input_features_shape));
120  input_tensors.emplace_back(create_tensor(inputs.input_points, inputs.input_points_shape));
121  input_tensors.emplace_back(create_tensor(inputs.input_mask, inputs.input_mask_shape));
122  input_tensors.emplace_back(create_tensor(inputs.input_jet, inputs.input_jet_shape));
123  input_tensors.emplace_back(create_tensor(inputs.input_time, inputs.input_time_shape));
124 
125  std::vector<const char *> input_node_names;
126  input_node_names.reserve(m_input_node_names.size());
127  std::transform(m_input_node_names.begin(), m_input_node_names.end(), std::back_inserter(input_node_names), [](const std::string &name) { return name.c_str(); });
128 
129  std::vector<const char *> output_node_names;
130  output_node_names.reserve(m_output_node_names.size());
131  std::transform(m_output_node_names.begin(), m_output_node_names.end(), std::back_inserter(output_node_names), [](const std::string &name) { return name.c_str(); });
132 
133  auto output_tensors = m_ort_session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), input_tensors.data(), input_node_names.size(), output_node_names.data(), output_node_names.size());
134 
136  for (size_t i = 0; i < output_tensors.size(); ++i) {
137  const auto &tensor = output_tensors[i];
138  const size_t length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
139  const float *data = tensor.GetTensorData<float>();
140  (i == 0 ? output.output_1 : output.output_2) = std::vector<float>(data, data + length);
141  }
142  return output;
143 }
144 
146  // Decorators for reading the necessary features from the xAOD::DiTauJet object
163  // Accessors for reading the necessary features from the xAOD::TrackParticle object
164  static const SG::ConstAccessor< uint8_t > numberOfInrmstPxlLyrHitsAcc ("numberOfInnermostPixelLayerHits");
165  static const SG::ConstAccessor< uint8_t > numberOfPixelHitsAcc ("numberOfPixelHits");
166  static const SG::ConstAccessor< uint8_t > numberOfSCTHitsAcc ("numberOfSCTHits");
167  static const SG::ConstAccessor< float > z0Acc ("z0");
168  static const SG::ConstAccessor< float > d0Acc ("d0");
169  // Input features for Ditau tagger ONNX model
170  std::vector<float> jet_vars = {
171  R_max_leadAcc (ditau),
172  R_max_sublAcc (ditau),
173  R_tracks_sublAcc (ditau),
174  R_isotrackAcc (ditau),
175  d0_leadtrack_leadAcc (ditau),
176  d0_leadtrack_sublAcc (ditau),
177  f_core_leadAcc (ditau),
178  f_core_sublAcc (ditau),
179  f_subjet_sublAcc (ditau),
180  f_subjetAcc (ditau),
181  f_isotracksAcc (ditau),
182  M_core_leadAcc (ditau),
183  M_core_sublAcc (ditau),
184  M_tracks_leadAcc (ditau),
185  static_cast<float>( n_trackAcc (ditau)),
186  };
187  std::vector<int64_t> jet_shape = {1, static_cast<int64_t>(jet_vars.size())};
188 
189  const TrackParticleLinks_t &vTauTracks = ditau.trackLinks();
190  std::vector<std::vector<float>> track_features(m_maxTracks, std::vector<float>(11, 0.0f));
191 
192  float jet_eta = ditau.eta();
193  float jet_phi = ditau.phi();
194  size_t num_tracks = std::min(static_cast<size_t>(m_maxTracks), vTauTracks.size());
195 
196  for (size_t i = 0; i < num_tracks; ++i) {
197  const ElementLink<xAOD::TrackParticleContainer> &trackLink = vTauTracks[i];
198  if (!trackLink.isValid()) continue;
199  const xAOD::TrackParticle *xTrack = *trackLink;
200  float track_eta = xTrack->eta();
201  float track_phi = xTrack->phi();
202  float delta_eta = track_eta - jet_eta;
203  float delta_phi = std::remainder(track_phi - jet_phi, 2 * M_PI);
204  float delta_R = std::hypot(delta_eta, delta_phi);
205  float track_pt = static_cast<float>(xTrack->pt());
206  float pt_log = std::log(track_pt + 1e-8f);
207  float jet_pt = ditau_ptAcc(ditau);
208  float pt_ratio = track_pt / jet_pt;
209  float pt_ratio_log = (pt_ratio <= 1.0f) ? std::log(1.0f - pt_ratio + 1e-8f) : 0.0f;
210  float track_charge = xTrack->charge();
211 
212  track_features[i] = {
213  delta_eta,
214  delta_phi,
215  pt_log,
216  d0Acc(*xTrack),
217  pt_ratio_log,
218  z0Acc(*xTrack),
219  delta_R,
220  static_cast<float>(numberOfInrmstPxlLyrHitsAcc(*xTrack)),
221  static_cast<float>(numberOfPixelHitsAcc(*xTrack)),
222  static_cast<float>(numberOfSCTHitsAcc(*xTrack)),
223  track_charge
224  };
225  }
226  std::vector<int64_t> track_shape = {1, static_cast<int64_t>(m_maxTracks), 11};
227 
228  // Actual ONNX inference
230  flatten(track_features),
231  track_shape,
232  extract_points(track_features),
233  {1, track_shape[1], 2},
234  create_mask(track_features),
235  {1, track_shape[1]},
236  std::move(jet_vars),
237  std::move(jet_shape),
238  {0.0f},
239  {1, 1}
240  };
241  auto output = run_inference(inputs);
242  return output.output_1[1];
243 }
xAOD::TrackParticle_v1::pt
virtual double pt() const override final
The transverse momentum ( ) of the particle.
Definition: TrackParticle_v1.cxx:74
AllowedVariables::e
e
Definition: AsgElectronSelectorTool.cxx:37
DiTauOnnxDiscriminantTool::InferenceOutput
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:56
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
DiTauOnnxDiscriminantTool::m_onnxModelPath
Gaudi::Property< std::string > m_onnxModelPath
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:48
DiTauOnnxDiscriminantTool::extract_points
std::vector< float > extract_points(const std::vector< std::vector< float >> &track_features) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:92
DiTauOnnxDiscriminantTool::m_M_core_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_core_sublead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:94
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
SG::Accessor< float >
xAOD::TrackParticle_v1::charge
float charge() const
Returns the charge.
Definition: TrackParticle_v1.cxx:151
DiTauOnnxDiscriminantTool::m_R_max_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_max_lead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:87
DiTauOnnxDiscriminantTool::run_inference
InferenceOutput run_inference(OnnxInputs &inputs) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:116
xAOD::TrackParticle_v1::eta
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
Definition: TrackParticle_v1.cxx:78
eFEXNTuple.delta_R
def delta_R(eta1, phi1, eta2, phi2)
Definition: eFEXNTuple.py:21
min
constexpr double min()
Definition: ap_fixedTest.cxx:26
DiTauOnnxDiscriminantTool::m_d0_leadtrack_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_d0_leadtrack_sublead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:97
DiTauOnnxDiscriminantTool::m_input_node_names
const std::vector< std::string > m_input_node_names
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:53
M_PI
#define M_PI
Definition: ActiveFraction.h:11
DiTauOnnxDiscriminantTool::m_R_isotrack_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_isotrack_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:91
DiTauOnnxDiscriminantTool::m_R_track_all_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_track_all_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:90
TrackParticleLinks_t
std::vector< ElementLink< xAOD::TrackParticleContainer > > TrackParticleLinks_t
Definition: src/DiTauOnnxDiscriminantTool.cxx:16
DiTauOnnxDiscriminantTool::m_M_core_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_core_lead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:93
SG::ConstAccessor< uint8_t >
xAOD::DiTauJet_v1::eta
virtual double eta() const
The pseudorapidity ( ) of the particle.
DiTauOnnxDiscriminantTool::m_d0_leadtrack_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_d0_leadtrack_lead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:96
DiTauOnnxDiscriminantTool::create_tensor
Ort::Value create_tensor(std::vector< float > &data, const std::vector< int64_t > &shape) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:111
DiTauOnnxDiscriminantTool::flatten
std::vector< float > flatten(const std::vector< std::vector< float >> &vec_2d) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:83
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
DiTauOnnxDiscriminantTool::m_ditau_pt_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_ditau_pt_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:82
python.utils.AtlRunQueryLookup.mask
string mask
Definition: AtlRunQueryLookup.py:459
DiTauOnnxDiscriminantTool::~DiTauOnnxDiscriminantTool
virtual ~DiTauOnnxDiscriminantTool()
DiTauOnnxDiscriminantTool.h
DiTauOnnxDiscriminantTool::m_f_subjet_subl_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_subjet_subl_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:85
xAOD::DiTauJet_v1::phi
virtual double phi() const
The azimuthal angle ( ) of the particle.
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
SG::ReadDecorHandle
Handle class for reading a decoration on an object.
Definition: StoreGate/StoreGate/ReadDecorHandle.h:94
DiTauOnnxDiscriminantTool::GetDiTauObjOnnxScore
float GetDiTauObjOnnxScore(const xAOD::DiTauJet &ditau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:145
lumiFormat.i
int i
Definition: lumiFormat.py:85
DiTauOnnxDiscriminantTool::m_f_core_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_core_sublead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:84
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
Amg::transform
Amg::Vector3D transform(Amg::Vector3D &v, Amg::Transform3D &tr)
Transform a point from a Trasformation3D.
Definition: GeoPrimitivesHelpers.h:156
test_pyathena.parent
parent
Definition: test_pyathena.py:15
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
hist_file_dump.f
f
Definition: hist_file_dump.py:140
DiTauOnnxDiscriminantTool::finalize
virtual StatusCode finalize() override
Finalizer.
Definition: src/DiTauOnnxDiscriminantTool.cxx:64
DiTauOnnxDiscriminantTool::m_M_track_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_track_lead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:95
merge.output
output
Definition: merge.py:16
DiTauOnnxDiscriminantTool::OnnxInputs
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:61
DiTauOnnxDiscriminantTool::m_ort_session
std::unique_ptr< Ort::Session > m_ort_session
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:52
DiTauOnnxDiscriminantTool::m_n_track_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_n_track_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:89
DiTauOnnxDiscriminantTool::m_f_core_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_core_lead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:83
remainder
std::vector< std::string > remainder(const std::vector< std::string > &v1, const std::vector< std::string > &v2)
Definition: compareFlatTrees.cxx:44
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
DiTauOnnxDiscriminantTool::DiTauOnnxDiscriminantTool
DiTauOnnxDiscriminantTool(const std::string &type, const std::string &name, const IInterface *parent)
Definition: src/DiTauOnnxDiscriminantTool.cxx:20
DiTauOnnxDiscriminantTool::create_mask
std::vector< float > create_mask(const std::vector< std::vector< float >> &track_features) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:102
DiTauToolBase
The base class for all tau tools.
Definition: DiTauToolBase.h:20
DiTauOnnxDiscriminantTool::m_output_node_names
const std::vector< std::string > m_output_node_names
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:54
DiTauOnnxDiscriminantTool::execute
virtual StatusCode execute(DiTauCandidateData *data, const EventContext &ctx) const override
Execute - called for each Ditau candidate.
Definition: src/DiTauOnnxDiscriminantTool.cxx:72
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:283
xAOD::score
@ score
Definition: TrackingPrimitives.h:514
DiTauOnnxDiscriminantTool::m_ort_env
std::unique_ptr< Ort::Env > m_ort_env
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:51
python.general.flattened
def flattened(l)
Definition: general.py:125
TauGNNUtils::Variables::Track::pt_log
bool pt_log(const xAOD::TauJet &, const xAOD::TauTrack &track, double &out)
Definition: TauGNNUtils.cxx:478
eFEXNTuple.delta_phi
def delta_phi(phi1, phi2)
Definition: eFEXNTuple.py:15
DiTauOnnxDiscriminantTool::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: src/DiTauOnnxDiscriminantTool.cxx:30
python.CaloCondTools.log
log
Definition: CaloCondTools.py:20
xAOD::DiTauJet_v1
Definition: DiTauJet_v1.h:31
DiTauOnnxDiscriminantTool::m_f_subjets_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_subjets_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:86
Trk::jet_phi
@ jet_phi
Definition: JetVtxParamDefs.h:28
DiTauOnnxDiscriminantTool::m_R_track_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_track_sublead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:92
DiTauCandidateData
Definition: DiTauCandidateData.h:15
xAOD::track
@ track
Definition: TrackingPrimitives.h:513
xAOD::TrackParticle_v1
Class describing a TrackParticle.
Definition: TrackParticle_v1.h:43
DiTauOnnxDiscriminantTool::m_f_isotracks_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_isotracks_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:98
xAOD::DiTauJet_v1::trackLinks
const TrackParticleLinks_t & trackLinks() const
length
double length(const pvec &v)
Definition: FPGATrackSimLLPDoubletHoughTransformTool.cxx:26
jobOptions.points
points
Definition: jobOptions.GenevaPy8_Zmumu.py:97
DiTauOnnxDiscriminantTool::m_maxTracks
Gaudi::Property< size_t > m_maxTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:49
xAOD::TrackParticle_v1::phi
virtual double phi() const override final
The azimuthal angle ( ) of the particle (has range to .)
DiTauOnnxDiscriminantTool::m_R_max_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_max_sublead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:88