ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
TauHFVetoTool.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 
7 using namespace TauAnalysisTools;
8 
9 TauHFVetoTool::TauHFVetoTool(const std::string& name) : asg::AsgTool(name) {
10 }
11 
13  ATH_CHECK(m_onnxTool_bveto1p.retrieve());
14  ATH_CHECK(m_onnxTool_bveto3p.retrieve());
15  ATH_CHECK(m_onnxTool_cveto1p.retrieve());
16  ATH_CHECK(m_onnxTool_cveto3p.retrieve());
17  return StatusCode::SUCCESS;
18 }
19 
20 const xAOD::Jet* TauHFVetoTool::findClosestPFlowJet(const xAOD::TauJet* xTau, const xAOD::JetContainer* vPFlowJets) const {
21  // loop through the jets and find the closest one to the tau
22  const xAOD::Jet* xClosestJet = nullptr;
23  double dMinDeltaR = 999.;
24  for (const xAOD::Jet* xJet : *vPFlowJets){
25  double dDeltaR = xTau->p4().DeltaR(xJet->p4());
26  if (dDeltaR < dMinDeltaR){
27  dMinDeltaR = dDeltaR;
28  xClosestJet = xJet;
29  }
30  }
31  return xClosestJet;
32 }
33 
35  // Define Decorators for the auxiliary variables
36  SG::Decorator<float> acc_bVetoScore("bVetoScore");
37  SG::Decorator<float> acc_cVetoScore("cVetoScore");
38  for (const xAOD::TauJet* xTau : *Taus){
39  const xAOD::Jet* xAuxJet = findClosestPFlowJet(xTau, PFlowJets);
40  auto input = assembleInputValues(xTau, xAuxJet);
41  int prongness = xTau->nTracksCharged();
42  ATH_CHECK(bVetoScore(prongness, input, acc_bVetoScore(*xTau)));
43  ATH_CHECK(cVetoScore(prongness, input, acc_cVetoScore(*xTau)));
44  }
45  return StatusCode::SUCCESS;
46 }
47 
48 StatusCode TauHFVetoTool::inference(const ToolHandle<AthOnnx::IOnnxRuntimeInferenceTool> onnxTool, const std::vector<float>& inputValues, float& output) const {
49  auto inputData = inputValues;
50  std::vector<float> outputScores;
51  std::vector<Ort::Value> inputTensors;
52  std::vector<Ort::Value> outputTensors;
53 
54  ATH_CHECK(onnxTool->addInput(inputTensors, inputData, 0, 1));
55  ATH_CHECK(onnxTool->addOutput(outputTensors, outputScores, 0, 1));
56  ATH_CHECK(onnxTool->inference(inputTensors, outputTensors));
57 
58  output = outputScores[1];
59 
60  return StatusCode::SUCCESS;
61 }
62 
63 std::vector<float> TauHFVetoTool::assembleInputValues(const xAOD::TauJet* xTau, const xAOD::Jet* xAuxJet) const {
64  std::vector<float> inputValues;
65  // gather charged tau tracks
66  std::vector<const xAOD::TauTrack*> tracks = xTau->tracks(xAOD::TauJetParameters::classifiedCharged);
67  // sort tracks according to pT, descending
68  if (tracks.size() > 1)
69  std::sort(tracks.begin(), tracks.end(), [](const xAOD::TauTrack* a, const xAOD::TauTrack* b) {
70  return a->pt() > b->pt();
71  });
72  // gather necessary info related to the auxiliary jet
74  double dl1dv01_pb(0);
75  btag->pb("DL1dv01", dl1dv01_pb);
76  double dl1dv01_pc(0);
77  btag->pc("DL1dv01", dl1dv01_pc);
78  // cast the double to float
79  float dl1dv01_pb_f = static_cast<float>(dl1dv01_pb);
80  float dl1dv01_pc_f = static_cast<float>(dl1dv01_pc);
81  float AbsDEtaLeadTrk = tracks.size() > 0 ? std::abs(tracks[0]->eta() - xTau->eta()) : -999;
82  float AbsDPhiLeadTrk = tracks.size() > 0 ? std::abs(tracks[0]->p4().DeltaPhi(xTau->p4())) : -999;
83  // declare accessors for taus and auxiliary jets
84  SG::ConstAccessor <float> acc_jetRNNtrans("RNNJetScoreSigTrans");
85  SG::ConstAccessor <float> acc_eleRNNtrans("RNNEleScoreSigTrans_v1");
86  SG::ConstAccessor <float> acc_etOverPtLeadTrk("etOverPtLeadTrk");
87  SG::ConstAccessor <float> acc_dRmax("dRmax");
88  SG::ConstAccessor <float> acc_auxJetWidth("Width");
90  // assemble input values according to prongness
91  if (xTau->nTracksCharged() == 1) {
92  // assemble inputValues
93  inputValues = {
94  acc_jetRNNtrans(*xTau), // jetRNNtrans
95  acc_eleRNNtrans(*xTau), // eleRNNtrans
96  AbsDEtaLeadTrk, // AbsDEtaLeadTrk
97  AbsDPhiLeadTrk, // AbsDPhiLeadTrk
98  acc_etOverPtLeadTrk(*xTau), // etOverPtLeadTrk
99  acc_dRmax(*xTau), // dRmax
100  static_cast<float>(xAOD::TrackingHelpers::d0significance(tracks[0]->track())), // trk0d0sig
101  static_cast<float>(tracks[0]->track()->z0()), // trk0z0
102  dl1dv01_pb_f, // DL1dv01_pb
103  dl1dv01_pc_f, // DL1dv01_pc
104  acc_auxJetWidth(*xAuxJet), // auxJetWidth
105  static_cast<float>(xTau->p4().DeltaR(xAuxJet->p4())), // dRJet
106  static_cast<float>(xTau->pt() / xAuxJet->pt()), // ptRatio
107  static_cast<float>(acc_GhostTrack(*xAuxJet).size()) // jetNtrk
108  };
109  } else if (xTau->nTracksCharged() == 3) {
110  // assemble inputValues
111  inputValues = {
112  acc_jetRNNtrans(*xTau), // jetRNNtrans
113  acc_eleRNNtrans(*xTau), // eleRNNtrans
114  AbsDEtaLeadTrk, // AbsDEtaLeadTrk
115  AbsDPhiLeadTrk, // AbsDPhiLeadTrk
116  acc_etOverPtLeadTrk(*xTau), // etOverPtLeadTrk
117  acc_dRmax(*xTau), // dRmax
118  static_cast<float>(xAOD::TrackingHelpers::d0significance(tracks[0]->track())), // trk0d0sig
119  static_cast<float>(xAOD::TrackingHelpers::d0significance(tracks[1]->track())), // trk1d0sig
120  static_cast<float>(xAOD::TrackingHelpers::d0significance(tracks[2]->track())), // trk2d0sig
121  static_cast<float>(tracks[0]->track()->z0()), // trk0z0
122  static_cast<float>(tracks[1]->track()->z0()), // trk1z0
123  static_cast<float>(tracks[2]->track()->z0()), // trk2z0
124  dl1dv01_pb_f, // DL1dv01_pb
125  dl1dv01_pc_f, // DL1dv01_pc
126  acc_auxJetWidth(*xAuxJet), // auxJetWidth
127  static_cast<float>(xTau->p4().DeltaR(xAuxJet->p4())), // dRJet
128  static_cast<float>(xTau->pt() / xAuxJet->pt()), // ptRatio
129  static_cast<float>(acc_GhostTrack(*xAuxJet).size()) // jetNtrk
130  };
131  }
132  return inputValues;
133 }
134 
135 StatusCode TauHFVetoTool::bVetoScore(const int& prongness, const std::vector<float>& input, float& output) const {
136  if (prongness == 1) {
137  return inference(m_onnxTool_bveto1p, input, output);
138  }
139  else if (prongness == 3) {
140  return inference(m_onnxTool_bveto3p, input, output);
141  }
142  else {
143  output = -999;
144  return StatusCode::SUCCESS;
145  }
146 }
147 
148 StatusCode TauHFVetoTool::cVetoScore(const int& prongness, const std::vector<float>& input, float& output) const {
149  if (prongness == 1) {
150  return inference(m_onnxTool_cveto1p, input, output);
151  }
152  else if (prongness == 3) {
153  return inference(m_onnxTool_cveto3p, input, output);
154  }
155  else {
156  output = -999;
157  return StatusCode::SUCCESS;
158  }
159 }
TauAnalysisTools::TauHFVetoTool::findClosestPFlowJet
virtual const xAOD::Jet * findClosestPFlowJet(const xAOD::TauJet *xTau, const xAOD::JetContainer *vPFlowJets) const override
Definition: TauHFVetoTool.cxx:20
TauAnalysisTools
Definition: TruthCollectionMakerTau.h:16
TauAnalysisTools::TauHFVetoTool::initialize
virtual StatusCode initialize() override
Dummy implementation of the initialisation function.
Definition: TauHFVetoTool.cxx:12
TauAnalysisTools::TauHFVetoTool::cVetoScore
virtual StatusCode cVetoScore(const int &prongness, const std::vector< float > &input, float &output) const
Definition: TauHFVetoTool.cxx:148
xAOD::TauJet_v3::eta
virtual double eta() const
The pseudorapidity ( ) of the particle.
xAOD::TrackingHelpers::d0significance
double d0significance(const xAOD::TrackParticle *tp, double d0_uncert_beam_spot_2)
Definition: TrackParticlexAODHelpers.cxx:42
xAOD::TauJetParameters::classifiedCharged
@ classifiedCharged
Definition: TauDefs.h:406
asg
Definition: DataHandleTestTool.h:28
SG::ConstAccessor< float >
TauAnalysisTools::TauHFVetoTool::inference
virtual StatusCode inference(const ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_onnxTool, const std::vector< float > &inputValues, float &output) const
Definition: TauHFVetoTool.cxx:48
TauAnalysisTools::TauHFVetoTool::bVetoScore
virtual StatusCode bVetoScore(const int &prongness, const std::vector< float > &input, float &output) const
Definition: TauHFVetoTool.cxx:135
xAOD::TauJet_v3::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
TauAnalysisTools::TauHFVetoTool::m_onnxTool_cveto3p
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_onnxTool_cveto3p
Definition: TauHFVetoTool.h:43
xAOD::TauJet_v3::nTracksCharged
size_t nTracksCharged() const
Definition: TauJet_v3.cxx:532
xAOD::BTagging_v1::pc
bool pc(const std::string &taggername, double &value) const
Definition: BTagging_v1.cxx:367
SG::Decorator< float >
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
xAOD::TauJet_v3
Class describing a tau jet.
Definition: TauJet_v3.h:41
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
xAOD::BTagging_v1
Definition: BTagging_v1.h:39
DataVector
Derived DataVector<T>.
Definition: DataVector.h:794
TauAnalysisTools::TauHFVetoTool::assembleInputValues
virtual std::vector< float > assembleInputValues(const xAOD::TauJet *xTau, const xAOD::Jet *xAuxJet) const override
Definition: TauHFVetoTool.cxx:63
TauAnalysisTools::TauHFVetoTool::TauHFVetoTool
TauHFVetoTool(const std::string &name)
Definition: TauHFVetoTool.cxx:9
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:77
TauAnalysisTools::TauHFVetoTool::m_onnxTool_cveto1p
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_onnxTool_cveto1p
Definition: TauHFVetoTool.h:42
TauHFVetoTool.h
xAOD::BTaggingUtilities::getBTagging
const BTagging * getBTagging(const SG::AuxElement &part)
Access the default xAOD::BTagging object associated to an object.
Definition: BTaggingUtilities.cxx:37
xAOD::Jet_v1
Class describing a jet.
Definition: Jet_v1.h:57
xAOD::Jet_v1::p4
virtual FourMom_t p4() const
The full 4-momentum of the particle.
Definition: Jet_v1.cxx:71
a
TList * a
Definition: liststreamerinfos.cxx:10
TauAnalysisTools::TauHFVetoTool::m_onnxTool_bveto3p
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_onnxTool_bveto3p
Definition: TauHFVetoTool.h:41
xAOD::TauTrack_v1
Definition: TauTrack_v1.h:27
xAOD::TauJet_v3::p4
virtual FourMom_t p4() const
The full 4-momentum of the particle.
Definition: TauJet_v3.cxx:97
TauAnalysisTools::TauHFVetoTool::m_onnxTool_bveto1p
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_onnxTool_bveto1p
Definition: TauHFVetoTool.h:40
TauAnalysisTools::TauHFVetoTool::applyHFvetoBDTs
virtual StatusCode applyHFvetoBDTs(const xAOD::TauJetContainer *Taus, const xAOD::JetContainer *PFlowJets) const override
Definition: TauHFVetoTool.cxx:34
xAOD::Jet_v1::pt
virtual double pt() const
The transverse momentum ( ) of the particle.
Definition: Jet_v1.cxx:44
xAOD::TauJet_v3::tracks
std::vector< const TauTrack * > tracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
Get the v<const pointer> to a given tauTrack collection associated with this tau.
Definition: TauJet_v3.cxx:493
xAOD::BTagging_v1::pb
bool pb(const std::string &taggername, double &value) const
Definition: BTagging_v1.cxx:360