3 from AthenaConfiguration.ComponentAccumulator
import ComponentAccumulator
4 from AthenaConfiguration.ComponentFactory
import CompFactory
8 '''Algorithm that overwrites numTrack() and charge() of tauJets in container'''
11 acc.setPrivateTools(CompFactory.TauVertexFinder(
14 AssociatedTracks =
'GhostTrack',
15 InDetTrackSelectionToolForTJVA =
'',
16 Key_trackPartInputContainer =
'',
17 Key_vertexInputContainer =
'',
24 '''Tau track association'''
27 from TrkConfig.TrkVertexFitterUtilsConfig
import AtlasTrackToVertexIPEstimatorCfg
30 from TrackToVertex.TrackToVertexConfig
import TrackToVertexCfg
33 from TrackToCalo.TrackToCaloConfig
import ParticleCaloExtensionToolCfg
36 from InDetConfig.InDetTrackSelectorToolConfig
import TrigTauInDetTrackSelectorToolCfg
39 acc.setPrivateTools(CompFactory.TauTrackFinder(
43 TrackSelectorToolTau = TrigTauInDetTrackSelectorTool,
44 TrackToVertexTool = TrackToVertexTool,
45 Key_trackPartInputContainer = TrackParticlesContainer,
46 maxDeltaZ0wrtLeadTrk = 0.75*mm,
47 removeTracksOutsideZ0wrtLeadTrk =
True,
48 ParticleCaloExtensionTool = ParticleCaloExtensionTool,
49 BypassExtrapolator =
True,
50 tauParticleCache =
"",
51 TrackToVertexIPEstimator = AtlasTrackToVertexIPEstimator,
57 '''Vertex variables calculation'''
60 from TrkConfig.TrkVertexFittersConfig
import TauAdaptiveVertexFitterCfg
63 from TrkConfig.TrkVertexSeedFinderToolsConfig
import CrossDistancesSeedFinderCfg
66 acc.setPrivateTools(CompFactory.TauVertexVariables(
68 VertexFitter = TauAdaptiveVertexFitter,
69 SeedFinder = CrossDistancesSeedFinder,
75 '''TauJet identification inference based on LVNN models, for RNNs and DeepSets'''
78 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
79 except NameError:
raise ValueError(f
'Invalid LVNN (RNN, DeepSet) TauID configuration: {tau_id}')
82 output_variable =
'RNNJetScore' if use_taujet_rnnscore
else f
'{tau_id}_Score'
85 sfx =
'_RNNJetScore' if use_taujet_rnnscore
else ''
87 acc.setPrivateTools(CompFactory.TauJetRNNEvaluator(
88 name = f
'TrigTau_TauJetLVNNEvaluator_{tau_id}{sfx}',
89 useTRT = flags.Detector.EnableTRT,
92 NetworkFile0P = id_flags.NetworkConfig[0],
93 NetworkFile1P = id_flags.NetworkConfig[1],
94 NetworkFile3P = id_flags.NetworkConfig[2],
95 InputLayerScalar =
'scalar',
96 InputLayerTracks =
'tracks',
97 InputLayerClusters =
'clusters',
98 OutputLayer =
'rnnid_output',
99 OutputNode =
'sig_prob',
102 MaxTracks = id_flags.MaxTracks,
103 MaxClusters = id_flags.MaxClusters,
105 VertexCorrection =
False,
106 TrackClassification =
False,
109 OutputVarname = output_variable,
115 '''TauJet identification inference based on ONNX models, for GNNs, transformers, etc...'''
118 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
119 except NameError:
raise ValueError(f
'Invalid ONNX TauID configuration: {tau_id}')
121 acc.setPrivateTools(CompFactory.TauGNNEvaluator(
122 name = f
'TrigTau_TauJetONNXEvaluator_{tau_id}',
123 useTRT = flags.Detector.EnableTRT,
126 NetworkFile0P = id_flags.ONNXConfig[0],
127 NetworkFile1P = id_flags.ONNXConfig[1],
128 NetworkFile3P = id_flags.ONNXConfig[2],
129 InputLayerScalar =
'tau_vars',
130 InputLayerTracks =
'track_vars',
131 InputLayerClusters =
'cluster_vars',
132 NodeNameTau =
'pTau',
133 NodeNameJet =
'pJet',
136 MaxTracks = id_flags.MaxTracks,
137 MaxClusters = id_flags.MaxClusters,
139 VertexCorrection =
False,
140 TrackClassification =
False,
143 OutputVarname = f
'{tau_id}_Score',
144 OutputDiscriminant = id_flags.OutputDiscriminant,
145 OutputPTau = f
'{tau_id}_ProbTau',
146 OutputPJet = f
'{tau_id}_ProbJet',
149 MinProngTrackPt = id_flags.MinProngTrackPt
if hasattr(id_flags,
'MinProngTrackPt')
else 0,
156 TauJet signal transformed score and ID WPs decorator tool,
157 for the legacy mediumRNN/tightRNN_tracktwoMVA/tracktwoLLP/trackLRT chains ONLY!
162 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
163 except NameError:
raise ValueError(f
'Invalid TauID configuration: {tau_id}')
167 import PyUtils.RootUtils
as ru
168 ROOT = ru.import_root()
170 cppyy.load_library(
'libxAODTau_cDict')
172 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigVeryLoose,
173 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigLoose,
174 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigMedium,
175 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigTight,
178 acc.setPrivateTools(CompFactory.TauWPDecorator(
179 name=f
'TrigTau_TauWPDecoratorRNN_{precision_seq_name}_{tau_id}',
180 flatteningFile0Prong=id_flags.ScoreFlatteningConfig[0],
181 flatteningFile1Prong=id_flags.ScoreFlatteningConfig[1],
182 flatteningFile3Prong=id_flags.ScoreFlatteningConfig[2],
183 CutEnumVals=WPEnumVals,
184 SigEff0P=id_flags.TargetEff[0],
185 SigEff1P=id_flags.TargetEff[1],
186 SigEff3P=id_flags.TargetEff[2],
187 ScoreName=
'RNNJetScore',
188 NewScoreName=
'RNNJetScoreSigTrans',
195 '''TauJet signal transformed score and ID WPs decorator tool'''
198 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
199 except NameError:
raise ValueError(f
'Invalid TauID configuration: {tau_id}')
201 acc.setPrivateTools(CompFactory.TauWPDecorator(
202 name=f
'TrigTau_TauWPDecoratorRNN_{precision_seq_name}_{tau_id}',
203 flatteningFile0Prong=id_flags.ScoreFlatteningConfig[0],
204 flatteningFile1Prong=id_flags.ScoreFlatteningConfig[1],
205 flatteningFile3Prong=id_flags.ScoreFlatteningConfig[2],
206 TauContainerName=tauContainerName,
207 DecorWPNames=[f
'{tau_id}_{wp}' for wp
in id_flags.WPNames],
208 DecorWPCutEffs0P=id_flags.TargetEff[0],
209 DecorWPCutEffs1P=id_flags.TargetEff[1],
210 DecorWPCutEffs3P=id_flags.TargetEff[2],
211 ScoreName=f
'{tau_id}_Score',
212 NewScoreName=f
'{tau_id}_ScoreSigTrans',