ATLAS Offline Software
Loading...
Searching...
No Matches
InDetGNNTrackingConfig.py
Go to the documentation of this file.
2# Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3#
4
5from pathlib import Path
6
7from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
8from AthenaConfiguration.ComponentFactory import CompFactory
9from InDetGNNTracking.InDetGNNTrackingConfigFlags import GNNTrackFinderToolType
10
11
13 flags, name="DumpObjects", outfile="Dump_GNN4Itk.root", **kwargs):
14 '''
15 create algorithm which dumps GNN training information to ROOT file
16 '''
17 acc = ComponentAccumulator()
18
19 acc.addService(
20 CompFactory.THistSvc(
21 Output=[f"{name} DATAFILE='{outfile}', OPT='RECREATE'"]
22 )
23 )
24
25 kwargs.setdefault("NtupleFileName", flags.Tracking.GNN.DumpObjects.NtupleFileName)
26 kwargs.setdefault("NtupleTreeName", flags.Tracking.GNN.DumpObjects.NtupleTreeName)
27 kwargs.setdefault("rootFile", True)
28
29 acc.addEventAlgo(CompFactory.InDet.DumpObjects(name, **kwargs))
30 return acc
31
32def GNNTrackFinderToolCfg(flags, name='GNNTrackFinderTool', **kwargs):
33 """Sets up a GNNTrackFinderTool tool and returns it."""
34 acc = ComponentAccumulator()
35
36
37 kwargs.setdefault("embeddingDim", flags.Tracking.GNN.TrackFinder.embeddingDim)
38 kwargs.setdefault("rVal", flags.Tracking.GNN.TrackFinder.rVal)
39 kwargs.setdefault("knnVal", flags.Tracking.GNN.TrackFinder.knnVal)
40 kwargs.setdefault("filterCut", flags.Tracking.GNN.TrackFinder.filterCut)
41 kwargs.setdefault("inputMLModelDir", flags.Tracking.GNN.TrackFinder.inputMLModelDir)
42 kwargs.setdefault("ccCut", flags.Tracking.GNN.TrackFinder.ccCut)
43 kwargs.setdefault("walkMin", flags.Tracking.GNN.TrackFinder.walkMin)
44 kwargs.setdefault("walkMax", flags.Tracking.GNN.TrackFinder.walkMax)
45 kwargs.setdefault("EmbeddingFeatureNames", flags.Tracking.GNN.TrackFinder.EmbeddingFeatureNames)
46 kwargs.setdefault("EmbeddingFeatureScales", flags.Tracking.GNN.TrackFinder.EmbeddingFeatureScales)
47 kwargs.setdefault("FilterFeatureNames", flags.Tracking.GNN.TrackFinder.FilterFeatureNames)
48 kwargs.setdefault("FilterFeatureScales", flags.Tracking.GNN.TrackFinder.FilterFeatureScales)
49 kwargs.setdefault("GNNFeatureNames", flags.Tracking.GNN.TrackFinder.GNNFeatureNames)
50 kwargs.setdefault("GNNFeatureScales", flags.Tracking.GNN.TrackFinder.GNNFeatureScales)
51
52 from AthOnnxComps.OnnxRuntimeInferenceConfig import OnnxRuntimeInferenceToolCfg
53 ort_exe_provider = flags.Tracking.GNN.TrackFinder.ORTExeProvider
54 kwargs.setdefault("Embedding", acc.popToolsAndMerge(
55 OnnxRuntimeInferenceToolCfg(flags, str(Path(kwargs["inputMLModelDir"]) / "embedding.onnx"),
56 ort_exe_provider, name="Embedding")
57 ))
58 kwargs.setdefault("Filtering", acc.popToolsAndMerge(
59 OnnxRuntimeInferenceToolCfg(flags, str(Path(kwargs["inputMLModelDir"]) / "filtering.onnx"),
60 ort_exe_provider, name="Filtering")
61 ))
62 kwargs.setdefault("GNN", acc.popToolsAndMerge(
63 OnnxRuntimeInferenceToolCfg(flags, str(Path(kwargs["inputMLModelDir"]) / "gnn.onnx"),
64 ort_exe_provider, name="GNN")
65 ))
66
67 acc.setPrivateTools(CompFactory.InDet.SiGNNTrackFinderTool(name, **kwargs))
68 return acc
69
70
71def GNNTrackFinderTritonToolCfg(flags, name='GNNTrackFinderTritonTool', **kwargs):
72 """Sets up a GNNTrackFinderTritonTool tool and returns it."""
73 from AthTritonComps.TritonToolConfig import TritonToolCfg
74
75 acc = ComponentAccumulator()
76
77 kwargs.setdefault("TritonTool", acc.popToolsAndMerge(
78 TritonToolCfg(flags, model_name=flags.Tracking.GNN.Triton.model,
79 url=flags.Tracking.GNN.Triton.url,
80 port=flags.Tracking.GNN.Triton.port,
81 ))
82 )
83 kwargs.setdefault("SpacepointFeatureTool", acc.popToolsAndMerge(SpacepointFeatureToolCfg(flags)))
84
85 acc.setPrivateTools(CompFactory.InDet.GNNTrackFinderTritonTool(name, **kwargs))
86 return acc
87
88
89def SeedFitterToolCfg(flags, name="SeedFitterTool", **kwargs):
90 """Sets up a SeedFitter tool and returns it."""
91 acc = ComponentAccumulator()
92
93
94 acc.setPrivateTools(CompFactory.InDet.SeedFitterTool(name, **kwargs))
95 return acc
96
97
98def SpacepointFeatureToolCfg(flags, name="SpacepointFeatureTool", **kwargs):
99 """Sets up a SpacepointFeature tool and returns it."""
100 acc = ComponentAccumulator()
101
102
103 acc.setPrivateTools(CompFactory.InDet.SpacepointFeatureTool(name, **kwargs))
104 return acc
105
106
107def GNNTrackReaderToolCfg(flags, name='GNNTrackReaderTool', **kwargs):
108 """Set up a GNNTrackReader tool and return it."""
109 acc = ComponentAccumulator()
110
111
112 kwargs.setdefault("inputTracksDir", flags.Tracking.GNN.TrackReader.inputTracksDir)
113 kwargs.setdefault("csvPrefix", flags.Tracking.GNN.TrackReader.csvPrefix)
114
115 acc.setPrivateTools(CompFactory.InDet.GNNTrackReaderTool(name, **kwargs))
116 return acc
117
118def GNNTrackMakerCfg(flags, name="GNNTrackMaker", **kwargs):
119 """Sets up a GNNTrackMaker algorithm and returns it."""
120
121 if flags.Tracking.GNN.usePixelHitsOnly:
122 return GNNSeedingTrackMakerCfg(flags, name, **kwargs)
123
124 return GNNEndToEndTrackMaker(flags, name, **kwargs)
125
126def GNNEndToEndTrackMaker(flags, name="GNNEndToEndTrackMaker", **kwargs):
127 """Sets up a GNNTrackMaker algorithm and returns it."""
128
129 acc = ComponentAccumulator()
130
131
132 SeedFitterTool = acc.popToolsAndMerge(SeedFitterToolCfg(flags))
133 kwargs.setdefault("SeedFitterTool", SeedFitterTool)
134
135 from TrkConfig.CommonTrackFitterConfig import ITkTrackFitterCfg
136 InDetTrackFitter = acc.popToolsAndMerge(ITkTrackFitterCfg(flags))
137 kwargs.setdefault("TrackFitter", InDetTrackFitter)
138
139 if "TrackSummaryTool" not in kwargs:
140 from TrkConfig.TrkTrackSummaryToolConfig import ITkTrackSummaryToolCfg
141
142 kwargs.setdefault(
143 "TrackSummaryTool", acc.popToolsAndMerge(ITkTrackSummaryToolCfg(flags))
144 )
145
146 if flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.TrackFinder:
147 InDetGNNTrackFinderTool = acc.popToolsAndMerge(GNNTrackFinderToolCfg(flags))
148 kwargs.setdefault("GNNTrackFinderTool", InDetGNNTrackFinderTool)
149 kwargs.setdefault("GNNTrackReaderTool", None)
150 elif flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.TrackReader:
151 InDetGNNTrackReader = acc.popToolsAndMerge(GNNTrackReaderToolCfg(flags))
152 kwargs.setdefault("GNNTrackReaderTool", InDetGNNTrackReader)
153 kwargs.setdefault("GNNTrackFinderTool", None)
154 elif flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.Triton:
155 InDetGNNTrackFinderTool = acc.popToolsAndMerge(GNNTrackFinderTritonToolCfg(flags))
156 kwargs.setdefault("GNNTrackReaderTool", None)
157 kwargs.setdefault("GNNTrackFinderTool", InDetGNNTrackFinderTool)
158 else:
159 raise RuntimeError("GNNTrackFinder or GNNTrackReader must be enabled!")
160
161 kwargs.setdefault("areInputClusters", flags.Tracking.GNN.useClusterTracks)
162 kwargs.setdefault("doRecoTrackCuts", flags.Tracking.GNN.doRecoTrackCuts)
163
164 # add eta dependent cut service
165 if "InDetEtaDependentCutSvc" not in kwargs:
166 from InDetConfig.InDetEtaDependentCutsConfig import ITkEtaDependentCutsSvcCfg
167 acc.merge(ITkEtaDependentCutsSvcCfg(flags))
168 kwargs.setdefault("InDetEtaDependentCutsSvc", acc.getService("ITkEtaDependentCutsSvc"+flags.Tracking.ActiveConfig.extension))
169
170 kwargs.setdefault("minClusters", flags.Tracking.GNN.minClusters)
171 kwargs.setdefault("pTmin", flags.Tracking.GNN.pTmin)
172 kwargs.setdefault("etamax", flags.Tracking.GNN.etamax)
173 kwargs.setdefault("minPixelClusters", flags.Tracking.GNN.minPixelClusters)
174 kwargs.setdefault("minStripClusters", flags.Tracking.GNN.minStripClusters)
175
176 acc.addEventAlgo(CompFactory.InDet.SiSPGNNTrackMaker(name, **kwargs))
177 return acc
178
179def GNNSeedingTrackMakerCfg(flags, name="GNNSeedingTrackMaker", **kwargs):
180 """Sets up a GNN for seeding algorithm and returns it."""
181 acc = ComponentAccumulator()
182
183 from InDetConfig.SiCombinatorialTrackFinderToolConfig import SiDetElementBoundaryLinksCondAlg_xk_ITkPixel_Cfg, SiDetElementBoundaryLinksCondAlg_xk_ITkStrip_Cfg
184 acc.merge(SiDetElementBoundaryLinksCondAlg_xk_ITkPixel_Cfg(flags))
185 acc.merge(SiDetElementBoundaryLinksCondAlg_xk_ITkStrip_Cfg(flags))
186
187 # To produce AtlasFieldCacheCondObj
188 from MagFieldServices.MagFieldServicesConfig import (
189 AtlasFieldCacheCondAlgCfg)
190 acc.merge(AtlasFieldCacheCondAlgCfg(flags))
191
192 from TrkConfig.TrkRIO_OnTrackCreatorConfig import ITkRotCreatorCfg
193 ITkRotCreator = acc.popToolsAndMerge(ITkRotCreatorCfg(
194 flags, name="ITkRotCreator"+flags.Tracking.ActiveConfig.extension))
195 acc.addPublicTool(ITkRotCreator)
196 kwargs.setdefault("RIOonTrackTool", ITkRotCreator)
197
198 from TrkConfig.TrkExRungeKuttaPropagatorConfig import (
199 RungeKuttaPropagatorCfg)
200 ITkPatternPropagator = acc.popToolsAndMerge(
201 RungeKuttaPropagatorCfg(flags, name="ITkPatternPropagator"))
202 acc.addPublicTool(ITkPatternPropagator)
203 kwargs.setdefault("PropagatorTool", ITkPatternPropagator)
204
205 from TrkConfig.TrkMeasurementUpdatorConfig import KalmanUpdator_xkCfg
206 ITkPatternUpdator = acc.popToolsAndMerge(
207 KalmanUpdator_xkCfg(flags, name="ITkPatternUpdator"))
208 acc.addPublicTool(ITkPatternUpdator)
209 kwargs.setdefault("UpdatorTool", ITkPatternUpdator)
210
211 from InDetConfig.InDetBoundaryCheckToolConfig import ITkBoundaryCheckToolCfg
212 kwargs.setdefault("BoundaryCheckTool", acc.popToolsAndMerge(
213 ITkBoundaryCheckToolCfg(flags)))
214
215 from PixelConditionsTools.ITkPixelConditionsSummaryConfig import (
216 ITkPixelConditionsSummaryCfg)
217 kwargs.setdefault("PixelSummaryTool", acc.popToolsAndMerge(
218 ITkPixelConditionsSummaryCfg(flags)))
219
220 from SCT_ConditionsTools.ITkStripConditionsToolsConfig import (
221 ITkStripConditionsSummaryToolCfg)
222 kwargs.setdefault("StripSummaryTool", acc.popToolsAndMerge(
223 ITkStripConditionsSummaryToolCfg(flags)))
224
225 if flags.Tracking.GNN.useTrackFinder:
226 kwargs.setdefault("GNNTrackFinderTool", acc.popToolsAndMerge(GNNTrackFinderToolCfg(flags)))
227 kwargs.setdefault("GNNTrackReaderTool", None)
228 elif flags.Tracking.GNN.useTrackReader:
229 kwargs.setdefault("GNNTrackReaderTool", acc.popToolsAndMerge(GNNTrackReaderToolCfg(flags)))
230 kwargs.setdefault("GNNTrackFinderTool", None)
231 else:
232 raise RuntimeError("GNNTrackFinder or GNNTrackReader must be enabled!")
233
234 kwargs.setdefault("SeedFitterTool", acc.popToolsAndMerge(SeedFitterToolCfg(flags)))
235
236 from TrkConfig.CommonTrackFitterConfig import ITkTrackFitterCfg
237 kwargs.setdefault("TrackFitter", acc.popToolsAndMerge(ITkTrackFitterCfg(flags)))
238
239 from InDetConfig.SiDetElementsRoadToolConfig import ITkSiDetElementsRoadMaker_xkCfg
240 kwargs.setdefault("RoadTool", acc.popToolsAndMerge(ITkSiDetElementsRoadMaker_xkCfg(flags)))
241
242 # configurations for Kalman filter.
243 # similar to https://gitlab.cern.ch/atlas/athena/-/blob/main/InnerDetector/InDetConfig/python/SiTrackMakerConfig.py#L188
244 kwargs.setdefault("nClustersMin", flags.Tracking.ActiveConfig.minClusters[0])
245 kwargs.setdefault("nWeightedClustersMin", flags.Tracking.ActiveConfig.nWeightedClustersMin[0])
246 kwargs.setdefault("nHolesMax", flags.Tracking.ActiveConfig.nHolesMax[0])
247 kwargs.setdefault("nHolesGapMax", flags.Tracking.ActiveConfig.nHolesGapMax[0])
248
249 kwargs.setdefault("pTmin", flags.Tracking.ActiveConfig.minPT[0])
250 kwargs.setdefault("pTminBrem", flags.Tracking.ActiveConfig.minPTBrem[0])
251 kwargs.setdefault("Xi2max", flags.Tracking.ActiveConfig.Xi2max[0])
252 kwargs.setdefault("Xi2maxNoAdd", flags.Tracking.ActiveConfig.Xi2maxNoAdd[0])
253 kwargs.setdefault("Xi2maxMultiTracks", flags.Tracking.ActiveConfig.Xi2max[0])
254 kwargs.setdefault("doMultiTracksProd", False)
255
256 acc.addEventAlgo(CompFactory.InDet.GNNSeedingTrackMaker(name, **kwargs))
257 return acc
DumpObjectsCfg(flags, name="DumpObjects", outfile="Dump_GNN4Itk.root", **kwargs)
SeedFitterToolCfg(flags, name="SeedFitterTool", **kwargs)
GNNTrackReaderToolCfg(flags, name='GNNTrackReaderTool', **kwargs)
GNNSeedingTrackMakerCfg(flags, name="GNNSeedingTrackMaker", **kwargs)
GNNTrackMakerCfg(flags, name="GNNTrackMaker", **kwargs)
GNNEndToEndTrackMaker(flags, name="GNNEndToEndTrackMaker", **kwargs)
GNNTrackFinderToolCfg(flags, name='GNNTrackFinderTool', **kwargs)
GNNTrackFinderTritonToolCfg(flags, name='GNNTrackFinderTritonTool', **kwargs)
SpacepointFeatureToolCfg(flags, name="SpacepointFeatureTool", **kwargs)