13 flags, name="DumpObjects", outfile="Dump_GNN4Itk.root", **kwargs):
15 create algorithm which dumps GNN training information to ROOT file
17 acc = ComponentAccumulator()
21 Output=[f
"{name} DATAFILE='{outfile}', OPT='RECREATE'"]
25 kwargs.setdefault(
"NtupleFileName", flags.Tracking.GNN.DumpObjects.NtupleFileName)
26 kwargs.setdefault(
"NtupleTreeName", flags.Tracking.GNN.DumpObjects.NtupleTreeName)
27 kwargs.setdefault(
"rootFile",
True)
29 acc.addEventAlgo(CompFactory.InDet.DumpObjects(name, **kwargs))
33 """Sets up a GNNTrackFinderTool tool and returns it."""
34 acc = ComponentAccumulator()
37 kwargs.setdefault(
"embeddingDim", flags.Tracking.GNN.TrackFinder.embeddingDim)
38 kwargs.setdefault(
"rVal", flags.Tracking.GNN.TrackFinder.rVal)
39 kwargs.setdefault(
"knnVal", flags.Tracking.GNN.TrackFinder.knnVal)
40 kwargs.setdefault(
"filterCut", flags.Tracking.GNN.TrackFinder.filterCut)
41 kwargs.setdefault(
"inputMLModelDir", flags.Tracking.GNN.TrackFinder.inputMLModelDir)
42 kwargs.setdefault(
"ccCut", flags.Tracking.GNN.TrackFinder.ccCut)
43 kwargs.setdefault(
"walkMin", flags.Tracking.GNN.TrackFinder.walkMin)
44 kwargs.setdefault(
"walkMax", flags.Tracking.GNN.TrackFinder.walkMax)
45 kwargs.setdefault(
"EmbeddingFeatureNames", flags.Tracking.GNN.TrackFinder.EmbeddingFeatureNames)
46 kwargs.setdefault(
"EmbeddingFeatureScales", flags.Tracking.GNN.TrackFinder.EmbeddingFeatureScales)
47 kwargs.setdefault(
"FilterFeatureNames", flags.Tracking.GNN.TrackFinder.FilterFeatureNames)
48 kwargs.setdefault(
"FilterFeatureScales", flags.Tracking.GNN.TrackFinder.FilterFeatureScales)
49 kwargs.setdefault(
"GNNFeatureNames", flags.Tracking.GNN.TrackFinder.GNNFeatureNames)
50 kwargs.setdefault(
"GNNFeatureScales", flags.Tracking.GNN.TrackFinder.GNNFeatureScales)
52 from AthOnnxComps.OnnxRuntimeInferenceConfig
import OnnxRuntimeInferenceToolCfg
53 ort_exe_provider = flags.Tracking.GNN.TrackFinder.ORTExeProvider
54 kwargs.setdefault(
"Embedding", acc.popToolsAndMerge(
55 OnnxRuntimeInferenceToolCfg(flags, str(Path(kwargs[
"inputMLModelDir"]) /
"embedding.onnx"),
56 ort_exe_provider, name=
"Embedding")
58 kwargs.setdefault(
"Filtering", acc.popToolsAndMerge(
59 OnnxRuntimeInferenceToolCfg(flags, str(Path(kwargs[
"inputMLModelDir"]) /
"filtering.onnx"),
60 ort_exe_provider, name=
"Filtering")
62 kwargs.setdefault(
"GNN", acc.popToolsAndMerge(
63 OnnxRuntimeInferenceToolCfg(flags, str(Path(kwargs[
"inputMLModelDir"]) /
"gnn.onnx"),
64 ort_exe_provider, name=
"GNN")
67 acc.setPrivateTools(CompFactory.InDet.SiGNNTrackFinderTool(name, **kwargs))
72 """Sets up a GNNTrackFinderTritonTool tool and returns it."""
73 from AthTritonComps.TritonToolConfig
import TritonToolCfg
75 acc = ComponentAccumulator()
77 kwargs.setdefault(
"TritonTool", acc.popToolsAndMerge(
78 TritonToolCfg(flags, model_name=flags.Tracking.GNN.Triton.model,
79 url=flags.Tracking.GNN.Triton.url,
80 port=flags.Tracking.GNN.Triton.port,
84 kwargs.setdefault(
"FeatureNames", flags.Tracking.GNN.spacepointFeatures)
85 acc.setPrivateTools(CompFactory.InDet.GNNTrackFinderTritonTool(name, **kwargs))
90 """Sets up an ActsGnnModuleMapFinderTool and returns it."""
91 acc = ComponentAccumulator()
93 kwargs.setdefault(
"moduleMapPath", flags.Tracking.GNN.ActsPipeline.moduleMapPath)
94 kwargs.setdefault(
"gnnPath", flags.Tracking.GNN.ActsPipeline.gnnPath)
95 kwargs.setdefault(
"edgeCut", flags.Tracking.GNN.ActsPipeline.edgeCut)
96 kwargs.setdefault(
"numTrtContexts", flags.Tracking.GNN.ActsPipeline.numTrtContexts)
97 kwargs.setdefault(
"minCandidateMeasurements", flags.Tracking.GNN.ActsPipeline.minCandidateMeasurements)
100 acc.setPrivateTools(CompFactory.InDet.ActsGnnModuleMapFinderTool(name, **kwargs))
142 """Sets up a GNNTrackMaker algorithm and returns it."""
144 acc = ComponentAccumulator()
148 kwargs.setdefault(
"SeedFitterTool", SeedFitterTool)
150 from TrkConfig.CommonTrackFitterConfig
import ITkTrackFitterCfg
151 InDetTrackFitter = acc.popToolsAndMerge(ITkTrackFitterCfg(flags))
152 kwargs.setdefault(
"TrackFitter", InDetTrackFitter)
154 if "TrackSummaryTool" not in kwargs:
155 from TrkConfig.TrkTrackSummaryToolConfig
import ITkTrackSummaryToolCfg
158 "TrackSummaryTool", acc.popToolsAndMerge(ITkTrackSummaryToolCfg(flags))
161 if flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.TrackFinder:
163 kwargs.setdefault(
"GNNTrackFinderTool", InDetGNNTrackFinderTool)
164 kwargs.setdefault(
"GNNTrackReaderTool",
None)
165 elif flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.TrackReader:
167 kwargs.setdefault(
"GNNTrackReaderTool", InDetGNNTrackReader)
168 kwargs.setdefault(
"GNNTrackFinderTool",
None)
169 elif flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.Triton:
171 kwargs.setdefault(
"GNNTrackReaderTool",
None)
172 kwargs.setdefault(
"GNNTrackFinderTool", InDetGNNTrackFinderTool)
173 elif flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.ActsPipeline:
175 kwargs.setdefault(
"GNNTrackReaderTool",
None)
176 kwargs.setdefault(
"GNNTrackFinderTool", InDetGNNTrackFinderTool)
178 raise RuntimeError(
"GNNTrackFinder or GNNTrackReader must be enabled!")
180 kwargs.setdefault(
"areInputClusters", flags.Tracking.GNN.useClusterTracks)
181 kwargs.setdefault(
"doRecoTrackCuts", flags.Tracking.GNN.doRecoTrackCuts)
182 kwargs.setdefault(
"saveEdgeScore", flags.Tracking.GNN.ActsPipeline.saveEdgeScore)
185 if "InDetEtaDependentCutSvc" not in kwargs:
186 from InDetConfig.InDetEtaDependentCutsConfig
import ITkEtaDependentCutsSvcCfg
187 acc.merge(ITkEtaDependentCutsSvcCfg(flags))
188 kwargs.setdefault(
"InDetEtaDependentCutsSvc", acc.getService(
"ITkEtaDependentCutsSvc"+flags.Tracking.ActiveConfig.extension))
190 kwargs.setdefault(
"minClusters", flags.Tracking.GNN.minClusters)
191 kwargs.setdefault(
"pTmin", flags.Tracking.GNN.pTmin)
192 kwargs.setdefault(
"etamax", flags.Tracking.GNN.etamax)
193 kwargs.setdefault(
"minPixelClusters", flags.Tracking.GNN.minPixelClusters)
194 kwargs.setdefault(
"minStripClusters", flags.Tracking.GNN.minStripClusters)
196 acc.addEventAlgo(CompFactory.InDet.SiSPGNNTrackMaker(name, **kwargs))
200 """Sets up a GNN for seeding algorithm and returns it."""
201 acc = ComponentAccumulator()
203 from InDetConfig.SiCombinatorialTrackFinderToolConfig
import SiDetElementBoundaryLinksCondAlg_xk_ITkPixel_Cfg, SiDetElementBoundaryLinksCondAlg_xk_ITkStrip_Cfg
204 acc.merge(SiDetElementBoundaryLinksCondAlg_xk_ITkPixel_Cfg(flags))
205 acc.merge(SiDetElementBoundaryLinksCondAlg_xk_ITkStrip_Cfg(flags))
208 from MagFieldServices.MagFieldServicesConfig
import (
209 AtlasFieldCacheCondAlgCfg)
210 acc.merge(AtlasFieldCacheCondAlgCfg(flags))
212 from TrkConfig.TrkRIO_OnTrackCreatorConfig
import ITkRotCreatorCfg
213 ITkRotCreator = acc.popToolsAndMerge(ITkRotCreatorCfg(
214 flags, name=
"ITkRotCreator"+flags.Tracking.ActiveConfig.extension))
215 acc.addPublicTool(ITkRotCreator)
216 kwargs.setdefault(
"RIOonTrackTool", ITkRotCreator)
218 from TrkConfig.TrkExRungeKuttaPropagatorConfig
import (
219 RungeKuttaPropagatorCfg)
220 ITkPatternPropagator = acc.popToolsAndMerge(
221 RungeKuttaPropagatorCfg(flags, name=
"ITkPatternPropagator"))
222 acc.addPublicTool(ITkPatternPropagator)
223 kwargs.setdefault(
"PropagatorTool", ITkPatternPropagator)
225 from TrkConfig.TrkMeasurementUpdatorConfig
import KalmanUpdator_xkCfg
226 ITkPatternUpdator = acc.popToolsAndMerge(
227 KalmanUpdator_xkCfg(flags, name=
"ITkPatternUpdator"))
228 acc.addPublicTool(ITkPatternUpdator)
229 kwargs.setdefault(
"UpdatorTool", ITkPatternUpdator)
231 from InDetConfig.InDetBoundaryCheckToolConfig
import ITkBoundaryCheckToolCfg
232 kwargs.setdefault(
"BoundaryCheckTool", acc.popToolsAndMerge(
233 ITkBoundaryCheckToolCfg(flags)))
235 from PixelConditionsTools.ITkPixelConditionsSummaryConfig
import (
236 ITkPixelConditionsSummaryCfg)
237 kwargs.setdefault(
"PixelSummaryTool", acc.popToolsAndMerge(
238 ITkPixelConditionsSummaryCfg(flags)))
240 from SCT_ConditionsTools.ITkStripConditionsToolsConfig
import (
241 ITkStripConditionsSummaryToolCfg)
242 kwargs.setdefault(
"StripSummaryTool", acc.popToolsAndMerge(
243 ITkStripConditionsSummaryToolCfg(flags)))
245 if flags.Tracking.GNN.useTrackFinder:
247 kwargs.setdefault(
"GNNTrackReaderTool",
None)
248 elif flags.Tracking.GNN.useTrackReader:
250 kwargs.setdefault(
"GNNTrackFinderTool",
None)
252 raise RuntimeError(
"GNNTrackFinder or GNNTrackReader must be enabled!")
254 kwargs.setdefault(
"SeedFitterTool", acc.popToolsAndMerge(
SeedFitterToolCfg(flags)))
256 from TrkConfig.CommonTrackFitterConfig
import ITkTrackFitterCfg
257 kwargs.setdefault(
"TrackFitter", acc.popToolsAndMerge(ITkTrackFitterCfg(flags)))
259 from InDetConfig.SiDetElementsRoadToolConfig
import ITkSiDetElementsRoadMaker_xkCfg
260 kwargs.setdefault(
"RoadTool", acc.popToolsAndMerge(ITkSiDetElementsRoadMaker_xkCfg(flags)))
264 kwargs.setdefault(
"nClustersMin", flags.Tracking.ActiveConfig.minClusters[0])
265 kwargs.setdefault(
"nWeightedClustersMin", flags.Tracking.ActiveConfig.nWeightedClustersMin[0])
266 kwargs.setdefault(
"nHolesMax", flags.Tracking.ActiveConfig.nHolesMax[0])
267 kwargs.setdefault(
"nHolesGapMax", flags.Tracking.ActiveConfig.nHolesGapMax[0])
269 kwargs.setdefault(
"pTmin", flags.Tracking.ActiveConfig.minPT[0])
270 kwargs.setdefault(
"pTminBrem", flags.Tracking.ActiveConfig.minPTBrem[0])
271 kwargs.setdefault(
"Xi2max", flags.Tracking.ActiveConfig.Xi2max[0])
272 kwargs.setdefault(
"Xi2maxNoAdd", flags.Tracking.ActiveConfig.Xi2maxNoAdd[0])
273 kwargs.setdefault(
"Xi2maxMultiTracks", flags.Tracking.ActiveConfig.Xi2max[0])
274 kwargs.setdefault(
"doMultiTracksProd",
False)
276 acc.addEventAlgo(CompFactory.InDet.GNNSeedingTrackMaker(name, **kwargs))