ATLAS Offline Software
DiTauRec/DiTauOnnxDiscriminantTool.h
Go to the documentation of this file.
1 // Dear emacs, this is -*- c++ -*-
2 
3 /*
4  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
5 */
6 
7 #pragma once
8 
9 // EDM include(s):
10 #include "xAODTau/TauxAODHelpers.h"
11 #include "xAODTau/DiTauJet.h"
12 #include "DiTauToolBase.h"
13 #include "GaudiKernel/ToolHandle.h"
15 
22 
25 
26 #include <onnxruntime_cxx_api.h>
27 
28 
30  : public DiTauToolBase
31 {
32 public:
33 
34  DiTauOnnxDiscriminantTool( const std::string& type, const std::string& name, const IInterface * parent);
35 
37 
38  // initialize the tool
39  virtual StatusCode initialize() override;
40 
41  //finalize the tool
42  virtual StatusCode finalize() override;
43 
44  // calculate ID variables
45  virtual StatusCode execute(DiTauCandidateData * data, const EventContext& ctx) const override;
46 
47  // calculate ID variables
48  virtual StatusCode executeObj(xAOD::DiTauJet& xDiTau, const EventContext& ctx ) const override;
49 
50  // calculate the score
51  float GetDiTauObjOnnxScore(const xAOD::DiTauJet& ditau) const;
52 
53 private:
54 
55  float m_dDefault = -1234;
56 
58  TLorentzVector subjet_p4;
59  std::vector<const xAOD::TrackParticle*> vTracks;
60  std::vector<const xAOD::TrackParticle*> vIsoTracks;
61  std::vector<const xAOD::TrackParticle*> vCoreTracks;
62  const xAOD::TrackParticle* leadTrack = nullptr;
63  };
65  std::vector<const xAOD::TrackParticle*> vTracks;
66  std::vector<const xAOD::TrackParticle*> vIsoTracks;
67  int nSubjets = 0;
68  std::vector<SubjetTrackingInfo> vSubjetInfo;
69  };
70 
71  int n_subjets (const xAOD::DiTauJet& xDiTau) const;
72  float ditau_pt (const xAOD::DiTauJet& xDiTau) const;
73  float f_core (const xAOD::DiTauJet& xDiTau, int iSubjet) const;
74  float f_subjet (const xAOD::DiTauJet& xDiTau, int iSubjet) const;
75  float f_subjets (const xAOD::DiTauJet& xDiTau) const;
76  float R_max (const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo, int iSubjet) const;
77  int n_track (const xAOD::DiTauJet& xDiTau) const;
78  float R_isotrack (const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo) const;
79  float R_tracks (const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo, int iSubjet) const;
80  float mass_core (const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo, int iSubjet) const;
81  float mass_tracks (const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo, int iSubjet) const;
82  float d0_leadtrack (const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo, int iSubjet) const;
83  float f_isotracks (const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo) const;
84 
85  StatusCode getTrackingInfo(const xAOD::DiTauJet& xDiTau, DitauTrackingInfo& trackingInfo) const;
86 
87  Gaudi::Property<std::string> m_onnxModelPath {this, "onnxModelPath", "TrigTauRec/00-11-02/dev/boosted_ditau_omni_model.onnx"};
88  Gaudi::Property<size_t> m_maxTracks {this, "maxTracks", 10};
89 
90  std::unique_ptr<Ort::Env> m_ort_env;
91  std::unique_ptr<Ort::Session> m_ort_session;
92  const std::vector<std::string> m_input_node_names = {"input_features", "input_points", "input_mask", "input_jet", "input_time"};
93  const std::vector<std::string> m_output_node_names = {"output_1", "output_2"};
94 
95  struct InferenceOutput {
96  std::vector<float> output_1;
97  std::vector<float> output_2;
98  };
99 
100  struct OnnxInputs {
101  std::vector<float> input_features;
102  std::vector<int64_t> input_features_shape;
103  std::vector<float> input_points;
104  std::vector<int64_t> input_points_shape;
105  std::vector<float> input_mask;
106  std::vector<int64_t> input_mask_shape;
107  std::vector<float> input_jet;
108  std::vector<int64_t> input_jet_shape;
109  std::vector<float> input_time;
110  std::vector<int64_t> input_time_shape;
111  };
112 
113  Ort::Value create_tensor(std::vector<float> &data, const std::vector<int64_t> &shape) const;
115  std::vector<float> flatten(const std::vector<std::vector<float>> &vec_2d) const;
116  std::vector<float> extract_points(const std::vector<std::vector<float>> &track_features) const;
117  std::vector<float> create_mask(const std::vector<std::vector<float>> &track_features) const;
118 
119 };
120 
121 
DiTauOnnxDiscriminantTool::SubjetTrackingInfo::vTracks
std::vector< const xAOD::TrackParticle * > vTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:59
DiTauOnnxDiscriminantTool::InferenceOutput
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:95
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
PropertyWrapper.h
DiTauOnnxDiscriminantTool::f_isotracks
float f_isotracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:344
DiTauOnnxDiscriminantTool::m_onnxModelPath
Gaudi::Property< std::string > m_onnxModelPath
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:87
DiTauOnnxDiscriminantTool::OnnxInputs::input_time_shape
std::vector< int64_t > input_time_shape
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:110
DiTauOnnxDiscriminantTool::extract_points
std::vector< float > extract_points(const std::vector< std::vector< float >> &track_features) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:87
DiTauOnnxDiscriminantTool::executeObj
virtual StatusCode executeObj(xAOD::DiTauJet &xDiTau, const EventContext &ctx) const override
Execute - called for each Ditau jet.
Definition: src/DiTauOnnxDiscriminantTool.cxx:68
DiTauOnnxDiscriminantTool::f_subjet
float f_subjet(const xAOD::DiTauJet &xDiTau, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:253
DiTauToolBase.h
DiTauOnnxDiscriminantTool::OnnxInputs::input_mask_shape
std::vector< int64_t > input_mask_shape
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:106
DiTauOnnxDiscriminantTool::run_inference
InferenceOutput run_inference(OnnxInputs &inputs) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:111
DiTauOnnxDiscriminantTool::n_subjets
int n_subjets(const xAOD::DiTauJet &xDiTau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:235
DiTauOnnxDiscriminantTool::DitauTrackingInfo::nSubjets
int nSubjets
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:67
DiTauOnnxDiscriminantTool::SubjetTrackingInfo::vCoreTracks
std::vector< const xAOD::TrackParticle * > vCoreTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:61
DiTauOnnxDiscriminantTool::R_isotrack
float R_isotrack(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:278
DiTauOnnxDiscriminantTool::m_input_node_names
const std::vector< std::string > m_input_node_names
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:92
DiTauOnnxDiscriminantTool::d0_leadtrack
float d0_leadtrack(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:336
DiTauOnnxDiscriminantTool::m_dDefault
float m_dDefault
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:55
DiTauOnnxDiscriminantTool::OnnxInputs::input_time
std::vector< float > input_time
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:109
DiTauOnnxDiscriminantTool::create_tensor
Ort::Value create_tensor(std::vector< float > &data, const std::vector< int64_t > &shape) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:106
DiTauOnnxDiscriminantTool::flatten
std::vector< float > flatten(const std::vector< std::vector< float >> &vec_2d) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:78
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
DiTauOnnxDiscriminantTool::DitauTrackingInfo::vSubjetInfo
std::vector< SubjetTrackingInfo > vSubjetInfo
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:68
DiTauOnnxDiscriminantTool::~DiTauOnnxDiscriminantTool
virtual ~DiTauOnnxDiscriminantTool()
DiTauOnnxDiscriminantTool::f_subjets
float f_subjets(const xAOD::DiTauJet &xDiTau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:257
ReadDecorHandleKey.h
Property holding a SG store/key/clid/attr name from which a ReadDecorHandle is made.
DiTauOnnxDiscriminantTool::n_track
int n_track(const xAOD::DiTauJet &xDiTau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:274
DiTauOnnxDiscriminantTool::mass_core
float mass_core(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:310
DiTauOnnxDiscriminantTool::R_tracks
float R_tracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:295
DiTauOnnxDiscriminantTool::InferenceOutput::output_2
std::vector< float > output_2
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:97
DiTauOnnxDiscriminantTool::OnnxInputs::input_features_shape
std::vector< int64_t > input_features_shape
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:102
DiTauOnnxDiscriminantTool::GetDiTauObjOnnxScore
float GetDiTauObjOnnxScore(const xAOD::DiTauJet &ditau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:140
DiTauOnnxDiscriminantTool::OnnxInputs::input_mask
std::vector< float > input_mask
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:105
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ReadHandleKey.h
Property holding a SG store/key/clid from which a ReadHandle is made.
TauxAODHelpers.h
test_pyathena.parent
parent
Definition: test_pyathena.py:15
DiTauOnnxDiscriminantTool::OnnxInputs::input_features
std::vector< float > input_features
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:101
DiTauOnnxDiscriminantTool::finalize
virtual StatusCode finalize() override
Finalizer.
Definition: src/DiTauOnnxDiscriminantTool.cxx:49
DiTauOnnxDiscriminantTool::DitauTrackingInfo::vIsoTracks
std::vector< const xAOD::TrackParticle * > vIsoTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:66
DiTauJet.h
DiTauOnnxDiscriminantTool::OnnxInputs::input_jet_shape
std::vector< int64_t > input_jet_shape
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:108
DiTauOnnxDiscriminantTool::OnnxInputs
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:100
PathResolver.h
DiTauOnnxDiscriminantTool::m_ort_session
std::unique_ptr< Ort::Session > m_ort_session
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:91
DiTauOnnxDiscriminantTool::getTrackingInfo
StatusCode getTrackingInfo(const xAOD::DiTauJet &xDiTau, DitauTrackingInfo &trackingInfo) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:356
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
DiTauOnnxDiscriminantTool::SubjetTrackingInfo::vIsoTracks
std::vector< const xAOD::TrackParticle * > vIsoTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:60
DiTauOnnxDiscriminantTool::DiTauOnnxDiscriminantTool
DiTauOnnxDiscriminantTool(const std::string &type, const std::string &name, const IInterface *parent)
Definition: src/DiTauOnnxDiscriminantTool.cxx:20
ReadHandle.h
Handle class for reading from StoreGate.
DiTauOnnxDiscriminantTool::create_mask
std::vector< float > create_mask(const std::vector< std::vector< float >> &track_features) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:97
DiTauToolBase
The base class for all tau tools.
Definition: DiTauToolBase.h:21
DiTauOnnxDiscriminantTool::m_output_node_names
const std::vector< std::string > m_output_node_names
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:93
DiTauOnnxDiscriminantTool::mass_tracks
float mass_tracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:323
DiTauOnnxDiscriminantTool::OnnxInputs::input_points
std::vector< float > input_points
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:103
DiTauOnnxDiscriminantTool::execute
virtual StatusCode execute(DiTauCandidateData *data, const EventContext &ctx) const override
Execute - called for each Ditau candidate.
Definition: src/DiTauOnnxDiscriminantTool.cxx:57
DiTauOnnxDiscriminantTool::OnnxInputs::input_points_shape
std::vector< int64_t > input_points_shape
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:104
DiTauOnnxDiscriminantTool::R_max
float R_max(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:262
TrackParticle.h
WriteDecorHandleKey.h
DiTauOnnxDiscriminantTool::m_ort_env
std::unique_ptr< Ort::Env > m_ort_env
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:90
DiTauOnnxDiscriminantTool::OnnxInputs::input_jet
std::vector< float > input_jet
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:107
DiTauOnnxDiscriminantTool::SubjetTrackingInfo::subjet_p4
TLorentzVector subjet_p4
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:58
DiTauOnnxDiscriminantTool::SubjetTrackingInfo::leadTrack
const xAOD::TrackParticle * leadTrack
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:62
DiTauOnnxDiscriminantTool::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: src/DiTauOnnxDiscriminantTool.cxx:30
DiTauOnnxDiscriminantTool::SubjetTrackingInfo
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:57
DiTauOnnxDiscriminantTool
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:31
ReadDecorHandle.h
Handle class for reading a decoration on an object.
xAOD::DiTauJet_v1
Definition: DiTauJet_v1.h:31
DiTauOnnxDiscriminantTool::InferenceOutput::output_1
std::vector< float > output_1
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:96
DiTauCandidateData
Definition: DiTauCandidateData.h:15
DiTauOnnxDiscriminantTool::DitauTrackingInfo
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:64
xAOD::TrackParticle_v1
Class describing a TrackParticle.
Definition: TrackParticle_v1.h:43
DiTauOnnxDiscriminantTool::ditau_pt
float ditau_pt(const xAOD::DiTauJet &xDiTau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:243
DiTauOnnxDiscriminantTool::f_core
float f_core(const xAOD::DiTauJet &xDiTau, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:248
TrackParticleContainer.h
DiTauOnnxDiscriminantTool::m_maxTracks
Gaudi::Property< size_t > m_maxTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:88
DiTauOnnxDiscriminantTool::DitauTrackingInfo::vTracks
std::vector< const xAOD::TrackParticle * > vTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:65