ATLAS Offline Software
Loading...
Searching...
No Matches
DiTauOnnxDiscriminantTool Class Reference

#include <DiTauOnnxDiscriminantTool.h>

Inheritance diagram for DiTauOnnxDiscriminantTool:
Collaboration diagram for DiTauOnnxDiscriminantTool:

Classes

struct  DitauTrackingInfo
struct  InferenceOutput
struct  OnnxInputs
struct  SubjetTrackingInfo

Public Member Functions

 DiTauOnnxDiscriminantTool (const std::string &type, const std::string &name, const IInterface *parent)
virtual ~DiTauOnnxDiscriminantTool ()
virtual StatusCode initialize () override
 Tool initializer.
virtual StatusCode finalize () override
 Finalizer.
virtual StatusCode execute (DiTauCandidateData *data, const EventContext &ctx) const override
 Execute - called for each Ditau candidate.
virtual StatusCode executeObj (xAOD::DiTauJet &xDiTau, const EventContext &ctx) const override
 Execute - called for each Ditau jet.
float GetDiTauObjOnnxScore (const xAOD::DiTauJet &ditau) const
virtual StatusCode eventInitialize (DiTauCandidateData *data)
 Event initializer - called at the beginning of each event.
template<class T>
bool retrieveTool (T &tool)
 Convenience functions to handle storegate objects.
ServiceHandle< StoreGateSvc > & evtStore ()
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.
const ServiceHandle< StoreGateSvc > & detStore () const
 The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.
virtual StatusCode sysInitialize () override
 Perform system initialization for an algorithm.
virtual StatusCode sysStart () override
 Handle START transition.
virtual std::vector< Gaudi::DataHandle * > inputHandles () const override
 Return this algorithm's input handles.
virtual std::vector< Gaudi::DataHandle * > outputHandles () const override
 Return this algorithm's output handles.
Gaudi::Details::PropertyBase & declareProperty (Gaudi::Property< T, V, H > &t)
void updateVHKA (Gaudi::Details::PropertyBase &)
MsgStream & msg () const
bool msgLvl (const MSG::Level lvl) const

Static Public Member Functions

static const InterfaceID & interfaceID ()
 InterfaceID implementation needed for ToolHandle.

Protected Member Functions

void renounceArray (SG::VarHandleKeyArray &handlesArray)
 remove all handles from I/O resolution
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce (T &h)
void extraDeps_update_handler (Gaudi::Details::PropertyBase &ExtraDeps)
 Add StoreName to extra input/output deps as needed.

Private Types

typedef ServiceHandle< StoreGateSvcStoreGateSvc_t

Private Member Functions

