ATLAS Offline Software
Loading...
Searching...
No Matches
TauConfigurationTools.py
Go to the documentation of this file.
1# Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2from typing import Any
3
4from AthenaConfiguration.AthConfigFlags import AthConfigFlags
5
6from AthenaCommon.Logging import logging
7log = logging.getLogger(__name__)
8
9# This file contains helper functions for the Tau Trigger signature configuration
10
11
12
16 flags: AthConfigFlags,
17 key: str | None = None,
18 alt_key: str | None = None,
19 algs: dict[str, list[str]] | list[str] | None = None,
20 mc_algs: dict[str, list[str]] | list[str] | None = None,
21 dev_algs: dict[str, list[str]] | list[str] | None = None,
22) -> list[str]:
23 '''Get the list of algorithms for a specific menu key; if not found, the alternate key will be tried if provided.'''
24 def _getAlgs(key: str | None):
25 if algs is None: ret = None
26 elif isinstance(algs, dict): ret = algs[key] if key in algs else None
27 else: ret = algs # list
28 if any(pfx in flags.Trigger.triggerMenuSetup for pfx in ['MC_', 'Dev_']) and mc_algs and (isinstance(mc_algs, list) or key in mc_algs):
29 if ret is None: ret = []
30 ret += mc_algs[key] if isinstance(mc_algs, dict) else mc_algs
31 if 'Dev_' in flags.Trigger.triggerMenuSetup and dev_algs and (isinstance(dev_algs, list) or key in dev_algs):
32 if ret is None: ret = []
33 ret += dev_algs[key] if isinstance(dev_algs, dict) else dev_algs
34 return ret
35
36 ret_algs = _getAlgs(key)
37 if ret_algs is None and alt_key is not None:
38 ret_algs = _getAlgs(alt_key)
39 return ret_algs
40
41
42
43def useBuiltInTauJetRNNScore(tau_id: str) -> bool:
44 '''Check if the TauJet's built-in RNN score and WP variables have to be used, instead of the decorator-based variables'''
45 # Support for "legacy" algorithms, where the scores are stored in the built-in TauJet aux variables
46 return tau_id in ['DeepSet', 'RNNLLP']
47
48
49def getTauIDScoreVariables(tau_id: str) -> tuple[str, str]:
50 '''Return the (score, score_sig_trans) variable name pair for a given TauID/Sequence configuration'''
51 # Support for "legacy" algorithms, where the scores are stored in the built-in TauJet aux variables
52 if useBuiltInTauJetRNNScore(tau_id): return ('RNNJetScore', 'RNNJetScoreSigTrans')
53
54 return (f'{tau_id}_Score', f'{tau_id}_ScoreSigTrans')
55
56
57def getTauIDAlgorithm(flags: AthConfigFlags, selection: str, name_mapping: dict[str, str] | None = None) -> str:
58 # Sort ID names from longest to shortest, to check for a full match
59 tau_ids = sorted(list(flags.Trigger.Offline.Tau), key=len, reverse=True)
60 for tau_id in tau_ids:
61 if selection.endswith(tau_id): return tau_id
62
63 # Remap names (e.g. DS -> DeepSet)
64 if name_mapping:
65 name_mapping = dict(sorted(name_mapping.items(), key=lambda p: len(p[0]), reverse=True))
66 for short_name, long_name in name_mapping.items():
67 if selection.endswith(short_name): return long_name
68
69 return selection
70
71
72
73
76
77def getHitZAlgs(flags: AthConfigFlags, precision_sequence: str, alt_precision_sequence: str | None = None) -> list[str]:
78 '''
79 Get the list of HitZ algorithms for the CaloHits reco sequence.
80 The configuration for each algorithm is contained in flags.Trigger.Offline.Tau.<alg>.
81 '''
82 return getMenuAlgs(
83 flags,
84 key=precision_sequence,
85 alt_key=alt_precision_sequence,
86
87 # Default HitZ algorithms to run in all menus
88 algs={},
89
90 # Additional HitZ algorithms to run ONLY if we're using the MC (or Dev) menu
91 mc_algs={},
92
93 # Additional HitZ algorithms to run ONLY if we're using the Dev menu
94 dev_algs=['HitZ'],
95 )
96
97
98def getHitZConfig(flags: AthConfigFlags, chainPart: dict[str, Any]) -> tuple[str, float] | None:
99 '''
100 Get the HLT HitZ configuration tuple: (algorithm name, sigma cut value in mm)
101 '''
102 if not chainPart['hitz']: return None
103
104 import re
105 # Match strings of the form: '10mmX5mmHitZ', '5mmHitZ', 'HitZ', etc...
106 match = re.match(r'((?P<sigma>(\d|p)+)mm)?(X(\d|p)+mm)?(?P<alg>.+)', chainPart['hitz'])
107 if match:
108 alg = match.group('alg')
109
110 alg_flags = getattr(flags.Trigger.Offline.Tau, alg, None)
111 if alg_flags is None:
112 raise ValueError(f'HitZ algorithm "{alg}" configuration not found in flags.Trigger.Offline.Tau.{alg}')
113
114 sigma = match.group('sigma')
115 if sigma is None: sigma = alg_flags.DefaultMaxZ0Sigma # mm (default value)
116 else: sigma = float(sigma.replace('p', '.')) # mm
117
118 return (alg, sigma)
119
120 raise ValueError(f'Invalid HitZ configuration string: {chainPart["hitz"]}')
121
122
123def getHitZVariables(alg: str) -> tuple[str, str]:
124 '''Return the (z, sigma) variable name pair for a given HitZ algorithm'''
125 return (f'{alg}_z0', f'{alg}_z0_sigma')
126
127
128def getCaloHitsPreselAlgs(flags: AthConfigFlags, precision_sequence: str, alt_precision_sequence: str | None = None) -> list[str]:
129 '''
130 Get the list of CaloHits preselection TauID inferences to be executed for the CaloHits reco sequence.
131 The configuration for each algorithm is contained in flags.Trigger.Offline.Tau.<alg>.
132 '''
133 return getMenuAlgs(
134 flags,
135 key=precision_sequence,
136 alt_key=alt_precision_sequence,
137
138 # Default inferences to run in all menus
139 algs=[],
140
141 # Additional inferences to run ONLY if we're using the MC (or Dev) menu
142 mc_algs=[],
143
144 # Additional inferences to run ONLY if we're using the Dev menu
145 dev_algs=[],
146 )
147
148
149def getChainCaloHitsPreselConfigName(flags: AthConfigFlags, chainPart: dict[str, Any]) -> str:
150 '''Clean the CaloHits preselection configuration for a chainPart dict'''
151 sel = chainPart['calohitsPresel']
152
153 if not sel or sel == 'idperfCHP': return 'idperf' # No preselection
154
155 return getTauIDAlgorithm(
156 flags,
157 sel,
158 name_mapping={'CHTP': 'GNCaloHitsTauPresel'},
159 )
160
161
162def getChainCaloHitsSeqName(chainPart: dict[str, Any]) -> str | None:
163 '''Get the HLT Tau CaloHits sequence name suffix'''
164 if not chainPart['hitz'] and not chainPart['calohitsPresel']: return None
165
166 parts = []
167
168 # HitZ RoI updating selection
169 if chainPart['hitz']: parts.append(chainPart['hitz'])
170
171 if not parts: parts = ['CaloHitsBase']
172 return '_'.join(parts)
173
174
175
176
179
180def getPrecisionSequenceTauIDs(flags: AthConfigFlags, precision_sequence: str, alt_precision_sequence: str | None = None) -> list[str]:
181 '''
182 Get the list of precision TauID inferences to be executed for each HLT tau trigger reco sequence
183 The configuration for each algorithm is contained in flags.Trigger.Offline.Tau.<alg>.
184 '''
185 return getMenuAlgs(
186 flags,
187 key=precision_sequence,
188 alt_key=alt_precision_sequence,
189
190 # Default Tau ID algorithms to run in all menus
191 algs={
192 'MVA': ['GNTau', 'MesonCuts', 'GNTauDev1'],
193 'LLP': ['RNNLLP'],
194 'LRT': ['RNNLLP'],
195 },
196
197 # Additional Tau ID algorithms to run ONLY if we're using the MC (or Dev) menu
198 mc_algs={
199 'MVA': ['DeepSet'],
200 },
201
202 # Additional Tau ID algorithms to run ONLY if we're using the Dev menu
203 dev_algs={
204 }
205 )
206
207
208# The following functions are only required while we still have triggers
209# with the RNN/DeepSet naming scheme in the Menu (e.g. mediumRNN_tracktwoMVA/LLP)
210rnn_wps = ['verylooseRNN', 'looseRNN', 'mediumRNN', 'tightRNN']
211noid_selections = ['perf', 'idperf']
212meson_selections = ['kaonpi1', 'kaonpi2', 'dipion1', 'dipion2', 'dipion3', 'dipion4', 'dikaonmass', 'singlepion']
213
214def getChainIDConfigName(flags: AthConfigFlags, chainPart: dict[str, Any]) -> str:
215 '''Clean the ID configuration for a chainPart dict'''
216 sel = chainPart['selection']
217
218 # Support for the Legacy trigger names:
219 if chainPart['reconstruction'] == 'tracktwoMVA':
220 if sel in rnn_wps:
221 return 'DeepSet'
222 elif sel in meson_selections:
223 return 'MesonCuts'
224 elif chainPart['reconstruction'] in ['tracktwoLLP', 'trackLRT'] and sel in rnn_wps:
225 return 'RNNLLP'
226
227 return getTauIDAlgorithm(
228 flags,
229 sel,
230 name_mapping={},
231 )
232
233 return sel
234
235
236def getChainPrecisionSeqName(chainPart: dict[str, Any], include_calohits_seq_name: bool = False) -> str:
237 '''
238 Get the HLT Tau Precision sequence name suffix.
239 This is also used for the HLT_TrigTauRecMerged_... and HLT_tautrack_... EDM collection names.
240 '''
241 ret = chainPart['reconstruction']
242
243 # Support for the Legacy trigger names:
244 if ret == 'tracktwoMVA': return 'MVA'
245 elif ret == 'tracktwoLLP': return 'LLP'
246 elif ret == 'trackLRT': return 'LRT'
247
248 if include_calohits_seq_name:
249 calohits_seq = getChainCaloHitsSeqName(chainPart)
250 ret += f'_{calohits_seq}' if calohits_seq else ''
251
252 return ret
253
254
255
256
259
260def getChainSequenceConfigName(chainPart: dict[str, Any]) -> str:
261 '''Get the HLT Tau signature global menu sequence name (e.g. ptonly, tracktwo, trackLRT, etc...)'''
262 name = []
263
264 if chainPart['hitz'] or chainPart['calohitsPresel']:
265 name.append('CaloHits')
266
267 name.append(chainPart['reconstruction'])
268
269 return '_'.join(name)
270
str getChainSequenceConfigName(dict[str, Any] chainPart)
Global Tau menu sequence.
str getChainIDConfigName(AthConfigFlags flags, dict[str, Any] chainPart)
str getChainPrecisionSeqName(dict[str, Any] chainPart, bool include_calohits_seq_name=False)
tuple[str, str] getTauIDScoreVariables(str tau_id)
list[str] getHitZAlgs(AthConfigFlags flags, str precision_sequence, str|None alt_precision_sequence=None)
CaloHits sequence algorithms.
list[str] getPrecisionSequenceTauIDs(AthConfigFlags flags, str precision_sequence, str|None alt_precision_sequence=None)
Precision sequence TauIDs.
str|None getChainCaloHitsSeqName(dict[str, Any] chainPart)
str getTauIDAlgorithm(AthConfigFlags flags, str selection, dict[str, str]|None name_mapping=None)
str getChainCaloHitsPreselConfigName(AthConfigFlags flags, dict[str, Any] chainPart)
tuple[str, float]|None getHitZConfig(AthConfigFlags flags, dict[str, Any] chainPart)
list[str] getCaloHitsPreselAlgs(AthConfigFlags flags, str precision_sequence, str|None alt_precision_sequence=None)
list[str] getMenuAlgs(AthConfigFlags flags, str|None key=None, str|None alt_key=None, dict[str, list[str]]|list[str]|None algs=None, dict[str, list[str]]|list[str]|None mc_algs=None, dict[str, list[str]]|list[str]|None dev_algs=None)
Global helper functions.