ATLAS Offline Software
Loading...
Searching...
No Matches
TauConfigurationTools.py
Go to the documentation of this file.
1# Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2from typing import Any
3
4from AthenaConfiguration.AthConfigFlags import AthConfigFlags
5from AthenaConfiguration.AccumulatorCache import AccumulatorCache
6
7from AthenaCommon.Logging import logging
8log = logging.getLogger(__name__)
9
10# This file contains helper functions for the Tau Trigger signature configuration
11
12
13
17 flags: AthConfigFlags,
18 key: str | None = None,
19 alt_key: str | None = None,
20 algs: dict[str, list[str]] | list[str] | None = None,
21 mc_algs: dict[str, list[str]] | list[str] | None = None,
22 dev_algs: dict[str, list[str]] | list[str] | None = None,
23) -> list[str]:
24 '''Get the list of algorithms for a specific menu key; if not found, the alternate key will be tried if provided.'''
25 def _getAlgs(key: str | None):
26 if algs is None: ret = None
27 elif isinstance(algs, dict): ret = algs[key] if key in algs else None
28 else: ret = algs # list
29 if any(pfx in flags.Trigger.triggerMenuSetup for pfx in ['MC_', 'Dev_']) and mc_algs and (isinstance(mc_algs, list) or key in mc_algs):
30 if ret is None: ret = []
31 ret += mc_algs[key] if isinstance(mc_algs, dict) else mc_algs
32 if 'Dev_' in flags.Trigger.triggerMenuSetup and dev_algs and (isinstance(dev_algs, list) or key in dev_algs):
33 if ret is None: ret = []
34 ret += dev_algs[key] if isinstance(dev_algs, dict) else dev_algs
35 return ret
36
37 ret_algs = _getAlgs(key)
38 if ret_algs is None and alt_key is not None:
39 ret_algs = _getAlgs(alt_key)
40 return ret_algs
41
42
43
44def useBuiltInTauJetRNNScore(tau_id: str) -> bool:
45 '''Check if the TauJet's built-in RNN score and WP variables have to be used, instead of the decorator-based variables'''
46 # Support for "legacy" algorithms, where the scores are stored in the built-in TauJet aux variables
47 return tau_id in ['DeepSet', 'RNNLLP']
48
49
50def getTauIDScoreVariables(tau_id: str) -> tuple[str, str]:
51 '''Return the (score, score_sig_trans) variable name pair for a given TauID/Sequence configuration'''
52 # Support for "legacy" algorithms, where the scores are stored in the built-in TauJet aux variables
53 if useBuiltInTauJetRNNScore(tau_id): return ('RNNJetScore', 'RNNJetScoreSigTrans')
54
55 return (f'{tau_id}_Score', f'{tau_id}_ScoreSigTrans')
56
57
58@AccumulatorCache # called many times and looping over flags is slow
59def getTauIDAlgorithm(flags: AthConfigFlags, selection: str,
60 name_mapping: tuple[tuple[str, str], ...] | None = None) -> str:
61
62 # Sort ID names from longest to shortest, to check for a full match
63 tau_ids = sorted(flags.Trigger.Offline.Tau, key=len, reverse=True)
64 for tau_id in tau_ids:
65 if selection.endswith(tau_id): return tau_id
66
67 # Remap names (e.g. DS -> DeepSet)
68 if name_mapping:
69 name_mapping = sorted(name_mapping, key=lambda p: len(p[0]), reverse=True)
70 for short_name, long_name in name_mapping:
71 if selection.endswith(short_name): return long_name
72
73 return selection
74
75
76
77
80
81def getHitZAlgs(flags: AthConfigFlags, precision_sequence: str, alt_precision_sequence: str | None = None) -> list[str]:
82 '''
83 Get the list of HitZ algorithms for the CaloHits reco sequence.
84 The configuration for each algorithm is contained in flags.Trigger.Offline.Tau.<alg>.
85 '''
86 return getMenuAlgs(
87 flags,
88 key=precision_sequence,
89 alt_key=alt_precision_sequence,
90
91 # Default HitZ algorithms to run in all menus
92 algs=['HitZ'],
93
94 # Additional HitZ algorithms to run ONLY if we're using the MC (or Dev) menu
95 mc_algs={},
96
97 # Additional HitZ algorithms to run ONLY if we're using the Dev menu
98 dev_algs={},
99 )
100
101
102def getHitZConfig(flags: AthConfigFlags, chainPart: dict[str, Any]) -> tuple[str, float] | None:
103 '''
104 Get the HLT HitZ configuration tuple: (algorithm name, sigma cut value in mm)
105 '''
106 if not chainPart['hitz']: return None
107
108 import re
109 # Match strings of the form: '10mmX5mmHitZ', '5mmHitZ', 'HitZ', etc...
110 match = re.match(r'((?P<sigma>(\d|p)+)mm)?(X(\d|p)+mm)?(?P<alg>.+)', chainPart['hitz'])
111 if match:
112 alg = match.group('alg')
113
114 alg_flags = getattr(flags.Trigger.Offline.Tau, alg, None)
115 if alg_flags is None:
116 raise ValueError(f'HitZ algorithm "{alg}" configuration not found in flags.Trigger.Offline.Tau.{alg}')
117
118 sigma = match.group('sigma')
119 if sigma is None: sigma = alg_flags.DefaultMaxZ0Sigma # mm (default value)
120 else: sigma = float(sigma.replace('p', '.')) # mm
121
122 return (alg, sigma)
123
124 raise ValueError(f'Invalid HitZ configuration string: {chainPart["hitz"]}')
125
126
127def getHitZVariables(alg: str) -> tuple[str, str]:
128 '''Return the (z, sigma) variable name pair for a given HitZ algorithm'''
129 return (f'{alg}_z0', f'{alg}_z0_sigma')
130
131
132def getCaloHitsPreselAlgs(flags: AthConfigFlags, precision_sequence: str, alt_precision_sequence: str | None = None) -> list[str]:
133 '''
134 Get the list of CaloHits preselection TauID inferences to be executed for the CaloHits reco sequence.
135 The configuration for each algorithm is contained in flags.Trigger.Offline.Tau.<alg>.
136 '''
137 return getMenuAlgs(
138 flags,
139 key=precision_sequence,
140 alt_key=alt_precision_sequence,
141
142 # Default inferences to run in all menus
143 algs=[],
144
145 # Additional inferences to run ONLY if we're using the MC (or Dev) menu
146 mc_algs=[],
147
148 # Additional inferences to run ONLY if we're using the Dev menu
149 dev_algs=[],
150 )
151
152
153def getChainCaloHitsPreselConfigName(flags: AthConfigFlags, chainPart: dict[str, Any]) -> str:
154 '''Clean the CaloHits preselection configuration for a chainPart dict'''
155 sel = chainPart['calohitsPresel']
156
157 if not sel or sel == 'idperfCHP': return 'idperf' # No preselection
158
159 return getTauIDAlgorithm(
160 flags,
161 sel,
162 name_mapping=(('CHTP', 'GNCaloHitsTauPresel')),
163 )
164
165
166def getChainCaloHitsSeqName(chainPart: dict[str, Any]) -> str | None:
167 '''Get the HLT Tau CaloHits sequence name suffix'''
168 if not chainPart['hitz'] and not chainPart['calohitsPresel']: return None
169
170 parts = []
171
172 # HitZ RoI updating selection
173 if chainPart['hitz']: parts.append(chainPart['hitz'])
174
175 if not parts: parts = ['CaloHitsBase']
176 return '_'.join(parts)
177
178
179
180
183
184def getPrecisionSequenceTauIDs(flags: AthConfigFlags, precision_sequence: str, alt_precision_sequence: str | None = None) -> list[str]:
185 '''
186 Get the list of precision TauID inferences to be executed for each HLT tau trigger reco sequence
187 The configuration for each algorithm is contained in flags.Trigger.Offline.Tau.<alg>.
188 '''
189 return getMenuAlgs(
190 flags,
191 key=precision_sequence,
192 alt_key=alt_precision_sequence,
193
194 # Default Tau ID algorithms to run in all menus
195 algs={
196 'MVA': ['GNTau', 'MesonCuts', 'GNTauDev1'],
197 'LLP': ['RNNLLP'],
198 'LRT': ['RNNLLP'],
199 },
200
201 # Additional Tau ID algorithms to run ONLY if we're using the MC (or Dev) menu
202 mc_algs={
203 'MVA': ['DeepSet'],
204 },
205
206 # Additional Tau ID algorithms to run ONLY if we're using the Dev menu
207 dev_algs={
208 }
209 )
210
211
212# The following functions are only required while we still have triggers
213# with the RNN/DeepSet naming scheme in the Menu (e.g. mediumRNN_tracktwoMVA/LLP)
214rnn_wps = ['verylooseRNN', 'looseRNN', 'mediumRNN', 'tightRNN']
215noid_selections = ['perf', 'idperf']
216meson_selections = ['kaonpi1', 'kaonpi2', 'dipion1', 'dipion2', 'dipion3', 'dipion4', 'dikaonmass', 'singlepion']
217
218def getChainIDConfigName(flags: AthConfigFlags, chainPart: dict[str, Any]) -> str:
219 '''Clean the ID configuration for a chainPart dict'''
220 sel = chainPart['selection']
221
222 # Support for the Legacy trigger names:
223 if chainPart['reconstruction'] == 'tracktwoMVA':
224 if sel in rnn_wps:
225 return 'DeepSet'
226 elif sel in meson_selections:
227 return 'MesonCuts'
228 elif chainPart['reconstruction'] in ['tracktwoLLP', 'trackLRT'] and sel in rnn_wps:
229 return 'RNNLLP'
230
231 return getTauIDAlgorithm(
232 flags,
233 sel,
234 )
235
236 return sel
237
238
239def getChainPrecisionSeqName(chainPart: dict[str, Any], include_calohits_seq_name: bool = False) -> str:
240 '''
241 Get the HLT Tau Precision sequence name suffix.
242 This is also used for the HLT_TrigTauRecMerged_... and HLT_tautrack_... EDM collection names.
243 '''
244 ret = chainPart['reconstruction']
245
246 # Support for the Legacy trigger names:
247 if ret == 'tracktwoMVA': return 'MVA'
248 elif ret == 'tracktwoLLP': return 'LLP'
249 elif ret == 'trackLRT': return 'LRT'
250
251 if include_calohits_seq_name:
252 calohits_seq = getChainCaloHitsSeqName(chainPart)
253 ret += f'_{calohits_seq}' if calohits_seq else ''
254
255 return ret
256
257
258
259
262
263def getChainSequenceConfigName(chainPart: dict[str, Any]) -> str:
264 '''Get the HLT Tau signature global menu sequence name (e.g. ptonly, tracktwo, trackLRT, etc...)'''
265 name = []
266
267 if chainPart['hitz'] or chainPart['calohitsPresel']:
268 name.append('CaloHits')
269
270 name.append(chainPart['reconstruction'])
271
272 return '_'.join(name)
273
str getChainSequenceConfigName(dict[str, Any] chainPart)
Global Tau menu sequence.
str getChainIDConfigName(AthConfigFlags flags, dict[str, Any] chainPart)
str getChainPrecisionSeqName(dict[str, Any] chainPart, bool include_calohits_seq_name=False)
tuple[str, str] getTauIDScoreVariables(str tau_id)
list[str] getHitZAlgs(AthConfigFlags flags, str precision_sequence, str|None alt_precision_sequence=None)
CaloHits sequence algorithms.
str getTauIDAlgorithm(AthConfigFlags flags, str selection, tuple[tuple[str, str],...]|None name_mapping=None)
list[str] getPrecisionSequenceTauIDs(AthConfigFlags flags, str precision_sequence, str|None alt_precision_sequence=None)
Precision sequence TauIDs.
str|None getChainCaloHitsSeqName(dict[str, Any] chainPart)
str getChainCaloHitsPreselConfigName(AthConfigFlags flags, dict[str, Any] chainPart)
tuple[str, float]|None getHitZConfig(AthConfigFlags flags, dict[str, Any] chainPart)
list[str] getCaloHitsPreselAlgs(AthConfigFlags flags, str precision_sequence, str|None alt_precision_sequence=None)
list[str] getMenuAlgs(AthConfigFlags flags, str|None key=None, str|None alt_key=None, dict[str, list[str]]|list[str]|None algs=None, dict[str, list[str]]|list[str]|None mc_algs=None, dict[str, list[str]]|list[str]|None dev_algs=None)
Global helper functions.