int n_subjets (const xAOD::DiTauJet &xDiTau) const
float ditau_pt (const xAOD::DiTauJet &xDiTau) const
float f_core (const xAOD::DiTauJet &xDiTau, int iSubjet) const
float f_subjet (const xAOD::DiTauJet &xDiTau, int iSubjet) const
float f_subjets (const xAOD::DiTauJet &xDiTau) const
float R_max (const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
int n_track (const xAOD::DiTauJet &xDiTau) const
float R_isotrack (const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo) const
float R_tracks (const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
float mass_core (const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
float mass_tracks (const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
float d0_leadtrack (const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
float f_isotracks (const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo) const
StatusCode getTrackingInfo (const xAOD::DiTauJet &xDiTau, DitauTrackingInfo &trackingInfo) const
Ort::Value create_tensor (std::vector< float > &data, const std::vector< int64_t > &shape) const
InferenceOutput run_inference (OnnxInputs &inputs) const
std::vector< float > flatten (const std::vector< std::vector< float > > &vec_2d) const
std::vector< float > extract_points (const std::vector< std::vector< float > > &track_features) const
std::vector< float > create_mask (const std::vector< std::vector< float > > &track_features) const
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
 specialization for handling Gaudi::Property<SG::VarHandleKey>

Private Attributes

float m_dDefault = -1234
Gaudi::Property< std::string > m_onnxModelPath {this, "onnxModelPath", "TrigTauRec/00-11-02/dev/boosted_ditau_omni_model.onnx"}
Gaudi::Property< size_t > m_maxTracks {this, "maxTracks", 10}
std::unique_ptr< Ort::Env > m_ort_env
std::unique_ptr< Ort::Session > m_ort_session
const std::vector< std::string > m_input_node_names = {"input_features", "input_points", "input_mask", "input_jet", "input_time"}
const std::vector< std::string > m_output_node_names = {"output_1", "output_2"}
StoreGateSvc_t m_evtStore
 Pointer to StoreGate (event store by default)
StoreGateSvc_t m_detStore
 Pointer to StoreGate (detector store by default)
std::vector< SG::VarHandleKeyArray * > m_vhka
bool m_varHandleArraysDeclared

Detailed Description

Definition at line 29 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

Member Typedef Documentation

◆ StoreGateSvc_t

typedef ServiceHandle<StoreGateSvc> AthCommonDataStore< AthCommonMsg< AlgTool > >::StoreGateSvc_t
privateinherited

Definition at line 388 of file AthCommonDataStore.h.

Constructor & Destructor Documentation

◆ DiTauOnnxDiscriminantTool()

DiTauOnnxDiscriminantTool::DiTauOnnxDiscriminantTool ( const std::string & type,
const std::string & name,
const IInterface * parent )

Definition at line 20 of file src/DiTauOnnxDiscriminantTool.cxx.

20 :
21 DiTauToolBase(type, name, parent)
22{
23 declareInterface<DiTauToolBase > (this);
24}
DiTauToolBase(const std::string &type, const std::string &name, const IInterface *parent)

◆ ~DiTauOnnxDiscriminantTool()

DiTauOnnxDiscriminantTool::~DiTauOnnxDiscriminantTool ( )
virtualdefault

Member Function Documentation

◆ create_mask()

std::vector< float > DiTauOnnxDiscriminantTool::create_mask ( const std::vector< std::vector< float > > & track_features) const
private

Definition at line 97 of file src/DiTauOnnxDiscriminantTool.cxx.

97 {
98 std::vector<float> mask;
99 mask.reserve(track_features.size());
100 std::transform(track_features.begin(), track_features.end(), std::back_inserter(mask), [](const auto &track) {
101 return std::abs(track[2]) > 1e-6 ? 1.0f : 0.0f;
102 });
103 return mask;
104}

◆ create_tensor()

Ort::Value DiTauOnnxDiscriminantTool::create_tensor ( std::vector< float > & data,
const std::vector< int64_t > & shape ) const
private

Definition at line 106 of file src/DiTauOnnxDiscriminantTool.cxx.

106 {
107 Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
108 return Ort::Value::CreateTensor<float>(memory_info, data.data(), data.size(),shape.data(), shape.size());
109}
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11

◆ d0_leadtrack()

float DiTauOnnxDiscriminantTool::d0_leadtrack ( const xAOD::DiTauJet & xDiTau,
const DitauTrackingInfo & ditauInfo,
int iSubjet ) const
private

Definition at line 336 of file src/DiTauOnnxDiscriminantTool.cxx.

336 {
337 SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
338 if (!subjetInfo.leadTrack) {
339 return m_dDefault;
340 }
341 return subjetInfo.leadTrack->d0();
342}
float d0() const
Returns the parameter.

◆ declareGaudiProperty()

Gaudi::Details::PropertyBase & AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T, V, H > & hndl,
const SG::VarHandleKeyType &  )
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKey>

Definition at line 156 of file AthCommonDataStore.h.

158 {
160 hndl.value(),
161 hndl.documentation());
162
163 }
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T, V, H > &t)

◆ declareProperty()

Gaudi::Details::PropertyBase & AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( Gaudi::Property< T, V, H > & t)
inlineinherited

Definition at line 145 of file AthCommonDataStore.h.

145 {
146 typedef typename SG::HandleClassifier<T>::type htype;
148 }
Gaudi::Details::PropertyBase & declareGaudiProperty(Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
specialization for handling Gaudi::Property<SG::VarHandleKey>

◆ detStore()

const ServiceHandle< StoreGateSvc > & AthCommonDataStore< AthCommonMsg< AlgTool > >::detStore ( ) const
inlineinherited

The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 95 of file AthCommonDataStore.h.

◆ ditau_pt()

float DiTauOnnxDiscriminantTool::ditau_pt ( const xAOD::DiTauJet & xDiTau) const
private

Definition at line 243 of file src/DiTauOnnxDiscriminantTool.cxx.

244{
245 return xDiTau.subjetPt(0)+xDiTau.subjetPt(1);
246}
float subjetPt(unsigned int numSubjet) const

◆ eventInitialize()

StatusCode DiTauToolBase::eventInitialize ( DiTauCandidateData * data)
virtualinherited

Event initializer - called at the beginning of each event.

Definition at line 32 of file DiTauToolBase.cxx.

33{
34 return StatusCode::SUCCESS;
35}

◆ evtStore()

ServiceHandle< StoreGateSvc > & AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( )
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 85 of file AthCommonDataStore.h.

◆ execute()

StatusCode DiTauOnnxDiscriminantTool::execute ( DiTauCandidateData * data,
const EventContext & ctx ) const
overridevirtual

Execute - called for each Ditau candidate.

Reimplemented from DiTauToolBase.

Definition at line 57 of file src/DiTauOnnxDiscriminantTool.cxx.

58{
59 static const SG::Accessor<float> omni_scoreDec("omni_score");
60 xAOD::DiTauJet* xDitau = data->xAODDiTau;
61 ATH_MSG_DEBUG("Inferencing omni DiTau ID score...");
62 float score = GetDiTauObjOnnxScore(*xDitau);
63 ATH_MSG_DEBUG("DiTau ID score: " << score);
64 omni_scoreDec(*xDitau) = score;
65 return StatusCode::SUCCESS;
66}
#define ATH_MSG_DEBUG(x)
float GetDiTauObjOnnxScore(const xAOD::DiTauJet &ditau) const
DiTauJet_v1 DiTauJet
Definition of the current version.
Definition DiTauJet.h:17

◆ executeObj()

StatusCode DiTauOnnxDiscriminantTool::executeObj ( xAOD::DiTauJet & xDiTau,
const EventContext & ctx ) const
overridevirtual

Execute - called for each Ditau jet.

Reimplemented from DiTauToolBase.

Definition at line 68 of file src/DiTauOnnxDiscriminantTool.cxx.

69{
70 static const SG::Accessor<float> omni_scoreDec("omni_score");
71 ATH_MSG_DEBUG("Inferencing omni DiTau ID score...");
72 float score = GetDiTauObjOnnxScore(xDiTau);
73 ATH_MSG_DEBUG("DiTau ID score: " << score);
74 omni_scoreDec(xDiTau) = score;
75 return StatusCode::SUCCESS;
76}

◆ extract_points()

std::vector< float > DiTauOnnxDiscriminantTool::extract_points ( const std::vector< std::vector< float > > & track_features) const
private

Definition at line 87 of file src/DiTauOnnxDiscriminantTool.cxx.

87 {
88 std::vector<float> points;
89 points.reserve(track_features.size() * 2);
90 for (const auto &track : track_features) {
91 points.push_back(track[0]); // delta_eta
92 points.push_back(track[1]); // delta_phi
93 }
94 return points;
95}

◆ extraDeps_update_handler()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::extraDeps_update_handler ( Gaudi::Details::PropertyBase & ExtraDeps)
protectedinherited

Add StoreName to extra input/output deps as needed.

use the logic of the VarHandleKey to parse the DataObjID keys supplied via the ExtraInputs and ExtraOuputs Properties to add the StoreName if it's not explicitly given

◆ f_core()

float DiTauOnnxDiscriminantTool::f_core ( const xAOD::DiTauJet & xDiTau,
int iSubjet ) const
private

Definition at line 248 of file src/DiTauOnnxDiscriminantTool.cxx.

249{
250 return xDiTau.fCore(iSubjet);
251}
float fCore(unsigned int numSubjet) const

◆ f_isotracks()

float DiTauOnnxDiscriminantTool::f_isotracks ( const xAOD::DiTauJet & xDiTau,
const DitauTrackingInfo & ditauInfo ) const
private

Definition at line 344 of file src/DiTauOnnxDiscriminantTool.cxx.

344 {
345 float iso_pt = 0;
346 for (const xAOD::TrackParticle* xTrack: ditauInfo.vIsoTracks) {
347 iso_pt += xTrack->pt();
348 }
349 if( xDiTau.pt() == 0.){
350 return m_dDefault;
351 } else {
352 return iso_pt / xDiTau.pt();
353 }
354}
virtual double pt() const
The transverse momentum ( ) of the particle.
virtual double pt() const override final
The transverse momentum ( ) of the particle.
TrackParticle_v1 TrackParticle
Reference the current persistent version:

◆ f_subjet()

float DiTauOnnxDiscriminantTool::f_subjet ( const xAOD::DiTauJet & xDiTau,
int iSubjet ) const
private

Definition at line 253 of file src/DiTauOnnxDiscriminantTool.cxx.

253 {
254 return xDiTau.subjetPt(iSubjet) / xDiTau.pt();
255}

◆ f_subjets()

float DiTauOnnxDiscriminantTool::f_subjets ( const xAOD::DiTauJet & xDiTau) const
private

Definition at line 257 of file src/DiTauOnnxDiscriminantTool.cxx.

258{
259 return (xDiTau.subjetPt(0) + xDiTau.subjetPt(1))/ xDiTau.pt();
260}

◆ finalize()

StatusCode DiTauOnnxDiscriminantTool::finalize ( )
overridevirtual

Finalizer.

Reimplemented from DiTauToolBase.

Definition at line 49 of file src/DiTauOnnxDiscriminantTool.cxx.

50{
51 ATH_MSG_INFO( "Finalizing DiTauOnnxDiscriminantTool" );
52 m_ort_session.reset();
53 m_ort_env.reset();
54 return StatusCode::SUCCESS;
55}
#define ATH_MSG_INFO(x)
std::unique_ptr< Ort::Session > m_ort_session

◆ flatten()

std::vector< float > DiTauOnnxDiscriminantTool::flatten ( const std::vector< std::vector< float > > & vec_2d) const
private

Definition at line 78 of file src/DiTauOnnxDiscriminantTool.cxx.

78 {
79 std::vector<float> flattened;
80 flattened.reserve(vec_2d.size() * (vec_2d.empty() ? 0 : vec_2d[0].size()));
81 for (const auto &inner : vec_2d) {
82 flattened.insert(flattened.end(), inner.begin(), inner.end());
83 }
84 return flattened;
85}

◆ GetDiTauObjOnnxScore()

float DiTauOnnxDiscriminantTool::GetDiTauObjOnnxScore ( const xAOD::DiTauJet & ditau) const

Definition at line 140 of file src/DiTauOnnxDiscriminantTool.cxx.

140 {
141
142 // do the calculation only for ditau with at least 2 subjets
143 if(n_subjets(ditau)<2){
144 return m_dDefault;
145 }
146
147 DitauTrackingInfo ditauTrackingInfo;
148 if(!(getTrackingInfo(ditau, ditauTrackingInfo))){
149 return m_dDefault;
150 }
151
152 // Accessors for reading the necessary features from the xAOD::TrackParticle object
153 static const SG::ConstAccessor< uint8_t > numberOfInrmstPxlLyrHitsAcc ("numberOfInnermostPixelLayerHits");
154 static const SG::ConstAccessor< uint8_t > numberOfPixelHitsAcc ("numberOfPixelHits");
155 static const SG::ConstAccessor< uint8_t > numberOfSCTHitsAcc ("numberOfSCTHits");
156 static const SG::ConstAccessor< float > z0Acc ("z0");
157 static const SG::ConstAccessor< float > d0Acc ("d0");
158 // Input features for Ditau tagger ONNX model
159 std::vector<float> jet_vars = {
160 R_max(ditau, ditauTrackingInfo, 0),
161 R_max(ditau, ditauTrackingInfo, 1),
162 R_tracks(ditau, ditauTrackingInfo, 1),
163 R_isotrack(ditau, ditauTrackingInfo),
164 d0_leadtrack(ditau, ditauTrackingInfo, 0),
165 d0_leadtrack(ditau, ditauTrackingInfo, 1),
166 f_core(ditau,0),
167 f_core(ditau,1),
168 f_subjet(ditau,1),
169 f_subjets(ditau),
170 f_isotracks(ditau, ditauTrackingInfo),
171 mass_core(ditau, ditauTrackingInfo, 0),
172 mass_core(ditau, ditauTrackingInfo, 1),
173 mass_tracks(ditau, ditauTrackingInfo, 0),
174 static_cast<float>( n_track(ditau)),
175 };
176 std::vector<int64_t> jet_shape = {1, static_cast<int64_t>(jet_vars.size())};
177
178 const TrackParticleLinks_t &vTauTracks = ditau.trackLinks();
179 std::vector<std::vector<float>> track_features(m_maxTracks, std::vector<float>(11, 0.0f));
180
181 float jet_eta = ditau.eta();
182 float jet_phi = ditau.phi();
183 size_t num_tracks = std::min(static_cast<size_t>(m_maxTracks), vTauTracks.size());
184
185 for (size_t i = 0; i < num_tracks; ++i) {
186 const ElementLink<xAOD::TrackParticleContainer> &trackLink = vTauTracks[i];
187 if (!trackLink.isValid()) continue;
188 const xAOD::TrackParticle *xTrack = *trackLink;
189 float track_eta = xTrack->eta();
190 float track_phi = xTrack->phi();
191 float delta_eta = track_eta - jet_eta;
192 float delta_phi = std::remainder(track_phi - jet_phi, 2 * M_PI);
193 float delta_R = std::hypot(delta_eta, delta_phi);
194 float track_pt = static_cast<float>(xTrack->pt());
195 float pt_log = std::log(track_pt + 1e-8f);
196 float jet_pt = ditau_pt(ditau); //ditau_ptAcc(ditau);
197 float pt_ratio = track_pt / jet_pt;
198 float pt_ratio_log = (pt_ratio <= 1.0f) ? std::log(1.0f - pt_ratio + 1e-8f) : 0.0f;
199 float track_charge = xTrack->charge();
200
201 track_features[i] = {
202 delta_eta,
203 delta_phi,
204 pt_log,
205 d0Acc(*xTrack),
206 pt_ratio_log,
207 z0Acc(*xTrack),
208 delta_R,
209 static_cast<float>(numberOfInrmstPxlLyrHitsAcc(*xTrack)),
210 static_cast<float>(numberOfPixelHitsAcc(*xTrack)),
211 static_cast<float>(numberOfSCTHitsAcc(*xTrack)),
212 track_charge
213 };
214 }
215 std::vector<int64_t> track_shape = {1, static_cast<int64_t>(m_maxTracks), 11};
216
217 // Actual ONNX inference
219 flatten(track_features),
220 track_shape,
221 extract_points(track_features),
222 {1, track_shape[1], 2},
223 create_mask(track_features),
224 {1, track_shape[1]},
225 std::move(jet_vars),
226 std::move(jet_shape),
227 {0.0f},
228 {1, 1}
229 };
230 auto output = run_inference(inputs);
231 return output.output_1[1];
232}
#define M_PI
float R_max(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
InferenceOutput run_inference(OnnxInputs &inputs) const
StatusCode getTrackingInfo(const xAOD::DiTauJet &xDiTau, DitauTrackingInfo &trackingInfo) const
int n_track(const xAOD::DiTauJet &xDiTau) const
float R_isotrack(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo) const
float mass_tracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
float ditau_pt(const xAOD::DiTauJet &xDiTau) const
std::vector< float > create_mask(const std::vector< std::vector< float > > &track_features) const
int n_subjets(const xAOD::DiTauJet &xDiTau) const
float mass_core(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
float R_tracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
float f_subjets(const xAOD::DiTauJet &xDiTau) const
float f_core(const xAOD::DiTauJet &xDiTau, int iSubjet) const
std::vector< float > extract_points(const std::vector< std::vector< float > > &track_features) const
std::vector< float > flatten(const std::vector< std::vector< float > > &vec_2d) const
float d0_leadtrack(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
float f_subjet(const xAOD::DiTauJet &xDiTau, int iSubjet) const
float f_isotracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo) const
virtual double eta() const
The pseudorapidity ( ) of the particle.
virtual double phi() const
The azimuthal angle ( ) of the particle.
const TrackParticleLinks_t & trackLinks() const
virtual double phi() const override final
The azimuthal angle ( ) of the particle (has range to .)
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
float charge() const
Returns the charge.
bool pt_log(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
delta_phi(phi1, phi2)
Definition eFEXNTuple.py:14
delta_R(eta1, phi1, eta2, phi2)
Definition eFEXNTuple.py:20
output
Definition merge.py:16
std::vector< ElementLink< xAOD::TrackParticleContainer > > TrackParticleLinks_t

◆ getTrackingInfo()

StatusCode DiTauOnnxDiscriminantTool::getTrackingInfo ( const xAOD::DiTauJet & xDiTau,
DitauTrackingInfo & trackingInfo ) const
private

Definition at line 356 of file src/DiTauOnnxDiscriminantTool.cxx.

356 {
357 static const SG::ConstAccessor<std::vector<ElementLink<xAOD::TrackParticleContainer>>> trackLinksAcc("trackLinks");
358 static const SG::ConstAccessor<std::vector<ElementLink<xAOD::TrackParticleContainer>>> isoTrackLinksAcc("isoTrackLinks");
359 static const SG::ConstAccessor<float> R_subjetAcc("R_subjet");
360 static const SG::ConstAccessor<float> R_coreAcc("R_core");
361
362
363 if (!trackLinksAcc.isAvailable(xDiTau) || !isoTrackLinksAcc.isAvailable(xDiTau)) {
364 ATH_MSG_WARNING("Track " << (!trackLinksAcc.isAvailable(xDiTau) ? "DiTauJet.trackLinks" : "DiTauJet.isoTrackLinks") << " links not available.");
365 return StatusCode::FAILURE;
366 }
367
368 int nSubjets = n_subjets(xDiTau);
369 float Rsubjet = R_subjetAcc(xDiTau);
370 float RCore = R_coreAcc(xDiTau);
371
372 trackingInfo.nSubjets = nSubjets;
373 trackingInfo.vSubjetInfo.clear();
374 trackingInfo.vIsoTracks.clear();
375 trackingInfo.vTracks.clear();
376
377 // Get the track links from the DiTauJet and store them in the tracking info
378 std::vector<ElementLink<xAOD::TrackParticleContainer>> isoTrackLinks = xDiTau.isoTrackLinks();
379 for (const auto &trackLink: isoTrackLinks) {
380 if (!trackLink.isValid()) {
381 ATH_MSG_WARNING("Iso track link is not valid");
382 continue;
383 }
384 const xAOD::TrackParticle* xTrack = *trackLink;
385 trackingInfo.vIsoTracks.push_back(xTrack);
386 }
387 std::vector<ElementLink<xAOD::TrackParticleContainer>> trackLinks = xDiTau.trackLinks();
388 for (const auto &trackLink : trackLinks) {
389 if (!trackLink.isValid()) {
390 ATH_MSG_WARNING("track link is not valid");
391 continue;
392 }
393 const xAOD::TrackParticle* xTrack = *trackLink;
394 trackingInfo.vTracks.push_back(xTrack);
395 }
396 // store subjet p4
397 for (int i=0; i<nSubjets; ++i){
398 SubjetTrackingInfo subjetTrackingInfo;
399 TLorentzVector subjet_p4 = TLorentzVector();
400 subjet_p4.SetPtEtaPhiE( xDiTau.subjetPt(i), xDiTau.subjetEta(i), xDiTau.subjetPhi(i), xDiTau.subjetE(i));
401 subjetTrackingInfo.subjet_p4 = subjet_p4;
402 trackingInfo.vSubjetInfo.push_back(subjetTrackingInfo);
403 }
404 for (const auto track : trackingInfo.vTracks) {
405 float dRMin = 999;
406 int inSubjet = -1;
407 for (int i=0; i<nSubjets; ++i){
408 float dRTrackSubjet = trackingInfo.vSubjetInfo[i].subjet_p4.DeltaR(track->p4());
409 if (dRTrackSubjet < Rsubjet && dRTrackSubjet < dRMin){
410 dRMin = dRTrackSubjet;
411 inSubjet = i;
412 }
413 }
414 if (inSubjet >= 0){
415 trackingInfo.vSubjetInfo[inSubjet].vTracks.push_back(track);
416 }
417 }
418 // find leading track in subjets
419 for (int i=0; i<nSubjets; ++i){
420 float ptLeadTrack = 0;
421 for (const auto track : trackingInfo.vSubjetInfo[i].vTracks){
422 if (track->pt() > ptLeadTrack){
423 ptLeadTrack = track->pt();
424 trackingInfo.vSubjetInfo[i].leadTrack = track;
425 }
426 }
427 }
428 // find core track in subjets
429 for (int i=0; i<nSubjets; ++i){
430 for (const auto track : trackingInfo.vSubjetInfo[i].vTracks){
431 auto subjetTrackingInfo = trackingInfo.vSubjetInfo[i];
432 if (subjetTrackingInfo.subjet_p4.DeltaR(track->p4()) < RCore){
433 trackingInfo.vSubjetInfo[i].vCoreTracks.push_back(track);
434 }
435 }
436 }
437 //find isotracks in subjets
438 for (const auto track : trackingInfo.vIsoTracks){
439 float RIso = 0.4;
440 float dRMin = 999;
441 int inSubjet = -1;
442 for (int i=0; i<nSubjets; ++i){
443 float dRTrackSubjet = trackingInfo.vSubjetInfo[i].subjet_p4.DeltaR(track->p4());
444 if (dRTrackSubjet > Rsubjet && dRTrackSubjet < RIso && dRTrackSubjet < dRMin){
445 dRMin = dRTrackSubjet;
446 inSubjet = i;
447 }
448 }
449 if (inSubjet >= 0){
450 trackingInfo.vSubjetInfo[inSubjet].vIsoTracks.push_back(track);
451 }
452 }
453 return StatusCode::SUCCESS;
454}
#define ATH_MSG_WARNING(x)
const TrackParticleLinks_t & isoTrackLinks() const
float subjetEta(unsigned int numSubjet) const
float subjetE(unsigned int numSubjet) const
float subjetPhi(unsigned int numSubjet) const

◆ initialize()

StatusCode DiTauOnnxDiscriminantTool::initialize ( )
overridevirtual

Tool initializer.

Reimplemented from DiTauToolBase.

Definition at line 30 of file src/DiTauOnnxDiscriminantTool.cxx.

31{
32 ATH_MSG_INFO( "Initializing DiTauOnnxDiscriminantTool" );
33 ATH_MSG_INFO( "onnxModelPath: " << m_onnxModelPath );
34
36 if (model_path.empty()) {
37 ATH_MSG_ERROR("Could not find model file: " << m_onnxModelPath);
38 return StatusCode::FAILURE;
39 }
40 m_ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "OnnxUtil");
41 Ort::SessionOptions session_options;
42 session_options.SetIntraOpNumThreads(1);
43 session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
44 session_options.DisableCpuMemArena();
45 m_ort_session = std::make_unique<Ort::Session>(*m_ort_env, model_path.c_str(), session_options);
46 return StatusCode::SUCCESS;
47}
#define ATH_MSG_ERROR(x)
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Gaudi::Property< std::string > m_onnxModelPath

◆ inputHandles()

virtual std::vector< Gaudi::DataHandle * > AthCommonDataStore< AthCommonMsg< AlgTool > >::inputHandles ( ) const
overridevirtualinherited

Return this algorithm's input handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ interfaceID()

const InterfaceID & DiTauToolBase::interfaceID ( )
staticinherited

InterfaceID implementation needed for ToolHandle.

Definition at line 9 of file DiTauToolBase.cxx.

9 {
10 return DiTauToolBaseID;
11}
static const InterfaceID DiTauToolBaseID("DiTauToolBase", 1, 0)

◆ mass_core()

float DiTauOnnxDiscriminantTool::mass_core ( const xAOD::DiTauJet & xDiTau,
const DitauTrackingInfo & ditauInfo,
int iSubjet ) const
private

Definition at line 310 of file src/DiTauOnnxDiscriminantTool.cxx.

310 {
311 TLorentzVector allCoreTracks_p4;
312 SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
313 for (const xAOD::TrackParticle* xTrack: subjetInfo.vCoreTracks) {
314 allCoreTracks_p4 += xTrack->p4();
315 }
316 float mass = allCoreTracks_p4.M();
317 if (mass < 0) {
318 return m_dDefault;
319 }
320 return mass;
321}
virtual FourMom_t p4() const override final
The full 4-momentum of the particle.

◆ mass_tracks()

float DiTauOnnxDiscriminantTool::mass_tracks ( const xAOD::DiTauJet & xDiTau,
const DitauTrackingInfo & ditauInfo,
int iSubjet ) const
private

Definition at line 323 of file src/DiTauOnnxDiscriminantTool.cxx.

323 {
324 TLorentzVector allTracks_p4;
325 SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
326 for (const xAOD::TrackParticle* xTrack: subjetInfo.vTracks) {
327 allTracks_p4 += xTrack->p4();
328 }
329 float mass = allTracks_p4.M();
330 if (mass < 0) {
331 return m_dDefault;
332 }
333 return mass;
334}

◆ msg()

MsgStream & AthCommonMsg< AlgTool >::msg ( ) const
inlineinherited

Definition at line 24 of file AthCommonMsg.h.

24 {
25 return this->msgStream();
26 }

◆ msgLvl()

bool AthCommonMsg< AlgTool >::msgLvl ( const MSG::Level lvl) const
inlineinherited

Definition at line 30 of file AthCommonMsg.h.

30 {
31 return this->msgLevel(lvl);
32 }

◆ n_subjets()

int DiTauOnnxDiscriminantTool::n_subjets ( const xAOD::DiTauJet & xDiTau) const
private

Definition at line 235 of file src/DiTauOnnxDiscriminantTool.cxx.

235 {
236 int nSubjet = 0;
237 while (xDiTau.subjetPt(nSubjet) > 0. ){
238 nSubjet++;
239 }
240 return nSubjet;
241}

◆ n_track()

int DiTauOnnxDiscriminantTool::n_track ( const xAOD::DiTauJet & xDiTau) const
private

Definition at line 274 of file src/DiTauOnnxDiscriminantTool.cxx.

274 {
275 return xDiTau.nTracks();
276}
size_t nTracks() const

◆ outputHandles()

virtual std::vector< Gaudi::DataHandle * > AthCommonDataStore< AthCommonMsg< AlgTool > >::outputHandles ( ) const
overridevirtualinherited

Return this algorithm's output handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ R_isotrack()

float DiTauOnnxDiscriminantTool::R_isotrack ( const xAOD::DiTauJet & xDiTau,
const DitauTrackingInfo & ditauInfo ) const
private

Definition at line 278 of file src/DiTauOnnxDiscriminantTool.cxx.

279{
280 float R_sum = 0;
281 float pt = 0;
282 for (int i = 0; i < 2; i++) {
283 SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(i);
284 for (const xAOD::TrackParticle* xTrack: subjetInfo.vIsoTracks) {
285 R_sum += subjetInfo.subjet_p4.DeltaR(xTrack->p4()) * xTrack->pt();
286 pt += xTrack->pt();
287 }
288 }
289 if (pt == 0) {
290 return m_dDefault;
291 }
292 return R_sum / pt;
293}

◆ R_max()

float DiTauOnnxDiscriminantTool::R_max ( const xAOD::DiTauJet & xDiTau,
const DitauTrackingInfo & ditauInfo,
int iSubjet ) const
private

Definition at line 262 of file src/DiTauOnnxDiscriminantTool.cxx.

263{
264 const SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
265 float Rmax = 0;
266 for (const xAOD::TrackParticle* xTrack: subjetInfo.vTracks) {
267 if (subjetInfo.subjet_p4.DeltaR(xTrack->p4()) > Rmax) {
268 Rmax = subjetInfo.subjet_p4.DeltaR(xTrack->p4());
269 }
270 }
271 return Rmax;
272}

◆ R_tracks()

float DiTauOnnxDiscriminantTool::R_tracks ( const xAOD::DiTauJet & xDiTau,
const DitauTrackingInfo & ditauInfo,
int iSubjet ) const
private

Definition at line 295 of file src/DiTauOnnxDiscriminantTool.cxx.

295 {
296 float R_sum = 0;
297 float pt = 0;
298
299 SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
300 for (const xAOD::TrackParticle* xTrack: subjetInfo.vTracks) {
301 R_sum += subjetInfo.subjet_p4.DeltaR(xTrack->p4()) * xTrack->pt();
302 pt += xTrack->pt();
303 }
304 if (pt == 0) {
305 return m_dDefault;
306 }
307 return R_sum / pt;
308}

◆ renounce()

std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > AthCommonDataStore< AthCommonMsg< AlgTool > >::renounce ( T & h)
inlineprotectedinherited

Definition at line 380 of file AthCommonDataStore.h.

381 {
382 h.renounce();
384 }
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce(T &h)

◆ renounceArray()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::renounceArray ( SG::VarHandleKeyArray & handlesArray)
inlineprotectedinherited

remove all handles from I/O resolution

Definition at line 364 of file AthCommonDataStore.h.

364 {
366 }

◆ retrieveTool()

template<class T>
bool DiTauToolBase::retrieveTool ( T & tool)
inherited

Convenience functions to handle storegate objects.

Definition at line 59 of file DiTauToolBase.cxx.

59 {
60 if (tool.retrieve().isFailure()) {
61 ATH_MSG_FATAL("Failed to retrieve tool " << tool);
62 return false;
63 } else {
64 ATH_MSG_VERBOSE("Retrieved tool " << tool);
65 }
66 return true;
67}
#define ATH_MSG_FATAL(x)
#define ATH_MSG_VERBOSE(x)

◆ run_inference()

DiTauOnnxDiscriminantTool::InferenceOutput DiTauOnnxDiscriminantTool::run_inference ( OnnxInputs & inputs) const
private

Definition at line 111 of file src/DiTauOnnxDiscriminantTool.cxx.

111 {
112 std::vector<Ort::Value> input_tensors;
113 input_tensors.reserve(m_input_node_names.size());
114 input_tensors.emplace_back(create_tensor(inputs.input_features, inputs.input_features_shape));
115 input_tensors.emplace_back(create_tensor(inputs.input_points, inputs.input_points_shape));
116 input_tensors.emplace_back(create_tensor(inputs.input_mask, inputs.input_mask_shape));
117 input_tensors.emplace_back(create_tensor(inputs.input_jet, inputs.input_jet_shape));
118 input_tensors.emplace_back(create_tensor(inputs.input_time, inputs.input_time_shape));
119
120 std::vector<const char *> input_node_names;
121 input_node_names.reserve(m_input_node_names.size());
122 std::transform(m_input_node_names.begin(), m_input_node_names.end(), std::back_inserter(input_node_names), [](const std::string &name) { return name.c_str(); });
123
124 std::vector<const char *> output_node_names;
125 output_node_names.reserve(m_output_node_names.size());
126 std::transform(m_output_node_names.begin(), m_output_node_names.end(), std::back_inserter(output_node_names), [](const std::string &name) { return name.c_str(); });
127
128 auto output_tensors = m_ort_session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), input_tensors.data(), input_node_names.size(), output_node_names.data(), output_node_names.size());
129
131 for (size_t i = 0; i < output_tensors.size(); ++i) {
132 const auto &tensor = output_tensors[i];
133 const size_t length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
134 const float *data = tensor.GetTensorData<float>();
135 (i == 0 ? output.output_1 : output.output_2) = std::vector<float>(data, data + length);
136 }
137 return output;
138}
double length(const pvec &v)
Ort::Value create_tensor(std::vector< float > &data, const std::vector< int64_t > &shape) const
const std::vector< std::string > m_output_node_names
const std::vector< std::string > m_input_node_names

◆ sysInitialize()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysInitialize ( )
overridevirtualinherited

Perform system initialization for an algorithm.

We override this to declare all the elements of handle key arrays at the end of initialization. See comments on updateVHKA.

Reimplemented in asg::AsgMetadataTool, AthCheckedComponent< AthAlgTool >, AthCheckedComponent<::AthAlgTool >, and DerivationFramework::CfAthAlgTool.

◆ sysStart()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysStart ( )
overridevirtualinherited

Handle START transition.

We override this in order to make sure that conditions handle keys can cache a pointer to the conditions container.

◆ updateVHKA()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::updateVHKA ( Gaudi::Details::PropertyBase & )
inlineinherited

Definition at line 308 of file AthCommonDataStore.h.

308 {
309 // debug() << "updateVHKA for property " << p.name() << " " << p.toString()
310 // << " size: " << m_vhka.size() << endmsg;
311 for (auto &a : m_vhka) {
313 for (auto k : keys) {
314 k->setOwner(this);
315 }
316 }
317 }
std::vector< SG::VarHandleKeyArray * > m_vhka

Member Data Documentation

◆ m_dDefault

float DiTauOnnxDiscriminantTool::m_dDefault = -1234
private

Definition at line 55 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

◆ m_detStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
privateinherited

Pointer to StoreGate (detector store by default)

Definition at line 393 of file AthCommonDataStore.h.

◆ m_evtStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
privateinherited

Pointer to StoreGate (event store by default)

Definition at line 390 of file AthCommonDataStore.h.

◆ m_input_node_names

const std::vector<std::string> DiTauOnnxDiscriminantTool::m_input_node_names = {"input_features", "input_points", "input_mask", "input_jet", "input_time"}
private

Definition at line 92 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

92{"input_features", "input_points", "input_mask", "input_jet", "input_time"};

◆ m_maxTracks

Gaudi::Property<size_t> DiTauOnnxDiscriminantTool::m_maxTracks {this, "maxTracks", 10}
private

Definition at line 88 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

88{this, "maxTracks", 10};

◆ m_onnxModelPath

Gaudi::Property<std::string> DiTauOnnxDiscriminantTool::m_onnxModelPath {this, "onnxModelPath", "TrigTauRec/00-11-02/dev/boosted_ditau_omni_model.onnx"}
private

Definition at line 87 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

87{this, "onnxModelPath", "TrigTauRec/00-11-02/dev/boosted_ditau_omni_model.onnx"};

◆ m_ort_env

std::unique_ptr<Ort::Env> DiTauOnnxDiscriminantTool::m_ort_env
private

Definition at line 90 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

◆ m_ort_session

std::unique_ptr<Ort::Session> DiTauOnnxDiscriminantTool::m_ort_session
private

Definition at line 91 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

◆ m_output_node_names

const std::vector<std::string> DiTauOnnxDiscriminantTool::m_output_node_names = {"output_1", "output_2"}
private

Definition at line 93 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

93{"output_1", "output_2"};

◆ m_varHandleArraysDeclared

bool AthCommonDataStore< AthCommonMsg< AlgTool > >::m_varHandleArraysDeclared
privateinherited

Definition at line 399 of file AthCommonDataStore.h.

◆ m_vhka

std::vector<SG::VarHandleKeyArray*> AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
privateinherited

Definition at line 398 of file AthCommonDataStore.h.


The documentation for this class was generated from the following files: