ATLAS Offline Software
Loading...
Searching...
No Matches
DiTauRec/DiTauOnnxDiscriminantTool.h
Go to the documentation of this file.
1// Dear emacs, this is -*- c++ -*-
2
3/*
4 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
5*/
6
7#pragma once
8
9// EDM include(s):
11#include "xAODTau/DiTauJet.h"
12#include "DiTauToolBase.h"
13#include "GaudiKernel/ToolHandle.h"
15
22
25
26#include <onnxruntime_cxx_api.h>
27
28
30 : public DiTauToolBase
31{
32public:
33
34 DiTauOnnxDiscriminantTool( const std::string& type, const std::string& name, const IInterface * parent);
35
37
38 // initialize the tool
39 virtual StatusCode initialize() override;
40
41 //finalize the tool
42 virtual StatusCode finalize() override;
43
44 // calculate ID variables
45 virtual StatusCode execute(DiTauCandidateData * data, const EventContext& ctx) const override;
46
47 // calculate ID variables
48 virtual StatusCode executeObj(xAOD::DiTauJet& xDiTau, const EventContext& ctx ) const override;
49
50 // calculate the score
51 float GetDiTauObjOnnxScore(const xAOD::DiTauJet& ditau) const;
52
53private:
54
55 float m_dDefault = -1234;
56
58 TLorentzVector subjet_p4;
59 std::vector<const xAOD::TrackParticle*> vTracks;
60 std::vector<const xAOD::TrackParticle*> vIsoTracks;
61 std::vector<const xAOD::TrackParticle*> vCoreTracks;
63 };
65 std::vector<const xAOD::TrackParticle*> vTracks;
66 std::vector<const xAOD::TrackParticle*> vIsoTracks;
67 int nSubjets = 0;
68 std::vector<SubjetTrackingInfo> vSubjetInfo;
69 };
70
71 int n_subjets (const xAOD::DiTauJet& xDiTau) const;
72 float ditau_pt (const xAOD::DiTauJet& xDiTau) const;
73 float f_core (const xAOD::DiTauJet& xDiTau, int iSubjet) const;
74 float f_subjet (const xAOD::DiTauJet& xDiTau, int iSubjet) const;
75 float f_subjets (const xAOD::DiTauJet& xDiTau) const;
76 float R_max (const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo, int iSubjet) const;
77 int n_track (const xAOD::DiTauJet& xDiTau) const;
78 float R_isotrack (const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo) const;
79 float R_tracks (const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo, int iSubjet) const;
80 float mass_core (const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo, int iSubjet) const;
81 float mass_tracks (const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo, int iSubjet) const;
82 float d0_leadtrack (const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo, int iSubjet) const;
83 float f_isotracks (const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo) const;
84
85 StatusCode getTrackingInfo(const xAOD::DiTauJet& xDiTau, DitauTrackingInfo& trackingInfo) const;
86
87 Gaudi::Property<std::string> m_onnxModelPath {this, "onnxModelPath", "TrigTauRec/00-11-02/dev/boosted_ditau_omni_model.onnx"};
88 Gaudi::Property<size_t> m_maxTracks {this, "maxTracks", 10};
89
90 std::unique_ptr<Ort::Env> m_ort_env;
91 std::unique_ptr<Ort::Session> m_ort_session;
92 const std::vector<std::string> m_input_node_names = {"input_features", "input_points", "input_mask", "input_jet", "input_time"};
93 const std::vector<std::string> m_output_node_names = {"output_1", "output_2"};
94
96 std::vector<float> output_1;
97 std::vector<float> output_2;
98 };
99
100 struct OnnxInputs {
101 std::vector<float> input_features;
102 std::vector<int64_t> input_features_shape;
103 std::vector<float> input_points;
104 std::vector<int64_t> input_points_shape;
105 std::vector<float> input_mask;
106 std::vector<int64_t> input_mask_shape;
107 std::vector<float> input_jet;
108 std::vector<int64_t> input_jet_shape;
109 std::vector<float> input_time;
110 std::vector<int64_t> input_time_shape;
111 };
112
113 Ort::Value create_tensor(std::vector<float> &data, const std::vector<int64_t> &shape) const;
115 std::vector<float> flatten(const std::vector<std::vector<float>> &vec_2d) const;
116 std::vector<float> extract_points(const std::vector<std::vector<float>> &track_features) const;
117 std::vector<float> create_mask(const std::vector<std::vector<float>> &track_features) const;
118
119};
120
121
Property holding a SG store/key/clid/attr name from which a ReadDecorHandle is made.
Handle class for reading a decoration on an object.
Property holding a SG store/key/clid from which a ReadHandle is made.
Handle class for reading from StoreGate.
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
float R_max(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
DiTauOnnxDiscriminantTool(const std::string &type, const std::string &name, const IInterface *parent)
InferenceOutput run_inference(OnnxInputs &inputs) const
StatusCode getTrackingInfo(const xAOD::DiTauJet &xDiTau, DitauTrackingInfo &trackingInfo) const
int n_track(const xAOD::DiTauJet &xDiTau) const
float R_isotrack(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo) const
float mass_tracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
float ditau_pt(const xAOD::DiTauJet &xDiTau) const
virtual StatusCode finalize() override
Finalizer.
Ort::Value create_tensor(std::vector< float > &data, const std::vector< int64_t > &shape) const
std::vector< float > create_mask(const std::vector< std::vector< float > > &track_features) const
int n_subjets(const xAOD::DiTauJet &xDiTau) const
const std::vector< std::string > m_output_node_names
float mass_core(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
std::unique_ptr< Ort::Session > m_ort_session
float R_tracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
float f_subjets(const xAOD::DiTauJet &xDiTau) const
float f_core(const xAOD::DiTauJet &xDiTau, int iSubjet) const
std::vector< float > extract_points(const std::vector< std::vector< float > > &track_features) const
std::vector< float > flatten(const std::vector< std::vector< float > > &vec_2d) const
virtual StatusCode execute(DiTauCandidateData *data, const EventContext &ctx) const override
Execute - called for each Ditau candidate.
virtual StatusCode initialize() override
Tool initializer.
virtual StatusCode executeObj(xAOD::DiTauJet &xDiTau, const EventContext &ctx) const override
Execute - called for each Ditau jet.
float d0_leadtrack(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
const std::vector< std::string > m_input_node_names
float GetDiTauObjOnnxScore(const xAOD::DiTauJet &ditau) const
float f_subjet(const xAOD::DiTauJet &xDiTau, int iSubjet) const
virtual ~DiTauOnnxDiscriminantTool()
float f_isotracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo) const
Gaudi::Property< std::string > m_onnxModelPath
DiTauToolBase(const std::string &type, const std::string &name, const IInterface *parent)
TrackParticle_v1 TrackParticle
Reference the current persistent version:
DiTauJet_v1 DiTauJet
Definition of the current version.
Definition DiTauJet.h:17
std::vector< const xAOD::TrackParticle * > vIsoTracks
std::vector< const xAOD::TrackParticle * > vIsoTracks
std::vector< const xAOD::TrackParticle * > vCoreTracks