ATLAS Offline Software
Classes | Public Member Functions | Static Public Member Functions | Protected Member Functions | Private Types | Private Member Functions | Private Attributes | List of all members
DiTauOnnxDiscriminantTool Class Reference

#include <DiTauOnnxDiscriminantTool.h>

Inheritance diagram for DiTauOnnxDiscriminantTool:
Collaboration diagram for DiTauOnnxDiscriminantTool:

Classes

struct  DitauTrackingInfo
 
struct  InferenceOutput
 
struct  OnnxInputs
 
struct  SubjetTrackingInfo
 

Public Member Functions

 DiTauOnnxDiscriminantTool (const std::string &type, const std::string &name, const IInterface *parent)
 
virtual ~DiTauOnnxDiscriminantTool ()
 
virtual StatusCode initialize () override
 Tool initializer. More...
 
virtual StatusCode finalize () override
 Finalizer. More...
 
virtual StatusCode execute (DiTauCandidateData *data, const EventContext &ctx) const override
 Execute - called for each Ditau candidate. More...
 
virtual StatusCode executeObj (xAOD::DiTauJet &xDiTau, const EventContext &ctx) const override
 Execute - called for each Ditau jet. More...
 
float GetDiTauObjOnnxScore (const xAOD::DiTauJet &ditau) const
 
virtual StatusCode eventInitialize (DiTauCandidateData *data)
 Event initializer - called at the beginning of each event. More...
 
template<class T >
bool retrieveTool (T &tool)
 Convenience functions to handle storegate objects. More...
 
ServiceHandle< StoreGateSvc > & evtStore ()
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc. More...
 
const ServiceHandle< StoreGateSvc > & evtStore () const
 The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc. More...
 
const ServiceHandle< StoreGateSvc > & detStore () const
 The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc. More...
 
virtual StatusCode sysInitialize () override
 Perform system initialization for an algorithm. More...
 
virtual StatusCode sysStart () override
 Handle START transition. More...
 
virtual std::vector< Gaudi::DataHandle * > inputHandles () const override
 Return this algorithm's input handles. More...
 
virtual std::vector< Gaudi::DataHandle * > outputHandles () const override
 Return this algorithm's output handles. More...
 
Gaudi::Details::PropertyBase & declareProperty (Gaudi::Property< T, V, H > &t)
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleKey &hndl, const std::string &doc, const SG::VarHandleKeyType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleBase &hndl, const std::string &doc, const SG::VarHandleType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, SG::VarHandleKeyArray &hndArr, const std::string &doc, const SG::VarHandleKeyArrayType &)
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, T &property, const std::string &doc, const SG::NotHandleType &)
 Declare a new Gaudi property. More...
 
Gaudi::Details::PropertyBase * declareProperty (const std::string &name, T &property, const std::string &doc="none")
 Declare a new Gaudi property. More...
 
void updateVHKA (Gaudi::Details::PropertyBase &)
 
MsgStream & msg () const
 
MsgStream & msg (const MSG::Level lvl) const
 
bool msgLvl (const MSG::Level lvl) const
 

Static Public Member Functions

static const InterfaceID & interfaceID ()
 InterfaceID implementation needed for ToolHandle. More...
 

Protected Member Functions

void renounceArray (SG::VarHandleKeyArray &handlesArray)
 remove all handles from I/O resolution More...
 
std::enable_if_t< std::is_void_v< std::result_of_t< decltype(&T::renounce)(T)> > &&!std::is_base_of_v< SG::VarHandleKeyArray, T > &&std::is_base_of_v< Gaudi::DataHandle, T >, void > renounce (T &h)
 
void extraDeps_update_handler (Gaudi::Details::PropertyBase &ExtraDeps)
 Add StoreName to extra input/output deps as needed. More...
 

Private Types

typedef ServiceHandle< StoreGateSvcStoreGateSvc_t
 

Private Member Functions

int n_subjets (const xAOD::DiTauJet &xDiTau) const
 
float ditau_pt (const xAOD::DiTauJet &xDiTau) const
 
float f_core (const xAOD::DiTauJet &xDiTau, int iSubjet) const
 
float f_subjet (const xAOD::DiTauJet &xDiTau, int iSubjet) const
 
float f_subjets (const xAOD::DiTauJet &xDiTau) const
 
