3 from AthenaConfiguration.ComponentAccumulator
import ComponentAccumulator
4 from AthenaConfiguration.ComponentFactory
import CompFactory
8 '''Algorithm that overwrites numTrack() and charge() of tauJets in container'''
11 acc.setPrivateTools(CompFactory.TauVertexFinder(
14 AssociatedTracks =
'GhostTrack',
15 InDetTrackSelectionToolForTJVA =
'',
16 Key_trackPartInputContainer =
'',
17 Key_vertexInputContainer =
'',
24 '''Tau track association'''
27 from TrkConfig.TrkVertexFitterUtilsConfig
import AtlasTrackToVertexIPEstimatorCfg
30 from TrackToVertex.TrackToVertexConfig
import TrackToVertexCfg
33 from TrackToCalo.TrackToCaloConfig
import ParticleCaloExtensionToolCfg
36 from InDetConfig.InDetTrackSelectorToolConfig
import TrigTauInDetTrackSelectorToolCfg
39 acc.setPrivateTools(CompFactory.TauTrackFinder(
43 TrackSelectorToolTau = TrigTauInDetTrackSelectorTool,
44 TrackToVertexTool = TrackToVertexTool,
45 Key_trackPartInputContainer = TrackParticlesContainer,
46 maxDeltaZ0wrtLeadTrk = 0.75*mm,
47 removeTracksOutsideZ0wrtLeadTrk =
True,
48 ParticleCaloExtensionTool = ParticleCaloExtensionTool,
49 BypassSelector =
False,
50 BypassExtrapolator =
True,
51 tauParticleCache =
"",
52 TrackToVertexIPEstimator = AtlasTrackToVertexIPEstimator,
58 '''Vertex variables calculation'''
61 from TrkConfig.TrkVertexFittersConfig
import TauAdaptiveVertexFitterCfg
64 from TrkConfig.TrkVertexSeedFinderToolsConfig
import CrossDistancesSeedFinderCfg
67 acc.setPrivateTools(CompFactory.TauVertexVariables(
69 VertexFitter = TauAdaptiveVertexFitter,
70 SeedFinder = CrossDistancesSeedFinder,
76 '''TauJet identification inference based on LVNN models, for RNNs and DeepSets'''
79 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
80 except NameError:
raise ValueError(f
'Invalid LVNN (RNN, DeepSet) TauID configuration: {tau_id}')
83 output_variable =
'RNNJetScore' if use_taujet_rnnscore
else f
'{tau_id}_Score'
86 sfx =
'_RNNJetScore' if use_taujet_rnnscore
else ''
88 acc.setPrivateTools(CompFactory.TauJetRNNEvaluator(
89 name = f
'TrigTau_TauJetLVNNEvaluator_{tau_id}{sfx}',
92 NetworkFile0P = id_flags.NetworkConfig[0],
93 NetworkFile1P = id_flags.NetworkConfig[1],
94 NetworkFile3P = id_flags.NetworkConfig[2],
95 InputLayerScalar =
'scalar',
96 InputLayerTracks =
'tracks',
97 InputLayerClusters =
'clusters',
98 OutputLayer =
'rnnid_output',
99 OutputNode =
'sig_prob',
102 MaxTracks = id_flags.MaxTracks,
103 MaxClusters = id_flags.MaxClusters,
105 VertexCorrection =
False,
106 TrackClassification =
False,
109 OutputVarname = output_variable,
115 '''TauJet identification inference based on ONNX models, for GNNs, transformers, etc...'''
118 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
119 except NameError:
raise ValueError(f
'Invalid ONNX TauID configuration: {tau_id}')
121 acc.setPrivateTools(CompFactory.TauGNNEvaluator(
122 name = f
'TrigTau_TauJetONNXEvaluator_{tau_id}',
125 NetworkFile0P = id_flags.ONNXConfig[0],
126 NetworkFile1P = id_flags.ONNXConfig[1],
127 NetworkFile3P = id_flags.ONNXConfig[2],
128 InputLayerScalar =
'tau_vars',
129 InputLayerTracks =
'track_vars',
130 InputLayerClusters =
'cluster_vars',
131 NodeNameTau =
'pTau',
132 NodeNameJet =
'pJet',
135 MaxTracks = id_flags.MaxTracks,
136 MaxClusters = id_flags.MaxClusters,
138 VertexCorrection =
False,
139 TrackClassification =
False,
142 OutputVarname = f
'{tau_id}_Score',
143 OutputPTau = f
'{tau_id}_ProbTau',
144 OutputPJet = f
'{tau_id}_ProbJet',
147 MinProngTrackPt = id_flags.MinProngTrackPt
if hasattr(id_flags,
'MinProngTrackPt')
else 0,
154 TauJet signal transformed score and ID WPs decorator tool,
155 for the legacy mediumRNN/tightRNN_tracktwoMVA/tracktwoLLP/trackLRT chains ONLY!
160 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
161 except NameError:
raise ValueError(f
'Invalid TauID configuration: {tau_id}')
165 import PyUtils.RootUtils
as ru
166 ROOT = ru.import_root()
168 cppyy.load_library(
'libxAODTau_cDict')
170 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigVeryLoose,
171 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigLoose,
172 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigMedium,
173 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigTight,
176 acc.setPrivateTools(CompFactory.TauWPDecorator(
177 name=f
'TrigTau_TauWPDecoratorRNN_{precision_seq_name}_{tau_id}',
178 flatteningFile0Prong=id_flags.ScoreFlatteningConfig[0],
179 flatteningFile1Prong=id_flags.ScoreFlatteningConfig[1],
180 flatteningFile3Prong=id_flags.ScoreFlatteningConfig[2],
181 CutEnumVals=WPEnumVals,
182 SigEff0P=id_flags.TargetEff[0],
183 SigEff1P=id_flags.TargetEff[1],
184 SigEff3P=id_flags.TargetEff[2],
185 ScoreName=
'RNNJetScore',
186 NewScoreName=
'RNNJetScoreSigTrans',
193 '''TauJet signal transformed score and ID WPs decorator tool'''
196 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
197 except NameError:
raise ValueError(f
'Invalid TauID configuration: {tau_id}')
199 acc.setPrivateTools(CompFactory.TauWPDecorator(
200 name=f
'TrigTau_TauWPDecoratorRNN_{precision_seq_name}_{tau_id}',
201 flatteningFile0Prong=id_flags.ScoreFlatteningConfig[0],
202 flatteningFile1Prong=id_flags.ScoreFlatteningConfig[1],
203 flatteningFile3Prong=id_flags.ScoreFlatteningConfig[2],
204 DecorWPNames=[f
'{tau_id}_{wp}' for wp
in id_flags.WPNames],
205 DecorWPCutEffs0P=id_flags.TargetEff[0],
206 DecorWPCutEffs1P=id_flags.TargetEff[1],
207 DecorWPCutEffs3P=id_flags.TargetEff[2],
208 ScoreName=f
'{tau_id}_Score',
209 NewScoreName=f
'{tau_id}_ScoreSigTrans',