ATLAS Offline Software
Loading...
Searching...
No Matches
TrigTauPrecisionIDHypoTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
8#include "GaudiKernel/SystemOfUnits.h"
9
11
12
13using namespace TrigCompositeUtils;
14
15TrigTauPrecisionIDHypoTool::TrigTauPrecisionIDHypoTool(const std::string& type, const std::string& name, const IInterface* parent)
16 : base_class(type, name, parent),
17 m_decisionId(HLT::Identifier::fromToolName(name))
18{
19
20}
21
22
24{
25 ATH_MSG_DEBUG(name() << ": in initialize()");
26
27 ATH_MSG_DEBUG("TrigTauPrecisionIDHypoTool will cut on:");
28 ATH_MSG_DEBUG(" - PtMin: " << m_ptMin.value());
29 ATH_MSG_DEBUG(" - NTracksMin: " << m_numTrackMin.value());
30 ATH_MSG_DEBUG(" - NTracksMax: " << m_numTrackMax.value());
31 ATH_MSG_DEBUG(" - NIsoTracksMax: " << m_numIsoTrackMax.value());
32 if(m_trackPtCut >= 0) ATH_MSG_DEBUG(" - trackPtCut: " << m_trackPtCut.value());
33 ATH_MSG_DEBUG(" - IDMethod: " << m_idMethod.value());
34 ATH_MSG_DEBUG(" - IDWP: " << m_idWP.value());
35 ATH_MSG_DEBUG(" - HighPtSelectionTrkThr: " << m_highPtTrkThr.value());
36 ATH_MSG_DEBUG(" - HighPtSelectionIDThr: " << m_highPtIdThr.value());
37 ATH_MSG_DEBUG(" - HighPtIDWP: " << m_highPtIdWP.value());
38 ATH_MSG_DEBUG(" - HighPtSelectionJetThr: " << m_highPtJetThr.value());
39
41 ATH_MSG_ERROR("Invalid tool configuration!");
42 return StatusCode::FAILURE;
43 }
44
46 ATH_MSG_ERROR("Invalid IDMethod value, " << m_idMethod.value());
47 return StatusCode::FAILURE;
48 } else if(m_idMethod == IDMethod::Disabled && (!m_idWP.empty() || !m_highPtIdWP.empty())) {
49 ATH_MSG_ERROR("Must not set IDWP or HighPtIDWP if using IDMethod=0");
50 return StatusCode::FAILURE;
51 } else if(m_idMethod != IDMethod::Disabled) {
52 if(m_idWP.empty()) {
53 ATH_MSG_ERROR("Must provide an ID WP");
54 return StatusCode::FAILURE;
55 }
56
57 // Fallback to default ID WP
58 if(m_highPtIdWP.empty()) m_highPtIdWP = m_idWP;
59
60 // Initialize WP accessors
63 }
64
65 // Parse RNN ID WPs
71 else {
72 ATH_MSG_ERROR("Invalid RNN ID WP: " << m_idWP.value());
73 return StatusCode::FAILURE;
74 }
75
80 else {
81 ATH_MSG_ERROR("Invalid High-pT RNN ID WP: " << m_highPtIdWP.value());
82 return StatusCode::FAILURE;
83 }
84 }
85
86 // Now create the "cache" of TauID score accessors for the Monitoring...
87 ATH_MSG_DEBUG("TauID score monitoring: ");
88 for(const auto& [key, p] : m_monitoredIdScores) {
89 ATH_MSG_DEBUG(" - IDName: " << key);
90 ATH_MSG_DEBUG(" - IDScoreName: " << p.first);
91 ATH_MSG_DEBUG(" - IDScoreSigTransName: " << p.second);
92 if(p.first.empty() || p.second.empty()) {
93 ATH_MSG_ERROR("Invalid score variable names; skipping this entry for the monitoring!");
94 continue;
95 }
96
97 m_monitoredIdAccessors.emplace(key, std::make_pair(SG::ConstAccessor<float>(p.first), SG::ConstAccessor<float>(p.second)));
98 }
99
100 return StatusCode::SUCCESS;
101}
102
103
105{
106 ATH_MSG_DEBUG(name() << ": in execute()");
107
108 auto NInputTaus = Monitored::Scalar<int>("NInputTaus", -1);
109 auto passedCuts = Monitored::Scalar<int>("CutCounter", 0);
110 auto PtAccepted = Monitored::Scalar<float>("PtAccepted", -1);
111 auto NTracksAccepted = Monitored::Scalar<int>("NTracksAccepted", -1);
112 auto NIsoTracksAccepted = Monitored::Scalar<int>("NIsoTracksAccepted", -1);
113
114 std::map<std::string, Monitored::Scalar<float>> monitoredIdVariables;
115 for(const auto& [key, p] : m_monitoredIdScores) {
116 monitoredIdVariables.emplace(key + "_TauJetScoreAccepted_0p", Monitored::Scalar<float>(key + "_TauJetScoreAccepted_0p", -1));
117 monitoredIdVariables.emplace(key + "_TauJetScoreTransAccepted_0p", Monitored::Scalar<float>(key + "_TauJetScoreTransAccepted_0p", -1));
118 monitoredIdVariables.emplace(key + "_TauJetScoreAccepted_1p", Monitored::Scalar<float>(key + "_TauJetScoreAccepted_1p", -1));
119 monitoredIdVariables.emplace(key + "_TauJetScoreTransAccepted_1p", Monitored::Scalar<float>(key + "_TauJetScoreTransAccepted_1p", -1));
120 monitoredIdVariables.emplace(key + "_TauJetScoreAccepted_mp", Monitored::Scalar<float>(key + "_TauJetScoreAccepted_mp", -1));
121 monitoredIdVariables.emplace(key + "_TauJetScoreTransAccepted_mp", Monitored::Scalar<float>(key + "_TauJetScoreTransAccepted_mp", -1));
122 }
123
124 std::vector<std::reference_wrapper<Monitored::IMonitoredVariable>> monVars = {
125 std::ref(NInputTaus), std::ref(passedCuts), std::ref(PtAccepted), std::ref(NTracksAccepted), std::ref(NIsoTracksAccepted)
126 };
127 for(auto& [key, var] : monitoredIdVariables) monVars.push_back(std::ref(var));
128 auto monitorIt = Monitored::Group(m_monTool, monVars);
129
130
131 // Tau pass flag
132 bool pass = false;
133
134 if(m_acceptAll) {
135 pass = true;
136 ATH_MSG_DEBUG("AcceptAll property is set: taking all events");
137 }
138
139 // Debugging location of the TauJet RoI
140 ATH_MSG_DEBUG("Input RoI eta: " << input.roi->eta() << ", phi: " << input.roi->phi() << ", z: " << input.roi->zed());
141
142 const xAOD::TauJetContainer* TauContainer = input.tauContainer;
143 NInputTaus = TauContainer->size();
144 // There should only be a single TauJet in the TauJetContainer; just in case we still run the loop
145 for(const xAOD::TauJet* Tau : *TauContainer) {
146 ATH_MSG_DEBUG(" New HLT TauJet candidate:");
147
148 float pT = Tau->pt();
149
150 //---------------------------------------------------------
151 // Calibrated tau pT cut ('idperf' step)
152 //---------------------------------------------------------
153 ATH_MSG_DEBUG(" pT: " << pT / Gaudi::Units::GeV);
154
155 if(!(pT > m_ptMin)) continue;
156 passedCuts++;
157 PtAccepted = pT / Gaudi::Units::GeV;
158
159
160 //---------------------------------------------------------
161 // Track counting ('perf' step)
162 //---------------------------------------------------------
163 int numTrack = 0, numIsoTrack = 0;
164 if(m_trackPtCut > 0) {
165 // Raise the track pT threshold when counting tracks in the 'perf' step, to reduce sensitivity to pileup tracks
166 // Overrides the default 1 GeV cut by the InDetTrackSelectorTool used during the TauJet construction
167 for(const auto* track : Tau->tracks(xAOD::TauJetParameters::TauTrackFlag::classifiedCharged)) {
168 if(track->pt() > m_trackPtCut) numTrack++;
169 }
170 for(const auto* track : Tau->tracks(xAOD::TauJetParameters::TauTrackFlag::classifiedIsolation)) {
171 if(track->pt() > m_trackPtCut) numIsoTrack++;
172 }
173 } else {
174 // Use the default 1 GeV selection in the InDetTrackSelectorTool, executed during the TauJet construction
175 numTrack = Tau->nTracks();
176 numIsoTrack = Tau->nTracksIsolation();
177 }
178
179 ATH_MSG_DEBUG(" N Tracks: " << numTrack);
180 ATH_MSG_DEBUG(" N Iso Tracks: " << numIsoTrack);
181
182 // Apply track multiplicity cuts, except for idperf
183 if(!m_acceptAll) {
184 // NTrackMin and NIsoTracksMax
185 if(pT < m_highPtTrkThr) {
186 if(numTrack < m_numTrackMin) continue;
187 if(numIsoTrack > m_numIsoTrackMax) continue;
188 }
189 // NTrackMax
190 if(pT < m_highPtJetThr) {
191 if(numTrack > m_numTrackMax) continue;
192 }
193 }
194 // Note: we disabled the track selection for high pT taus
195
196 passedCuts++;
197 NTracksAccepted = numTrack;
198 NIsoTracksAccepted = numIsoTrack;
199
200
201 //---------------------------------------------------------
202 // ID WP selection (ID step)
203 //---------------------------------------------------------
204 int local_id_wp = IDWP::Standard;
205 if(pT > m_highPtIdThr) local_id_wp = IDWP::HighPt; // Set ID to HighPt WP
206 if(pT > m_highPtJetThr) local_id_wp = IDWP::None; // Disable the ID WP cut
207
208 if(!m_acceptAll && m_idMethod != IDMethod::Disabled && local_id_wp != IDWP::None) {
209 if(m_idMethod == IDMethod::Decorator) { // Decorated scores (e.g. for GNTau)
210 if(local_id_wp == IDWP::Standard) {
211 if(!m_id_wp_acc.isAvailable(*Tau)) ATH_MSG_ERROR("The TauID '" << m_idWP << "' variable is not available!");
212 if(!m_id_wp_acc(*Tau)) continue;
213 } else if(local_id_wp == IDWP::HighPt) {
214 if(!m_highpt_id_wp_acc.isAvailable(*Tau)) ATH_MSG_ERROR("The HighPt TauID '" << m_highPtIdWP << "' variable is not available!");
215 if(!m_highpt_id_wp_acc(*Tau)) continue;
216 }
217
218 } else if(m_idMethod == IDMethod::RNN) { // Legacy RNN/DeepSet scores
219 if(!Tau->hasDiscriminant(xAOD::TauJetParameters::RNNJetScoreSigTrans)) {
220 ATH_MSG_ERROR(" RNNJetScoreSigTrans not available. Make sure the TauWPDecorator is run for the RNN Tau ID!");
221 }
222
223 if(local_id_wp == IDWP::Standard && !Tau->isTau(static_cast<xAOD::TauJetParameters::IsTauFlag>(m_rnn_id_wp))) continue;
224 else if(local_id_wp == IDWP::HighPt && !Tau->isTau(static_cast<xAOD::TauJetParameters::IsTauFlag>(m_rnn_highpt_id_wp))) continue;
225 }
226 }
227
228 // TauID Score monitoring
229 for(const auto& [key, p] : m_monitoredIdAccessors) {
230 if(!p.first.isAvailable(*Tau))
231 ATH_MSG_WARNING("TauID Score " << m_monitoredIdScores.value().at(key).first << " is not available. Make sure the correct inferences are included in the chain reconstruction sequence!");
232
233 if(!p.second.isAvailable(*Tau))
234 ATH_MSG_WARNING("TauID ScoreSigTrans " << m_monitoredIdScores.value().at(key).second << " is not available. Make sure the correct inferences are included in the chain reconstruction sequence!");
235
236 ATH_MSG_DEBUG(" TauID \"" << key << "\" ScoreSigTrans: " << p.second(*Tau));
237
238 // Monitor ID scores
239 if(Tau->nTracks() == 0) {
240 monitoredIdVariables.at(key + "_TauJetScoreAccepted_0p") = p.first(*Tau);
241 monitoredIdVariables.at(key + "_TauJetScoreTransAccepted_0p") = p.second(*Tau);
242 } else if(Tau->nTracks() == 1) {
243 monitoredIdVariables.at(key + "_TauJetScoreAccepted_1p") = p.first(*Tau);
244 monitoredIdVariables.at(key + "_TauJetScoreTransAccepted_1p") = p.second(*Tau);
245 } else { // MP tau
246 monitoredIdVariables.at(key + "_TauJetScoreAccepted_mp") = p.first(*Tau);
247 monitoredIdVariables.at(key + "_TauJetScoreTransAccepted_mp") = p.second(*Tau);
248 }
249 }
250
251 passedCuts++;
252
253
254 //---------------------------------------------------------
255 // At least one Tau passed all the cuts. Accept the event!
256 //---------------------------------------------------------
257 pass = true;
258
259 ATH_MSG_DEBUG(" Pass hypo tool: " << pass);
260 }
261
262 return pass;
263}
264
265
266StatusCode TrigTauPrecisionIDHypoTool::decide(std::vector<ITrigTauJetHypoTool::ToolInfo>& input) const {
267 for(ITrigTauJetHypoTool::ToolInfo& i : input) {
268 if(passed(m_decisionId.numeric(), i.previousDecisionIDs)) {
269 if(decide(i)) {
270 addDecisionID(m_decisionId, i.decision);
271 }
272 }
273 }
274
275 return StatusCode::SUCCESS;
276}
#define ATH_MSG_ERROR(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
Header file to be included by clients of the Monitored infrastructure.
size_type size() const noexcept
Returns the number of elements in the collection.
Group of local monitoring quantities and retain correlation when filling histograms
Declare a monitored scalar variable.
Helper class to provide constant type-safe access to aux data.
virtual StatusCode decide(std::vector< ITrigTauJetHypoTool::ToolInfo > &input) const override
Gaudi::Property< std::string > m_highPtIdWP
TrigTauPrecisionIDHypoTool(const std::string &type, const std::string &name, const IInterface *parent)
Gaudi::Property< std::map< std::string, std::pair< std::string, std::string > > > m_monitoredIdScores
Gaudi::Property< int > m_numIsoTrackMax
Gaudi::Property< std::string > m_idWP
Gaudi::Property< float > m_trackPtCut
std::map< std::string, std::pair< SG::ConstAccessor< float >, SG::ConstAccessor< float > > > m_monitoredIdAccessors
Gaudi::Property< float > m_highPtJetThr
Gaudi::Property< float > m_highPtIdThr
ToolHandle< GenericMonitoringTool > m_monTool
Gaudi::Property< float > m_highPtTrkThr
SG::ConstAccessor< char > m_highpt_id_wp_acc
SG::ConstAccessor< char > m_id_wp_acc
virtual StatusCode initialize() override
It used to be useful piece of code for replacing actual SG with other store of similar functionality ...
bool passed(DecisionID id, const DecisionIDContainer &idSet)
checks if required decision ID is in the set of IDs in the container
void addDecisionID(DecisionID id, Decision *d)
Appends the decision (given as ID) to the decision object.
@ RNNJetScoreSigTrans
RNN score which is signal transformed/flattened.
Definition TauDefs.h:92
IsTauFlag
Enum for IsTau flags.
Definition TauDefs.h:116
TauJet_v3 TauJet
Definition of the current "tau version".
TauJetContainer_v3 TauJetContainer
Definition of the current "taujet container version".