3 from AthenaConfiguration.ComponentAccumulator
import ComponentAccumulator
4 from AthenaConfiguration.ComponentFactory
import CompFactory
8 '''Algorithm that overwrites numTrack() and charge() of tauJets in container'''
11 acc.setPrivateTools(CompFactory.TauVertexFinder(
14 AssociatedTracks =
'GhostTrack',
15 InDetTrackSelectionToolForTJVA =
'',
16 Key_trackPartInputContainer =
'',
17 Key_vertexInputContainer =
'',
24 '''Tau track association'''
27 from TrkConfig.TrkVertexFitterUtilsConfig
import AtlasTrackToVertexIPEstimatorCfg
30 from TrackToVertex.TrackToVertexConfig
import TrackToVertexCfg
33 from TrackToCalo.TrackToCaloConfig
import ParticleCaloExtensionToolCfg
36 from InDetConfig.InDetTrackSelectorToolConfig
import TrigTauInDetTrackSelectorToolCfg
39 acc.setPrivateTools(CompFactory.TauTrackFinder(
43 TrackSelectorToolTau = TrigTauInDetTrackSelectorTool,
44 TrackToVertexTool = TrackToVertexTool,
45 Key_trackPartInputContainer = TrackParticlesContainer,
46 maxDeltaZ0wrtLeadTrk = 0.75*mm,
47 removeTracksOutsideZ0wrtLeadTrk =
True,
48 ParticleCaloExtensionTool = ParticleCaloExtensionTool,
49 BypassExtrapolator =
True,
50 tauParticleCache =
"",
51 TrackToVertexIPEstimator = AtlasTrackToVertexIPEstimator,
57 '''Vertex variables calculation'''
60 from TrkConfig.TrkVertexFittersConfig
import TauAdaptiveVertexFitterCfg
63 from TrkConfig.TrkVertexSeedFinderToolsConfig
import CrossDistancesSeedFinderCfg
66 acc.setPrivateTools(CompFactory.TauVertexVariables(
68 VertexFitter = TauAdaptiveVertexFitter,
69 SeedFinder = CrossDistancesSeedFinder,
75 '''TauJet identification inference based on LVNN models, for RNNs and DeepSets'''
78 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
79 except NameError:
raise ValueError(f
'Invalid LVNN (RNN, DeepSet) TauID configuration: {tau_id}')
82 output_variable =
'RNNJetScore' if use_taujet_rnnscore
else f
'{tau_id}_Score'
85 sfx =
'_RNNJetScore' if use_taujet_rnnscore
else ''
87 acc.setPrivateTools(CompFactory.TauJetRNNEvaluator(
88 name = f
'TrigTau_TauJetLVNNEvaluator_{tau_id}{sfx}',
91 NetworkFile0P = id_flags.NetworkConfig[0],
92 NetworkFile1P = id_flags.NetworkConfig[1],
93 NetworkFile3P = id_flags.NetworkConfig[2],
94 InputLayerScalar =
'scalar',
95 InputLayerTracks =
'tracks',
96 InputLayerClusters =
'clusters',
97 OutputLayer =
'rnnid_output',
98 OutputNode =
'sig_prob',
101 MaxTracks = id_flags.MaxTracks,
102 MaxClusters = id_flags.MaxClusters,
104 VertexCorrection =
False,
105 TrackClassification =
False,
108 OutputVarname = output_variable,
114 '''TauJet identification inference based on ONNX models, for GNNs, transformers, etc...'''
117 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
118 except NameError:
raise ValueError(f
'Invalid ONNX TauID configuration: {tau_id}')
120 acc.setPrivateTools(CompFactory.TauGNNEvaluator(
121 name = f
'TrigTau_TauJetONNXEvaluator_{tau_id}',
124 NetworkFile0P = id_flags.ONNXConfig[0],
125 NetworkFile1P = id_flags.ONNXConfig[1],
126 NetworkFile3P = id_flags.ONNXConfig[2],
127 InputLayerScalar =
'tau_vars',
128 InputLayerTracks =
'track_vars',
129 InputLayerClusters =
'cluster_vars',
130 NodeNameTau =
'pTau',
131 NodeNameJet =
'pJet',
134 MaxTracks = id_flags.MaxTracks,
135 MaxClusters = id_flags.MaxClusters,
137 VertexCorrection =
False,
138 TrackClassification =
False,
141 OutputVarname = f
'{tau_id}_Score',
142 OutputDiscriminant = id_flags.OutputDiscriminant,
143 OutputPTau = f
'{tau_id}_ProbTau',
144 OutputPJet = f
'{tau_id}_ProbJet',
147 MinProngTrackPt = id_flags.MinProngTrackPt
if hasattr(id_flags,
'MinProngTrackPt')
else 0,
154 TauJet signal transformed score and ID WPs decorator tool,
155 for the legacy mediumRNN/tightRNN_tracktwoMVA/tracktwoLLP/trackLRT chains ONLY!
160 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
161 except NameError:
raise ValueError(f
'Invalid TauID configuration: {tau_id}')
165 import PyUtils.RootUtils
as ru
166 ROOT = ru.import_root()
168 cppyy.load_library(
'libxAODTau_cDict')
170 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigVeryLoose,
171 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigLoose,
172 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigMedium,
173 ROOT.xAOD.TauJetParameters.IsTauFlag.JetRNNSigTight,
176 acc.setPrivateTools(CompFactory.TauWPDecorator(
177 name=f
'TrigTau_TauWPDecoratorRNN_{precision_seq_name}_{tau_id}',
178 flatteningFile0Prong=id_flags.ScoreFlatteningConfig[0],
179 flatteningFile1Prong=id_flags.ScoreFlatteningConfig[1],
180 flatteningFile3Prong=id_flags.ScoreFlatteningConfig[2],
181 CutEnumVals=WPEnumVals,
182 SigEff0P=id_flags.TargetEff[0],
183 SigEff1P=id_flags.TargetEff[1],
184 SigEff3P=id_flags.TargetEff[2],
185 ScoreName=
'RNNJetScore',
186 NewScoreName=
'RNNJetScoreSigTrans',
193 '''TauJet signal transformed score and ID WPs decorator tool'''
196 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
197 except NameError:
raise ValueError(f
'Invalid TauID configuration: {tau_id}')
199 acc.setPrivateTools(CompFactory.TauWPDecorator(
200 name=f
'TrigTau_TauWPDecoratorRNN_{precision_seq_name}_{tau_id}',
201 flatteningFile0Prong=id_flags.ScoreFlatteningConfig[0],
202 flatteningFile1Prong=id_flags.ScoreFlatteningConfig[1],
203 flatteningFile3Prong=id_flags.ScoreFlatteningConfig[2],
204 TauContainerName=tauContainerName,
205 DecorWPNames=[f
'{tau_id}_{wp}' for wp
in id_flags.WPNames],
206 DecorWPCutEffs0P=id_flags.TargetEff[0],
207 DecorWPCutEffs1P=id_flags.TargetEff[1],
208 DecorWPCutEffs3P=id_flags.TargetEff[2],
209 ScoreName=f
'{tau_id}_Score',
210 NewScoreName=f
'{tau_id}_ScoreSigTrans',