float R_max (const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
 
int n_track (const xAOD::DiTauJet &xDiTau) const
 
float R_isotrack (const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo) const
 
float R_tracks (const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
 
float mass_core (const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
 
float mass_tracks (const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
 
float d0_leadtrack (const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
 
float f_isotracks (const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo) const
 
StatusCode getTrackingInfo (const xAOD::DiTauJet &xDiTau, DitauTrackingInfo &trackingInfo) const
 
Ort::Value create_tensor (std::vector< float > &data, const std::vector< int64_t > &shape) const
 
InferenceOutput run_inference (OnnxInputs &inputs) const
 
std::vector< float > flatten (const std::vector< std::vector< float >> &vec_2d) const
 
std::vector< float > extract_points (const std::vector< std::vector< float >> &track_features) const
 
std::vector< float > create_mask (const std::vector< std::vector< float >> &track_features) const
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
 specialization for handling Gaudi::Property<SG::VarHandleKey> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyArrayType &)
 specialization for handling Gaudi::Property<SG::VarHandleKeyArray> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T, V, H > &hndl, const SG::VarHandleType &)
 specialization for handling Gaudi::Property<SG::VarHandleBase> More...
 
Gaudi::Details::PropertyBase & declareGaudiProperty (Gaudi::Property< T, V, H > &t, const SG::NotHandleType &)
 specialization for handling everything that's not a Gaudi::Property<SG::VarHandleKey> or a <SG::VarHandleKeyArray> More...
 

Private Attributes

float m_dDefault = -1234
 
Gaudi::Property< std::string > m_onnxModelPath {this, "onnxModelPath", "TrigTauRec/00-11-02/dev/boosted_ditau_omni_model.onnx"}
 
Gaudi::Property< size_t > m_maxTracks {this, "maxTracks", 10}
 
std::unique_ptr< Ort::Env > m_ort_env
 
std::unique_ptr< Ort::Session > m_ort_session
 
const std::vector< std::string > m_input_node_names = {"input_features", "input_points", "input_mask", "input_jet", "input_time"}
 
const std::vector< std::string > m_output_node_names = {"output_1", "output_2"}
 
StoreGateSvc_t m_evtStore
 Pointer to StoreGate (event store by default) More...
 
StoreGateSvc_t m_detStore
 Pointer to StoreGate (detector store by default) More...
 
std::vector< SG::VarHandleKeyArray * > m_vhka
 
bool m_varHandleArraysDeclared
 

Detailed Description

Definition at line 29 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

Member Typedef Documentation

◆ StoreGateSvc_t

typedef ServiceHandle<StoreGateSvc> AthCommonDataStore< AthCommonMsg< AlgTool > >::StoreGateSvc_t
privateinherited

Definition at line 388 of file AthCommonDataStore.h.

Constructor & Destructor Documentation

◆ DiTauOnnxDiscriminantTool()

DiTauOnnxDiscriminantTool::DiTauOnnxDiscriminantTool ( const std::string &  type,
const std::string &  name,
const IInterface *  parent 
)

Definition at line 20 of file src/DiTauOnnxDiscriminantTool.cxx.

20  :
22 {
23  declareInterface<DiTauToolBase > (this);
24 }

◆ ~DiTauOnnxDiscriminantTool()

DiTauOnnxDiscriminantTool::~DiTauOnnxDiscriminantTool ( )
virtualdefault

Member Function Documentation

◆ create_mask()

std::vector< float > DiTauOnnxDiscriminantTool::create_mask ( const std::vector< std::vector< float >> &  track_features) const
private

Definition at line 97 of file src/DiTauOnnxDiscriminantTool.cxx.

97  {
98  std::vector<float> mask;
99  mask.reserve(track_features.size());
100  std::transform(track_features.begin(), track_features.end(), std::back_inserter(mask), [](const auto &track) {
101  return std::abs(track[2]) > 1e-6 ? 1.0f : 0.0f;
102  });
103  return mask;
104 }

◆ create_tensor()

Ort::Value DiTauOnnxDiscriminantTool::create_tensor ( std::vector< float > &  data,
const std::vector< int64_t > &  shape 
) const
private

Definition at line 106 of file src/DiTauOnnxDiscriminantTool.cxx.

106  {
107  Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
108  return Ort::Value::CreateTensor<float>(memory_info, data.data(), data.size(),shape.data(), shape.size());
109 }

◆ d0_leadtrack()

float DiTauOnnxDiscriminantTool::d0_leadtrack ( const xAOD::DiTauJet xDiTau,
const DitauTrackingInfo ditauInfo,
int  iSubjet 
) const
private

Definition at line 336 of file src/DiTauOnnxDiscriminantTool.cxx.

336  {
337  SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
338  if (!subjetInfo.leadTrack) {
339  return m_dDefault;
340  }
341  return subjetInfo.leadTrack->d0();
342 }

◆ declareGaudiProperty() [1/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T, V, H > &  hndl,
const SG::VarHandleKeyArrayType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKeyArray>

Definition at line 170 of file AthCommonDataStore.h.

172  {
173  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
174  hndl.value(),
175  hndl.documentation());
176 
177  }

◆ declareGaudiProperty() [2/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T, V, H > &  hndl,
const SG::VarHandleKeyType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleKey>

Definition at line 156 of file AthCommonDataStore.h.

158  {
159  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
160  hndl.value(),
161  hndl.documentation());
162 
163  }

◆ declareGaudiProperty() [3/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T, V, H > &  hndl,
const SG::VarHandleType  
)
inlineprivateinherited

specialization for handling Gaudi::Property<SG::VarHandleBase>

Definition at line 184 of file AthCommonDataStore.h.

186  {
187  return *AthCommonDataStore<PBASE>::declareProperty(hndl.name(),
188  hndl.value(),
189  hndl.documentation());
190  }

◆ declareGaudiProperty() [4/4]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareGaudiProperty ( Gaudi::Property< T, V, H > &  t,
const SG::NotHandleType  
)
inlineprivateinherited

specialization for handling everything that's not a Gaudi::Property<SG::VarHandleKey> or a <SG::VarHandleKeyArray>

Definition at line 199 of file AthCommonDataStore.h.

200  {
201  return PBASE::declareProperty(t);
202  }

◆ declareProperty() [1/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleBase hndl,
const std::string &  doc,
const SG::VarHandleType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
hndlObject holding the property value.
docDocumentation string for the property.

This is the version for types that derive from SG::VarHandleBase. The property value object is put on the input and output lists as appropriate; then we forward to the base class.

Definition at line 245 of file AthCommonDataStore.h.

249  {
250  this->declare(hndl.vhKey());
251  hndl.vhKey().setOwner(this);
252 
253  return PBASE::declareProperty(name,hndl,doc);
254  }

◆ declareProperty() [2/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleKey hndl,
const std::string &  doc,
const SG::VarHandleKeyType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
hndlObject holding the property value.
docDocumentation string for the property.

This is the version for types that derive from SG::VarHandleKey. The property value object is put on the input and output lists as appropriate; then we forward to the base class.

Definition at line 221 of file AthCommonDataStore.h.

225  {
226  this->declare(hndl);
227  hndl.setOwner(this);
228 
229  return PBASE::declareProperty(name,hndl,doc);
230  }

◆ declareProperty() [3/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
SG::VarHandleKeyArray hndArr,
const std::string &  doc,
const SG::VarHandleKeyArrayType  
)
inlineinherited

Definition at line 259 of file AthCommonDataStore.h.

263  {
264 
265  // std::ostringstream ost;
266  // ost << Algorithm::name() << " VHKA declareProp: " << name
267  // << " size: " << hndArr.keys().size()
268  // << " mode: " << hndArr.mode()
269  // << " vhka size: " << m_vhka.size()
270  // << "\n";
271  // debug() << ost.str() << endmsg;
272 
273  hndArr.setOwner(this);
274  m_vhka.push_back(&hndArr);
275 
276  Gaudi::Details::PropertyBase* p = PBASE::declareProperty(name, hndArr, doc);
277  if (p != 0) {
278  p->declareUpdateHandler(&AthCommonDataStore<PBASE>::updateVHKA, this);
279  } else {
280  ATH_MSG_ERROR("unable to call declareProperty on VarHandleKeyArray "
281  << name);
282  }
283 
284  return p;
285 
286  }

◆ declareProperty() [4/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
T &  property,
const std::string &  doc,
const SG::NotHandleType  
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
propertyObject holding the property value.
docDocumentation string for the property.

This is the generic version, for types that do not derive from SG::VarHandleKey. It just forwards to the base class version of declareProperty.

Definition at line 333 of file AthCommonDataStore.h.

337  {
338  return PBASE::declareProperty(name, property, doc);
339  }

◆ declareProperty() [5/6]

Gaudi::Details::PropertyBase* AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( const std::string &  name,
T &  property,
const std::string &  doc = "none" 
)
inlineinherited

Declare a new Gaudi property.

Parameters
nameName of the property.
propertyObject holding the property value.
docDocumentation string for the property.

This dispatches to either the generic declareProperty or the one for VarHandle/Key/KeyArray.

Definition at line 352 of file AthCommonDataStore.h.

355  {
356  typedef typename SG::HandleClassifier<T>::type htype;
357  return declareProperty (name, property, doc, htype());
358  }

◆ declareProperty() [6/6]

Gaudi::Details::PropertyBase& AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty ( Gaudi::Property< T, V, H > &  t)
inlineinherited

Definition at line 145 of file AthCommonDataStore.h.

145  {
146  typedef typename SG::HandleClassifier<T>::type htype;
148  }

◆ detStore()

const ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::detStore ( ) const
inlineinherited

The standard StoreGateSvc/DetectorStore Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 95 of file AthCommonDataStore.h.

95 { return m_detStore; }

◆ ditau_pt()

float DiTauOnnxDiscriminantTool::ditau_pt ( const xAOD::DiTauJet xDiTau) const
private

Definition at line 243 of file src/DiTauOnnxDiscriminantTool.cxx.

244 {
245  return xDiTau.subjetPt(0)+xDiTau.subjetPt(1);
246 }

◆ eventInitialize()

StatusCode DiTauToolBase::eventInitialize ( DiTauCandidateData data)
virtualinherited

Event initializer - called at the beginning of each event.

Definition at line 32 of file DiTauToolBase.cxx.

33 {
34  return StatusCode::SUCCESS;
35 }

◆ evtStore() [1/2]

ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( )
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 85 of file AthCommonDataStore.h.

85 { return m_evtStore; }

◆ evtStore() [2/2]

const ServiceHandle<StoreGateSvc>& AthCommonDataStore< AthCommonMsg< AlgTool > >::evtStore ( ) const
inlineinherited

The standard StoreGateSvc (event store) Returns (kind of) a pointer to the StoreGateSvc.

Definition at line 90 of file AthCommonDataStore.h.

90 { return m_evtStore; }

◆ execute()

StatusCode DiTauOnnxDiscriminantTool::execute ( DiTauCandidateData data,
const EventContext &  ctx 
) const
overridevirtual

Execute - called for each Ditau candidate.

Reimplemented from DiTauToolBase.

Definition at line 57 of file src/DiTauOnnxDiscriminantTool.cxx.

58 {
59  static const SG::Accessor<float> omni_scoreDec("omni_score");
60  xAOD::DiTauJet* xDitau = data->xAODDiTau;
61  ATH_MSG_DEBUG("Inferencing omni DiTau ID score...");
62  float score = GetDiTauObjOnnxScore(*xDitau);
63  ATH_MSG_DEBUG("DiTau ID score: " << score);
64  omni_scoreDec(*xDitau) = score;
65  return StatusCode::SUCCESS;
66 }

◆ executeObj()

StatusCode DiTauOnnxDiscriminantTool::executeObj ( xAOD::DiTauJet xDiTau,
const EventContext &  ctx 
) const
overridevirtual

Execute - called for each Ditau jet.

Reimplemented from DiTauToolBase.

Definition at line 68 of file src/DiTauOnnxDiscriminantTool.cxx.

69 {
70  static const SG::Accessor<float> omni_scoreDec("omni_score");
71  ATH_MSG_DEBUG("Inferencing omni DiTau ID score...");
72  float score = GetDiTauObjOnnxScore(xDiTau);
73  ATH_MSG_DEBUG("DiTau ID score: " << score);
74  omni_scoreDec(xDiTau) = score;
75  return StatusCode::SUCCESS;
76 }

◆ extract_points()

std::vector< float > DiTauOnnxDiscriminantTool::extract_points ( const std::vector< std::vector< float >> &  track_features) const
private

Definition at line 87 of file src/DiTauOnnxDiscriminantTool.cxx.

87  {
88  std::vector<float> points;
89  points.reserve(track_features.size() * 2);
90  for (const auto &track : track_features) {
91  points.push_back(track[0]); // delta_eta
92  points.push_back(track[1]); // delta_phi
93  }
94  return points;
95 }

◆ extraDeps_update_handler()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::extraDeps_update_handler ( Gaudi::Details::PropertyBase &  ExtraDeps)
protectedinherited

Add StoreName to extra input/output deps as needed.

use the logic of the VarHandleKey to parse the DataObjID keys supplied via the ExtraInputs and ExtraOuputs Properties to add the StoreName if it's not explicitly given

◆ f_core()

float DiTauOnnxDiscriminantTool::f_core ( const xAOD::DiTauJet xDiTau,
int  iSubjet 
) const
private

Definition at line 248 of file src/DiTauOnnxDiscriminantTool.cxx.

249 {
250  return xDiTau.fCore(iSubjet);
251 }

◆ f_isotracks()

float DiTauOnnxDiscriminantTool::f_isotracks ( const xAOD::DiTauJet xDiTau,
const DitauTrackingInfo ditauInfo 
) const
private

Definition at line 344 of file src/DiTauOnnxDiscriminantTool.cxx.

344  {
345  float iso_pt = 0;
346  for (const xAOD::TrackParticle* xTrack: ditauInfo.vIsoTracks) {
347  iso_pt += xTrack->pt();
348  }
349  if( xDiTau.pt() == 0.){
350  return m_dDefault;
351  } else {
352  return iso_pt / xDiTau.pt();
353  }
354 }

◆ f_subjet()

float DiTauOnnxDiscriminantTool::f_subjet ( const xAOD::DiTauJet xDiTau,
int  iSubjet 
) const
private

Definition at line 253 of file src/DiTauOnnxDiscriminantTool.cxx.

253  {
254  return xDiTau.subjetPt(iSubjet) / xDiTau.pt();
255 }

◆ f_subjets()

float DiTauOnnxDiscriminantTool::f_subjets ( const xAOD::DiTauJet xDiTau) const
private

Definition at line 257 of file src/DiTauOnnxDiscriminantTool.cxx.

258 {
259  return (xDiTau.subjetPt(0) + xDiTau.subjetPt(1))/ xDiTau.pt();
260 }

◆ finalize()

StatusCode DiTauOnnxDiscriminantTool::finalize ( )
overridevirtual

Finalizer.

Reimplemented from DiTauToolBase.

Definition at line 49 of file src/DiTauOnnxDiscriminantTool.cxx.

50 {
51  ATH_MSG_INFO( "Finalizing DiTauOnnxDiscriminantTool" );
52  m_ort_session.reset();
53  m_ort_env.reset();
54  return StatusCode::SUCCESS;
55 }

◆ flatten()

std::vector< float > DiTauOnnxDiscriminantTool::flatten ( const std::vector< std::vector< float >> &  vec_2d) const
private

Definition at line 78 of file src/DiTauOnnxDiscriminantTool.cxx.

78  {
79  std::vector<float> flattened;
80  flattened.reserve(vec_2d.size() * (vec_2d.empty() ? 0 : vec_2d[0].size()));
81  for (const auto &inner : vec_2d) {
82  flattened.insert(flattened.end(), inner.begin(), inner.end());
83  }
84  return flattened;
85 }

◆ GetDiTauObjOnnxScore()

float DiTauOnnxDiscriminantTool::GetDiTauObjOnnxScore ( const xAOD::DiTauJet ditau) const

Definition at line 140 of file src/DiTauOnnxDiscriminantTool.cxx.

140  {
141 
142  // do the calculation only for ditau with at least 2 subjets
143  if(n_subjets(ditau)<2){
144  return m_dDefault;
145  }
146 
147  DitauTrackingInfo ditauTrackingInfo;
148  if(!(getTrackingInfo(ditau, ditauTrackingInfo))){
149  return m_dDefault;
150  }
151 
152  // Accessors for reading the necessary features from the xAOD::TrackParticle object
153  static const SG::ConstAccessor< uint8_t > numberOfInrmstPxlLyrHitsAcc ("numberOfInnermostPixelLayerHits");
154  static const SG::ConstAccessor< uint8_t > numberOfPixelHitsAcc ("numberOfPixelHits");
155  static const SG::ConstAccessor< uint8_t > numberOfSCTHitsAcc ("numberOfSCTHits");
156  static const SG::ConstAccessor< float > z0Acc ("z0");
157  static const SG::ConstAccessor< float > d0Acc ("d0");
158  // Input features for Ditau tagger ONNX model
159  std::vector<float> jet_vars = {
160  R_max(ditau, ditauTrackingInfo, 0),
161  R_max(ditau, ditauTrackingInfo, 1),
162  R_tracks(ditau, ditauTrackingInfo, 1),
163  R_isotrack(ditau, ditauTrackingInfo),
164  d0_leadtrack(ditau, ditauTrackingInfo, 0),
165  d0_leadtrack(ditau, ditauTrackingInfo, 1),
166  f_core(ditau,0),
167  f_core(ditau,1),
168  f_subjet(ditau,1),
169  f_subjets(ditau),
170  f_isotracks(ditau, ditauTrackingInfo),
171  mass_core(ditau, ditauTrackingInfo, 0),
172  mass_core(ditau, ditauTrackingInfo, 1),
173  mass_tracks(ditau, ditauTrackingInfo, 0),
174  static_cast<float>( n_track(ditau)),
175  };
176  std::vector<int64_t> jet_shape = {1, static_cast<int64_t>(jet_vars.size())};
177 
178  const TrackParticleLinks_t &vTauTracks = ditau.trackLinks();
179  std::vector<std::vector<float>> track_features(m_maxTracks, std::vector<float>(11, 0.0f));
180 
181  float jet_eta = ditau.eta();
182  float jet_phi = ditau.phi();
183  size_t num_tracks = std::min(static_cast<size_t>(m_maxTracks), vTauTracks.size());
184 
185  for (size_t i = 0; i < num_tracks; ++i) {
186  const ElementLink<xAOD::TrackParticleContainer> &trackLink = vTauTracks[i];
187  if (!trackLink.isValid()) continue;
188  const xAOD::TrackParticle *xTrack = *trackLink;
189  float track_eta = xTrack->eta();
190  float track_phi = xTrack->phi();
191  float delta_eta = track_eta - jet_eta;
192  float delta_phi = std::remainder(track_phi - jet_phi, 2 * M_PI);
193  float delta_R = std::hypot(delta_eta, delta_phi);
194  float track_pt = static_cast<float>(xTrack->pt());
195  float pt_log = std::log(track_pt + 1e-8f);
196  float jet_pt = ditau_pt(ditau); //ditau_ptAcc(ditau);
197  float pt_ratio = track_pt / jet_pt;
198  float pt_ratio_log = (pt_ratio <= 1.0f) ? std::log(1.0f - pt_ratio + 1e-8f) : 0.0f;
199  float track_charge = xTrack->charge();
200 
201  track_features[i] = {
202  delta_eta,
203  delta_phi,
204  pt_log,
205  d0Acc(*xTrack),
206  pt_ratio_log,
207  z0Acc(*xTrack),
208  delta_R,
209  static_cast<float>(numberOfInrmstPxlLyrHitsAcc(*xTrack)),
210  static_cast<float>(numberOfPixelHitsAcc(*xTrack)),
211  static_cast<float>(numberOfSCTHitsAcc(*xTrack)),
212  track_charge
213  };
214  }
215  std::vector<int64_t> track_shape = {1, static_cast<int64_t>(m_maxTracks), 11};
216 
217  // Actual ONNX inference
218  OnnxInputs inputs{
219  flatten(track_features),
220  track_shape,
221  extract_points(track_features),
222  {1, track_shape[1], 2},
223  create_mask(track_features),
224  {1, track_shape[1]},
225  std::move(jet_vars),
226  std::move(jet_shape),
227  {0.0f},
228  {1, 1}
229  };
230  auto output = run_inference(inputs);
231  return output.output_1[1];
232 }

◆ getTrackingInfo()

StatusCode DiTauOnnxDiscriminantTool::getTrackingInfo ( const xAOD::DiTauJet xDiTau,
DitauTrackingInfo trackingInfo 
) const
private

Definition at line 356 of file src/DiTauOnnxDiscriminantTool.cxx.

356  {
357  static const SG::ConstAccessor<std::vector<ElementLink<xAOD::TrackParticleContainer>>> trackLinksAcc("trackLinks");
358  static const SG::ConstAccessor<std::vector<ElementLink<xAOD::TrackParticleContainer>>> isoTrackLinksAcc("isoTrackLinks");
359  static const SG::ConstAccessor<float> R_subjetAcc("R_subjet");
360  static const SG::ConstAccessor<float> R_coreAcc("R_core");
361 
362 
363  if (!trackLinksAcc.isAvailable(xDiTau) || !isoTrackLinksAcc.isAvailable(xDiTau)) {
364  ATH_MSG_WARNING("Track " << (!trackLinksAcc.isAvailable(xDiTau) ? "DiTauJet.trackLinks" : "DiTauJet.isoTrackLinks") << " links not available.");
365  return StatusCode::FAILURE;
366  }
367 
368  int nSubjets = n_subjets(xDiTau);
369  float Rsubjet = R_subjetAcc(xDiTau);
370  float RCore = R_coreAcc(xDiTau);
371 
372  trackingInfo.nSubjets = nSubjets;
373  trackingInfo.vSubjetInfo.clear();
374  trackingInfo.vIsoTracks.clear();
375  trackingInfo.vTracks.clear();
376 
377  // Get the track links from the DiTauJet and store them in the tracking info
378  std::vector<ElementLink<xAOD::TrackParticleContainer>> isoTrackLinks = xDiTau.isoTrackLinks();
379  for (const auto &trackLink: isoTrackLinks) {
380  if (!trackLink.isValid()) {
381  ATH_MSG_WARNING("Iso track link is not valid");
382  continue;
383  }
384  const xAOD::TrackParticle* xTrack = *trackLink;
385  trackingInfo.vIsoTracks.push_back(xTrack);
386  }
387  std::vector<ElementLink<xAOD::TrackParticleContainer>> trackLinks = xDiTau.trackLinks();
388  for (const auto &trackLink : trackLinks) {
389  if (!trackLink.isValid()) {
390  ATH_MSG_WARNING("track link is not valid");
391  continue;
392  }
393  const xAOD::TrackParticle* xTrack = *trackLink;
394  trackingInfo.vTracks.push_back(xTrack);
395  }
396  // store subjet p4
397  for (int i=0; i<nSubjets; ++i){
398  SubjetTrackingInfo subjetTrackingInfo;
399  TLorentzVector subjet_p4 = TLorentzVector();
400  subjet_p4.SetPtEtaPhiE( xDiTau.subjetPt(i), xDiTau.subjetEta(i), xDiTau.subjetPhi(i), xDiTau.subjetE(i));
401  subjetTrackingInfo.subjet_p4 = subjet_p4;
402  trackingInfo.vSubjetInfo.push_back(subjetTrackingInfo);
403  }
404  for (const auto track : trackingInfo.vTracks) {
405  float dRMin = 999;
406  int inSubjet = -1;
407  for (int i=0; i<nSubjets; ++i){
408  float dRTrackSubjet = trackingInfo.vSubjetInfo[i].subjet_p4.DeltaR(track->p4());
409  if (dRTrackSubjet < Rsubjet && dRTrackSubjet < dRMin){
410  dRMin = dRTrackSubjet;
411  inSubjet = i;
412  }
413  }
414  if (inSubjet >= 0){
415  trackingInfo.vSubjetInfo[inSubjet].vTracks.push_back(track);
416  }
417  }
418  // find leading track in subjets
419  for (int i=0; i<nSubjets; ++i){
420  float ptLeadTrack = 0;
421  for (const auto track : trackingInfo.vSubjetInfo[i].vTracks){
422  if (track->pt() > ptLeadTrack){
423  ptLeadTrack = track->pt();
424  trackingInfo.vSubjetInfo[i].leadTrack = track;
425  }
426  }
427  }
428  // find core track in subjets
429  for (int i=0; i<nSubjets; ++i){
430  for (const auto track : trackingInfo.vSubjetInfo[i].vTracks){
431  auto subjetTrackingInfo = trackingInfo.vSubjetInfo[i];
432  if (subjetTrackingInfo.subjet_p4.DeltaR(track->p4()) < RCore){
433  trackingInfo.vSubjetInfo[i].vCoreTracks.push_back(track);
434  }
435  }
436  }
437  //find isotracks in subjets
438  for (const auto track : trackingInfo.vIsoTracks){
439  float RIso = 0.4;
440  float dRMin = 999;
441  int inSubjet = -1;
442  for (int i=0; i<nSubjets; ++i){
443  float dRTrackSubjet = trackingInfo.vSubjetInfo[i].subjet_p4.DeltaR(track->p4());
444  if (dRTrackSubjet > Rsubjet && dRTrackSubjet < RIso && dRTrackSubjet < dRMin){
445  dRMin = dRTrackSubjet;
446  inSubjet = i;
447  }
448  }
449  if (inSubjet >= 0){
450  trackingInfo.vSubjetInfo[inSubjet].vIsoTracks.push_back(track);
451  }
452  }
453  return StatusCode::SUCCESS;
454 }

◆ initialize()

StatusCode DiTauOnnxDiscriminantTool::initialize ( )
overridevirtual

Tool initializer.

Reimplemented from DiTauToolBase.

Definition at line 30 of file src/DiTauOnnxDiscriminantTool.cxx.

31 {
32  ATH_MSG_INFO( "Initializing DiTauOnnxDiscriminantTool" );
33  ATH_MSG_INFO( "onnxModelPath: " << m_onnxModelPath );
34 
35  auto model_path = PathResolverFindCalibFile (m_onnxModelPath);
36  if (model_path.empty()) {
37  ATH_MSG_ERROR("Could not find model file: " << m_onnxModelPath);
38  return StatusCode::FAILURE;
39  }
40  m_ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "OnnxUtil");
41  Ort::SessionOptions session_options;
42  session_options.SetIntraOpNumThreads(1);
43  session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
44  session_options.DisableCpuMemArena();
45  m_ort_session = std::make_unique<Ort::Session>(*m_ort_env, model_path.c_str(), session_options);
46  return StatusCode::SUCCESS;
47 }

◆ inputHandles()

virtual std::vector<Gaudi::DataHandle*> AthCommonDataStore< AthCommonMsg< AlgTool > >::inputHandles ( ) const
overridevirtualinherited

Return this algorithm's input handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ interfaceID()

const InterfaceID & DiTauToolBase::interfaceID ( )
staticinherited

InterfaceID implementation needed for ToolHandle.

Definition at line 9 of file DiTauToolBase.cxx.

9  {
10  return DiTauToolBaseID;
11 }

◆ mass_core()

float DiTauOnnxDiscriminantTool::mass_core ( const xAOD::DiTauJet xDiTau,
const DitauTrackingInfo ditauInfo,
int  iSubjet 
) const
private

Definition at line 310 of file src/DiTauOnnxDiscriminantTool.cxx.

310  {
311  TLorentzVector allCoreTracks_p4;
312  SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
313  for (const xAOD::TrackParticle* xTrack: subjetInfo.vCoreTracks) {
314  allCoreTracks_p4 += xTrack->p4();
315  }
316  float mass = allCoreTracks_p4.M();
317  if (mass < 0) {
318  return m_dDefault;
319  }
320  return mass;
321 }

◆ mass_tracks()

float DiTauOnnxDiscriminantTool::mass_tracks ( const xAOD::DiTauJet xDiTau,
const DitauTrackingInfo ditauInfo,
int  iSubjet 
) const
private

Definition at line 323 of file src/DiTauOnnxDiscriminantTool.cxx.

323  {
324  TLorentzVector allTracks_p4;
325  SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
326  for (const xAOD::TrackParticle* xTrack: subjetInfo.vTracks) {
327  allTracks_p4 += xTrack->p4();
328  }
329  float mass = allTracks_p4.M();
330  if (mass < 0) {
331  return m_dDefault;
332  }
333  return mass;
334 }

◆ msg() [1/2]

MsgStream& AthCommonMsg< AlgTool >::msg ( ) const
inlineinherited

Definition at line 24 of file AthCommonMsg.h.

24  {
25  return this->msgStream();
26  }

◆ msg() [2/2]

MsgStream& AthCommonMsg< AlgTool >::msg ( const MSG::Level  lvl) const
inlineinherited

Definition at line 27 of file AthCommonMsg.h.

27  {
28  return this->msgStream(lvl);
29  }

◆ msgLvl()

bool AthCommonMsg< AlgTool >::msgLvl ( const MSG::Level  lvl) const
inlineinherited

Definition at line 30 of file AthCommonMsg.h.

30  {
31  return this->msgLevel(lvl);
32  }

◆ n_subjets()

int DiTauOnnxDiscriminantTool::n_subjets ( const xAOD::DiTauJet xDiTau) const
private

Definition at line 235 of file src/DiTauOnnxDiscriminantTool.cxx.

235  {
236  int nSubjet = 0;
237  while (xDiTau.subjetPt(nSubjet) > 0. ){
238  nSubjet++;
239  }
240  return nSubjet;
241 }

◆ n_track()

int DiTauOnnxDiscriminantTool::n_track ( const xAOD::DiTauJet xDiTau) const
private

Definition at line 274 of file src/DiTauOnnxDiscriminantTool.cxx.

274  {
275  return xDiTau.nTracks();
276 }

◆ outputHandles()

virtual std::vector<Gaudi::DataHandle*> AthCommonDataStore< AthCommonMsg< AlgTool > >::outputHandles ( ) const
overridevirtualinherited

Return this algorithm's output handles.

We override this to include handle instances from key arrays if they have not yet been declared. See comments on updateVHKA.

◆ R_isotrack()

float DiTauOnnxDiscriminantTool::R_isotrack ( const xAOD::DiTauJet xDiTau,
const DitauTrackingInfo ditauInfo 
) const
private

Definition at line 278 of file src/DiTauOnnxDiscriminantTool.cxx.

279 {
280  float R_sum = 0;
281  float pt = 0;
282  for (int i = 0; i < 2; i++) {
283  SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(i);
284  for (const xAOD::TrackParticle* xTrack: subjetInfo.vIsoTracks) {
285  R_sum += subjetInfo.subjet_p4.DeltaR(xTrack->p4()) * xTrack->pt();
286  pt += xTrack->pt();
287  }
288  }
289  if (pt == 0) {
290  return m_dDefault;
291  }
292  return R_sum / pt;
293 }

◆ R_max()

float DiTauOnnxDiscriminantTool::R_max ( const xAOD::DiTauJet xDiTau,
const DitauTrackingInfo ditauInfo,
int  iSubjet 
) const
private

Definition at line 262 of file src/DiTauOnnxDiscriminantTool.cxx.

263 {
264  const SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
265  float Rmax = 0;
266  for (const xAOD::TrackParticle* xTrack: subjetInfo.vTracks) {
267  if (subjetInfo.subjet_p4.DeltaR(xTrack->p4()) > Rmax) {
268  Rmax = subjetInfo.subjet_p4.DeltaR(xTrack->p4());
269  }
270  }
271  return Rmax;
272 }

◆ R_tracks()

float DiTauOnnxDiscriminantTool::R_tracks ( const xAOD::DiTauJet xDiTau,
const DitauTrackingInfo ditauInfo,
int  iSubjet 
) const
private

Definition at line 295 of file src/DiTauOnnxDiscriminantTool.cxx.

295  {
296  float R_sum = 0;
297  float pt = 0;
298 
299  SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
300  for (const xAOD::TrackParticle* xTrack: subjetInfo.vTracks) {
301  R_sum += subjetInfo.subjet_p4.DeltaR(xTrack->p4()) * xTrack->pt();
302  pt += xTrack->pt();
303  }
304  if (pt == 0) {
305  return m_dDefault;
306  }
307  return R_sum / pt;
308 }

◆ renounce()

std::enable_if_t<std::is_void_v<std::result_of_t<decltype(&T::renounce)(T)> > && !std::is_base_of_v<SG::VarHandleKeyArray, T> && std::is_base_of_v<Gaudi::DataHandle, T>, void> AthCommonDataStore< AthCommonMsg< AlgTool > >::renounce ( T &  h)
inlineprotectedinherited

Definition at line 380 of file AthCommonDataStore.h.

381  {
382  h.renounce();
383  PBASE::renounce (h);
384  }

◆ renounceArray()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::renounceArray ( SG::VarHandleKeyArray handlesArray)
inlineprotectedinherited

remove all handles from I/O resolution

Definition at line 364 of file AthCommonDataStore.h.

364  {
365  handlesArray.renounce();
366  }

◆ retrieveTool()

template<class T >
bool DiTauToolBase::retrieveTool ( T &  tool)
inherited

Convenience functions to handle storegate objects.

Definition at line 59 of file DiTauToolBase.cxx.

59  {
60  if (tool.retrieve().isFailure()) {
61  ATH_MSG_FATAL("Failed to retrieve tool " << tool);
62  return false;
63  } else {
64  ATH_MSG_VERBOSE("Retrieved tool " << tool);
65  }
66  return true;
67 }

◆ run_inference()

DiTauOnnxDiscriminantTool::InferenceOutput DiTauOnnxDiscriminantTool::run_inference ( OnnxInputs inputs) const
private

Definition at line 111 of file src/DiTauOnnxDiscriminantTool.cxx.

111  {
112  std::vector<Ort::Value> input_tensors;
113  input_tensors.reserve(m_input_node_names.size());
114  input_tensors.emplace_back(create_tensor(inputs.input_features, inputs.input_features_shape));
115  input_tensors.emplace_back(create_tensor(inputs.input_points, inputs.input_points_shape));
116  input_tensors.emplace_back(create_tensor(inputs.input_mask, inputs.input_mask_shape));
117  input_tensors.emplace_back(create_tensor(inputs.input_jet, inputs.input_jet_shape));
118  input_tensors.emplace_back(create_tensor(inputs.input_time, inputs.input_time_shape));
119 
120  std::vector<const char *> input_node_names;
121  input_node_names.reserve(m_input_node_names.size());
122  std::transform(m_input_node_names.begin(), m_input_node_names.end(), std::back_inserter(input_node_names), [](const std::string &name) { return name.c_str(); });
123 
124  std::vector<const char *> output_node_names;
125  output_node_names.reserve(m_output_node_names.size());
126  std::transform(m_output_node_names.begin(), m_output_node_names.end(), std::back_inserter(output_node_names), [](const std::string &name) { return name.c_str(); });
127 
128  auto output_tensors = m_ort_session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), input_tensors.data(), input_node_names.size(), output_node_names.data(), output_node_names.size());
129 
130  InferenceOutput output;
131  for (size_t i = 0; i < output_tensors.size(); ++i) {
132  const auto &tensor = output_tensors[i];
133  const size_t length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
134  const float *data = tensor.GetTensorData<float>();
135  (i == 0 ? output.output_1 : output.output_2) = std::vector<float>(data, data + length);
136  }
137  return output;
138 }

◆ sysInitialize()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysInitialize ( )
overridevirtualinherited

Perform system initialization for an algorithm.

We override this to declare all the elements of handle key arrays at the end of initialization. See comments on updateVHKA.

Reimplemented in DerivationFramework::CfAthAlgTool, AthCheckedComponent< AthAlgTool >, AthCheckedComponent<::AthAlgTool >, and asg::AsgMetadataTool.

◆ sysStart()

virtual StatusCode AthCommonDataStore< AthCommonMsg< AlgTool > >::sysStart ( )
overridevirtualinherited

Handle START transition.

We override this in order to make sure that conditions handle keys can cache a pointer to the conditions container.

◆ updateVHKA()

void AthCommonDataStore< AthCommonMsg< AlgTool > >::updateVHKA ( Gaudi::Details::PropertyBase &  )
inlineinherited

Definition at line 308 of file AthCommonDataStore.h.

308  {
309  // debug() << "updateVHKA for property " << p.name() << " " << p.toString()
310  // << " size: " << m_vhka.size() << endmsg;
311  for (auto &a : m_vhka) {
312  std::vector<SG::VarHandleKey*> keys = a->keys();
313  for (auto k : keys) {
314  k->setOwner(this);
315  }
316  }
317  }

Member Data Documentation

◆ m_dDefault

float DiTauOnnxDiscriminantTool::m_dDefault = -1234
private

Definition at line 55 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

◆ m_detStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
privateinherited

Pointer to StoreGate (detector store by default)

Definition at line 393 of file AthCommonDataStore.h.

◆ m_evtStore

StoreGateSvc_t AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
privateinherited

Pointer to StoreGate (event store by default)

Definition at line 390 of file AthCommonDataStore.h.

◆ m_input_node_names

const std::vector<std::string> DiTauOnnxDiscriminantTool::m_input_node_names = {"input_features", "input_points", "input_mask", "input_jet", "input_time"}
private

Definition at line 92 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

◆ m_maxTracks

Gaudi::Property<size_t> DiTauOnnxDiscriminantTool::m_maxTracks {this, "maxTracks", 10}
private

Definition at line 88 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

◆ m_onnxModelPath

Gaudi::Property<std::string> DiTauOnnxDiscriminantTool::m_onnxModelPath {this, "onnxModelPath", "TrigTauRec/00-11-02/dev/boosted_ditau_omni_model.onnx"}
private

Definition at line 87 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

◆ m_ort_env

std::unique_ptr<Ort::Env> DiTauOnnxDiscriminantTool::m_ort_env
private

Definition at line 90 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

◆ m_ort_session

std::unique_ptr<Ort::Session> DiTauOnnxDiscriminantTool::m_ort_session
private

Definition at line 91 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

◆ m_output_node_names

const std::vector<std::string> DiTauOnnxDiscriminantTool::m_output_node_names = {"output_1", "output_2"}
private

Definition at line 93 of file DiTauRec/DiTauOnnxDiscriminantTool.h.

◆ m_varHandleArraysDeclared

bool AthCommonDataStore< AthCommonMsg< AlgTool > >::m_varHandleArraysDeclared
privateinherited

Definition at line 399 of file AthCommonDataStore.h.

◆ m_vhka

std::vector<SG::VarHandleKeyArray*> AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
privateinherited

Definition at line 398 of file AthCommonDataStore.h.


The documentation for this class was generated from the following files:
xAOD::TrackParticle_v1::pt
virtual double pt() const override final
The transverse momentum ( ) of the particle.
Definition: TrackParticle_v1.cxx:74
AllowedVariables::e
e
Definition: AsgElectronSelectorTool.cxx:37
xAOD::DiTauJet_v1::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
DiTauToolBase::DiTauToolBase
DiTauToolBase(const std::string &type, const std::string &name, const IInterface *parent)
Definition: DiTauToolBase.cxx:14
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
DiTauOnnxDiscriminantTool::f_isotracks
float f_isotracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:344
ATH_MSG_FATAL
#define ATH_MSG_FATAL(x)
Definition: AthMsgStreamMacros.h:34
DiTauOnnxDiscriminantTool::m_onnxModelPath
Gaudi::Property< std::string > m_onnxModelPath
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:87
DiTauOnnxDiscriminantTool::extract_points
std::vector< float > extract_points(const std::vector< std::vector< float >> &track_features) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:87
xAOD::DiTauJet_v1::fCore
float fCore(unsigned int numSubjet) const
Definition: DiTauJet_v1.cxx:167
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
DiTauOnnxDiscriminantTool::f_subjet
float f_subjet(const xAOD::DiTauJet &xDiTau, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:253
Base_Fragment.mass
mass
Definition: Sherpa_i/share/common/Base_Fragment.py:59
SG::Accessor< float >
xAOD::TrackParticle_v1::charge
float charge() const
Returns the charge.
Definition: TrackParticle_v1.cxx:151
DiTauOnnxDiscriminantTool::run_inference
InferenceOutput run_inference(OnnxInputs &inputs) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:111
DiTauOnnxDiscriminantTool::n_subjets
int n_subjets(const xAOD::DiTauJet &xDiTau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:235
xAOD::TrackParticle_v1::eta
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
Definition: TrackParticle_v1.cxx:78
eFEXNTuple.delta_R
def delta_R(eta1, phi1, eta2, phi2)
Definition: eFEXNTuple.py:20
min
constexpr double min()
Definition: ap_fixedTest.cxx:26
DiTauOnnxDiscriminantTool::R_isotrack
float R_isotrack(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:278
xAOD::DiTauJet_v1::subjetPhi
float subjetPhi(unsigned int numSubjet) const
Definition: DiTauJet_v1.cxx:111
DiTauOnnxDiscriminantTool::m_input_node_names
const std::vector< std::string > m_input_node_names
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:92
DiTauOnnxDiscriminantTool::d0_leadtrack
float d0_leadtrack(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:336
test_pyathena.pt
pt
Definition: test_pyathena.py:11
M_PI
#define M_PI
Definition: ActiveFraction.h:11
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_evtStore
StoreGateSvc_t m_evtStore
Pointer to StoreGate (event store by default)
Definition: AthCommonDataStore.h:390
DiTauOnnxDiscriminantTool::m_dDefault
float m_dDefault
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:55
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_vhka
std::vector< SG::VarHandleKeyArray * > m_vhka
Definition: AthCommonDataStore.h:398
TrackParticleLinks_t
std::vector< ElementLink< xAOD::TrackParticleContainer > > TrackParticleLinks_t
Definition: src/DiTauOnnxDiscriminantTool.cxx:16
SG::ConstAccessor< uint8_t >
read_hist_ntuple.t
t
Definition: read_hist_ntuple.py:5
xAOD::DiTauJet_v1::eta
virtual double eta() const
The pseudorapidity ( ) of the particle.
ATH_MSG_VERBOSE
#define ATH_MSG_VERBOSE(x)
Definition: AthMsgStreamMacros.h:28
DiTauOnnxDiscriminantTool::create_tensor
Ort::Value create_tensor(std::vector< float > &data, const std::vector< int64_t > &shape) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:106
DiTauOnnxDiscriminantTool::flatten
std::vector< float > flatten(const std::vector< std::vector< float >> &vec_2d) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:78
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
python.utils.AtlRunQueryLookup.mask
string mask
Definition: AtlRunQueryLookup.py:459
DiTauOnnxDiscriminantTool::f_subjets
float f_subjets(const xAOD::DiTauJet &xDiTau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:257
xAOD::DiTauJet_v1::subjetE
float subjetE(unsigned int numSubjet) const
Definition: DiTauJet_v1.cxx:121
SG::VarHandleKeyArray::setOwner
virtual void setOwner(IDataHandleHolder *o)=0
IDTPMcnv.htype
htype
Definition: IDTPMcnv.py:29
DiTauOnnxDiscriminantTool::n_track
int n_track(const xAOD::DiTauJet &xDiTau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:274
xAOD::TrackParticle_v1::p4
virtual FourMom_t p4() const override final
The full 4-momentum of the particle.
Definition: TrackParticle_v1.cxx:130
AthCommonDataStore::declareGaudiProperty
Gaudi::Details::PropertyBase & declareGaudiProperty(Gaudi::Property< T, V, H > &hndl, const SG::VarHandleKeyType &)
specialization for handling Gaudi::Property<SG::VarHandleKey>
Definition: AthCommonDataStore.h:156
python.utils.AtlRunQueryDQUtils.p
p
Definition: AtlRunQueryDQUtils.py:209
DiTauOnnxDiscriminantTool::mass_core
float mass_core(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:310
xAOD::DiTauJet_v1::phi
virtual double phi() const
The azimuthal angle ( ) of the particle.
AthCommonDataStore
Definition: AthCommonDataStore.h:52
DiTauOnnxDiscriminantTool::R_tracks
float R_tracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:295
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
DiTauOnnxDiscriminantTool::GetDiTauObjOnnxScore
float GetDiTauObjOnnxScore(const xAOD::DiTauJet &ditau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:140
lumiFormat.i
int i
Definition: lumiFormat.py:85
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
AthCommonDataStore::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T, V, H > &t)
Definition: AthCommonDataStore.h:145
Amg::transform
Amg::Vector3D transform(Amg::Vector3D &v, Amg::Transform3D &tr)
Transform a point from a Trasformation3D.
Definition: GeoPrimitivesHelpers.h:156
test_pyathena.parent
parent
Definition: test_pyathena.py:15
hist_file_dump.f
f
Definition: hist_file_dump.py:140
AthCommonDataStore< AthCommonMsg< AlgTool > >::m_detStore
StoreGateSvc_t m_detStore
Pointer to StoreGate (detector store by default)
Definition: AthCommonDataStore.h:393
xAOD::DiTauJet_v1::subjetEta
float subjetEta(unsigned int numSubjet) const
Definition: DiTauJet_v1.cxx:101
SG::VarHandleKeyArray::renounce
virtual void renounce()=0
SG::HandleClassifier::type
std::conditional< std::is_base_of< SG::VarHandleKeyArray, T >::value, VarHandleKeyArrayType, type2 >::type type
Definition: HandleClassifier.h:54
merge.output
output
Definition: merge.py:16
xAOD::DiTauJet_v1::nTracks
size_t nTracks() const
Definition: DiTauJet_v1.cxx:224
DiTauOnnxDiscriminantTool::m_ort_session
std::unique_ptr< Ort::Session > m_ort_session
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:91
merge_scale_histograms.doc
string doc
Definition: merge_scale_histograms.py:9
remainder
std::vector< std::string > remainder(const std::vector< std::string > &v1, const std::vector< std::string > &v2)
Definition: compareFlatTrees.cxx:44
DiTauOnnxDiscriminantTool::getTrackingInfo
StatusCode getTrackingInfo(const xAOD::DiTauJet &xDiTau, DitauTrackingInfo &trackingInfo) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:356
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
DiTauOnnxDiscriminantTool::create_mask
std::vector< float > create_mask(const std::vector< std::vector< float >> &track_features) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:97
AtlCoolConsole.tool
tool
Definition: AtlCoolConsole.py:452
DiTauOnnxDiscriminantTool::m_output_node_names
const std::vector< std::string > m_output_node_names
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:93
TauGNNUtils::Variables::Track::pt_log
bool pt_log(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:341
DiTauOnnxDiscriminantTool::mass_tracks
float mass_tracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:323
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:283
DiTauOnnxDiscriminantTool::R_max
float R_max(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:262
xAOD::score
@ score
Definition: TrackingPrimitives.h:514
DiTauOnnxDiscriminantTool::m_ort_env
std::unique_ptr< Ort::Env > m_ort_env
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:90
a
TList * a
Definition: liststreamerinfos.cxx:10
python.general.flattened
def flattened(l)
Definition: general.py:125
h
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
eFEXNTuple.delta_phi
def delta_phi(phi1, phi2)
Definition: eFEXNTuple.py:14
python.CaloCondTools.log
log
Definition: CaloCondTools.py:20
SG::VarHandleBase::vhKey
SG::VarHandleKey & vhKey()
Return a non-const reference to the HandleKey.
Definition: StoreGate/src/VarHandleBase.cxx:629
xAOD::DiTauJet_v1
Definition: DiTauJet_v1.h:31
xAOD::DiTauJet_v1::isoTrackLinks
const TrackParticleLinks_t & isoTrackLinks() const
Trk::jet_phi
@ jet_phi
Definition: JetVtxParamDefs.h:28
python.Bindings.keys
keys
Definition: Control/AthenaPython/python/Bindings.py:801
xAOD::DiTauJet_v1::subjetPt
float subjetPt(unsigned int numSubjet) const
Definition: DiTauJet_v1.cxx:91
xAOD::track
@ track
Definition: TrackingPrimitives.h:513
xAOD::TrackParticle_v1
Class describing a TrackParticle.
Definition: TrackParticle_v1.h:43
DiTauOnnxDiscriminantTool::ditau_pt
float ditau_pt(const xAOD::DiTauJet &xDiTau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:243
xAOD::DiTauJet_v1::trackLinks
const TrackParticleLinks_t & trackLinks() const
length
double length(const pvec &v)
Definition: FPGATrackSimLLPDoubletHoughTransformTool.cxx:26
DiTauOnnxDiscriminantTool::f_core
float f_core(const xAOD::DiTauJet &xDiTau, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:248
jobOptions.points
points
Definition: jobOptions.GenevaPy8_Zmumu.py:97
DiTauOnnxDiscriminantTool::m_maxTracks
Gaudi::Property< size_t > m_maxTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:88
xAOD::TrackParticle_v1::phi
virtual double phi() const override final
The azimuthal angle ( ) of the particle (has range to .)
fitman.k
k
Definition: fitman.py:528