ATLAS Offline Software
DiTauRec/DiTauOnnxDiscriminantTool.h
Go to the documentation of this file.
1 // Dear emacs, this is -*- c++ -*-
2 
3 /*
4  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
5 */
6 
7 #pragma once
8 
9 // EDM include(s):
10 #include "xAODTau/TauxAODHelpers.h"
11 #include "xAODTau/DiTauJet.h"
12 #include "DiTauToolBase.h"
13 #include "GaudiKernel/ToolHandle.h"
15 
21 
24 
25 #include <onnxruntime_cxx_api.h>
26 
27 
29  : public DiTauToolBase
30 {
31 public:
32 
33  DiTauOnnxDiscriminantTool( const std::string& type, const std::string& name, const IInterface * parent);
34 
36 
37  // initialize the tool
38  virtual StatusCode initialize() override;
39 
40  //finalize the tool
41  virtual StatusCode finalize() override;
42 
43  // calculate ID variables
44  virtual StatusCode execute(DiTauCandidateData * data, const EventContext& ctx) const override;
45 
46 private:
47 
48  Gaudi::Property<std::string> m_onnxModelPath {this, "onnxModelPath", "TrigTauRec/00-11-02/dev/boosted_ditau_omni_model.onnx"};
49  Gaudi::Property<size_t> m_maxTracks {this, "maxTracks", 10};
50 
51  std::unique_ptr<Ort::Env> m_ort_env;
52  std::unique_ptr<Ort::Session> m_ort_session;
53  const std::vector<std::string> m_input_node_names = {"input_features", "input_points", "input_mask", "input_jet", "input_time"};
54  const std::vector<std::string> m_output_node_names = {"output_1", "output_2"};
55 
56  struct InferenceOutput {
57  std::vector<float> output_1;
58  std::vector<float> output_2;
59  };
60 
61  struct OnnxInputs {
62  std::vector<float> input_features;
63  std::vector<int64_t> input_features_shape;
64  std::vector<float> input_points;
65  std::vector<int64_t> input_points_shape;
66  std::vector<float> input_mask;
67  std::vector<int64_t> input_mask_shape;
68  std::vector<float> input_jet;
69  std::vector<int64_t> input_jet_shape;
70  std::vector<float> input_time;
71  std::vector<int64_t> input_time_shape;
72  };
73 
74  Ort::Value create_tensor(std::vector<float> &data, const std::vector<int64_t> &shape) const;
76  std::vector<float> flatten(const std::vector<std::vector<float>> &vec_2d) const;
77  std::vector<float> extract_points(const std::vector<std::vector<float>> &track_features) const;
78  std::vector<float> create_mask(const std::vector<std::vector<float>> &track_features) const;
79  float GetDiTauObjOnnxScore(const xAOD::DiTauJet& ditau) const;
80 
81  // ReadDecorHandleKeys for the DiTau decorations
82  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_ditau_pt_DecorKey { this, "DiTauPtDecorName", "DiTauJets.ditau_pt", "Name of the DiTau Pt decoration"};
83  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_core_lead_DecorKey { this, "DiTauFCoreLeadName", "DiTauJets.f_core_lead", "Name of the Ditau leading subjet core energy fraction decoration"};
84  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_core_sublead_DecorKey { this, "DiTauFCoreSubLeadName", "DiTauJets.f_core_subl", "Name of the Ditau subleading subjet core energy fraction decoration"};
85  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_subjet_subl_DecorKey { this, "DiTauSubjetSublName", "DiTauJets.f_subjet_subl", "Name of the Ditau subleading subjet pt fraction decoration"};
86  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_subjets_DecorKey { this, "DiTauSubjetsName", "DiTauJets.f_subjets", "Name of the DiTau subjets fraction decoration"};
87  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_max_lead_DecorKey { this, "DiTauRMaxLeadName", "DiTauJets.R_max_lead", "Name of the Ditau Max dR distance track from leading subjet decoration"};
88  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_max_sublead_DecorKey { this, "DiTauRMaxSubleadName", "DiTauJets.R_max_subl", "Name of the Ditau Max dR distance track from subleading subjet decoration"};
89  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_n_track_DecorKey { this, "DiTauNTrackName", "DiTauJets.n_track", "Name of the Ditau number of tracks decoration"};
90  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_track_all_DecorKey { this, "DiTauRTrackAllName", "DiTauJets.R_track_all", "Name of the Ditau DeltaR tracks over pt in the large region decoration"};
91  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_isotrack_DecorKey { this, "DiTauRIsoTrackAllName", "DiTauJets.R_isotrack", "Name of the Ditau DeltaR isolated tracks over pt decoration"};
92  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_track_sublead_DecorKey { this, "DiTauRTrackSubleadName", "DiTauJets.R_tracks_subl", "Name of the Ditau DeltaR tracks over pt in the large region of the subleading subjet decoration"};
93  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_M_core_lead_DecorKey { this, "DiTauMCoreLeadName", "DiTauJets.m_core_lead", "Name of the Ditau mass of tracks in the core region of the leading subjet decoration"};
94  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_M_core_sublead_DecorKey { this, "DiTauMCoreSubleadName", "DiTauJets.m_core_subl", "Name of the Ditau mass of tracks in the core region of the leading subjet decoration"};
95  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_M_track_lead_DecorKey { this, "DiTauMTrackLeadName", "DiTauJets.m_tracks_lead", "Name of the Ditau mass of tracks in the leading subjet decoration"};
96  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_d0_leadtrack_lead_DecorKey { this, "DiTauD0LeadTrackLeadName", "DiTauJets.d0_leadtrack_lead", "Name of the DiTau dR between the leading track within the lead subjet with respect to the lead subjet"};
97  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_d0_leadtrack_sublead_DecorKey { this, "DiTauD0SubleadTrackLeadName", "DiTauJets.d0_leadtrack_subl", "Name of the DiTau dR between the leading track within the sublead subjet with respect to the sublead subjet"};
98  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_isotracks_DecorKey { this, "DiTauFIsotracks", "DiTauJets.f_isotracks", "Name of the DiTau energy fraction carried by isolated tracks"};
99 };
DiTauOnnxDiscriminantTool::InferenceOutput
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:56
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
PropertyWrapper.h
DiTauOnnxDiscriminantTool::m_onnxModelPath
Gaudi::Property< std::string > m_onnxModelPath
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:48
DiTauOnnxDiscriminantTool::OnnxInputs::input_time_shape
std::vector< int64_t > input_time_shape
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:71
DiTauOnnxDiscriminantTool::extract_points
std::vector< float > extract_points(const std::vector< std::vector< float >> &track_features) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:92
DiTauOnnxDiscriminantTool::m_M_core_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_core_sublead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:94
DiTauToolBase.h
DiTauOnnxDiscriminantTool::OnnxInputs::input_mask_shape
std::vector< int64_t > input_mask_shape
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:67
DiTauOnnxDiscriminantTool::m_R_max_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_max_lead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:87
DiTauOnnxDiscriminantTool::run_inference
InferenceOutput run_inference(OnnxInputs &inputs) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:116
DiTauOnnxDiscriminantTool::m_d0_leadtrack_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_d0_leadtrack_sublead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:97
DiTauOnnxDiscriminantTool::m_input_node_names
const std::vector< std::string > m_input_node_names
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:53
DiTauOnnxDiscriminantTool::m_R_isotrack_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_isotrack_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:91
DiTauOnnxDiscriminantTool::m_R_track_all_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_track_all_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:90
DiTauOnnxDiscriminantTool::m_M_core_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_core_lead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:93
DiTauOnnxDiscriminantTool::OnnxInputs::input_time
std::vector< float > input_time
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:70
DiTauOnnxDiscriminantTool::m_d0_leadtrack_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_d0_leadtrack_lead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:96
DiTauOnnxDiscriminantTool::create_tensor
Ort::Value create_tensor(std::vector< float > &data, const std::vector< int64_t > &shape) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:111
DiTauOnnxDiscriminantTool::flatten
std::vector< float > flatten(const std::vector< std::vector< float >> &vec_2d) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:83
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
DiTauOnnxDiscriminantTool::m_ditau_pt_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_ditau_pt_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:82
DiTauOnnxDiscriminantTool::~DiTauOnnxDiscriminantTool
virtual ~DiTauOnnxDiscriminantTool()
ReadDecorHandleKey.h
Property holding a SG store/key/clid/attr name from which a ReadDecorHandle is made.
DiTauOnnxDiscriminantTool::m_f_subjet_subl_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_subjet_subl_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:85
DiTauOnnxDiscriminantTool::InferenceOutput::output_2
std::vector< float > output_2
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:58
DiTauOnnxDiscriminantTool::OnnxInputs::input_features_shape
std::vector< int64_t > input_features_shape
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:63
DiTauOnnxDiscriminantTool::GetDiTauObjOnnxScore
float GetDiTauObjOnnxScore(const xAOD::DiTauJet &ditau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:145
DiTauOnnxDiscriminantTool::OnnxInputs::input_mask
std::vector< float > input_mask
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:66
DiTauOnnxDiscriminantTool::m_f_core_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_core_sublead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:84
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ReadHandleKey.h
Property holding a SG store/key/clid from which a ReadHandle is made.
TauxAODHelpers.h
test_pyathena.parent
parent
Definition: test_pyathena.py:15
DiTauOnnxDiscriminantTool::OnnxInputs::input_features
std::vector< float > input_features
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:62
DiTauOnnxDiscriminantTool::finalize
virtual StatusCode finalize() override
Finalizer.
Definition: src/DiTauOnnxDiscriminantTool.cxx:64
DiTauJet.h
DiTauOnnxDiscriminantTool::m_M_track_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_track_lead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:95
DiTauOnnxDiscriminantTool::OnnxInputs::input_jet_shape
std::vector< int64_t > input_jet_shape
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:69
DiTauOnnxDiscriminantTool::OnnxInputs
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:61
PathResolver.h
DiTauOnnxDiscriminantTool::m_ort_session
std::unique_ptr< Ort::Session > m_ort_session
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:52
DiTauOnnxDiscriminantTool::m_n_track_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_n_track_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:89
DiTauOnnxDiscriminantTool::m_f_core_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_core_lead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:83
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
DiTauOnnxDiscriminantTool::DiTauOnnxDiscriminantTool
DiTauOnnxDiscriminantTool(const std::string &type, const std::string &name, const IInterface *parent)
Definition: src/DiTauOnnxDiscriminantTool.cxx:20
ReadHandle.h
Handle class for reading from StoreGate.
DiTauOnnxDiscriminantTool::create_mask
std::vector< float > create_mask(const std::vector< std::vector< float >> &track_features) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:102
DiTauToolBase
The base class for all tau tools.
Definition: DiTauToolBase.h:20
DiTauOnnxDiscriminantTool::m_output_node_names
const std::vector< std::string > m_output_node_names
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:54
DiTauOnnxDiscriminantTool::OnnxInputs::input_points
std::vector< float > input_points
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:64
DiTauOnnxDiscriminantTool::execute
virtual StatusCode execute(DiTauCandidateData *data, const EventContext &ctx) const override
Execute - called for each Ditau candidate.
Definition: src/DiTauOnnxDiscriminantTool.cxx:72
DiTauOnnxDiscriminantTool::OnnxInputs::input_points_shape
std::vector< int64_t > input_points_shape
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:65
TrackParticle.h
DiTauOnnxDiscriminantTool::m_ort_env
std::unique_ptr< Ort::Env > m_ort_env
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:51
DiTauOnnxDiscriminantTool::OnnxInputs::input_jet
std::vector< float > input_jet
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:68
DiTauOnnxDiscriminantTool::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: src/DiTauOnnxDiscriminantTool.cxx:30
DiTauOnnxDiscriminantTool
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:30
ReadDecorHandle.h
Handle class for reading a decoration on an object.
xAOD::DiTauJet_v1
Definition: DiTauJet_v1.h:31
DiTauOnnxDiscriminantTool::InferenceOutput::output_1
std::vector< float > output_1
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:57
DiTauOnnxDiscriminantTool::m_f_subjets_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_subjets_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:86
DiTauOnnxDiscriminantTool::m_R_track_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_track_sublead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:92
DiTauCandidateData
Definition: DiTauCandidateData.h:15
DiTauOnnxDiscriminantTool::m_f_isotracks_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_isotracks_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:98
SG::ReadDecorHandleKey
Property holding a SG store/key/clid/attr name from which a ReadDecorHandle is made.
Definition: StoreGate/StoreGate/ReadDecorHandleKey.h:85
TrackParticleContainer.h
DiTauOnnxDiscriminantTool::m_maxTracks
Gaudi::Property< size_t > m_maxTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:49
DiTauOnnxDiscriminantTool::m_R_max_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_max_sublead_DecorKey
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:88