ATLAS Offline Software
Loading...
Searching...
No Matches
TrigTauRecToolsConfig.py
Go to the documentation of this file.
1# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2
3from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
4from AthenaConfiguration.ComponentFactory import CompFactory
6
7def trigTauVertexFinderCfg(flags, name=''):
8 '''Algorithm that overwrites numTrack() and charge() of tauJets in container'''
9 acc = ComponentAccumulator()
10
11 acc.setPrivateTools(CompFactory.TauVertexFinder(
12 name = name,
13 UseTJVA = False,
14 AssociatedTracks = 'GhostTrack',
15 InDetTrackSelectionToolForTJVA = '',
16 Key_trackPartInputContainer = '',
17 Key_vertexInputContainer = '',
18 TVATool = '',
19 ))
20
21 return acc
22
23def trigTauTrackFinderCfg(flags, name='', TrackParticlesContainer=''):
24 '''Tau track association'''
25 acc = ComponentAccumulator()
26
27 from TrkConfig.TrkVertexFitterUtilsConfig import AtlasTrackToVertexIPEstimatorCfg
28 AtlasTrackToVertexIPEstimator = acc.popToolsAndMerge(AtlasTrackToVertexIPEstimatorCfg(flags))
29
30 from TrackToVertex.TrackToVertexConfig import TrackToVertexCfg
31 TrackToVertexTool = acc.popToolsAndMerge(TrackToVertexCfg(flags))
32
33 from TrackToCalo.TrackToCaloConfig import ParticleCaloExtensionToolCfg
34 ParticleCaloExtensionTool = acc.popToolsAndMerge(ParticleCaloExtensionToolCfg(flags))
35
36 from InDetConfig.InDetTrackSelectorToolConfig import TrigTauInDetTrackSelectorToolCfg
37 TrigTauInDetTrackSelectorTool = acc.popToolsAndMerge(TrigTauInDetTrackSelectorToolCfg(flags))
38
39 acc.setPrivateTools(CompFactory.TauTrackFinder(
40 name = name,
41 MaxJetDrTau = 0.2,
42 MaxJetDrWide = 0.4,
43 TrackSelectorToolTau = TrigTauInDetTrackSelectorTool,
44 TrackToVertexTool = TrackToVertexTool,
45 Key_trackPartInputContainer = TrackParticlesContainer,
46 maxDeltaZ0wrtLeadTrk = 0.75*mm,
47 removeTracksOutsideZ0wrtLeadTrk = True,
48 ParticleCaloExtensionTool = ParticleCaloExtensionTool,
49 BypassExtrapolator = True,
50 tauParticleCache = "",
51 TrackToVertexIPEstimator = AtlasTrackToVertexIPEstimator,
52 ))
53
54 return acc
55
56def tauVertexVariablesCfg(flags, name=''):
57 '''Vertex variables calculation'''
58 acc = ComponentAccumulator()
59
60 from TrkConfig.TrkVertexFittersConfig import TauAdaptiveVertexFitterCfg
61 TauAdaptiveVertexFitter = acc.popToolsAndMerge(TauAdaptiveVertexFitterCfg(flags))
62
63 from TrkConfig.TrkVertexSeedFinderToolsConfig import CrossDistancesSeedFinderCfg
64 CrossDistancesSeedFinder = acc.popToolsAndMerge(CrossDistancesSeedFinderCfg(flags))
65
66 acc.setPrivateTools(CompFactory.TauVertexVariables(
67 name = name,
68 VertexFitter = TauAdaptiveVertexFitter,
69 SeedFinder = CrossDistancesSeedFinder,
70 ))
71
72 return acc
73
74def trigTauJetLVNNEvaluatorCfg(flags, tau_id='', use_taujet_rnnscore=True):
75 '''TauJet identification inference based on LVNN models, for RNNs and DeepSets'''
76 acc = ComponentAccumulator()
77
78 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
79 except NameError: raise ValueError(f'Invalid LVNN (RNN, DeepSet) TauID configuration: {tau_id}')
80
81 # For legacy mediumRNN/tightRNN_tracktwoMVA/tracktwoLLP/trackLRT chains
82 output_variable = 'RNNJetScore' if use_taujet_rnnscore else f'{tau_id}_Score'
83
84 # This is to prevent clashes between the same ID running over the legacy chains, and the new decorator-based chains
85 sfx = '_RNNJetScore' if use_taujet_rnnscore else ''
86
87 acc.setPrivateTools(CompFactory.TauJetRNNEvaluator(
88 name = f'TrigTau_TauJetLVNNEvaluator_{tau_id}{sfx}',
89 useTRT = flags.Detector.EnableTRT,
90
91 # Network config:
92 NetworkFile0P = id_flags.NetworkConfig[0],
93 NetworkFile1P = id_flags.NetworkConfig[1],
94 NetworkFile3P = id_flags.NetworkConfig[2],
95 InputLayerScalar = 'scalar',
96 InputLayerTracks = 'tracks',
97 InputLayerClusters = 'clusters',
98 OutputLayer = 'rnnid_output',
99 OutputNode = 'sig_prob',
100
101 # Inputs:
102 MaxTracks = id_flags.MaxTracks,
103 MaxClusters = id_flags.MaxClusters,
104 MaxClusterDR = 1.0,
105 VertexCorrection = False,
106 TrackClassification = False,
107
108 # Decorated TauJet variable names:
109 OutputVarname = output_variable,
110 ))
111
112 return acc
113
114def trigTauJetONNXEvaluatorCfg(flags, tau_id=''):
115 '''TauJet identification inference based on ONNX models, for GNNs, transformers, etc...'''
116 acc = ComponentAccumulator()
117
118 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
119 except NameError: raise ValueError(f'Invalid ONNX TauID configuration: {tau_id}')
120
121 acc.setPrivateTools(CompFactory.TauGNNEvaluator(
122 name = f'TrigTau_TauJetONNXEvaluator_{tau_id}',
123 useTRT = flags.Detector.EnableTRT,
124
125 # Network config:
126 NetworkFile0P = id_flags.ONNXConfig[0],
127 NetworkFile1P = id_flags.ONNXConfig[1],
128 NetworkFile3P = id_flags.ONNXConfig[2],
129 InputLayerScalar = 'tau_vars',
130 InputLayerTracks = 'track_vars',
131 InputLayerClusters = 'cluster_vars',
132 NodeNameTau = 'pTau',
133 NodeNameJet = 'pJet',
134
135 # Inputs:
136 MaxTracks = id_flags.MaxTracks,
137 MaxClusters = id_flags.MaxClusters,
138 MaxClusterDR = 1.0,
139 VertexCorrection = False,
140 TrackClassification = False,
141
142 # Decorated TauJet variable names:
143 OutputVarname = f'{tau_id}_Score',
144 OutputDiscriminant = id_flags.OutputDiscriminant,
145 OutputPTau = f'{tau_id}_ProbTau',
146 OutputPJet = f'{tau_id}_ProbJet',
147
148 # Tau prongness selection
149 MinProngTrackPt = id_flags.MinProngTrackPt if hasattr(id_flags, 'MinProngTrackPt') else 0,
150 ))
151
152 return acc
153
154def trigTauWPDecoratorRNNCfg(flags, tau_id: str, precision_seq_name: str):
155 '''
156 TauJet signal transformed score and ID WPs decorator tool,
157 for the legacy mediumRNN/tightRNN_tracktwoMVA/tracktwoLLP/trackLRT chains ONLY!
158 '''
159
160 acc = ComponentAccumulator()
161
162 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
163 except NameError: raise ValueError(f'Invalid TauID configuration: {tau_id}')
164
165 # In this version we store the WPs as flags accessable through
166 # tau->isTau(xAOD::TauJetParameters::IsTauFlags::JetRNNSig...)
167 import PyUtils.RootUtils as ru
168 ROOT = ru.import_root()
169 import cppyy
170 cppyy.load_library('libxAODTau_cDict')
171 WPEnumVals = [
172 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigVeryLoose,
173 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigLoose,
174 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigMedium,
175 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigTight,
176 ]
177
178 acc.setPrivateTools(CompFactory.TauWPDecorator(
179 name=f'TrigTau_TauWPDecoratorRNN_{precision_seq_name}_{tau_id}',
180 flatteningFile0Prong=id_flags.ScoreFlatteningConfig[0],
181 flatteningFile1Prong=id_flags.ScoreFlatteningConfig[1],
182 flatteningFile3Prong=id_flags.ScoreFlatteningConfig[2],
183 CutEnumVals=WPEnumVals,
184 SigEff0P=id_flags.TargetEff[0],
185 SigEff1P=id_flags.TargetEff[1],
186 SigEff3P=id_flags.TargetEff[2],
187 ScoreName='RNNJetScore',
188 NewScoreName='RNNJetScoreSigTrans',
189 DefineWPs=True,
190 ))
191
192 return acc
193
194def trigTauWPDecoratorCfg(flags, tau_id: str, precision_seq_name: str, tauContainerName: str):
195 '''TauJet signal transformed score and ID WPs decorator tool'''
196 acc = ComponentAccumulator()
197
198 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
199 except NameError: raise ValueError(f'Invalid TauID configuration: {tau_id}')
200
201 acc.setPrivateTools(CompFactory.TauWPDecorator(
202 name=f'TrigTau_TauWPDecoratorRNN_{precision_seq_name}_{tau_id}',
203 flatteningFile0Prong=id_flags.ScoreFlatteningConfig[0],
204 flatteningFile1Prong=id_flags.ScoreFlatteningConfig[1],
205 flatteningFile3Prong=id_flags.ScoreFlatteningConfig[2],
206 TauContainerName=tauContainerName,
207 DecorWPNames=[f'{tau_id}_{wp}' for wp in id_flags.WPNames],
208 DecorWPCutEffs0P=id_flags.TargetEff[0],
209 DecorWPCutEffs1P=id_flags.TargetEff[1],
210 DecorWPCutEffs3P=id_flags.TargetEff[2],
211 ScoreName=f'{tau_id}_Score',
212 NewScoreName=f'{tau_id}_ScoreSigTrans',
213 DefineWPs=True,
214 ))
215
216 return acc
trigTauVertexFinderCfg(flags, name='')