5 from pathlib
import Path
7 from AthenaConfiguration.ComponentAccumulator
import ComponentAccumulator
8 from AthenaConfiguration.ComponentFactory
import CompFactory
11 flags, name="DumpObjects", outfile="Dump_GNN4Itk.root", **kwargs):
13 create algorithm which dumps GNN training information to ROOT file
19 Output=[f
"{name} DATAFILE='{outfile}', OPT='RECREATE'"]
23 kwargs.setdefault(
"NtupleFileName", flags.Tracking.GNN.DumpObjects.NtupleFileName)
24 kwargs.setdefault(
"NtupleTreeName", flags.Tracking.GNN.DumpObjects.NtupleTreeName)
25 kwargs.setdefault(
"rootFile",
True)
27 acc.addEventAlgo(CompFactory.InDet.DumpObjects(name, **kwargs))
31 """Sets up a GNNTrackFinderTool tool and returns it."""
35 kwargs.setdefault(
"embeddingDim", flags.Tracking.GNN.TrackFinder.embeddingDim)
36 kwargs.setdefault(
"rVal", flags.Tracking.GNN.TrackFinder.rVal)
37 kwargs.setdefault(
"knnVal", flags.Tracking.GNN.TrackFinder.knnVal)
38 kwargs.setdefault(
"filterCut", flags.Tracking.GNN.TrackFinder.filterCut)
39 kwargs.setdefault(
"inputMLModelDir", flags.Tracking.GNN.TrackFinder.inputMLModelDir)
40 kwargs.setdefault(
"ccCut", flags.Tracking.GNN.TrackFinder.ccCut)
41 kwargs.setdefault(
"walkMin", flags.Tracking.GNN.TrackFinder.walkMin)
42 kwargs.setdefault(
"walkMax", flags.Tracking.GNN.TrackFinder.walkMax)
43 kwargs.setdefault(
"EmbeddingFeatureNames", flags.Tracking.GNN.TrackFinder.EmbeddingFeatureNames)
44 kwargs.setdefault(
"EmbeddingFeatureScales", flags.Tracking.GNN.TrackFinder.EmbeddingFeatureScales)
45 kwargs.setdefault(
"FilterFeatureNames", flags.Tracking.GNN.TrackFinder.FilterFeatureNames)
46 kwargs.setdefault(
"FilterFeatureScales", flags.Tracking.GNN.TrackFinder.FilterFeatureScales)
47 kwargs.setdefault(
"GNNFeatureNames", flags.Tracking.GNN.TrackFinder.GNNFeatureNames)
48 kwargs.setdefault(
"GNNFeatureScales", flags.Tracking.GNN.TrackFinder.GNNFeatureScales)
50 from AthOnnxComps.OnnxRuntimeInferenceConfig
import OnnxRuntimeInferenceToolCfg
51 ort_exe_provider = flags.Tracking.GNN.TrackFinder.ORTExeProvider
52 kwargs.setdefault(
"Embedding", acc.popToolsAndMerge(
54 ort_exe_provider, name=
"Embedding")
56 kwargs.setdefault(
"Filtering", acc.popToolsAndMerge(
58 ort_exe_provider, name=
"Filtering")
60 kwargs.setdefault(
"GNN", acc.popToolsAndMerge(
62 ort_exe_provider, name=
"GNN")
65 acc.setPrivateTools(CompFactory.InDet.SiGNNTrackFinderTool(name, **kwargs))
70 """Sets up a SeedFitter tool and returns it."""
74 acc.setPrivateTools(CompFactory.InDet.SeedFitterTool(name, **kwargs))
79 """Sets up a SpacepointFeature tool and returns it."""
83 acc.setPrivateTools(CompFactory.InDet.SpacepointFeatureTool(name, **kwargs))
88 """Set up a GNNTrackReader tool and return it."""
92 kwargs.setdefault(
"inputTracksDir", flags.Tracking.GNN.TrackReader.inputTracksDir)
93 kwargs.setdefault(
"csvPrefix", flags.Tracking.GNN.TrackReader.csvPrefix)
95 acc.setPrivateTools(CompFactory.InDet.GNNTrackReaderTool(name, **kwargs))
99 """Sets up a GNNTrackMaker algorithm and returns it."""
101 if flags.Tracking.GNN.usePixelHitsOnly:
107 """Sets up a GNNTrackMaker algorithm and returns it."""
113 kwargs.setdefault(
"SeedFitterTool", SeedFitterTool)
115 from TrkConfig.CommonTrackFitterConfig
import ITkTrackFitterCfg
117 kwargs.setdefault(
"TrackFitter", InDetTrackFitter)
119 if "TrackSummaryTool" not in kwargs:
120 from TrkConfig.TrkTrackSummaryToolConfig
import ITkTrackSummaryToolCfg
126 if flags.Tracking.GNN.useTrackFinder:
128 kwargs.setdefault(
"GNNTrackFinderTool", InDetGNNTrackFinderTool)
129 kwargs.setdefault(
"GNNTrackReaderTool",
None)
130 elif flags.Tracking.GNN.useTrackReader:
132 kwargs.setdefault(
"GNNTrackReaderTool", InDetGNNTrackReader)
133 kwargs.setdefault(
"GNNTrackFinderTool",
None)
135 raise RuntimeError(
"GNNTrackFinder or GNNTrackReader must be enabled!")
137 kwargs.setdefault(
"areInputClusters", flags.Tracking.GNN.useClusterTracks)
138 kwargs.setdefault(
"doRecoTrackCuts", flags.Tracking.GNN.doRecoTrackCuts)
141 if "InDetEtaDependentCutSvc" not in kwargs:
142 from InDetConfig.InDetEtaDependentCutsConfig
import ITkEtaDependentCutsSvcCfg
144 kwargs.setdefault(
"InDetEtaDependentCutsSvc", acc.getService(
"ITkEtaDependentCutsSvc"+flags.Tracking.ActiveConfig.extension))
146 kwargs.setdefault(
"minClusters", flags.Tracking.GNN.minClusters)
147 kwargs.setdefault(
"pTmin", flags.Tracking.GNN.pTmin)
148 kwargs.setdefault(
"etamax", flags.Tracking.GNN.etamax)
149 kwargs.setdefault(
"minPixelClusters", flags.Tracking.GNN.minPixelClusters)
150 kwargs.setdefault(
"minStripClusters", flags.Tracking.GNN.minStripClusters)
152 acc.addEventAlgo(CompFactory.InDet.SiSPGNNTrackMaker(name, **kwargs))
156 """Sets up a GNN for seeding algorithm and returns it."""
159 from InDetConfig.SiCombinatorialTrackFinderToolConfig
import SiDetElementBoundaryLinksCondAlg_xk_ITkPixel_Cfg, SiDetElementBoundaryLinksCondAlg_xk_ITkStrip_Cfg
164 from MagFieldServices.MagFieldServicesConfig
import (
165 AtlasFieldCacheCondAlgCfg)
168 from TrkConfig.TrkRIO_OnTrackCreatorConfig
import ITkRotCreatorCfg
170 flags, name=
"ITkRotCreator"+flags.Tracking.ActiveConfig.extension))
171 acc.addPublicTool(ITkRotCreator)
172 kwargs.setdefault(
"RIOonTrackTool", ITkRotCreator)
174 from TrkConfig.TrkExRungeKuttaPropagatorConfig
import (
175 RungeKuttaPropagatorCfg)
176 ITkPatternPropagator = acc.popToolsAndMerge(
178 acc.addPublicTool(ITkPatternPropagator)
179 kwargs.setdefault(
"PropagatorTool", ITkPatternPropagator)
181 from TrkConfig.TrkMeasurementUpdatorConfig
import KalmanUpdator_xkCfg
182 ITkPatternUpdator = acc.popToolsAndMerge(
184 acc.addPublicTool(ITkPatternUpdator)
185 kwargs.setdefault(
"UpdatorTool", ITkPatternUpdator)
187 from InDetConfig.InDetBoundaryCheckToolConfig
import ITkBoundaryCheckToolCfg
188 kwargs.setdefault(
"BoundaryCheckTool", acc.popToolsAndMerge(
191 from PixelConditionsTools.ITkPixelConditionsSummaryConfig
import (
192 ITkPixelConditionsSummaryCfg)
193 kwargs.setdefault(
"PixelSummaryTool", acc.popToolsAndMerge(
196 from SCT_ConditionsTools.ITkStripConditionsToolsConfig
import (
197 ITkStripConditionsSummaryToolCfg)
198 kwargs.setdefault(
"StripSummaryTool", acc.popToolsAndMerge(
201 if flags.Tracking.GNN.useTrackFinder:
203 kwargs.setdefault(
"GNNTrackReaderTool",
None)
204 elif flags.Tracking.GNN.useTrackReader:
206 kwargs.setdefault(
"GNNTrackFinderTool",
None)
208 raise RuntimeError(
"GNNTrackFinder or GNNTrackReader must be enabled!")
210 kwargs.setdefault(
"SeedFitterTool", acc.popToolsAndMerge(
SeedFitterToolCfg(flags)))
212 from TrkConfig.CommonTrackFitterConfig
import ITkTrackFitterCfg
215 from InDetConfig.SiDetElementsRoadToolConfig
import ITkSiDetElementsRoadMaker_xkCfg
220 kwargs.setdefault(
"nClustersMin", flags.Tracking.ActiveConfig.minClusters[0])
221 kwargs.setdefault(
"nWeightedClustersMin", flags.Tracking.ActiveConfig.nWeightedClustersMin[0])
222 kwargs.setdefault(
"nHolesMax", flags.Tracking.ActiveConfig.nHolesMax[0])
223 kwargs.setdefault(
"nHolesGapMax", flags.Tracking.ActiveConfig.nHolesGapMax[0])
225 kwargs.setdefault(
"pTmin", flags.Tracking.ActiveConfig.minPT[0])
226 kwargs.setdefault(
"pTminBrem", flags.Tracking.ActiveConfig.minPTBrem[0])
227 kwargs.setdefault(
"Xi2max", flags.Tracking.ActiveConfig.Xi2max[0])
228 kwargs.setdefault(
"Xi2maxNoAdd", flags.Tracking.ActiveConfig.Xi2maxNoAdd[0])
229 kwargs.setdefault(
"Xi2maxMultiTracks", flags.Tracking.ActiveConfig.Xi2max[0])
230 kwargs.setdefault(
"doMultiTracksProd",
False)
232 acc.addEventAlgo(CompFactory.InDet.GNNSeedingTrackMaker(name, **kwargs))