114def trigTauJetONNXEvaluatorCfg(flags, tau_id=''):
115 '''TauJet identification inference based on ONNX models, for GNNs, transformers, etc...'''
116 acc = ComponentAccumulator()
117
118 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
119 except NameError: raise ValueError(f'Invalid ONNX TauID configuration: {tau_id}')
120
121 acc.setPrivateTools(CompFactory.TauGNNEvaluator(
122 name = f'TrigTau_TauJetONNXEvaluator_{tau_id}',
123 useTRT = flags.Detector.EnableTRT,
124
125
126 NetworkFile0P = id_flags.ONNXConfig[0],
127 NetworkFile1P = id_flags.ONNXConfig[1],
128 NetworkFile3P = id_flags.ONNXConfig[2],
129 InputLayerScalar = 'tau_vars',
130 InputLayerTracks = 'track_vars',
131 InputLayerClusters = 'cluster_vars',
132 NodeNameTau = 'pTau',
133 NodeNameJet = 'pJet',
134
135
136 MaxTracks = id_flags.MaxTracks,
137 MaxClusters = id_flags.MaxClusters,
138 MaxClusterDR = 1.0,
139 VertexCorrection = False,
140 TrackClassification = False,
141
142
143 OutputVarname = f'{tau_id}_Score',
144 OutputDiscriminant = id_flags.OutputDiscriminant,
145 OutputPTau = f'{tau_id}_ProbTau',
146 OutputPJet = f'{tau_id}_ProbJet',
147
148
149 MinProngTrackPt = id_flags.MinProngTrackPt if hasattr(id_flags, 'MinProngTrackPt') else 0,
150 ))
151
152 return acc
153