ATLAS Offline Software
Loading...
Searching...
No Matches
TauHFVetoTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
4
6
7using namespace TauAnalysisTools;
8
9TauHFVetoTool::TauHFVetoTool(const std::string& name) : asg::AsgTool(name) {
10}
11
13 ATH_CHECK(m_onnxTool_bveto1p.retrieve());
14 ATH_CHECK(m_onnxTool_bveto3p.retrieve());
15 ATH_CHECK(m_onnxTool_cveto1p.retrieve());
16 ATH_CHECK(m_onnxTool_cveto3p.retrieve());
17 return StatusCode::SUCCESS;
18}
19
21 // loop through the jets and find the closest one to the tau
22 const xAOD::Jet* xClosestJet = nullptr;
23 double dMinDeltaR = 999.;
24 for (const xAOD::Jet* xJet : *vPFlowJets){
25 double dDeltaR = xTau->p4().DeltaR(xJet->p4());
26 if (dDeltaR < dMinDeltaR){
27 dMinDeltaR = dDeltaR;
28 xClosestJet = xJet;
29 }
30 }
31 return xClosestJet;
32}
33
34StatusCode TauHFVetoTool::applyHFvetoBDTs(const xAOD::TauJetContainer* Taus, const xAOD::JetContainer* PFlowJets) const {
35 // Define Decorators for the auxiliary variables
36 SG::Decorator<float> acc_bVetoScore("bVetoScore");
37 SG::Decorator<float> acc_cVetoScore("cVetoScore");
38 for (const xAOD::TauJet* xTau : *Taus){
39 const xAOD::Jet* xAuxJet = findClosestPFlowJet(xTau, PFlowJets);
40 auto input = assembleInputValues(xTau, xAuxJet);
41 int prongness = xTau->nTracksCharged();
42 ATH_CHECK(bVetoScore(prongness, input, acc_bVetoScore(*xTau)));
43 ATH_CHECK(cVetoScore(prongness, input, acc_cVetoScore(*xTau)));
44 }
45 return StatusCode::SUCCESS;
46}
47
48StatusCode TauHFVetoTool::inference(const ToolHandle<AthOnnx::IOnnxRuntimeInferenceTool> onnxTool, const std::vector<float>& inputValues, float& output) const {
49 auto inputData = inputValues;
50 std::vector<float> outputScores;
51 std::vector<Ort::Value> inputTensors;
52 std::vector<Ort::Value> outputTensors;
53
54 ATH_CHECK(onnxTool->addInput(inputTensors, inputData, 0, 1));
55 ATH_CHECK(onnxTool->addOutput(outputTensors, outputScores, 0, 1));
56 ATH_CHECK(onnxTool->inference(inputTensors, outputTensors));
57
58 output = outputScores[1];
59
60 return StatusCode::SUCCESS;
61}
62
63std::vector<float> TauHFVetoTool::assembleInputValues(const xAOD::TauJet* xTau, const xAOD::Jet* xAuxJet) const {
64 std::vector<float> inputValues;
65 // gather charged tau tracks
66 std::vector<const xAOD::TauTrack*> tracks = xTau->tracks(xAOD::TauJetParameters::classifiedCharged);
67 // sort tracks according to pT, descending
68 if (tracks.size() > 1)
69 std::sort(tracks.begin(), tracks.end(), [](const xAOD::TauTrack* a, const xAOD::TauTrack* b) {
70 return a->pt() > b->pt();
71 });
72 // gather necessary info related to the auxiliary jet
74 double dl1dv01_pb(0);
75 btag->pb("DL1dv01", dl1dv01_pb);
76 double dl1dv01_pc(0);
77 btag->pc("DL1dv01", dl1dv01_pc);
78 // cast the double to float
79 float dl1dv01_pb_f = static_cast<float>(dl1dv01_pb);
80 float dl1dv01_pc_f = static_cast<float>(dl1dv01_pc);
81 float AbsDEtaLeadTrk = tracks.size() > 0 ? std::abs(tracks[0]->eta() - xTau->eta()) : -999;
82 float AbsDPhiLeadTrk = tracks.size() > 0 ? std::abs(tracks[0]->p4().DeltaPhi(xTau->p4())) : -999;
83 // declare accessors for taus and auxiliary jets
84 SG::ConstAccessor <float> acc_jetRNNtrans("RNNJetScoreSigTrans");
85 SG::ConstAccessor <float> acc_eleRNNtrans("RNNEleScoreSigTrans_v1");
86 SG::ConstAccessor <float> acc_etOverPtLeadTrk("etOverPtLeadTrk");
87 SG::ConstAccessor <float> acc_dRmax("dRmax");
88 SG::ConstAccessor <float> acc_auxJetWidth("Width");
89 SG::ConstAccessor <std::vector<ElementLink<DataVector<xAOD::IParticle> > >> acc_GhostTrack("GhostTrack");
90 // assemble input values according to prongness
91 if (xTau->nTracksCharged() == 1) {
92 // assemble inputValues
93 inputValues = {
94 acc_jetRNNtrans(*xTau), // jetRNNtrans
95 acc_eleRNNtrans(*xTau), // eleRNNtrans
96 AbsDEtaLeadTrk, // AbsDEtaLeadTrk
97 AbsDPhiLeadTrk, // AbsDPhiLeadTrk
98 acc_etOverPtLeadTrk(*xTau), // etOverPtLeadTrk
99 acc_dRmax(*xTau), // dRmax
100 static_cast<float>(xAOD::TrackingHelpers::d0significance(tracks[0]->track())), // trk0d0sig
101 static_cast<float>(tracks[0]->track()->z0()), // trk0z0
102 dl1dv01_pb_f, // DL1dv01_pb
103 dl1dv01_pc_f, // DL1dv01_pc
104 acc_auxJetWidth(*xAuxJet), // auxJetWidth
105 static_cast<float>(xTau->p4().DeltaR(xAuxJet->p4())), // dRJet
106 static_cast<float>(xTau->pt() / xAuxJet->pt()), // ptRatio
107 static_cast<float>(acc_GhostTrack(*xAuxJet).size()) // jetNtrk
108 };
109 } else if (xTau->nTracksCharged() == 3) {
110 // assemble inputValues
111 inputValues = {
112 acc_jetRNNtrans(*xTau), // jetRNNtrans
113 acc_eleRNNtrans(*xTau), // eleRNNtrans
114 AbsDEtaLeadTrk, // AbsDEtaLeadTrk
115 AbsDPhiLeadTrk, // AbsDPhiLeadTrk
116 acc_etOverPtLeadTrk(*xTau), // etOverPtLeadTrk
117 acc_dRmax(*xTau), // dRmax
118 static_cast<float>(xAOD::TrackingHelpers::d0significance(tracks[0]->track())), // trk0d0sig
119 static_cast<float>(xAOD::TrackingHelpers::d0significance(tracks[1]->track())), // trk1d0sig
120 static_cast<float>(xAOD::TrackingHelpers::d0significance(tracks[2]->track())), // trk2d0sig
121 static_cast<float>(tracks[0]->track()->z0()), // trk0z0
122 static_cast<float>(tracks[1]->track()->z0()), // trk1z0
123 static_cast<float>(tracks[2]->track()->z0()), // trk2z0
124 dl1dv01_pb_f, // DL1dv01_pb
125 dl1dv01_pc_f, // DL1dv01_pc
126 acc_auxJetWidth(*xAuxJet), // auxJetWidth
127 static_cast<float>(xTau->p4().DeltaR(xAuxJet->p4())), // dRJet
128 static_cast<float>(xTau->pt() / xAuxJet->pt()), // ptRatio
129 static_cast<float>(acc_GhostTrack(*xAuxJet).size()) // jetNtrk
130 };
131 }
132 return inputValues;
133}
134
135StatusCode TauHFVetoTool::bVetoScore(const int& prongness, const std::vector<float>& input, float& output) const {
136 if (prongness == 1) {
137 return inference(m_onnxTool_bveto1p, input, output);
138 }
139 else if (prongness == 3) {
140 return inference(m_onnxTool_bveto3p, input, output);
141 }
142 else {
143 output = -999;
144 return StatusCode::SUCCESS;
145 }
146}
147
148StatusCode TauHFVetoTool::cVetoScore(const int& prongness, const std::vector<float>& input, float& output) const {
149 if (prongness == 1) {
150 return inference(m_onnxTool_cveto1p, input, output);
151 }
152 else if (prongness == 3) {
153 return inference(m_onnxTool_cveto3p, input, output);
154 }
155 else {
156 output = -999;
157 return StatusCode::SUCCESS;
158 }
159}
Scalar eta() const
pseudorapidity method
#define ATH_CHECK
Evaluate an expression and check for errors.
static Double_t a
Helper class to provide type-safe access to aux data.
Definition Decorator.h:59
virtual StatusCode applyHFvetoBDTs(const xAOD::TauJetContainer *Taus, const xAOD::JetContainer *PFlowJets) const override
virtual std::vector< float > assembleInputValues(const xAOD::TauJet *xTau, const xAOD::Jet *xAuxJet) const override
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_onnxTool_cveto3p
virtual StatusCode cVetoScore(const int &prongness, const std::vector< float > &input, float &output) const
virtual StatusCode inference(const ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_onnxTool, const std::vector< float > &inputValues, float &output) const
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_onnxTool_bveto1p
TauHFVetoTool(const std::string &name)
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_onnxTool_bveto3p
virtual const xAOD::Jet * findClosestPFlowJet(const xAOD::TauJet *xTau, const xAOD::JetContainer *vPFlowJets) const override
virtual StatusCode initialize() override
Dummy implementation of the initialisation function.
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_onnxTool_cveto1p
virtual StatusCode bVetoScore(const int &prongness, const std::vector< float > &input, float &output) const
AsgTool(const std::string &name)
Constructor specifying the tool instance's name.
Definition AsgTool.cxx:58
bool pc(const std::string &taggername, double &value) const
bool pb(const std::string &taggername, double &value) const
virtual FourMom_t p4() const
The full 4-momentum of the particle.
Definition Jet_v1.cxx:71
virtual double pt() const
The transverse momentum ( ) of the particle.
Definition Jet_v1.cxx:44
virtual FourMom_t p4() const
The full 4-momentum of the particle.
Definition TauJet_v3.cxx:96
virtual double pt() const
The transverse momentum ( ) of the particle.
size_t nTracksCharged() const
virtual double eta() const
The pseudorapidity ( ) of the particle.
std::vector< const TauTrack * > tracks(TauJetParameters::TauTrackFlag flag=TauJetParameters::TauTrackFlag::classifiedCharged) const
Get the v<const pointer> to a given tauTrack collection associated with this tau.
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
const BTagging * getBTagging(const SG::AuxElement &part)
Access the default xAOD::BTagging object associated to an object.
double d0significance(const xAOD::TrackParticle *tp, double d0_uncert_beam_spot_2)
Jet_v1 Jet
Definition of the current "jet version".
BTagging_v1 BTagging
Definition of the current "BTagging version".
Definition BTagging.h:17
TauTrack_v1 TauTrack
Definition of the current version.
Definition TauTrack.h:16
TauJet_v3 TauJet
Definition of the current "tau version".
JetContainer_v1 JetContainer
Definition of the current "jet container version".
TauJetContainer_v3 TauJetContainer
Definition of the current "taujet container version".