5 from pathlib
import Path
7 from AthenaConfiguration.ComponentAccumulator
import ComponentAccumulator
8 from AthenaConfiguration.ComponentFactory
import CompFactory
9 from InDetGNNTracking.InDetGNNTrackingConfigFlags
import GNNTrackFinderToolType
13 flags, name="DumpObjects", outfile="Dump_GNN4Itk.root", **kwargs):
15 create algorithm which dumps GNN training information to ROOT file
21 Output=[f
"{name} DATAFILE='{outfile}', OPT='RECREATE'"]
25 kwargs.setdefault(
"NtupleFileName", flags.Tracking.GNN.DumpObjects.NtupleFileName)
26 kwargs.setdefault(
"NtupleTreeName", flags.Tracking.GNN.DumpObjects.NtupleTreeName)
27 kwargs.setdefault(
"rootFile",
True)
29 acc.addEventAlgo(CompFactory.InDet.DumpObjects(name, **kwargs))
33 """Sets up a GNNTrackFinderTool tool and returns it."""
37 kwargs.setdefault(
"embeddingDim", flags.Tracking.GNN.TrackFinder.embeddingDim)
38 kwargs.setdefault(
"rVal", flags.Tracking.GNN.TrackFinder.rVal)
39 kwargs.setdefault(
"knnVal", flags.Tracking.GNN.TrackFinder.knnVal)
40 kwargs.setdefault(
"filterCut", flags.Tracking.GNN.TrackFinder.filterCut)
41 kwargs.setdefault(
"inputMLModelDir", flags.Tracking.GNN.TrackFinder.inputMLModelDir)
42 kwargs.setdefault(
"ccCut", flags.Tracking.GNN.TrackFinder.ccCut)
43 kwargs.setdefault(
"walkMin", flags.Tracking.GNN.TrackFinder.walkMin)
44 kwargs.setdefault(
"walkMax", flags.Tracking.GNN.TrackFinder.walkMax)
45 kwargs.setdefault(
"EmbeddingFeatureNames", flags.Tracking.GNN.TrackFinder.EmbeddingFeatureNames)
46 kwargs.setdefault(
"EmbeddingFeatureScales", flags.Tracking.GNN.TrackFinder.EmbeddingFeatureScales)
47 kwargs.setdefault(
"FilterFeatureNames", flags.Tracking.GNN.TrackFinder.FilterFeatureNames)
48 kwargs.setdefault(
"FilterFeatureScales", flags.Tracking.GNN.TrackFinder.FilterFeatureScales)
49 kwargs.setdefault(
"GNNFeatureNames", flags.Tracking.GNN.TrackFinder.GNNFeatureNames)
50 kwargs.setdefault(
"GNNFeatureScales", flags.Tracking.GNN.TrackFinder.GNNFeatureScales)
52 from AthOnnxComps.OnnxRuntimeInferenceConfig
import OnnxRuntimeInferenceToolCfg
53 ort_exe_provider = flags.Tracking.GNN.TrackFinder.ORTExeProvider
54 kwargs.setdefault(
"Embedding", acc.popToolsAndMerge(
56 ort_exe_provider, name=
"Embedding")
58 kwargs.setdefault(
"Filtering", acc.popToolsAndMerge(
60 ort_exe_provider, name=
"Filtering")
62 kwargs.setdefault(
"GNN", acc.popToolsAndMerge(
64 ort_exe_provider, name=
"GNN")
67 acc.setPrivateTools(CompFactory.InDet.SiGNNTrackFinderTool(name, **kwargs))
72 """Sets up a GNNTrackFinderTritonTool tool and returns it."""
73 from AthTritonComps.TritonToolConfig
import TritonToolCfg
77 kwargs.setdefault(
"TritonTool", acc.popToolsAndMerge(
78 TritonToolCfg(flags, model_name=flags.Tracking.GNN.Triton.model,
79 url=flags.Tracking.GNN.Triton.url,
80 port=flags.Tracking.GNN.Triton.port,
85 acc.setPrivateTools(CompFactory.InDet.GNNTrackFinderTritonTool(name, **kwargs))
90 """Sets up a SeedFitter tool and returns it."""
94 acc.setPrivateTools(CompFactory.InDet.SeedFitterTool(name, **kwargs))
99 """Sets up a SpacepointFeature tool and returns it."""
103 acc.setPrivateTools(CompFactory.InDet.SpacepointFeatureTool(name, **kwargs))
108 """Set up a GNNTrackReader tool and return it."""
112 kwargs.setdefault(
"inputTracksDir", flags.Tracking.GNN.TrackReader.inputTracksDir)
113 kwargs.setdefault(
"csvPrefix", flags.Tracking.GNN.TrackReader.csvPrefix)
115 acc.setPrivateTools(CompFactory.InDet.GNNTrackReaderTool(name, **kwargs))
119 """Sets up a GNNTrackMaker algorithm and returns it."""
121 if flags.Tracking.GNN.usePixelHitsOnly:
127 """Sets up a GNNTrackMaker algorithm and returns it."""
133 kwargs.setdefault(
"SeedFitterTool", SeedFitterTool)
135 from TrkConfig.CommonTrackFitterConfig
import ITkTrackFitterCfg
137 kwargs.setdefault(
"TrackFitter", InDetTrackFitter)
139 if "TrackSummaryTool" not in kwargs:
140 from TrkConfig.TrkTrackSummaryToolConfig
import ITkTrackSummaryToolCfg
146 if flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.TrackFinder:
148 kwargs.setdefault(
"GNNTrackFinderTool", InDetGNNTrackFinderTool)
149 kwargs.setdefault(
"GNNTrackReaderTool",
None)
150 elif flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.TrackReader:
152 kwargs.setdefault(
"GNNTrackReaderTool", InDetGNNTrackReader)
153 kwargs.setdefault(
"GNNTrackFinderTool",
None)
154 elif flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.Triton:
156 kwargs.setdefault(
"GNNTrackReaderTool",
None)
157 kwargs.setdefault(
"GNNTrackFinderTool", InDetGNNTrackFinderTool)
159 raise RuntimeError(
"GNNTrackFinder or GNNTrackReader must be enabled!")
161 kwargs.setdefault(
"areInputClusters", flags.Tracking.GNN.useClusterTracks)
162 kwargs.setdefault(
"doRecoTrackCuts", flags.Tracking.GNN.doRecoTrackCuts)
165 if "InDetEtaDependentCutSvc" not in kwargs:
166 from InDetConfig.InDetEtaDependentCutsConfig
import ITkEtaDependentCutsSvcCfg
168 kwargs.setdefault(
"InDetEtaDependentCutsSvc", acc.getService(
"ITkEtaDependentCutsSvc"+flags.Tracking.ActiveConfig.extension))
170 kwargs.setdefault(
"minClusters", flags.Tracking.GNN.minClusters)
171 kwargs.setdefault(
"pTmin", flags.Tracking.GNN.pTmin)
172 kwargs.setdefault(
"etamax", flags.Tracking.GNN.etamax)
173 kwargs.setdefault(
"minPixelClusters", flags.Tracking.GNN.minPixelClusters)
174 kwargs.setdefault(
"minStripClusters", flags.Tracking.GNN.minStripClusters)
176 acc.addEventAlgo(CompFactory.InDet.SiSPGNNTrackMaker(name, **kwargs))
180 """Sets up a GNN for seeding algorithm and returns it."""
183 from InDetConfig.SiCombinatorialTrackFinderToolConfig
import SiDetElementBoundaryLinksCondAlg_xk_ITkPixel_Cfg, SiDetElementBoundaryLinksCondAlg_xk_ITkStrip_Cfg
188 from MagFieldServices.MagFieldServicesConfig
import (
189 AtlasFieldCacheCondAlgCfg)
192 from TrkConfig.TrkRIO_OnTrackCreatorConfig
import ITkRotCreatorCfg
194 flags, name=
"ITkRotCreator"+flags.Tracking.ActiveConfig.extension))
195 acc.addPublicTool(ITkRotCreator)
196 kwargs.setdefault(
"RIOonTrackTool", ITkRotCreator)
198 from TrkConfig.TrkExRungeKuttaPropagatorConfig
import (
199 RungeKuttaPropagatorCfg)
200 ITkPatternPropagator = acc.popToolsAndMerge(
202 acc.addPublicTool(ITkPatternPropagator)
203 kwargs.setdefault(
"PropagatorTool", ITkPatternPropagator)
205 from TrkConfig.TrkMeasurementUpdatorConfig
import KalmanUpdator_xkCfg
206 ITkPatternUpdator = acc.popToolsAndMerge(
208 acc.addPublicTool(ITkPatternUpdator)
209 kwargs.setdefault(
"UpdatorTool", ITkPatternUpdator)
211 from InDetConfig.InDetBoundaryCheckToolConfig
import ITkBoundaryCheckToolCfg
212 kwargs.setdefault(
"BoundaryCheckTool", acc.popToolsAndMerge(
215 from PixelConditionsTools.ITkPixelConditionsSummaryConfig
import (
216 ITkPixelConditionsSummaryCfg)
217 kwargs.setdefault(
"PixelSummaryTool", acc.popToolsAndMerge(
220 from SCT_ConditionsTools.ITkStripConditionsToolsConfig
import (
221 ITkStripConditionsSummaryToolCfg)
222 kwargs.setdefault(
"StripSummaryTool", acc.popToolsAndMerge(
225 if flags.Tracking.GNN.useTrackFinder:
227 kwargs.setdefault(
"GNNTrackReaderTool",
None)
228 elif flags.Tracking.GNN.useTrackReader:
230 kwargs.setdefault(
"GNNTrackFinderTool",
None)
232 raise RuntimeError(
"GNNTrackFinder or GNNTrackReader must be enabled!")
234 kwargs.setdefault(
"SeedFitterTool", acc.popToolsAndMerge(
SeedFitterToolCfg(flags)))
236 from TrkConfig.CommonTrackFitterConfig
import ITkTrackFitterCfg
239 from InDetConfig.SiDetElementsRoadToolConfig
import ITkSiDetElementsRoadMaker_xkCfg
244 kwargs.setdefault(
"nClustersMin", flags.Tracking.ActiveConfig.minClusters[0])
245 kwargs.setdefault(
"nWeightedClustersMin", flags.Tracking.ActiveConfig.nWeightedClustersMin[0])
246 kwargs.setdefault(
"nHolesMax", flags.Tracking.ActiveConfig.nHolesMax[0])
247 kwargs.setdefault(
"nHolesGapMax", flags.Tracking.ActiveConfig.nHolesGapMax[0])
249 kwargs.setdefault(
"pTmin", flags.Tracking.ActiveConfig.minPT[0])
250 kwargs.setdefault(
"pTminBrem", flags.Tracking.ActiveConfig.minPTBrem[0])
251 kwargs.setdefault(
"Xi2max", flags.Tracking.ActiveConfig.Xi2max[0])
252 kwargs.setdefault(
"Xi2maxNoAdd", flags.Tracking.ActiveConfig.Xi2maxNoAdd[0])
253 kwargs.setdefault(
"Xi2maxMultiTracks", flags.Tracking.ActiveConfig.Xi2max[0])
254 kwargs.setdefault(
"doMultiTracksProd",
False)
256 acc.addEventAlgo(CompFactory.InDet.GNNSeedingTrackMaker(name, **kwargs))