ATLAS Offline Software
src/DiTauOnnxDiscriminantTool.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 
7 // Core include(s):
8 #include "AthLinks/ElementLink.h"
9 
10 
11 // EDM include(s):
12 
13 
14 
15 
16 using TrackParticleLinks_t = std::vector<ElementLink<xAOD::TrackParticleContainer>>;
17 
18 //=================================PUBLIC-PART==================================
19 //______________________________________________________________________________
20 DiTauOnnxDiscriminantTool::DiTauOnnxDiscriminantTool( const std::string& type, const std::string& name, const IInterface * parent) :
22 {
23  declareInterface<DiTauToolBase > (this);
24 }
25 
26 //______________________________________________________________________________
28 
29 //______________________________________________________________________________
31 {
32  ATH_MSG_INFO( "Initializing DiTauOnnxDiscriminantTool" );
33  ATH_MSG_INFO( "onnxModelPath: " << m_onnxModelPath );
34 
35  auto model_path = PathResolverFindCalibFile (m_onnxModelPath);
36  if (model_path.empty()) {
37  ATH_MSG_ERROR("Could not find model file: " << m_onnxModelPath);
38  return StatusCode::FAILURE;
39  }
40  m_ort_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "OnnxUtil");
41  Ort::SessionOptions session_options;
42  session_options.SetIntraOpNumThreads(1);
43  session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
44  session_options.DisableCpuMemArena();
45  m_ort_session = std::make_unique<Ort::Session>(*m_ort_env, model_path.c_str(), session_options);
46  return StatusCode::SUCCESS;
47 }
48 
50 {
51  ATH_MSG_INFO( "Finalizing DiTauOnnxDiscriminantTool" );
52  m_ort_session.reset();
53  m_ort_env.reset();
54  return StatusCode::SUCCESS;
55 }
56 
58 {
59  static const SG::Accessor<float> omni_scoreDec("omni_score");
60  xAOD::DiTauJet* xDitau = data->xAODDiTau;
61  ATH_MSG_DEBUG("Inferencing omni DiTau ID score...");
62  float score = GetDiTauObjOnnxScore(*xDitau);
63  ATH_MSG_DEBUG("DiTau ID score: " << score);
64  omni_scoreDec(*xDitau) = score;
65  return StatusCode::SUCCESS;
66 }
67 
68 StatusCode DiTauOnnxDiscriminantTool::executeObj( xAOD::DiTauJet &xDiTau, const EventContext& /*ctx*/) const
69 {
70  static const SG::Accessor<float> omni_scoreDec("omni_score");
71  ATH_MSG_DEBUG("Inferencing omni DiTau ID score...");
72  float score = GetDiTauObjOnnxScore(xDiTau);
73  ATH_MSG_DEBUG("DiTau ID score: " << score);
74  omni_scoreDec(xDiTau) = score;
75  return StatusCode::SUCCESS;
76 }
77 
78 std::vector<float> DiTauOnnxDiscriminantTool::flatten(const std::vector<std::vector<float>> &vec_2d) const{
79  std::vector<float> flattened;
80  flattened.reserve(vec_2d.size() * (vec_2d.empty() ? 0 : vec_2d[0].size()));
81  for (const auto &inner : vec_2d) {
82  flattened.insert(flattened.end(), inner.begin(), inner.end());
83  }
84  return flattened;
85 }
86 
87 std::vector<float> DiTauOnnxDiscriminantTool::extract_points(const std::vector<std::vector<float>> &track_features) const{
88  std::vector<float> points;
89  points.reserve(track_features.size() * 2);
90  for (const auto &track : track_features) {
91  points.push_back(track[0]); // delta_eta
92  points.push_back(track[1]); // delta_phi
93  }
94  return points;
95 }
96 
97 std::vector<float> DiTauOnnxDiscriminantTool::create_mask(const std::vector<std::vector<float>> &track_features) const{
98  std::vector<float> mask;
99  mask.reserve(track_features.size());
100  std::transform(track_features.begin(), track_features.end(), std::back_inserter(mask), [](const auto &track) {
101  return std::abs(track[2]) > 1e-6 ? 1.0f : 0.0f;
102  });
103  return mask;
104 }
105 
106 Ort::Value DiTauOnnxDiscriminantTool::create_tensor(std::vector<float> &data, const std::vector<int64_t> &shape) const{
107  Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
108  return Ort::Value::CreateTensor<float>(memory_info, data.data(), data.size(),shape.data(), shape.size());
109 }
110 
112  std::vector<Ort::Value> input_tensors;
113  input_tensors.reserve(m_input_node_names.size());
114  input_tensors.emplace_back(create_tensor(inputs.input_features, inputs.input_features_shape));
115  input_tensors.emplace_back(create_tensor(inputs.input_points, inputs.input_points_shape));
116  input_tensors.emplace_back(create_tensor(inputs.input_mask, inputs.input_mask_shape));
117  input_tensors.emplace_back(create_tensor(inputs.input_jet, inputs.input_jet_shape));
118  input_tensors.emplace_back(create_tensor(inputs.input_time, inputs.input_time_shape));
119 
120  std::vector<const char *> input_node_names;
121  input_node_names.reserve(m_input_node_names.size());
122  std::transform(m_input_node_names.begin(), m_input_node_names.end(), std::back_inserter(input_node_names), [](const std::string &name) { return name.c_str(); });
123 
124  std::vector<const char *> output_node_names;
125  output_node_names.reserve(m_output_node_names.size());
126  std::transform(m_output_node_names.begin(), m_output_node_names.end(), std::back_inserter(output_node_names), [](const std::string &name) { return name.c_str(); });
127 
128  auto output_tensors = m_ort_session->Run(Ort::RunOptions{nullptr}, input_node_names.data(), input_tensors.data(), input_node_names.size(), output_node_names.data(), output_node_names.size());
129 
131  for (size_t i = 0; i < output_tensors.size(); ++i) {
132  const auto &tensor = output_tensors[i];
133  const size_t length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
134  const float *data = tensor.GetTensorData<float>();
135  (i == 0 ? output.output_1 : output.output_2) = std::vector<float>(data, data + length);
136  }
137  return output;
138 }
139 
141 
142  // do the calculation only for ditau with at least 2 subjets
143  if(n_subjets(ditau)<2){
144  return m_dDefault;
145  }
146 
147  DitauTrackingInfo ditauTrackingInfo;
148  if(!(getTrackingInfo(ditau, ditauTrackingInfo))){
149  return m_dDefault;
150  }
151 
152  // Accessors for reading the necessary features from the xAOD::TrackParticle object
153  static const SG::ConstAccessor< uint8_t > numberOfInrmstPxlLyrHitsAcc ("numberOfInnermostPixelLayerHits");
154  static const SG::ConstAccessor< uint8_t > numberOfPixelHitsAcc ("numberOfPixelHits");
155  static const SG::ConstAccessor< uint8_t > numberOfSCTHitsAcc ("numberOfSCTHits");
156  static const SG::ConstAccessor< float > z0Acc ("z0");
157  static const SG::ConstAccessor< float > d0Acc ("d0");
158  // Input features for Ditau tagger ONNX model
159  std::vector<float> jet_vars = {
160  R_max(ditau, ditauTrackingInfo, 0),
161  R_max(ditau, ditauTrackingInfo, 1),
162  R_tracks(ditau, ditauTrackingInfo, 1),
163  R_isotrack(ditau, ditauTrackingInfo),
164  d0_leadtrack(ditau, ditauTrackingInfo, 0),
165  d0_leadtrack(ditau, ditauTrackingInfo, 1),
166  f_core(ditau,0),
167  f_core(ditau,1),
168  f_subjet(ditau,1),
169  f_subjets(ditau),
170  f_isotracks(ditau, ditauTrackingInfo),
171  mass_core(ditau, ditauTrackingInfo, 0),
172  mass_core(ditau, ditauTrackingInfo, 1),
173  mass_tracks(ditau, ditauTrackingInfo, 0),
174  static_cast<float>( n_track(ditau)),
175  };
176  std::vector<int64_t> jet_shape = {1, static_cast<int64_t>(jet_vars.size())};
177 
178  const TrackParticleLinks_t &vTauTracks = ditau.trackLinks();
179  std::vector<std::vector<float>> track_features(m_maxTracks, std::vector<float>(11, 0.0f));
180 
181  float jet_eta = ditau.eta();
182  float jet_phi = ditau.phi();
183  size_t num_tracks = std::min(static_cast<size_t>(m_maxTracks), vTauTracks.size());
184 
185  for (size_t i = 0; i < num_tracks; ++i) {
186  const ElementLink<xAOD::TrackParticleContainer> &trackLink = vTauTracks[i];
187  if (!trackLink.isValid()) continue;
188  const xAOD::TrackParticle *xTrack = *trackLink;
189  float track_eta = xTrack->eta();
190  float track_phi = xTrack->phi();
191  float delta_eta = track_eta - jet_eta;
192  float delta_phi = std::remainder(track_phi - jet_phi, 2 * M_PI);
193  float delta_R = std::hypot(delta_eta, delta_phi);
194  float track_pt = static_cast<float>(xTrack->pt());
195  float pt_log = std::log(track_pt + 1e-8f);
196  float jet_pt = ditau_pt(ditau); //ditau_ptAcc(ditau);
197  float pt_ratio = track_pt / jet_pt;
198  float pt_ratio_log = (pt_ratio <= 1.0f) ? std::log(1.0f - pt_ratio + 1e-8f) : 0.0f;
199  float track_charge = xTrack->charge();
200 
201  track_features[i] = {
202  delta_eta,
203  delta_phi,
204  pt_log,
205  d0Acc(*xTrack),
206  pt_ratio_log,
207  z0Acc(*xTrack),
208  delta_R,
209  static_cast<float>(numberOfInrmstPxlLyrHitsAcc(*xTrack)),
210  static_cast<float>(numberOfPixelHitsAcc(*xTrack)),
211  static_cast<float>(numberOfSCTHitsAcc(*xTrack)),
212  track_charge
213  };
214  }
215  std::vector<int64_t> track_shape = {1, static_cast<int64_t>(m_maxTracks), 11};
216 
217  // Actual ONNX inference
219  flatten(track_features),
220  track_shape,
221  extract_points(track_features),
222  {1, track_shape[1], 2},
223  create_mask(track_features),
224  {1, track_shape[1]},
225  std::move(jet_vars),
226  std::move(jet_shape),
227  {0.0f},
228  {1, 1}
229  };
230  auto output = run_inference(inputs);
231  return output.output_1[1];
232 }
233 
234 // Aux variables calculation
236  int nSubjet = 0;
237  while (xDiTau.subjetPt(nSubjet) > 0. ){
238  nSubjet++;
239  }
240  return nSubjet;
241 }
242 
244 {
245  return xDiTau.subjetPt(0)+xDiTau.subjetPt(1);
246 }
247 
248 float DiTauOnnxDiscriminantTool::f_core(const xAOD::DiTauJet& xDiTau, int iSubjet) const
249 {
250  return xDiTau.fCore(iSubjet);
251 }
252 
253 float DiTauOnnxDiscriminantTool::f_subjet(const xAOD::DiTauJet& xDiTau, int iSubjet) const {
254  return xDiTau.subjetPt(iSubjet) / xDiTau.pt();
255 }
256 
258 {
259  return (xDiTau.subjetPt(0) + xDiTau.subjetPt(1))/ xDiTau.pt();
260 }
261 
262 float DiTauOnnxDiscriminantTool::R_max(const xAOD::DiTauJet&, const DitauTrackingInfo& ditauInfo, int iSubjet) const
263 {
264  const SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
265  float Rmax = 0;
266  for (const xAOD::TrackParticle* xTrack: subjetInfo.vTracks) {
267  if (subjetInfo.subjet_p4.DeltaR(xTrack->p4()) > Rmax) {
268  Rmax = subjetInfo.subjet_p4.DeltaR(xTrack->p4());
269  }
270  }
271  return Rmax;
272 }
273 
275  return xDiTau.nTracks();
276 }
277 
279 {
280  float R_sum = 0;
281  float pt = 0;
282  for (int i = 0; i < 2; i++) {
283  SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(i);
284  for (const xAOD::TrackParticle* xTrack: subjetInfo.vIsoTracks) {
285  R_sum += subjetInfo.subjet_p4.DeltaR(xTrack->p4()) * xTrack->pt();
286  pt += xTrack->pt();
287  }
288  }
289  if (pt == 0) {
290  return m_dDefault;
291  }
292  return R_sum / pt;
293 }
294 
295 float DiTauOnnxDiscriminantTool::R_tracks(const xAOD::DiTauJet&, const DitauTrackingInfo& ditauInfo, int iSubjet) const {
296  float R_sum = 0;
297  float pt = 0;
298 
299  SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
300  for (const xAOD::TrackParticle* xTrack: subjetInfo.vTracks) {
301  R_sum += subjetInfo.subjet_p4.DeltaR(xTrack->p4()) * xTrack->pt();
302  pt += xTrack->pt();
303  }
304  if (pt == 0) {
305  return m_dDefault;
306  }
307  return R_sum / pt;
308 }
309 
310 float DiTauOnnxDiscriminantTool::mass_core(const xAOD::DiTauJet&, const DitauTrackingInfo& ditauInfo, int iSubjet) const {
311  TLorentzVector allCoreTracks_p4;
312  SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
313  for (const xAOD::TrackParticle* xTrack: subjetInfo.vCoreTracks) {
314  allCoreTracks_p4 += xTrack->p4();
315  }
316  float mass = allCoreTracks_p4.M();
317  if (mass < 0) {
318  return m_dDefault;
319  }
320  return mass;
321 }
322 
323 float DiTauOnnxDiscriminantTool::mass_tracks(const xAOD::DiTauJet&, const DitauTrackingInfo& ditauInfo, int iSubjet) const {
324  TLorentzVector allTracks_p4;
325  SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
326  for (const xAOD::TrackParticle* xTrack: subjetInfo.vTracks) {
327  allTracks_p4 += xTrack->p4();
328  }
329  float mass = allTracks_p4.M();
330  if (mass < 0) {
331  return m_dDefault;
332  }
333  return mass;
334 }
335 
336 float DiTauOnnxDiscriminantTool::d0_leadtrack(const xAOD::DiTauJet&, const DitauTrackingInfo& ditauInfo, int iSubjet) const {
337  SubjetTrackingInfo subjetInfo = ditauInfo.vSubjetInfo.at(iSubjet);
338  if (!subjetInfo.leadTrack) {
339  return m_dDefault;
340  }
341  return subjetInfo.leadTrack->d0();
342 }
343 
344 float DiTauOnnxDiscriminantTool::f_isotracks(const xAOD::DiTauJet& xDiTau, const DitauTrackingInfo& ditauInfo) const {
345  float iso_pt = 0;
346  for (const xAOD::TrackParticle* xTrack: ditauInfo.vIsoTracks) {
347  iso_pt += xTrack->pt();
348  }
349  if( xDiTau.pt() == 0.){
350  return m_dDefault;
351  } else {
352  return iso_pt / xDiTau.pt();
353  }
354 }
355 
357  static const SG::ConstAccessor<std::vector<ElementLink<xAOD::TrackParticleContainer>>> trackLinksAcc("trackLinks");
358  static const SG::ConstAccessor<std::vector<ElementLink<xAOD::TrackParticleContainer>>> isoTrackLinksAcc("isoTrackLinks");
359  static const SG::ConstAccessor<float> R_subjetAcc("R_subjet");
360  static const SG::ConstAccessor<float> R_coreAcc("R_core");
361 
362 
363  if (!trackLinksAcc.isAvailable(xDiTau) || !isoTrackLinksAcc.isAvailable(xDiTau)) {
364  ATH_MSG_WARNING("Track " << (!trackLinksAcc.isAvailable(xDiTau) ? "DiTauJet.trackLinks" : "DiTauJet.isoTrackLinks") << " links not available.");
365  return StatusCode::FAILURE;
366  }
367 
368  int nSubjets = n_subjets(xDiTau);
369  float Rsubjet = R_subjetAcc(xDiTau);
370  float RCore = R_coreAcc(xDiTau);
371 
372  trackingInfo.nSubjets = nSubjets;
373  trackingInfo.vSubjetInfo.clear();
374  trackingInfo.vIsoTracks.clear();
375  trackingInfo.vTracks.clear();
376 
377  // Get the track links from the DiTauJet and store them in the tracking info
378  std::vector<ElementLink<xAOD::TrackParticleContainer>> isoTrackLinks = xDiTau.isoTrackLinks();
379  for (const auto &trackLink: isoTrackLinks) {
380  if (!trackLink.isValid()) {
381  ATH_MSG_WARNING("Iso track link is not valid");
382  continue;
383  }
384  const xAOD::TrackParticle* xTrack = *trackLink;
385  trackingInfo.vIsoTracks.push_back(xTrack);
386  }
387  std::vector<ElementLink<xAOD::TrackParticleContainer>> trackLinks = xDiTau.trackLinks();
388  for (const auto &trackLink : trackLinks) {
389  if (!trackLink.isValid()) {
390  ATH_MSG_WARNING("track link is not valid");
391  continue;
392  }
393  const xAOD::TrackParticle* xTrack = *trackLink;
394  trackingInfo.vTracks.push_back(xTrack);
395  }
396  // store subjet p4
397  for (int i=0; i<nSubjets; ++i){
398  SubjetTrackingInfo subjetTrackingInfo;
399  TLorentzVector subjet_p4 = TLorentzVector();
400  subjet_p4.SetPtEtaPhiE( xDiTau.subjetPt(i), xDiTau.subjetEta(i), xDiTau.subjetPhi(i), xDiTau.subjetE(i));
401  subjetTrackingInfo.subjet_p4 = subjet_p4;
402  trackingInfo.vSubjetInfo.push_back(subjetTrackingInfo);
403  }
404  for (const auto track : trackingInfo.vTracks) {
405  float dRMin = 999;
406  int inSubjet = -1;
407  for (int i=0; i<nSubjets; ++i){
408  float dRTrackSubjet = trackingInfo.vSubjetInfo[i].subjet_p4.DeltaR(track->p4());
409  if (dRTrackSubjet < Rsubjet && dRTrackSubjet < dRMin){
410  dRMin = dRTrackSubjet;
411  inSubjet = i;
412  }
413  }
414  if (inSubjet >= 0){
415  trackingInfo.vSubjetInfo[inSubjet].vTracks.push_back(track);
416  }
417  }
418  // find leading track in subjets
419  for (int i=0; i<nSubjets; ++i){
420  float ptLeadTrack = 0;
421  for (const auto track : trackingInfo.vSubjetInfo[i].vTracks){
422  if (track->pt() > ptLeadTrack){
423  ptLeadTrack = track->pt();
424  trackingInfo.vSubjetInfo[i].leadTrack = track;
425  }
426  }
427  }
428  // find core track in subjets
429  for (int i=0; i<nSubjets; ++i){
430  for (const auto track : trackingInfo.vSubjetInfo[i].vTracks){
431  auto subjetTrackingInfo = trackingInfo.vSubjetInfo[i];
432  if (subjetTrackingInfo.subjet_p4.DeltaR(track->p4()) < RCore){
433  trackingInfo.vSubjetInfo[i].vCoreTracks.push_back(track);
434  }
435  }
436  }
437  //find isotracks in subjets
438  for (const auto track : trackingInfo.vIsoTracks){
439  float RIso = 0.4;
440  float dRMin = 999;
441  int inSubjet = -1;
442  for (int i=0; i<nSubjets; ++i){
443  float dRTrackSubjet = trackingInfo.vSubjetInfo[i].subjet_p4.DeltaR(track->p4());
444  if (dRTrackSubjet > Rsubjet && dRTrackSubjet < RIso && dRTrackSubjet < dRMin){
445  dRMin = dRTrackSubjet;
446  inSubjet = i;
447  }
448  }
449  if (inSubjet >= 0){
450  trackingInfo.vSubjetInfo[inSubjet].vIsoTracks.push_back(track);
451  }
452  }
453  return StatusCode::SUCCESS;
454 }
455 
456 
457 
xAOD::TrackParticle_v1::pt
virtual double pt() const override final
The transverse momentum ( ) of the particle.
Definition: TrackParticle_v1.cxx:74
AllowedVariables::e
e
Definition: AsgElectronSelectorTool.cxx:37
DiTauOnnxDiscriminantTool::SubjetTrackingInfo::vTracks
std::vector< const xAOD::TrackParticle * > vTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:59
xAOD::DiTauJet_v1::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
DiTauOnnxDiscriminantTool::InferenceOutput
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:95
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
DiTauOnnxDiscriminantTool::f_isotracks
float f_isotracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:344
DiTauOnnxDiscriminantTool::m_onnxModelPath
Gaudi::Property< std::string > m_onnxModelPath
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:87
DiTauOnnxDiscriminantTool::extract_points
std::vector< float > extract_points(const std::vector< std::vector< float >> &track_features) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:87
xAOD::DiTauJet_v1::fCore
float fCore(unsigned int numSubjet) const
Definition: DiTauJet_v1.cxx:167
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
DiTauOnnxDiscriminantTool::executeObj
virtual StatusCode executeObj(xAOD::DiTauJet &xDiTau, const EventContext &ctx) const override
Execute - called for each Ditau jet.
Definition: src/DiTauOnnxDiscriminantTool.cxx:68
DiTauOnnxDiscriminantTool::f_subjet
float f_subjet(const xAOD::DiTauJet &xDiTau, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:253
Base_Fragment.mass
mass
Definition: Sherpa_i/share/common/Base_Fragment.py:59
SG::Accessor< float >
xAOD::TrackParticle_v1::charge
float charge() const
Returns the charge.
Definition: TrackParticle_v1.cxx:151
DiTauOnnxDiscriminantTool::run_inference
InferenceOutput run_inference(OnnxInputs &inputs) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:111
DiTauOnnxDiscriminantTool::n_subjets
int n_subjets(const xAOD::DiTauJet &xDiTau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:235
xAOD::TrackParticle_v1::eta
virtual double eta() const override final
The pseudorapidity ( ) of the particle.
Definition: TrackParticle_v1.cxx:78
eFEXNTuple.delta_R
def delta_R(eta1, phi1, eta2, phi2)
Definition: eFEXNTuple.py:20
min
constexpr double min()
Definition: ap_fixedTest.cxx:26
DiTauOnnxDiscriminantTool::DitauTrackingInfo::nSubjets
int nSubjets
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:67
DiTauOnnxDiscriminantTool::SubjetTrackingInfo::vCoreTracks
std::vector< const xAOD::TrackParticle * > vCoreTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:61
DiTauOnnxDiscriminantTool::R_isotrack
float R_isotrack(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:278
xAOD::DiTauJet_v1::subjetPhi
float subjetPhi(unsigned int numSubjet) const
Definition: DiTauJet_v1.cxx:111
DiTauOnnxDiscriminantTool::m_input_node_names
const std::vector< std::string > m_input_node_names
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:92
DiTauOnnxDiscriminantTool::d0_leadtrack
float d0_leadtrack(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:336
test_pyathena.pt
pt
Definition: test_pyathena.py:11
M_PI
#define M_PI
Definition: ActiveFraction.h:11
DiTauOnnxDiscriminantTool::m_dDefault
float m_dDefault
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:55
TrackParticleLinks_t
std::vector< ElementLink< xAOD::TrackParticleContainer > > TrackParticleLinks_t
Definition: src/DiTauOnnxDiscriminantTool.cxx:16
SG::ConstAccessor< uint8_t >
xAOD::DiTauJet_v1::eta
virtual double eta() const
The pseudorapidity ( ) of the particle.
DiTauOnnxDiscriminantTool::create_tensor
Ort::Value create_tensor(std::vector< float > &data, const std::vector< int64_t > &shape) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:106
DiTauOnnxDiscriminantTool::flatten
std::vector< float > flatten(const std::vector< std::vector< float >> &vec_2d) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:78
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
python.CaloAddPedShiftConfig.type
type
Definition: CaloAddPedShiftConfig.py:42
xAOD::TrackParticle_v1::d0
float d0() const
Returns the parameter.
DiTauOnnxDiscriminantTool::DitauTrackingInfo::vSubjetInfo
std::vector< SubjetTrackingInfo > vSubjetInfo
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:68
python.utils.AtlRunQueryLookup.mask
string mask
Definition: AtlRunQueryLookup.py:459
DiTauOnnxDiscriminantTool::~DiTauOnnxDiscriminantTool
virtual ~DiTauOnnxDiscriminantTool()
DiTauOnnxDiscriminantTool::f_subjets
float f_subjets(const xAOD::DiTauJet &xDiTau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:257
xAOD::DiTauJet_v1::subjetE
float subjetE(unsigned int numSubjet) const
Definition: DiTauJet_v1.cxx:121
DiTauOnnxDiscriminantTool::n_track
int n_track(const xAOD::DiTauJet &xDiTau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:274
xAOD::TrackParticle_v1::p4
virtual FourMom_t p4() const override final
The full 4-momentum of the particle.
Definition: TrackParticle_v1.cxx:130
DiTauOnnxDiscriminantTool.h
DiTauOnnxDiscriminantTool::mass_core
float mass_core(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:310
xAOD::DiTauJet_v1::phi
virtual double phi() const
The azimuthal angle ( ) of the particle.
DiTauOnnxDiscriminantTool::R_tracks
float R_tracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:295
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
DiTauOnnxDiscriminantTool::GetDiTauObjOnnxScore
float GetDiTauObjOnnxScore(const xAOD::DiTauJet &ditau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:140
lumiFormat.i
int i
Definition: lumiFormat.py:85
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
Amg::transform
Amg::Vector3D transform(Amg::Vector3D &v, Amg::Transform3D &tr)
Transform a point from a Trasformation3D.
Definition: GeoPrimitivesHelpers.h:156
test_pyathena.parent
parent
Definition: test_pyathena.py:15
hist_file_dump.f
f
Definition: hist_file_dump.py:140
DiTauOnnxDiscriminantTool::finalize
virtual StatusCode finalize() override
Finalizer.
Definition: src/DiTauOnnxDiscriminantTool.cxx:49
DiTauOnnxDiscriminantTool::DitauTrackingInfo::vIsoTracks
std::vector< const xAOD::TrackParticle * > vIsoTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:66
xAOD::DiTauJet_v1::subjetEta
float subjetEta(unsigned int numSubjet) const
Definition: DiTauJet_v1.cxx:101
merge.output
output
Definition: merge.py:16
xAOD::DiTauJet_v1::nTracks
size_t nTracks() const
Definition: DiTauJet_v1.cxx:224
DiTauOnnxDiscriminantTool::OnnxInputs
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:100
DiTauOnnxDiscriminantTool::m_ort_session
std::unique_ptr< Ort::Session > m_ort_session
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:91
remainder
std::vector< std::string > remainder(const std::vector< std::string > &v1, const std::vector< std::string > &v2)
Definition: compareFlatTrees.cxx:44
DiTauOnnxDiscriminantTool::getTrackingInfo
StatusCode getTrackingInfo(const xAOD::DiTauJet &xDiTau, DitauTrackingInfo &trackingInfo) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:356
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
DiTauOnnxDiscriminantTool::SubjetTrackingInfo::vIsoTracks
std::vector< const xAOD::TrackParticle * > vIsoTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:60
DiTauOnnxDiscriminantTool::DiTauOnnxDiscriminantTool
DiTauOnnxDiscriminantTool(const std::string &type, const std::string &name, const IInterface *parent)
Definition: src/DiTauOnnxDiscriminantTool.cxx:20
DiTauOnnxDiscriminantTool::create_mask
std::vector< float > create_mask(const std::vector< std::vector< float >> &track_features) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:97
DiTauToolBase
The base class for all tau tools.
Definition: DiTauToolBase.h:21
DiTauOnnxDiscriminantTool::m_output_node_names
const std::vector< std::string > m_output_node_names
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:93
TauGNNUtils::Variables::Track::pt_log
bool pt_log(const xAOD::TauJet &, const xAOD::TauTrack &track, float &out)
Definition: TauGNNUtils.cxx:341
DiTauOnnxDiscriminantTool::mass_tracks
float mass_tracks(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:323
DiTauOnnxDiscriminantTool::execute
virtual StatusCode execute(DiTauCandidateData *data, const EventContext &ctx) const override
Execute - called for each Ditau candidate.
Definition: src/DiTauOnnxDiscriminantTool.cxx:57
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:283
DiTauOnnxDiscriminantTool::R_max
float R_max(const xAOD::DiTauJet &xDiTau, const DitauTrackingInfo &ditauInfo, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:262
xAOD::score
@ score
Definition: TrackingPrimitives.h:514
DiTauOnnxDiscriminantTool::m_ort_env
std::unique_ptr< Ort::Env > m_ort_env
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:90
python.general.flattened
def flattened(l)
Definition: general.py:125
DiTauOnnxDiscriminantTool::SubjetTrackingInfo::subjet_p4
TLorentzVector subjet_p4
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:58
ATH_MSG_WARNING
#define ATH_MSG_WARNING(x)
Definition: AthMsgStreamMacros.h:32
DiTauOnnxDiscriminantTool::SubjetTrackingInfo::leadTrack
const xAOD::TrackParticle * leadTrack
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:62
eFEXNTuple.delta_phi
def delta_phi(phi1, phi2)
Definition: eFEXNTuple.py:14
DiTauOnnxDiscriminantTool::initialize
virtual StatusCode initialize() override
Tool initializer.
Definition: src/DiTauOnnxDiscriminantTool.cxx:30
DiTauOnnxDiscriminantTool::SubjetTrackingInfo
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:57
python.CaloCondTools.log
log
Definition: CaloCondTools.py:20
xAOD::DiTauJet_v1
Definition: DiTauJet_v1.h:31
xAOD::DiTauJet_v1::isoTrackLinks
const TrackParticleLinks_t & isoTrackLinks() const
SG::ConstAccessor::isAvailable
bool isAvailable(const ELT &e) const
Test to see if this variable exists in the store.
Trk::jet_phi
@ jet_phi
Definition: JetVtxParamDefs.h:28
DiTauCandidateData
Definition: DiTauCandidateData.h:15
xAOD::DiTauJet_v1::subjetPt
float subjetPt(unsigned int numSubjet) const
Definition: DiTauJet_v1.cxx:91
DiTauOnnxDiscriminantTool::DitauTrackingInfo
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:64
xAOD::track
@ track
Definition: TrackingPrimitives.h:513
xAOD::TrackParticle_v1
Class describing a TrackParticle.
Definition: TrackParticle_v1.h:43
DiTauOnnxDiscriminantTool::ditau_pt
float ditau_pt(const xAOD::DiTauJet &xDiTau) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:243
xAOD::DiTauJet_v1::trackLinks
const TrackParticleLinks_t & trackLinks() const
length
double length(const pvec &v)
Definition: FPGATrackSimLLPDoubletHoughTransformTool.cxx:26
DiTauOnnxDiscriminantTool::f_core
float f_core(const xAOD::DiTauJet &xDiTau, int iSubjet) const
Definition: src/DiTauOnnxDiscriminantTool.cxx:248
jobOptions.points
points
Definition: jobOptions.GenevaPy8_Zmumu.py:97
DiTauOnnxDiscriminantTool::m_maxTracks
Gaudi::Property< size_t > m_maxTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:88
xAOD::TrackParticle_v1::phi
virtual double phi() const override final
The azimuthal angle ( ) of the particle (has range to .)
DiTauOnnxDiscriminantTool::DitauTrackingInfo::vTracks
std::vector< const xAOD::TrackParticle * > vTracks
Definition: DiTauRec/DiTauOnnxDiscriminantTool.h:65