22def trigTauTrackFinderCfg(flags, name='', TrackParticlesContainer=''):
23 '''Tau track association'''
24 acc = ComponentAccumulator()
26 from TrkConfig.TrkVertexFitterUtilsConfig import AtlasTrackToVertexIPEstimatorCfg
27 AtlasTrackToVertexIPEstimator = acc.popToolsAndMerge(AtlasTrackToVertexIPEstimatorCfg(flags))
29 from TrackToVertex.TrackToVertexConfig import TrackToVertexCfg
30 TrackToVertexTool = acc.popToolsAndMerge(TrackToVertexCfg(flags))
32 from TrackToCalo.TrackToCaloConfig import ParticleCaloExtensionToolCfg
33 ParticleCaloExtensionTool = acc.popToolsAndMerge(ParticleCaloExtensionToolCfg(flags))
35 from InDetConfig.InDetTrackSelectorToolConfig import TrigTauInDetTrackSelectorToolCfg
36 TrigTauInDetTrackSelectorTool = acc.popToolsAndMerge(TrigTauInDetTrackSelectorToolCfg(flags))
38 acc.setPrivateTools(CompFactory.TauTrackFinder(
42 TrackSelectorToolTau = TrigTauInDetTrackSelectorTool,
43 TrackToVertexTool = TrackToVertexTool,
44 Key_trackPartInputContainer = TrackParticlesContainer,
45 maxDeltaZ0wrtLeadTrk = 0.75*mm,
46 removeTracksOutsideZ0wrtLeadTrk = True,
47 ParticleCaloExtensionTool = ParticleCaloExtensionTool,
48 BypassExtrapolator = True,
49 tauParticleCache = "",
50 TrackToVertexIPEstimator = AtlasTrackToVertexIPEstimator,
55def tauVertexVariablesCfg(flags, name=''):
56 '''Vertex variables calculation'''
57 acc = ComponentAccumulator()
59 from TrkConfig.TrkVertexFittersConfig import TauAdaptiveVertexFitterCfg
60 TauAdaptiveVertexFitter = acc.popToolsAndMerge(TauAdaptiveVertexFitterCfg(flags))
62 from TrkConfig.TrkVertexSeedFinderToolsConfig import CrossDistancesSeedFinderCfg
63 CrossDistancesSeedFinder = acc.popToolsAndMerge(CrossDistancesSeedFinderCfg(flags))
65 acc.setPrivateTools(CompFactory.TauVertexVariables(
67 VertexFitter = TauAdaptiveVertexFitter,
68 SeedFinder = CrossDistancesSeedFinder,
73def trigTauJetLVNNEvaluatorCfg(flags, tau_id='', use_taujet_rnnscore=True):
74 '''TauJet identification inference based on LVNN models, for RNNs and DeepSets'''
75 acc = ComponentAccumulator()
77 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
78 except NameError: raise ValueError(f'Invalid LVNN (RNN, DeepSet) TauID configuration: {tau_id}')
80 # For legacy mediumRNN/tightRNN_tracktwoMVA/tracktwoLLP/trackLRT chains
81 output_variable = 'RNNJetScore' if use_taujet_rnnscore else f'{tau_id}_Score'
83 # This is to prevent clashes between the same ID running over the legacy chains, and the new decorator-based chains
84 sfx = '_RNNJetScore' if use_taujet_rnnscore else ''
86 acc.setPrivateTools(CompFactory.TauJetRNNEvaluator(
87 name = f'TrigTau_TauJetLVNNEvaluator_{tau_id}{sfx}',
88 useTRT = flags.Detector.EnableTRT,
91 NetworkFile0P = id_flags.NetworkConfig[0],
92 NetworkFile1P = id_flags.NetworkConfig[1],
93 NetworkFile3P = id_flags.NetworkConfig[2],
94 InputLayerScalar = 'scalar',
95 InputLayerTracks = 'tracks',
96 InputLayerClusters = 'clusters',
97 OutputLayer = 'rnnid_output',
98 OutputNode = 'sig_prob',
101 MaxTracks = id_flags.MaxTracks,
102 MaxClusters = id_flags.MaxClusters,
104 VertexCorrection = False,
105 TrackClassification = False,
107 # Decorated TauJet variable names:
108 OutputVarname = output_variable,
113def trigTauJetONNXEvaluatorCfg(flags, tau_id=''):
114 '''TauJet identification inference based on ONNX models, for GNNs, transformers, etc...'''
115 acc = ComponentAccumulator()
117 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
118 except NameError: raise ValueError(f'Invalid ONNX TauID configuration: {tau_id}')
120 if isinstance(id_flags.ONNXConfig, str):
121 network_config = {'NetworkFileInclusive': id_flags.ONNXConfig}
122 elif isinstance(id_flags.ONNXConfig, (list, tuple)) and len(id_flags.ONNXConfig) == 3:
124 'NetworkFile0P': id_flags.ONNXConfig[0],
125 'NetworkFile1P': id_flags.ONNXConfig[1],
126 'NetworkFile3P': id_flags.ONNXConfig[2],
129 raise ValueError(f'Invalid {tau_id} ONNX network config file')
131 acc.setPrivateTools(CompFactory.TauGNNEvaluator(
132 name = f'TrigTau_TauJetONNXEvaluator_{tau_id}',
133 useTRT = flags.Detector.EnableTRT,
137 InputLayerScalar = 'tau_vars',
138 InputLayerTracks = 'track_vars',
139 InputLayerClusters = 'cluster_vars',
140 NodeNameTau = id_flags.NodeNameTau if hasattr(id_flags, 'NodeNameTau') else 'pTau',
141 NodeNameJet = id_flags.NodeNameJet if hasattr(id_flags, 'NodeNameJet') else 'pJet',
145 MaxTracks = id_flags.MaxTracks,
146 MaxClusters = id_flags.MaxClusters,
148 VertexCorrection = False,
149 TrackClassification = False,
151 # Decorated TauJet variable names:
152 OutputVarname = f'{tau_id}_Score',
153 OutputDiscriminant = id_flags.OutputDiscriminant,
154 OutputPTau = f'{tau_id}_ProbTau',
155 OutputPJet = f'{tau_id}_ProbJet',
157 # Tau prongness selection
158 MinProngTrackPt = id_flags.MinProngTrackPt if hasattr(id_flags, 'MinProngTrackPt') else 0,
163def trigTauWPDecoratorRNNCfg(flags, tau_id: str, precision_seq_name: str):
165 TauJet signal transformed score and ID WPs decorator tool,
166 for the legacy mediumRNN/tightRNN_tracktwoMVA/tracktwoLLP/trackLRT chains ONLY!
169 acc = ComponentAccumulator()
171 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
172 except NameError: raise ValueError(f'Invalid TauID configuration: {tau_id}')
174 # In this version we store the WPs as flags accessable through
175 # tau->isTau(xAOD::TauJetParameters::IsTauFlags::JetRNNSig...)
176 import PyUtils.RootUtils as ru
177 ROOT = ru.import_root()
179 cppyy.load_library('libxAODTau_cDict')
181 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigVeryLoose,
182 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigLoose,
183 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigMedium,
184 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigTight,
187 acc.setPrivateTools(CompFactory.TauWPDecorator(
188 name=f'TrigTau_TauWPDecoratorRNN_{precision_seq_name}_{tau_id}',
189 flatteningFile0Prong=id_flags.ScoreFlatteningConfig[0],
190 flatteningFile1Prong=id_flags.ScoreFlatteningConfig[1],
191 flatteningFile3Prong=id_flags.ScoreFlatteningConfig[2],
192 CutEnumVals=WPEnumVals,
193 SigEff0P=id_flags.TargetEff[0],
194 SigEff1P=id_flags.TargetEff[1],
195 SigEff3P=id_flags.TargetEff[2],
196 ScoreName='RNNJetScore',
197 NewScoreName='RNNJetScoreSigTrans',
203def trigTauWPDecoratorCfg(flags, tau_id: str, precision_seq_name: str, tauContainerName: str):
204 '''TauJet signal transformed score and ID WPs decorator tool'''
205 acc = ComponentAccumulator()
207 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
208 except NameError: raise ValueError(f'Invalid TauID configuration: {tau_id}')
210 if isinstance(id_flags.ScoreFlatteningConfig, str):
212 'flatteningFile0Prong': id_flags.ScoreFlatteningConfig,
213 'flatteningFile1Prong': id_flags.ScoreFlatteningConfig,
214 'flatteningFile3Prong': id_flags.ScoreFlatteningConfig,
216 elif isinstance(id_flags.ScoreFlatteningConfig, (list, tuple)) and len(id_flags.ScoreFlatteningConfig) == 3:
218 'flatteningFile0Prong': id_flags.ScoreFlatteningConfig[0],
219 'flatteningFile1Prong': id_flags.ScoreFlatteningConfig[1],
220 'flatteningFile3Prong': id_flags.ScoreFlatteningConfig[2],
223 raise ValueError(f'Invalid {tau_id} WP decorator flattening config')
226 acc.setPrivateTools(CompFactory.TauWPDecorator(
227 name=f'TrigTau_TauWPDecorator_{precision_seq_name}_{tau_id}',
229 TauContainerName=tauContainerName,
230 DecorWPNames=[f'{tau_id}_{wp}' for wp in id_flags.TargetWPs],
231 DecorWPCutEffs0P=[eff[0] for eff in id_flags.TargetWPs.values()],
232 DecorWPCutEffs1P=[eff[1] for eff in id_flags.TargetWPs.values()],
233 DecorWPCutEffs3P=[eff[2] for eff in id_flags.TargetWPs.values()],
234 ScoreName=f'{tau_id}_Score',
235 NewScoreName=f'{tau_id}_ScoreSigTrans',