113def trigTauJetONNXEvaluatorCfg(flags, tau_id=''):
114 '''TauJet identification inference based on ONNX models, for GNNs, transformers, etc...'''
115 acc = ComponentAccumulator()
116
117 try: id_flags = getattr(flags.Trigger.Offline.Tau, tau_id)
118 except NameError: raise ValueError(f'Invalid ONNX TauID configuration: {tau_id}')
119
120 if isinstance(id_flags.ONNXConfig, str):
121 network_config = {'NetworkFileInclusive': id_flags.ONNXConfig}
122 elif isinstance(id_flags.ONNXConfig, (list, tuple)) and len(id_flags.ONNXConfig) == 3:
123 network_config = {
124 'NetworkFile0P': id_flags.ONNXConfig[0],
125 'NetworkFile1P': id_flags.ONNXConfig[1],
126 'NetworkFile3P': id_flags.ONNXConfig[2],
127 }
128 else:
129 raise ValueError(f'Invalid {tau_id} ONNX network config file')
130
131 acc.setPrivateTools(CompFactory.TauGNNEvaluator(
132 name = f'TrigTau_TauJetONNXEvaluator_{tau_id}',
133 useTRT = flags.Detector.EnableTRT,
134
135
136 **network_config,
137 InputLayerScalar = 'tau_vars',
138 InputLayerTracks = 'track_vars',
139 InputLayerClusters = 'cluster_vars',
140 NodeNameTau = id_flags.NodeNameTau if hasattr(id_flags, 'NodeNameTau') else 'pTau',
141 NodeNameJet = id_flags.NodeNameJet if hasattr(id_flags, 'NodeNameJet') else 'pJet',
142
143
144
145 MaxTracks = id_flags.MaxTracks,
146 MaxClusters = id_flags.MaxClusters,
147 MaxClusterDR = 1.0,
148 VertexCorrection = False,
149 TrackClassification = False,
150
151
152 OutputVarname = f'{tau_id}_Score',
153 OutputDiscriminant = id_flags.OutputDiscriminant,
154 OutputPTau = f'{tau_id}_ProbTau',
155 OutputPJet = f'{tau_id}_ProbJet',
156
157
158 MinProngTrackPt = id_flags.MinProngTrackPt if hasattr(id_flags, 'MinProngTrackPt') else 0,
159 ))
160
161 return acc
162