13 flags, name="DumpObjects", outfile="Dump_GNN4Itk.root", **kwargs):
15 create algorithm which dumps GNN training information to ROOT file
17 acc = ComponentAccumulator()
21 Output=[f
"{name} DATAFILE='{outfile}', OPT='RECREATE'"]
25 kwargs.setdefault(
"NtupleFileName", flags.Tracking.GNN.DumpObjects.NtupleFileName)
26 kwargs.setdefault(
"NtupleTreeName", flags.Tracking.GNN.DumpObjects.NtupleTreeName)
27 kwargs.setdefault(
"rootFile",
True)
29 acc.addEventAlgo(CompFactory.InDet.DumpObjects(name, **kwargs))
33 """Sets up a GNNTrackFinderTool tool and returns it."""
34 acc = ComponentAccumulator()
37 kwargs.setdefault(
"embeddingDim", flags.Tracking.GNN.TrackFinder.embeddingDim)
38 kwargs.setdefault(
"rVal", flags.Tracking.GNN.TrackFinder.rVal)
39 kwargs.setdefault(
"knnVal", flags.Tracking.GNN.TrackFinder.knnVal)
40 kwargs.setdefault(
"filterCut", flags.Tracking.GNN.TrackFinder.filterCut)
41 kwargs.setdefault(
"inputMLModelDir", flags.Tracking.GNN.TrackFinder.inputMLModelDir)
42 kwargs.setdefault(
"ccCut", flags.Tracking.GNN.TrackFinder.ccCut)
43 kwargs.setdefault(
"walkMin", flags.Tracking.GNN.TrackFinder.walkMin)
44 kwargs.setdefault(
"walkMax", flags.Tracking.GNN.TrackFinder.walkMax)
45 kwargs.setdefault(
"EmbeddingFeatureNames", flags.Tracking.GNN.TrackFinder.EmbeddingFeatureNames)
46 kwargs.setdefault(
"EmbeddingFeatureScales", flags.Tracking.GNN.TrackFinder.EmbeddingFeatureScales)
47 kwargs.setdefault(
"FilterFeatureNames", flags.Tracking.GNN.TrackFinder.FilterFeatureNames)
48 kwargs.setdefault(
"FilterFeatureScales", flags.Tracking.GNN.TrackFinder.FilterFeatureScales)
49 kwargs.setdefault(
"GNNFeatureNames", flags.Tracking.GNN.TrackFinder.GNNFeatureNames)
50 kwargs.setdefault(
"GNNFeatureScales", flags.Tracking.GNN.TrackFinder.GNNFeatureScales)
52 from AthOnnxComps.OnnxRuntimeInferenceConfig
import OnnxRuntimeInferenceToolCfg
53 ort_exe_provider = flags.Tracking.GNN.TrackFinder.ORTExeProvider
54 kwargs.setdefault(
"Embedding", acc.popToolsAndMerge(
55 OnnxRuntimeInferenceToolCfg(flags, str(Path(kwargs[
"inputMLModelDir"]) /
"embedding.onnx"),
56 ort_exe_provider, name=
"Embedding")
58 kwargs.setdefault(
"Filtering", acc.popToolsAndMerge(
59 OnnxRuntimeInferenceToolCfg(flags, str(Path(kwargs[
"inputMLModelDir"]) /
"filtering.onnx"),
60 ort_exe_provider, name=
"Filtering")
62 kwargs.setdefault(
"GNN", acc.popToolsAndMerge(
63 OnnxRuntimeInferenceToolCfg(flags, str(Path(kwargs[
"inputMLModelDir"]) /
"gnn.onnx"),
64 ort_exe_provider, name=
"GNN")
67 acc.setPrivateTools(CompFactory.InDet.SiGNNTrackFinderTool(name, **kwargs))
72 """Sets up a GNNTrackFinderTritonTool tool and returns it."""
73 from AthTritonComps.TritonToolConfig
import TritonToolCfg
75 acc = ComponentAccumulator()
77 kwargs.setdefault(
"TritonTool", acc.popToolsAndMerge(
78 TritonToolCfg(flags, model_name=flags.Tracking.GNN.Triton.model,
79 url=flags.Tracking.GNN.Triton.url,
80 port=flags.Tracking.GNN.Triton.port,
85 acc.setPrivateTools(CompFactory.InDet.GNNTrackFinderTritonTool(name, **kwargs))
127 """Sets up a GNNTrackMaker algorithm and returns it."""
129 acc = ComponentAccumulator()
133 kwargs.setdefault(
"SeedFitterTool", SeedFitterTool)
135 from TrkConfig.CommonTrackFitterConfig
import ITkTrackFitterCfg
136 InDetTrackFitter = acc.popToolsAndMerge(ITkTrackFitterCfg(flags))
137 kwargs.setdefault(
"TrackFitter", InDetTrackFitter)
139 if "TrackSummaryTool" not in kwargs:
140 from TrkConfig.TrkTrackSummaryToolConfig
import ITkTrackSummaryToolCfg
143 "TrackSummaryTool", acc.popToolsAndMerge(ITkTrackSummaryToolCfg(flags))
146 if flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.TrackFinder:
148 kwargs.setdefault(
"GNNTrackFinderTool", InDetGNNTrackFinderTool)
149 kwargs.setdefault(
"GNNTrackReaderTool",
None)
150 elif flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.TrackReader:
152 kwargs.setdefault(
"GNNTrackReaderTool", InDetGNNTrackReader)
153 kwargs.setdefault(
"GNNTrackFinderTool",
None)
154 elif flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.Triton:
156 kwargs.setdefault(
"GNNTrackReaderTool",
None)
157 kwargs.setdefault(
"GNNTrackFinderTool", InDetGNNTrackFinderTool)
159 raise RuntimeError(
"GNNTrackFinder or GNNTrackReader must be enabled!")
161 kwargs.setdefault(
"areInputClusters", flags.Tracking.GNN.useClusterTracks)
162 kwargs.setdefault(
"doRecoTrackCuts", flags.Tracking.GNN.doRecoTrackCuts)
165 if "InDetEtaDependentCutSvc" not in kwargs:
166 from InDetConfig.InDetEtaDependentCutsConfig
import ITkEtaDependentCutsSvcCfg
167 acc.merge(ITkEtaDependentCutsSvcCfg(flags))
168 kwargs.setdefault(
"InDetEtaDependentCutsSvc", acc.getService(
"ITkEtaDependentCutsSvc"+flags.Tracking.ActiveConfig.extension))
170 kwargs.setdefault(
"minClusters", flags.Tracking.GNN.minClusters)
171 kwargs.setdefault(
"pTmin", flags.Tracking.GNN.pTmin)
172 kwargs.setdefault(
"etamax", flags.Tracking.GNN.etamax)
173 kwargs.setdefault(
"minPixelClusters", flags.Tracking.GNN.minPixelClusters)
174 kwargs.setdefault(
"minStripClusters", flags.Tracking.GNN.minStripClusters)
176 acc.addEventAlgo(CompFactory.InDet.SiSPGNNTrackMaker(name, **kwargs))
180 """Sets up a GNN for seeding algorithm and returns it."""
181 acc = ComponentAccumulator()
183 from InDetConfig.SiCombinatorialTrackFinderToolConfig
import SiDetElementBoundaryLinksCondAlg_xk_ITkPixel_Cfg, SiDetElementBoundaryLinksCondAlg_xk_ITkStrip_Cfg
184 acc.merge(SiDetElementBoundaryLinksCondAlg_xk_ITkPixel_Cfg(flags))
185 acc.merge(SiDetElementBoundaryLinksCondAlg_xk_ITkStrip_Cfg(flags))
188 from MagFieldServices.MagFieldServicesConfig
import (
189 AtlasFieldCacheCondAlgCfg)
190 acc.merge(AtlasFieldCacheCondAlgCfg(flags))
192 from TrkConfig.TrkRIO_OnTrackCreatorConfig
import ITkRotCreatorCfg
193 ITkRotCreator = acc.popToolsAndMerge(ITkRotCreatorCfg(
194 flags, name=
"ITkRotCreator"+flags.Tracking.ActiveConfig.extension))
195 acc.addPublicTool(ITkRotCreator)
196 kwargs.setdefault(
"RIOonTrackTool", ITkRotCreator)
198 from TrkConfig.TrkExRungeKuttaPropagatorConfig
import (
199 RungeKuttaPropagatorCfg)
200 ITkPatternPropagator = acc.popToolsAndMerge(
201 RungeKuttaPropagatorCfg(flags, name=
"ITkPatternPropagator"))
202 acc.addPublicTool(ITkPatternPropagator)
203 kwargs.setdefault(
"PropagatorTool", ITkPatternPropagator)
205 from TrkConfig.TrkMeasurementUpdatorConfig
import KalmanUpdator_xkCfg
206 ITkPatternUpdator = acc.popToolsAndMerge(
207 KalmanUpdator_xkCfg(flags, name=
"ITkPatternUpdator"))
208 acc.addPublicTool(ITkPatternUpdator)
209 kwargs.setdefault(
"UpdatorTool", ITkPatternUpdator)
211 from InDetConfig.InDetBoundaryCheckToolConfig
import ITkBoundaryCheckToolCfg
212 kwargs.setdefault(
"BoundaryCheckTool", acc.popToolsAndMerge(
213 ITkBoundaryCheckToolCfg(flags)))
215 from PixelConditionsTools.ITkPixelConditionsSummaryConfig
import (
216 ITkPixelConditionsSummaryCfg)
217 kwargs.setdefault(
"PixelSummaryTool", acc.popToolsAndMerge(
218 ITkPixelConditionsSummaryCfg(flags)))
220 from SCT_ConditionsTools.ITkStripConditionsToolsConfig
import (
221 ITkStripConditionsSummaryToolCfg)
222 kwargs.setdefault(
"StripSummaryTool", acc.popToolsAndMerge(
223 ITkStripConditionsSummaryToolCfg(flags)))
225 if flags.Tracking.GNN.useTrackFinder:
227 kwargs.setdefault(
"GNNTrackReaderTool",
None)
228 elif flags.Tracking.GNN.useTrackReader:
230 kwargs.setdefault(
"GNNTrackFinderTool",
None)
232 raise RuntimeError(
"GNNTrackFinder or GNNTrackReader must be enabled!")
234 kwargs.setdefault(
"SeedFitterTool", acc.popToolsAndMerge(
SeedFitterToolCfg(flags)))
236 from TrkConfig.CommonTrackFitterConfig
import ITkTrackFitterCfg
237 kwargs.setdefault(
"TrackFitter", acc.popToolsAndMerge(ITkTrackFitterCfg(flags)))
239 from InDetConfig.SiDetElementsRoadToolConfig
import ITkSiDetElementsRoadMaker_xkCfg
240 kwargs.setdefault(
"RoadTool", acc.popToolsAndMerge(ITkSiDetElementsRoadMaker_xkCfg(flags)))
244 kwargs.setdefault(
"nClustersMin", flags.Tracking.ActiveConfig.minClusters[0])
245 kwargs.setdefault(
"nWeightedClustersMin", flags.Tracking.ActiveConfig.nWeightedClustersMin[0])
246 kwargs.setdefault(
"nHolesMax", flags.Tracking.ActiveConfig.nHolesMax[0])
247 kwargs.setdefault(
"nHolesGapMax", flags.Tracking.ActiveConfig.nHolesGapMax[0])
249 kwargs.setdefault(
"pTmin", flags.Tracking.ActiveConfig.minPT[0])
250 kwargs.setdefault(
"pTminBrem", flags.Tracking.ActiveConfig.minPTBrem[0])
251 kwargs.setdefault(
"Xi2max", flags.Tracking.ActiveConfig.Xi2max[0])
252 kwargs.setdefault(
"Xi2maxNoAdd", flags.Tracking.ActiveConfig.Xi2maxNoAdd[0])
253 kwargs.setdefault(
"Xi2maxMultiTracks", flags.Tracking.ActiveConfig.Xi2max[0])
254 kwargs.setdefault(
"doMultiTracksProd",
False)
256 acc.addEventAlgo(CompFactory.InDet.GNNSeedingTrackMaker(name, **kwargs))