ATLAS Offline Software
ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h
Go to the documentation of this file.
1 // Dear emacs, this is -*- c++ -*-
2 
3 /*
4  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
5 */
6 
7 #ifndef DITAURECTOOLS_DITAUONNXDISCRIMINANTTOOL_H
8 #define DITAURECTOOLS_DITAUONNXDISCRIMINANTTOOL_H
9 
11 #include <onnxruntime_cxx_api.h>
12 #include "AsgTools/AsgTool.h"
17 
18 namespace DiTauRecTools{
19 
22  , public asg::AsgTool
23 {
24 
28 
29 public:
30 
31  DiTauOnnxDiscriminantTool(const std::string& name);
32 
33  virtual ~DiTauOnnxDiscriminantTool();
34 
35  // initialize the tool
36  virtual StatusCode initialize() override;
37 
38  // calculate ID variables
39  virtual StatusCode execute(const xAOD::DiTauJet& xDiTau) const override;
40 
41 private:
42 
43  Gaudi::Property<std::string> m_onnxModelPath{this, "onnxModelPath", "TrigTauRec/00-11-02/dev/boosted_ditau_omni_model.onnx"};
44  Gaudi::Property<size_t> m_maxTracks{this, "maxTracks", 10};
45 
46  std::unique_ptr<Ort::Env> m_ort_env;
47  std::unique_ptr<Ort::Session> m_ort_session;
48  const std::vector<std::string> m_input_node_names = {"input_features", "input_points", "input_mask", "input_jet", "input_time"};
49  const std::vector<std::string> m_output_node_names = {"output_1", "output_2"};
50 
51  struct InferenceOutput {
52  std::vector<float> output_1;
53  std::vector<float> output_2;
54  };
55 
56  struct OnnxInputs {
57  std::vector<float> input_features;
58  std::vector<int64_t> input_features_shape;
59  std::vector<float> input_points;
60  std::vector<int64_t> input_points_shape;
61  std::vector<float> input_mask;
62  std::vector<int64_t> input_mask_shape;
63  std::vector<float> input_jet;
64  std::vector<int64_t> input_jet_shape;
65  std::vector<float> input_time;
66  std::vector<int64_t> input_time_shape;
67  };
68 
69  Ort::Value create_tensor(std::vector<float> &data, const std::vector<int64_t> &shape) const;
71  float nan_to_num(float value, float nan_replacement, float posinf_replacement, float neginf_replacement) const;
72  std::vector<float> flatten(const std::vector<std::vector<float>> &vec_2d) const;
73  std::vector<float> extract_points(const std::vector<std::vector<float>> &track_features) const;
74  std::vector<float> create_mask(const std::vector<std::vector<float>> &track_features) const;
75  float GetDiTauObjOnnxScore(const xAOD::DiTauJet& ditau) const;
76 
77  SG::ReadHandleKey<xAOD::DiTauJetContainer> m_ditauContainerKey {this, "DiTauContainerName", "DiTauJets", "DiTau container name"};
78  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_ditau_pt_DecorKey {this, "DiTauPtDecorName", "DiTauJets.ditau_pt", "Name of the DiTau Pt decoration"};
79  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_core_lead_DecorKey {this, "DiTauFCoreLeadName", "DiTauJets.f_core_lead", "Name of the Ditau leading subjet core energy fraction decoration"};
80  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_core_sublead_DecorKey {this, "DiTauFCoreSubLeadName", "DiTauJets.f_core_subl", "Name of the Ditau subleading subjet core energy fraction decoration"};
81  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_subjet_subl_DecorKey {this, "DiTauSubjetSublName", "DiTauJets.f_subjet_subl", "Name of the Ditau subleading subjet pt fraction decoration"};
82  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_subjets_DecorKey {this, "DiTauSubjetsName", "DiTauJets.f_subjets", "Name of the DiTau subjets fraction decoration"};
83  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_max_lead_DecorKey {this, "DiTauRMaxLeadName", "DiTauJets.R_max_lead", "Name of the Ditau Max dR distance track from leading subjet decoration"};
84  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_max_sublead_DecorKey {this, "DiTauRMaxSubleadName", "DiTauJets.R_max_subl", "Name of the Ditau Max dR distance track from subleading subjet decoration"};
85  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_n_track_DecorKey {this, "DiTauNTrackName", "DiTauJets.n_track", "Name of the Ditau number of tracks decoration"};
86  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_track_all_DecorKey{ this, "DiTauRTrackAllName", "DiTauJets.R_track_all", "Name of the Ditau DeltaR tracks over pt in the large region decoration"};
87  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_isotrack_DecorKey{ this, "DiTauRIsoTrackAllName", "DiTauJets.R_isotrack", "Name of the Ditau DeltaR isolated tracks over pt decoration"};
88  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_R_track_sublead_DecorKey{ this, "DiTauRTrackSubleadName", "DiTauJets.R_tracks_subl", "Name of the Ditau DeltaR tracks over pt in the large region of the subleading subjet decoration"};
89  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_M_core_lead_DecorKey{ this, "DiTauMCoreLeadName", "DiTauJets.m_core_lead", "Name of the Ditau mass of tracks in the core region of the leading subjet decoration"};
90  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_M_core_sublead_DecorKey{ this, "DiTauMCoreSubleadName", "DiTauJets.m_core_subl", "Name of the Ditau mass of tracks in the core region of the leading subjet decoration"};
91  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_M_track_lead_DecorKey{ this, "DiTauMTrackLeadName", "DiTauJets.m_tracks_lead", "Name of the Ditau mass of tracks in the leading subjet decoration"};
92  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_d0_leadtrack_lead_DecorKey{ this, "DiTauD0LeadTrackLeadName", "DiTauJets.d0_leadtrack_lead", "Name of the DiTau dR between the leading track within the lead subjet with respect to the lead subjet"};
93  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_d0_leadtrack_sublead_DecorKey{ this, "DiTauD0SubleadTrackLeadName", "DiTauJets.d0_leadtrack_subl", "Name of the DiTau dR between the leading track within the sublead subjet with respect to the sublead subjet"};
94  SG::ReadDecorHandleKey<xAOD::DiTauJetContainer> m_f_isotracks_DecorKey{ this, "DiTauFIsotracks", "DiTauJets.f_isotracks", "Name of the DiTau energy fraction carried by isolated tracks"};
95 
96 };
97 
98 }
99 #endif // DITAURECTOOLS_DITAUONNXDISCRIMINANTTOOL_H
100 
101 
DiTauRecTools::DiTauOnnxDiscriminantTool::m_ditau_pt_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_ditau_pt_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:78
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
asg::AsgTool
Base class for the dual-use tool implementation classes.
Definition: AsgTool.h:47
PropertyWrapper.h
DiTauRecTools::DiTauOnnxDiscriminantTool::m_ort_session
std::unique_ptr< Ort::Session > m_ort_session
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:47
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_points
std::vector< float > input_points
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:59
DiTauRecTools::DiTauOnnxDiscriminantTool::m_R_max_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_max_sublead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:84
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_time_shape
std::vector< int64_t > input_time_shape
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:66
DiTauRecTools::DiTauOnnxDiscriminantTool::InferenceOutput::output_1
std::vector< float > output_1
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:52
DiTauRecTools::DiTauOnnxDiscriminantTool::m_ditauContainerKey
SG::ReadHandleKey< xAOD::DiTauJetContainer > m_ditauContainerKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:77
athena.value
value
Definition: athena.py:124
DiTauRecTools::DiTauOnnxDiscriminantTool::m_M_track_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_track_lead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:91
DiTauJetContainer.h
DiTauRecTools::DiTauOnnxDiscriminantTool::m_output_node_names
const std::vector< std::string > m_output_node_names
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:49
DiTauRecTools::DiTauOnnxDiscriminantTool::run_inference
InferenceOutput run_inference(OnnxInputs &inputs) const
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:117
SG::ReadHandleKey
Property holding a SG store/key/clid from which a ReadHandle is made.
Definition: StoreGate/StoreGate/ReadHandleKey.h:39
DiTauRecTools::DiTauOnnxDiscriminantTool::InferenceOutput
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:51
DiTauRecTools::DiTauOnnxDiscriminantTool::initialize
virtual StatusCode initialize() override
Dummy implementation of the initialisation function.
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:27
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:56
DiTauRecTools::DiTauOnnxDiscriminantTool::create_mask
std::vector< float > create_mask(const std::vector< std::vector< float >> &track_features) const
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:103
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_mask
std::vector< float > input_mask
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:61
ReadDecorHandleKey.h
Property holding a SG store/key/clid/attr name from which a ReadDecorHandle is made.
DiTauRecTools::DiTauOnnxDiscriminantTool::m_f_core_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_core_lead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:79
DiTauRecTools::DiTauOnnxDiscriminantTool::m_d0_leadtrack_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_d0_leadtrack_lead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:92
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ReadHandleKey.h
Property holding a SG store/key/clid from which a ReadHandle is made.
DiTauRecTools::DiTauOnnxDiscriminantTool::m_d0_leadtrack_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_d0_leadtrack_sublead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:93
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_jet
std::vector< float > input_jet
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:63
DiTauRecTools::IDiTauToolBase
Definition: IDiTauToolBase.h:20
DiTauRecTools::DiTauOnnxDiscriminantTool::m_f_core_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_core_sublead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:80
DiTauRecTools::DiTauOnnxDiscriminantTool::m_M_core_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_core_lead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:89
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_mask_shape
std::vector< int64_t > input_mask_shape
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:62
DiTauRecTools::DiTauOnnxDiscriminantTool::m_f_subjet_subl_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_subjet_subl_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:81
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_time
std::vector< float > input_time
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:65
DiTauRecTools::DiTauOnnxDiscriminantTool
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:23
DiTauRecTools::DiTauOnnxDiscriminantTool::m_f_subjets_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_subjets_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:82
DiTauRecTools::DiTauOnnxDiscriminantTool::m_R_track_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_track_sublead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:88
DiTauRecTools::DiTauOnnxDiscriminantTool::create_tensor
Ort::Value create_tensor(std::vector< float > &data, const std::vector< int64_t > &shape) const
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:112
DiTauRecTools::DiTauOnnxDiscriminantTool::m_n_track_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_n_track_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:85
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_points_shape
std::vector< int64_t > input_points_shape
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:60
DiTauRecTools::DiTauOnnxDiscriminantTool::m_onnxModelPath
Gaudi::Property< std::string > m_onnxModelPath
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:43
DiTauRecTools::DiTauOnnxDiscriminantTool::m_ort_env
std::unique_ptr< Ort::Env > m_ort_env
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:46
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
DiTauRecTools
Implementation of boosted di-tau ID.
Definition: DiTauDiscriminantTool.h:31
DiTauRecTools::DiTauOnnxDiscriminantTool::m_M_core_sublead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_M_core_sublead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:90
DiTauRecTools::DiTauOnnxDiscriminantTool::m_f_isotracks_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_f_isotracks_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:94
DiTauRecTools::DiTauOnnxDiscriminantTool::flatten
std::vector< float > flatten(const std::vector< std::vector< float >> &vec_2d) const
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:84
DiTauRecTools::DiTauOnnxDiscriminantTool::m_input_node_names
const std::vector< std::string > m_input_node_names
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:48
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_features
std::vector< float > input_features
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:57
ASG_TOOL_CLASS
#define ASG_TOOL_CLASS(CLASSNAME, INT1)
Definition: AsgToolMacros.h:68
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_features_shape
std::vector< int64_t > input_features_shape
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:58
xAOD::DiTauJet_v1
Definition: DiTauJet_v1.h:31
DiTauRecTools::DiTauOnnxDiscriminantTool::nan_to_num
float nan_to_num(float value, float nan_replacement, float posinf_replacement, float neginf_replacement) const
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:74
SG::ReadDecorHandleKey
Property holding a SG store/key/clid/attr name from which a ReadDecorHandle is made.
Definition: StoreGate/StoreGate/ReadDecorHandleKey.h:85
AsgTool.h
IDiTauToolBase.h
DiTauRecTools::DiTauOnnxDiscriminantTool::GetDiTauObjOnnxScore
float GetDiTauObjOnnxScore(const xAOD::DiTauJet &ditau) const
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:146
DiTauRecTools::DiTauOnnxDiscriminantTool::InferenceOutput::output_2
std::vector< float > output_2
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:53
DiTauRecTools::DiTauOnnxDiscriminantTool::m_maxTracks
Gaudi::Property< size_t > m_maxTracks
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:44
DiTauRecTools::DiTauOnnxDiscriminantTool::m_R_isotrack_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_isotrack_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:87
DiTauRecTools::DiTauOnnxDiscriminantTool::m_R_max_lead_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_max_lead_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:83
DiTauRecTools::DiTauOnnxDiscriminantTool::OnnxInputs::input_jet_shape
std::vector< int64_t > input_jet_shape
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:64
DiTauRecTools::DiTauOnnxDiscriminantTool::m_R_track_all_DecorKey
SG::ReadDecorHandleKey< xAOD::DiTauJetContainer > m_R_track_all_DecorKey
Definition: ools/DiTauRecTools/DiTauOnnxDiscriminantTool.h:86
DiTauRecTools::DiTauOnnxDiscriminantTool::extract_points
std::vector< float > extract_points(const std::vector< std::vector< float >> &track_features) const
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:93
DiTauRecTools::DiTauOnnxDiscriminantTool::execute
virtual StatusCode execute(const xAOD::DiTauJet &xDiTau) const override
Declare the interface that the class provides.
Definition: ools/Root/DiTauOnnxDiscriminantTool.cxx:64