ATLAS Offline Software
InDetGNNTrackingConfig.py
Go to the documentation of this file.
1 #
2 # Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3 #
4 
5 from pathlib import Path
6 
7 from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
8 from AthenaConfiguration.ComponentFactory import CompFactory
9 from InDetGNNTracking.InDetGNNTrackingConfigFlags import GNNTrackFinderToolType
10 
11 
13  flags, name="DumpObjects", outfile="Dump_GNN4Itk.root", **kwargs):
14  '''
15  create algorithm which dumps GNN training information to ROOT file
16  '''
17  acc = ComponentAccumulator()
18 
19  acc.addService(
20  CompFactory.THistSvc(
21  Output=[f"{name} DATAFILE='{outfile}', OPT='RECREATE'"]
22  )
23  )
24 
25  kwargs.setdefault("NtupleFileName", flags.Tracking.GNN.DumpObjects.NtupleFileName)
26  kwargs.setdefault("NtupleTreeName", flags.Tracking.GNN.DumpObjects.NtupleTreeName)
27  kwargs.setdefault("rootFile", True)
28 
29  acc.addEventAlgo(CompFactory.InDet.DumpObjects(name, **kwargs))
30  return acc
31 
32 def GNNTrackFinderToolCfg(flags, name='GNNTrackFinderTool', **kwargs):
33  """Sets up a GNNTrackFinderTool tool and returns it."""
34  acc = ComponentAccumulator()
35 
36 
37  kwargs.setdefault("embeddingDim", flags.Tracking.GNN.TrackFinder.embeddingDim)
38  kwargs.setdefault("rVal", flags.Tracking.GNN.TrackFinder.rVal)
39  kwargs.setdefault("knnVal", flags.Tracking.GNN.TrackFinder.knnVal)
40  kwargs.setdefault("filterCut", flags.Tracking.GNN.TrackFinder.filterCut)
41  kwargs.setdefault("inputMLModelDir", flags.Tracking.GNN.TrackFinder.inputMLModelDir)
42  kwargs.setdefault("ccCut", flags.Tracking.GNN.TrackFinder.ccCut)
43  kwargs.setdefault("walkMin", flags.Tracking.GNN.TrackFinder.walkMin)
44  kwargs.setdefault("walkMax", flags.Tracking.GNN.TrackFinder.walkMax)
45  kwargs.setdefault("EmbeddingFeatureNames", flags.Tracking.GNN.TrackFinder.EmbeddingFeatureNames)
46  kwargs.setdefault("EmbeddingFeatureScales", flags.Tracking.GNN.TrackFinder.EmbeddingFeatureScales)
47  kwargs.setdefault("FilterFeatureNames", flags.Tracking.GNN.TrackFinder.FilterFeatureNames)
48  kwargs.setdefault("FilterFeatureScales", flags.Tracking.GNN.TrackFinder.FilterFeatureScales)
49  kwargs.setdefault("GNNFeatureNames", flags.Tracking.GNN.TrackFinder.GNNFeatureNames)
50  kwargs.setdefault("GNNFeatureScales", flags.Tracking.GNN.TrackFinder.GNNFeatureScales)
51 
52  from AthOnnxComps.OnnxRuntimeInferenceConfig import OnnxRuntimeInferenceToolCfg
53  ort_exe_provider = flags.Tracking.GNN.TrackFinder.ORTExeProvider
54  kwargs.setdefault("Embedding", acc.popToolsAndMerge(
55  OnnxRuntimeInferenceToolCfg(flags, str(Path(kwargs["inputMLModelDir"]) / "embedding.onnx"),
56  ort_exe_provider, name="Embedding")
57  ))
58  kwargs.setdefault("Filtering", acc.popToolsAndMerge(
59  OnnxRuntimeInferenceToolCfg(flags, str(Path(kwargs["inputMLModelDir"]) / "filtering.onnx"),
60  ort_exe_provider, name="Filtering")
61  ))
62  kwargs.setdefault("GNN", acc.popToolsAndMerge(
63  OnnxRuntimeInferenceToolCfg(flags, str(Path(kwargs["inputMLModelDir"]) / "gnn.onnx"),
64  ort_exe_provider, name="GNN")
65  ))
66 
67  acc.setPrivateTools(CompFactory.InDet.SiGNNTrackFinderTool(name, **kwargs))
68  return acc
69 
70 
71 def GNNTrackFinderTritonToolCfg(flags, name='GNNTrackFinderTritonTool', **kwargs):
72  """Sets up a GNNTrackFinderTritonTool tool and returns it."""
73  from AthTritonComps.TritonToolConfig import TritonToolCfg
74 
75  acc = ComponentAccumulator()
76 
77  kwargs.setdefault("TritonTool", acc.popToolsAndMerge(
78  TritonToolCfg(flags, model_name=flags.Tracking.GNN.Triton.model,
79  url=flags.Tracking.GNN.Triton.url,
80  port=flags.Tracking.GNN.Triton.port,
81  ))
82  )
83  kwargs.setdefault("SpacepointFeatureTool", acc.popToolsAndMerge(SpacepointFeatureToolCfg(flags)))
84 
85  acc.setPrivateTools(CompFactory.InDet.GNNTrackFinderTritonTool(name, **kwargs))
86  return acc
87 
88 
89 def SeedFitterToolCfg(flags, name="SeedFitterTool", **kwargs):
90  """Sets up a SeedFitter tool and returns it."""
91  acc = ComponentAccumulator()
92 
93 
94  acc.setPrivateTools(CompFactory.InDet.SeedFitterTool(name, **kwargs))
95  return acc
96 
97 
98 def SpacepointFeatureToolCfg(flags, name="SpacepointFeatureTool", **kwargs):
99  """Sets up a SpacepointFeature tool and returns it."""
100  acc = ComponentAccumulator()
101 
102 
103  acc.setPrivateTools(CompFactory.InDet.SpacepointFeatureTool(name, **kwargs))
104  return acc
105 
106 
107 def GNNTrackReaderToolCfg(flags, name='GNNTrackReaderTool', **kwargs):
108  """Set up a GNNTrackReader tool and return it."""
109  acc = ComponentAccumulator()
110 
111 
112  kwargs.setdefault("inputTracksDir", flags.Tracking.GNN.TrackReader.inputTracksDir)
113  kwargs.setdefault("csvPrefix", flags.Tracking.GNN.TrackReader.csvPrefix)
114 
115  acc.setPrivateTools(CompFactory.InDet.GNNTrackReaderTool(name, **kwargs))
116  return acc
117 
118 def GNNTrackMakerCfg(flags, name="GNNTrackMaker", **kwargs):
119  """Sets up a GNNTrackMaker algorithm and returns it."""
120 
121  if flags.Tracking.GNN.usePixelHitsOnly:
122  return GNNSeedingTrackMakerCfg(flags, name, **kwargs)
123 
124  return GNNEndToEndTrackMaker(flags, name, **kwargs)
125 
126 def GNNEndToEndTrackMaker(flags, name="GNNEndToEndTrackMaker", **kwargs):
127  """Sets up a GNNTrackMaker algorithm and returns it."""
128 
129  acc = ComponentAccumulator()
130 
131 
132  SeedFitterTool = acc.popToolsAndMerge(SeedFitterToolCfg(flags))
133  kwargs.setdefault("SeedFitterTool", SeedFitterTool)
134 
135  from TrkConfig.CommonTrackFitterConfig import ITkTrackFitterCfg
136  InDetTrackFitter = acc.popToolsAndMerge(ITkTrackFitterCfg(flags))
137  kwargs.setdefault("TrackFitter", InDetTrackFitter)
138 
139  if "TrackSummaryTool" not in kwargs:
140  from TrkConfig.TrkTrackSummaryToolConfig import ITkTrackSummaryToolCfg
141 
142  kwargs.setdefault(
143  "TrackSummaryTool", acc.popToolsAndMerge(ITkTrackSummaryToolCfg(flags))
144  )
145 
146  if flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.TrackFinder:
147  InDetGNNTrackFinderTool = acc.popToolsAndMerge(GNNTrackFinderToolCfg(flags))
148  kwargs.setdefault("GNNTrackFinderTool", InDetGNNTrackFinderTool)
149  kwargs.setdefault("GNNTrackReaderTool", None)
150  elif flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.TrackReader:
151  InDetGNNTrackReader = acc.popToolsAndMerge(GNNTrackReaderToolCfg(flags))
152  kwargs.setdefault("GNNTrackReaderTool", InDetGNNTrackReader)
153  kwargs.setdefault("GNNTrackFinderTool", None)
154  elif flags.Tracking.GNN.ToolType == GNNTrackFinderToolType.Triton:
155  InDetGNNTrackFinderTool = acc.popToolsAndMerge(GNNTrackFinderTritonToolCfg(flags))
156  kwargs.setdefault("GNNTrackReaderTool", None)
157  kwargs.setdefault("GNNTrackFinderTool", InDetGNNTrackFinderTool)
158  else:
159  raise RuntimeError("GNNTrackFinder or GNNTrackReader must be enabled!")
160 
161  kwargs.setdefault("areInputClusters", flags.Tracking.GNN.useClusterTracks)
162  kwargs.setdefault("doRecoTrackCuts", flags.Tracking.GNN.doRecoTrackCuts)
163 
164  # add eta dependent cut service
165  if "InDetEtaDependentCutSvc" not in kwargs:
166  from InDetConfig.InDetEtaDependentCutsConfig import ITkEtaDependentCutsSvcCfg
167  acc.merge(ITkEtaDependentCutsSvcCfg(flags))
168  kwargs.setdefault("InDetEtaDependentCutsSvc", acc.getService("ITkEtaDependentCutsSvc"+flags.Tracking.ActiveConfig.extension))
169 
170  kwargs.setdefault("minClusters", flags.Tracking.GNN.minClusters)
171  kwargs.setdefault("pTmin", flags.Tracking.GNN.pTmin)
172  kwargs.setdefault("etamax", flags.Tracking.GNN.etamax)
173  kwargs.setdefault("minPixelClusters", flags.Tracking.GNN.minPixelClusters)
174  kwargs.setdefault("minStripClusters", flags.Tracking.GNN.minStripClusters)
175 
176  acc.addEventAlgo(CompFactory.InDet.SiSPGNNTrackMaker(name, **kwargs))
177  return acc
178 
179 def GNNSeedingTrackMakerCfg(flags, name="GNNSeedingTrackMaker", **kwargs):
180  """Sets up a GNN for seeding algorithm and returns it."""
181  acc = ComponentAccumulator()
182 
183  from InDetConfig.SiCombinatorialTrackFinderToolConfig import SiDetElementBoundaryLinksCondAlg_xk_ITkPixel_Cfg, SiDetElementBoundaryLinksCondAlg_xk_ITkStrip_Cfg
186 
187  # To produce AtlasFieldCacheCondObj
188  from MagFieldServices.MagFieldServicesConfig import (
189  AtlasFieldCacheCondAlgCfg)
190  acc.merge(AtlasFieldCacheCondAlgCfg(flags))
191 
192  from TrkConfig.TrkRIO_OnTrackCreatorConfig import ITkRotCreatorCfg
193  ITkRotCreator = acc.popToolsAndMerge(ITkRotCreatorCfg(
194  flags, name="ITkRotCreator"+flags.Tracking.ActiveConfig.extension))
195  acc.addPublicTool(ITkRotCreator)
196  kwargs.setdefault("RIOonTrackTool", ITkRotCreator)
197 
198  from TrkConfig.TrkExRungeKuttaPropagatorConfig import (
199  RungeKuttaPropagatorCfg)
200  ITkPatternPropagator = acc.popToolsAndMerge(
201  RungeKuttaPropagatorCfg(flags, name="ITkPatternPropagator"))
202  acc.addPublicTool(ITkPatternPropagator)
203  kwargs.setdefault("PropagatorTool", ITkPatternPropagator)
204 
205  from TrkConfig.TrkMeasurementUpdatorConfig import KalmanUpdator_xkCfg
206  ITkPatternUpdator = acc.popToolsAndMerge(
207  KalmanUpdator_xkCfg(flags, name="ITkPatternUpdator"))
208  acc.addPublicTool(ITkPatternUpdator)
209  kwargs.setdefault("UpdatorTool", ITkPatternUpdator)
210 
211  from InDetConfig.InDetBoundaryCheckToolConfig import ITkBoundaryCheckToolCfg
212  kwargs.setdefault("BoundaryCheckTool", acc.popToolsAndMerge(
213  ITkBoundaryCheckToolCfg(flags)))
214 
215  from PixelConditionsTools.ITkPixelConditionsSummaryConfig import (
216  ITkPixelConditionsSummaryCfg)
217  kwargs.setdefault("PixelSummaryTool", acc.popToolsAndMerge(
219 
220  from SCT_ConditionsTools.ITkStripConditionsToolsConfig import (
221  ITkStripConditionsSummaryToolCfg)
222  kwargs.setdefault("StripSummaryTool", acc.popToolsAndMerge(
224 
225  if flags.Tracking.GNN.useTrackFinder:
226  kwargs.setdefault("GNNTrackFinderTool", acc.popToolsAndMerge(GNNTrackFinderToolCfg(flags)))
227  kwargs.setdefault("GNNTrackReaderTool", None)
228  elif flags.Tracking.GNN.useTrackReader:
229  kwargs.setdefault("GNNTrackReaderTool", acc.popToolsAndMerge(GNNTrackReaderToolCfg(flags)))
230  kwargs.setdefault("GNNTrackFinderTool", None)
231  else:
232  raise RuntimeError("GNNTrackFinder or GNNTrackReader must be enabled!")
233 
234  kwargs.setdefault("SeedFitterTool", acc.popToolsAndMerge(SeedFitterToolCfg(flags)))
235 
236  from TrkConfig.CommonTrackFitterConfig import ITkTrackFitterCfg
237  kwargs.setdefault("TrackFitter", acc.popToolsAndMerge(ITkTrackFitterCfg(flags)))
238 
239  from InDetConfig.SiDetElementsRoadToolConfig import ITkSiDetElementsRoadMaker_xkCfg
240  kwargs.setdefault("RoadTool", acc.popToolsAndMerge(ITkSiDetElementsRoadMaker_xkCfg(flags)))
241 
242  # configurations for Kalman filter.
243  # similar to https://gitlab.cern.ch/atlas/athena/-/blob/main/InnerDetector/InDetConfig/python/SiTrackMakerConfig.py#L188
244  kwargs.setdefault("nClustersMin", flags.Tracking.ActiveConfig.minClusters[0])
245  kwargs.setdefault("nWeightedClustersMin", flags.Tracking.ActiveConfig.nWeightedClustersMin[0])
246  kwargs.setdefault("nHolesMax", flags.Tracking.ActiveConfig.nHolesMax[0])
247  kwargs.setdefault("nHolesGapMax", flags.Tracking.ActiveConfig.nHolesGapMax[0])
248 
249  kwargs.setdefault("pTmin", flags.Tracking.ActiveConfig.minPT[0])
250  kwargs.setdefault("pTminBrem", flags.Tracking.ActiveConfig.minPTBrem[0])
251  kwargs.setdefault("Xi2max", flags.Tracking.ActiveConfig.Xi2max[0])
252  kwargs.setdefault("Xi2maxNoAdd", flags.Tracking.ActiveConfig.Xi2maxNoAdd[0])
253  kwargs.setdefault("Xi2maxMultiTracks", flags.Tracking.ActiveConfig.Xi2max[0])
254  kwargs.setdefault("doMultiTracksProd", False)
255 
256  acc.addEventAlgo(CompFactory.InDet.GNNSeedingTrackMaker(name, **kwargs))
257  return acc
python.InDetGNNTrackingConfig.GNNSeedingTrackMakerCfg
def GNNSeedingTrackMakerCfg(flags, name="GNNSeedingTrackMaker", **kwargs)
Definition: InDetGNNTrackingConfig.py:179
python.TrkRIO_OnTrackCreatorConfig.ITkRotCreatorCfg
def ITkRotCreatorCfg(flags, name='ITkRotCreator', **kwargs)
Definition: TrkRIO_OnTrackCreatorConfig.py:134
python.SiDetElementsRoadToolConfig.ITkSiDetElementsRoadMaker_xkCfg
def ITkSiDetElementsRoadMaker_xkCfg(flags, name="ITkSiRoadMaker", **kwargs)
Definition: SiDetElementsRoadToolConfig.py:61
python.JetAnalysisCommon.ComponentAccumulator
ComponentAccumulator
Definition: JetAnalysisCommon.py:302
python.TrkMeasurementUpdatorConfig.KalmanUpdator_xkCfg
def KalmanUpdator_xkCfg(flags, name='KalmanUpdator_xk', **kwargs)
Definition: TrkMeasurementUpdatorConfig.py:14
python.InDetBoundaryCheckToolConfig.ITkBoundaryCheckToolCfg
def ITkBoundaryCheckToolCfg(flags, name='ITkBoundaryCheckTool', **kwargs)
Definition: InDetBoundaryCheckToolConfig.py:102
python.InDetEtaDependentCutsConfig.ITkEtaDependentCutsSvcCfg
def ITkEtaDependentCutsSvcCfg(flags, name='ITkEtaDependentCutsSvc', **kwargs)
Definition: InDetEtaDependentCutsConfig.py:7
ITkPixelConditionsSummaryConfig.ITkPixelConditionsSummaryCfg
def ITkPixelConditionsSummaryCfg(flags, name="ITkPixelConditionsSummary", **kwargs)
Definition: ITkPixelConditionsSummaryConfig.py:13
python.CommonTrackFitterConfig.ITkTrackFitterCfg
def ITkTrackFitterCfg(flags, name='ITkTrackFitter', **kwargs)
ITk configs #####.
Definition: CommonTrackFitterConfig.py:134
python.TritonToolConfig.TritonToolCfg
def TritonToolCfg(flags, str model_name, str url, int port=8001, str model_version="", float timeout=0., bool ssl=False, name="TritonTool", **kwargs)
Definition: TritonToolConfig.py:6
python.SiCombinatorialTrackFinderToolConfig.SiDetElementBoundaryLinksCondAlg_xk_ITkStrip_Cfg
def SiDetElementBoundaryLinksCondAlg_xk_ITkStrip_Cfg(flags, name="ITkSiDetElementBoundaryLinksStripCondAlg", **kwargs)
Definition: SiCombinatorialTrackFinderToolConfig.py:51
python.InDetGNNTrackingConfig.SeedFitterToolCfg
def SeedFitterToolCfg(flags, name="SeedFitterTool", **kwargs)
Definition: InDetGNNTrackingConfig.py:89
python.TrkExRungeKuttaPropagatorConfig.RungeKuttaPropagatorCfg
def RungeKuttaPropagatorCfg(flags, name='AtlasRungeKuttaPropagator', **kwargs)
Definition: TrkExRungeKuttaPropagatorConfig.py:9
python.InDetGNNTrackingConfig.GNNEndToEndTrackMaker
def GNNEndToEndTrackMaker(flags, name="GNNEndToEndTrackMaker", **kwargs)
Definition: InDetGNNTrackingConfig.py:126
python.InDetGNNTrackingConfig.GNNTrackFinderToolCfg
def GNNTrackFinderToolCfg(flags, name='GNNTrackFinderTool', **kwargs)
Definition: InDetGNNTrackingConfig.py:32
python.TrkTrackSummaryToolConfig.ITkTrackSummaryToolCfg
def ITkTrackSummaryToolCfg(flags, name='ITkTrackSummaryTool', **kwargs)
Definition: TrkTrackSummaryToolConfig.py:94
python.InDetGNNTrackingConfig.GNNTrackReaderToolCfg
def GNNTrackReaderToolCfg(flags, name='GNNTrackReaderTool', **kwargs)
Definition: InDetGNNTrackingConfig.py:107
python.OnnxRuntimeInferenceConfig.OnnxRuntimeInferenceToolCfg
def OnnxRuntimeInferenceToolCfg(flags, str model_fname=None, Optional[OnnxRuntimeType] execution_provider=None, name="OnnxRuntimeInferenceTool", **kwargs)
Definition: OnnxRuntimeInferenceConfig.py:9
python.ITkStripConditionsToolsConfig.ITkStripConditionsSummaryToolCfg
def ITkStripConditionsSummaryToolCfg(flags, name="ITkStripConditionsSummaryTool", **kwargs)
Definition: ITkStripConditionsToolsConfig.py:17
str
Definition: BTagTrackIpAccessor.cxx:11
python.MagFieldServicesConfig.AtlasFieldCacheCondAlgCfg
def AtlasFieldCacheCondAlgCfg(flags, **kwargs)
Definition: MagFieldServicesConfig.py:8
python.InDetGNNTrackingConfig.DumpObjectsCfg
def DumpObjectsCfg(flags, name="DumpObjects", outfile="Dump_GNN4Itk.root", **kwargs)
Definition: InDetGNNTrackingConfig.py:12
python.SiCombinatorialTrackFinderToolConfig.SiDetElementBoundaryLinksCondAlg_xk_ITkPixel_Cfg
def SiDetElementBoundaryLinksCondAlg_xk_ITkPixel_Cfg(flags, name="ITkSiDetElementBoundaryLinksPixelCondAlg", **kwargs)
Definition: SiCombinatorialTrackFinderToolConfig.py:36
python.InDetGNNTrackingConfig.SpacepointFeatureToolCfg
def SpacepointFeatureToolCfg(flags, name="SpacepointFeatureTool", **kwargs)
Definition: InDetGNNTrackingConfig.py:98
python.InDetGNNTrackingConfig.GNNTrackMakerCfg
def GNNTrackMakerCfg(flags, name="GNNTrackMaker", **kwargs)
Definition: InDetGNNTrackingConfig.py:118
python.InDetGNNTrackingConfig.GNNTrackFinderTritonToolCfg
def GNNTrackFinderTritonToolCfg(flags, name='GNNTrackFinderTritonTool', **kwargs)
Definition: InDetGNNTrackingConfig.py:71