ATLAS Offline Software
ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h
Go to the documentation of this file.
1 // Dear emacs, this is -*- c++ -*-
2 
3 /*
4  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
5 */
6 
7 #ifndef DITAURECTOOLS_DITAUONNXDISCRIMINANTTOOL_H
8 #define DITAURECTOOLS_DITAUONNXDISCRIMINANTTOOL_H
9 
11 #include <onnxruntime_cxx_api.h>
12 #include "AsgTools/AsgTool.h"
17 
18 
19 
20 namespace DiTauRecTools{
21 
24  , public asg::AsgTool
25 {
26 
30 
31 public:
32 
33  DiTauOnnxDiscriminantTool(const std::string& name);
34 
35  virtual ~DiTauOnnxDiscriminantTool();
36 
37  // initialize the tool
38  virtual StatusCode initialize() override;
39 
40  // calculate ID variables
41  virtual StatusCode execute(const xAOD::DiTauJet& xDiTau) const override;
42 
43 private:
44 
45  Gaudi::Property<std::string> m_onnxModelPath{this, "onnxModelPath", "TrigTauRec/00-11-02/dev/boosted_ditau_omni_model.onnx"};
46  Gaudi::Property<size_t> m_maxTracks{this, "maxTracks", 10};
47 
48  std::unique_ptr<Ort::Env> m_ort_env;
49  std::unique_ptr<Ort::Session> m_ort_session;
50  const std::vector<std::string> m_input_node_names = {"input_features", "input_points", "input_mask", "input_jet", "input_time"};
51  const std::vector<std::string> m_output_node_names = {"output_1", "output_2"};
52 
53  struct InferenceOutput {
54  std::vector<float> output_1;
55  std::vector<float> output_2;
56  };
57 
58  struct OnnxInputs {
59  std::vector<float> input_features;
60  std::vector<int64_t> input_features_shape;
61  std::vector<float> input_points;
62  std::vector<int64_t> input_points_shape;
63  std::vector<float> input_mask;
64  std::vector<int64_t> input_mask_shape;
65  std::vector<float> input_jet;
66  std::vector<int64_t> input_jet_shape;
67  std::vector<float> input_time;
68  std::vector<int64_t> input_time_shape;
69  };
70 
71  Ort::Value create_tensor(std::vector<float> &data, const std::vector<int64_t> &shape) const;
73  float nan_to_num(float value, float nan_replacement, float posinf_replacement, float neginf_replacement) const;
74  std::vector<float> flatten(const std::vector<std::vector<float>> &vec_2d) const;
75  std::vector<float> extract_points(const std::vector<std::vector<float>> &track_features) const;
76  std::vector<float> create_mask(const std::vector<std::vector<float>> &track_features) const;
77  float GetDiTauObjOnnxScore(const xAOD::DiTauJet& ditau) const;
78 
79  SG::ReadHandleKey<xAOD::DiTauJetContainer> m_ditauContainerKey {this, "DiTauContainerName", "DiTauJets", "DiTau container name"};
80  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_ditau_pt_DecorKey {this, "DiTauPtDecorName", "DiTauJets.ditau_pt", "Name of the DiTau Pt decoration"};
81  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_core_lead_DecorKey {this, "DiTauFCoreLeadName", "DiTauJets.f_core_lead", "Name of the Ditau leading subjet core energy fraction decoration"};
82  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_core_sublead_DecorKey {this, "DiTauFCoreSubLeadName", "DiTauJets.f_core_subl", "Name of the Ditau subleading subjet core energy fraction decoration"};
83  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_subjet_subl_DecorKey {this, "DiTauSubjetSublName", "DiTauJets.f_subjet_subl", "Name of the Ditau subleading subjet pt fraction decoration"};
84  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_subjets_DecorKey {this, "DiTauSubjetsName", "DiTauJets.f_subjets", "Name of the DiTau subjets fraction decoration"};
85  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_max_lead_DecorKey {this, "DiTauRMaxLeadName", "DiTauJets.R_max_lead", "Name of the Ditau Max dR distance track from leading subjet decoration"};
86  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_max_sublead_DecorKey {this, "DiTauRMaxSubleadName", "DiTauJets.R_max_subl", "Name of the Ditau Max dR distance track from subleading subjet decoration"};
87  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_n_track_DecorKey {this, "DiTauNTrackName", "DiTauJets.n_track", "Name of the Ditau number of tracks decoration"};
88  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_track_all_DecorKey{ this, "DiTauRTrackAllName", "DiTauJets.R_track_all", "Name of the Ditau DeltaR tracks over pt in the large region decoration"};
89  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_isotrack_DecorKey{ this, "DiTauRIsoTrackAllName", "DiTauJets.R_isotrack", "Name of the Ditau DeltaR isolated tracks over pt decoration"};
90  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_track_sublead_DecorKey{ this, "DiTauRTrackSubleadName", "DiTauJets.R_tracks_subl", "Name of the Ditau DeltaR tracks over pt in the large region of the subleading subjet decoration"};
91  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_M_core_lead_DecorKey{ this, "DiTauMCoreLeadName", "DiTauJets.m_core_lead", "Name of the Ditau mass of tracks in the core region of the leading subjet decoration"};
92  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_M_core_sublead_DecorKey{ this, "DiTauMCoreSubleadName", "DiTauJets.m_core_subl", "Name of the Ditau mass of tracks in the core region of the leading subjet decoration"};
93  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_M_track_lead_DecorKey{ this, "DiTauMTrackLeadName", "DiTauJets.m_tracks_lead", "Name of the Ditau mass of tracks in the leading subjet decoration"};
94  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_d0_leadtrack_lead_DecorKey{ this, "DiTauD0LeadTrackLeadName", "DiTauJets.d0_leadtrack_lead", "Name of the DiTau dR between the leading track within the lead subjet with respect to the lead subjet"};
95  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_d0_leadtrack_sublead_DecorKey{ this, "DiTauD0SubleadTrackLeadName", "DiTauJets.d0_leadtrack_subl", "Name of the DiTau dR between the leading track within the sublead subjet with respect to the sublead subjet"};
96  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_isotracks_DecorKey{ this, "DiTauFIsotracks", "DiTauJets.f_isotracks", "Name of the DiTau energy fraction carried by isolated tracks"};
97 
98 };
99 
100 }
101 #endif // DITAURECTOOLS_DITAUONNXDISCRIMINANTTOOL_H
102 
103 
DiTauRecTools::DiTauOnnxDiscriminantTool::m_ditau_pt_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_ditau_pt_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:80
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
asg::AsgTool
Base class for the dual-use tool implementation classes.
Definition: AsgTool.h:47
PropertyWrapper.h
DiTauRecTools::DiTauOnnxDiscriminantTool::m_ort_session
std::unique_ptr< Ort::Session > m_ort_session
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:49
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_points
std::vector< float > input_points
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:61
DiTauRecTools::DiTauOnnxDiscriminantTool::m_R_max_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_max_sublead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:86
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_time_shape
std::vector< int64_t > input_time_shape
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:68
DiTauRecTools::DiTauOnnxDiscriminantTool::InferenceOutput::output_1
std::vector< float > output_1
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:54
DiTauRecTools::DiTauOnnxDiscriminantTool::m_ditauContainerKey
SG::ReadHandleKey< xAOD::DiTauJetContainer > m_ditauContainerKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:79
athena.value
value
Definition: athena.py:124
DiTauRecTools::DiTauOnnxDiscriminantTool::m_M_track_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_track_lead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:93
DiTauJetContainer.h
DiTauRecTools::DiTauOnnxDiscriminantTool::m_output_node_names
const std::vector< std::string > m_output_node_names
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:51
DiTauRecTools::DiTauOnnxDiscriminantTool::run_inference
InferenceOutput run_inference(OnnxInputs &inputs) const
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:116
SG::ReadHandleKey
Property holding a SG store/key/clid from which a ReadHandle is made.
Definition: StoreGate/StoreGate/ReadHandleKey.h:39
DiTauRecTools::DiTauOnnxDiscriminantTool::InferenceOutput
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:53
DiTauRecTools::DiTauOnnxDiscriminantTool::initialize
virtual StatusCode initialize() override
Dummy implementation of the initialisation function.
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:26
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:58
DiTauRecTools::DiTauOnnxDiscriminantTool::create_mask
std::vector< float > create_mask(const std::vector< std::vector< float >> &track_features) const
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:102
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_mask
std::vector< float > input_mask
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:63
ReadDecorHandleKey.h
Property holding a SG store/key/clid/attr name from which a ReadDecorHandle is made.
DiTauRecTools::DiTauOnnxDiscriminantTool::m_f_core_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_core_lead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:81
DiTauRecTools::DiTauOnnxDiscriminantTool::m_d0_leadtrack_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_d0_leadtrack_lead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:94
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ReadHandleKey.h
Property holding a SG store/key/clid from which a ReadHandle is made.
DiTauRecTools::DiTauOnnxDiscriminantTool::m_d0_leadtrack_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_d0_leadtrack_sublead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:95
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_jet
std::vector< float > input_jet
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:65
DiTauRecTools::IDiTauToolBase
Definition: IDiTauToolBase.h:19
DiTauRecTools::DiTauOnnxDiscriminantTool::m_f_core_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_core_sublead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:82
DiTauRecTools::DiTauOnnxDiscriminantTool::m_M_core_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_core_lead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:91
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_mask_shape
std::vector< int64_t > input_mask_shape
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:64
DiTauRecTools::DiTauOnnxDiscriminantTool::m_f_subjet_subl_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_subjet_subl_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:83
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_time
std::vector< float > input_time
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:67
DiTauRecTools::DiTauOnnxDiscriminantTool
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:25
DiTauRecTools::DiTauOnnxDiscriminantTool::m_f_subjets_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_subjets_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:84
DiTauRecTools::DiTauOnnxDiscriminantTool::m_R_track_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_track_sublead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:90
DiTauRecTools::DiTauOnnxDiscriminantTool::create_tensor
Ort::Value create_tensor(std::vector< float > &data, const std::vector< int64_t > &shape) const
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:111
DiTauRecTools::DiTauOnnxDiscriminantTool::m_n_track_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_n_track_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:87
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_points_shape
std::vector< int64_t > input_points_shape
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:62
DiTauRecTools::DiTauOnnxDiscriminantTool::m_onnxModelPath
Gaudi::Property< std::string > m_onnxModelPath
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:45
DiTauRecTools::DiTauOnnxDiscriminantTool::m_ort_env
std::unique_ptr< Ort::Env > m_ort_env
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:48
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
DiTauRecTools
Implementation of boosted di-tau ID.
Definition: DiTauDiscriminantTool.h:31
DiTauRecTools::DiTauOnnxDiscriminantTool::m_M_core_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_core_sublead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:92
DiTauRecTools::DiTauOnnxDiscriminantTool::m_f_isotracks_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_isotracks_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:96
DiTauRecTools::DiTauOnnxDiscriminantTool::flatten
std::vector< float > flatten(const std::vector< std::vector< float >> &vec_2d) const
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:83
DiTauRecTools::DiTauOnnxDiscriminantTool::m_input_node_names
const std::vector< std::string > m_input_node_names
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:50
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_features
std::vector< float > input_features
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:59
ASG_TOOL_CLASS
#define ASG_TOOL_CLASS(CLASSNAME, INT1)
Definition: AsgToolMacros.h:68
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_features_shape
std::vector< int64_t > input_features_shape
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:60
xAOD::DiTauJet_v1
Definition: DiTauJet_v1.h:31
DiTauRecTools::DiTauOnnxDiscriminantTool::nan_to_num
float nan_to_num(float value, float nan_replacement, float posinf_replacement, float neginf_replacement) const
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:73
SG::ReadDecorHandleKey
Property holding a SG store/key/clid/attr name from which a ReadDecorHandle is made.
Definition: StoreGate/StoreGate/ReadDecorHandleKey.h:85
AsgTool.h
IDiTauToolBase.h
DiTauRecTools::DiTauOnnxDiscriminantTool::GetDiTauObjOnnxScore
float GetDiTauObjOnnxScore(const xAOD::DiTauJet &ditau) const
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:145
DiTauRecTools::DiTauOnnxDiscriminantTool::InferenceOutput::output_2
std::vector< float > output_2
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:55
DiTauRecTools::DiTauOnnxDiscriminantTool::m_maxTracks
Gaudi::Property< size_t > m_maxTracks
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:46
DiTauRecTools::DiTauOnnxDiscriminantTool::m_R_isotrack_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_isotrack_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:89
DiTauRecTools::DiTauOnnxDiscriminantTool::m_R_max_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_max_lead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:85
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_jet_shape
std::vector< int64_t > input_jet_shape
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:66
DiTauRecTools::DiTauOnnxDiscriminantTool::m_R_track_all_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_track_all_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:88
DiTauRecTools::DiTauOnnxDiscriminantTool::extract_points
std::vector< float > extract_points(const std::vector< std::vector< float >> &track_features) const
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:92
DiTauRecTools::DiTauOnnxDiscriminantTool::execute
virtual StatusCode execute(const xAOD::DiTauJet &xDiTau) const override
Declare the interface that the class provides.
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:63