ATLAS Offline Software
InDetGNNTrackingConfig.py
Go to the documentation of this file.
1 #
2 # Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3 #
4 
5 from pathlib import Path
6 
7 from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
8 from AthenaConfiguration.ComponentFactory import CompFactory
9 
10 
12  flags, name="DumpObjects", outfile="Dump_GNN4Itk.root", **kwargs):
13  '''
14  create algorithm which dumps GNN training information to ROOT file
15  '''
16  acc = ComponentAccumulator()
17 
18  acc.addService(
19  CompFactory.THistSvc(
20  Output=[f"{name} DATAFILE='{outfile}', OPT='RECREATE'"]
21  )
22  )
23 
24  kwargs.setdefault("NtupleFileName", "/DumpObjects/")
25  kwargs.setdefault("NtupleTreeName", "GNN4ITk")
26  kwargs.setdefault("rootFile", True)
27 
28  acc.addEventAlgo(CompFactory.InDet.DumpObjects(name, **kwargs))
29  return acc
30 
31 def GNNTrackFinderToolCfg(flags, name='GNNTrackFinderTool', **kwargs):
32  """Sets up a GNNTrackFinderTool tool and returns it."""
33  acc = ComponentAccumulator()
34 
35 
36  kwargs.setdefault("embeddingDim", 8)
37  kwargs.setdefault("rVal", 1.7)
38  kwargs.setdefault("knnVal", 500)
39  kwargs.setdefault("filterCut", 0.21)
40  kwargs.setdefault("inputMLModelDir", "TrainedMLModels4ITk")
41  kwargs.setdefault("UseCUDA", False)
42 
43  from AthOnnxComps.OnnxRuntimeInferenceConfig import OnnxRuntimeInferenceToolCfg
44  kwargs.setdefault("Embedding", acc.popToolsAndMerge(
45  OnnxRuntimeInferenceToolCfg(flags, Path("TrainedMLModels4ITk") / "embedding.onnx")
46  ))
47  kwargs.setdefault("Filtering", acc.popToolsAndMerge(
48  OnnxRuntimeInferenceToolCfg(flags, Path("TrainedMLModels4ITk") / "filtering.onnx")
49  ))
50  kwargs.setdefault("GNN", acc.popToolsAndMerge(
51  OnnxRuntimeInferenceToolCfg(flags, Path("TrainedMLModels4ITk") / "gnn.onnx")
52  ))
53 
54  acc.setPrivateTools(CompFactory.InDet.SiGNNTrackFinderTool(name, **kwargs))
55  return acc
56 
57 
58 def SeedFitterToolCfg(flags, name="SeedFitterTool", **kwargs):
59  """Sets up a SeedFitter tool and returns it."""
60  acc = ComponentAccumulator()
61 
62 
63  acc.setPrivateTools(CompFactory.InDet.SeedFitterTool(name, **kwargs))
64  return acc
65 
66 def GNNTrackReaderToolCfg(flags, name='GNNTrackReaderTool', **kwargs):
67  """Set up a GNNTrackReader tool and return it."""
68  acc = ComponentAccumulator()
69 
70 
71  kwargs.setdefault("inputTracksDir", "gnntracks")
72  kwargs.setdefault("csvPrefix", "track")
73 
74  acc.setPrivateTools(CompFactory.InDet.GNNTrackReaderTool(name, **kwargs))
75  return acc
76 
77 def GNNTrackMakerCfg(flags, name="GNNTrackMaker", **kwargs):
78  """Sets up a GNNTrackMaker algorithm and returns it."""
79 
80  if flags.Tracking.GNN.usePixelHitsOnly:
81  return GNNSeedingTrackMakerCfg(flags, name, **kwargs)
82  else:
83  return GNNEndToEndTrackMaker(flags, name, **kwargs)
84 
85 def GNNEndToEndTrackMaker(flags, name="GNNEndToEndTrackMaker", **kwargs):
86  """Sets up a GNNTrackMaker algorithm and returns it."""
87 
88  acc = ComponentAccumulator()
89 
90 
91  SeedFitterTool = acc.popToolsAndMerge(SeedFitterToolCfg(flags))
92  kwargs.setdefault("SeedFitterTool", SeedFitterTool)
93 
94  from TrkConfig.CommonTrackFitterConfig import ITkTrackFitterCfg
95  InDetTrackFitter = acc.popToolsAndMerge(ITkTrackFitterCfg(flags))
96  kwargs.setdefault("TrackFitter", InDetTrackFitter)
97 
98  if flags.Tracking.GNN.useTrackFinder:
99  InDetGNNTrackFinderTool = acc.popToolsAndMerge(GNNTrackFinderToolCfg(flags))
100  kwargs.setdefault("GNNTrackFinderTool", InDetGNNTrackFinderTool)
101  kwargs.setdefault("GNNTrackReaderTool", None)
102  elif flags.Tracking.GNN.useTrackReader:
103  InDetGNNTrackReader = acc.popToolsAndMerge(GNNTrackReaderToolCfg(flags))
104  kwargs.setdefault("GNNTrackReaderTool", InDetGNNTrackReader)
105  kwargs.setdefault("GNNTrackFinderTool", None)
106  else:
107  raise RuntimeError("GNNTrackFinder or GNNTrackReader must be enabled!")
108 
109  acc.addEventAlgo(CompFactory.InDet.SiSPGNNTrackMaker(name, **kwargs))
110  return acc
111 
112 def GNNSeedingTrackMakerCfg(flags, name="GNNSeedingTrackMaker", **kwargs):
113  """Sets up a GNN for seeding algorithm and returns it."""
114  acc = ComponentAccumulator()
115 
116  from InDetConfig.SiCombinatorialTrackFinderToolConfig import SiDetElementBoundaryLinksCondAlg_xk_ITkPixel_Cfg, SiDetElementBoundaryLinksCondAlg_xk_ITkStrip_Cfg
119 
120  # To produce AtlasFieldCacheCondObj
121  from MagFieldServices.MagFieldServicesConfig import (
122  AtlasFieldCacheCondAlgCfg)
123  acc.merge(AtlasFieldCacheCondAlgCfg(flags))
124 
125  from TrkConfig.TrkRIO_OnTrackCreatorConfig import ITkRotCreatorCfg
126  ITkRotCreator = acc.popToolsAndMerge(ITkRotCreatorCfg(
127  flags, name="ITkRotCreator"+flags.Tracking.ActiveConfig.extension))
128  acc.addPublicTool(ITkRotCreator)
129  kwargs.setdefault("RIOonTrackTool", ITkRotCreator)
130 
131  from TrkConfig.TrkExRungeKuttaPropagatorConfig import (
132  RungeKuttaPropagatorCfg)
133  ITkPatternPropagator = acc.popToolsAndMerge(
134  RungeKuttaPropagatorCfg(flags, name="ITkPatternPropagator"))
135  acc.addPublicTool(ITkPatternPropagator)
136  kwargs.setdefault("PropagatorTool", ITkPatternPropagator)
137 
138  from TrkConfig.TrkMeasurementUpdatorConfig import KalmanUpdator_xkCfg
139  ITkPatternUpdator = acc.popToolsAndMerge(
140  KalmanUpdator_xkCfg(flags, name="ITkPatternUpdator"))
141  acc.addPublicTool(ITkPatternUpdator)
142  kwargs.setdefault("UpdatorTool", ITkPatternUpdator)
143 
144  from InDetConfig.InDetBoundaryCheckToolConfig import ITkBoundaryCheckToolCfg
145  kwargs.setdefault("BoundaryCheckTool", acc.popToolsAndMerge(
146  ITkBoundaryCheckToolCfg(flags)))
147 
148  from PixelConditionsTools.ITkPixelConditionsSummaryConfig import (
149  ITkPixelConditionsSummaryCfg)
150  kwargs.setdefault("PixelSummaryTool", acc.popToolsAndMerge(
152 
153  from SCT_ConditionsTools.ITkStripConditionsToolsConfig import (
154  ITkStripConditionsSummaryToolCfg)
155  kwargs.setdefault("StripSummaryTool", acc.popToolsAndMerge(
157 
158  if flags.Tracking.GNN.useTrackFinder:
159  kwargs.setdefault("GNNTrackFinderTool", acc.popToolsAndMerge(GNNTrackFinderToolCfg(flags)))
160  kwargs.setdefault("GNNTrackReaderTool", None)
161  elif flags.Tracking.GNN.useTrackReader:
162  kwargs.setdefault("GNNTrackReaderTool", acc.popToolsAndMerge(GNNTrackReaderToolCfg(flags)))
163  kwargs.setdefault("GNNTrackFinderTool", None)
164  else:
165  raise RuntimeError("GNNTrackFinder or GNNTrackReader must be enabled!")
166 
167  kwargs.setdefault("SeedFitterTool", acc.popToolsAndMerge(SeedFitterToolCfg(flags)))
168 
169  from TrkConfig.CommonTrackFitterConfig import ITkTrackFitterCfg
170  kwargs.setdefault("TrackFitter", acc.popToolsAndMerge(ITkTrackFitterCfg(flags)))
171 
172  from InDetConfig.SiDetElementsRoadToolConfig import ITkSiDetElementsRoadMaker_xkCfg
173  kwargs.setdefault("RoadTool", acc.popToolsAndMerge(ITkSiDetElementsRoadMaker_xkCfg(flags)))
174 
175  # configurations for Kalman filter.
176  # similar to https://gitlab.cern.ch/atlas/athena/-/blob/main/InnerDetector/InDetConfig/python/SiTrackMakerConfig.py#L188
177  kwargs.setdefault("nClustersMin", flags.Tracking.ActiveConfig.minClusters[0])
178  kwargs.setdefault("nWeightedClustersMin", flags.Tracking.ActiveConfig.nWeightedClustersMin[0])
179  kwargs.setdefault("nHolesMax", flags.Tracking.ActiveConfig.nHolesMax[0])
180  kwargs.setdefault("nHolesGapMax", flags.Tracking.ActiveConfig.nHolesGapMax[0])
181 
182  kwargs.setdefault("pTmin", flags.Tracking.ActiveConfig.minPT[0])
183  kwargs.setdefault("pTminBrem", flags.Tracking.ActiveConfig.minPTBrem[0])
184  kwargs.setdefault("Xi2max", flags.Tracking.ActiveConfig.Xi2max[0])
185  kwargs.setdefault("Xi2maxNoAdd", flags.Tracking.ActiveConfig.Xi2maxNoAdd[0])
186  kwargs.setdefault("Xi2maxMultiTracks", flags.Tracking.ActiveConfig.Xi2max[0])
187  kwargs.setdefault("doMultiTracksProd", False)
188 
189  acc.addEventAlgo(CompFactory.InDet.GNNSeedingTrackMaker(name, **kwargs))
190  return acc
python.InDetGNNTrackingConfig.GNNSeedingTrackMakerCfg
def GNNSeedingTrackMakerCfg(flags, name="GNNSeedingTrackMaker", **kwargs)
Definition: InDetGNNTrackingConfig.py:112
python.TrkRIO_OnTrackCreatorConfig.ITkRotCreatorCfg
def ITkRotCreatorCfg(flags, name='ITkRotCreator', **kwargs)
Definition: TrkRIO_OnTrackCreatorConfig.py:134
python.SiDetElementsRoadToolConfig.ITkSiDetElementsRoadMaker_xkCfg
def ITkSiDetElementsRoadMaker_xkCfg(flags, name="ITkSiRoadMaker", **kwargs)
Definition: SiDetElementsRoadToolConfig.py:61
python.JetAnalysisCommon.ComponentAccumulator
ComponentAccumulator
Definition: JetAnalysisCommon.py:302
python.TrkMeasurementUpdatorConfig.KalmanUpdator_xkCfg
def KalmanUpdator_xkCfg(flags, name='KalmanUpdator_xk', **kwargs)
Definition: TrkMeasurementUpdatorConfig.py:14
python.InDetBoundaryCheckToolConfig.ITkBoundaryCheckToolCfg
def ITkBoundaryCheckToolCfg(flags, name='ITkBoundaryCheckTool', **kwargs)
Definition: InDetBoundaryCheckToolConfig.py:102
ITkPixelConditionsSummaryConfig.ITkPixelConditionsSummaryCfg
def ITkPixelConditionsSummaryCfg(flags, name="ITkPixelConditionsSummary", **kwargs)
Definition: ITkPixelConditionsSummaryConfig.py:12
python.CommonTrackFitterConfig.ITkTrackFitterCfg
def ITkTrackFitterCfg(flags, name='ITkTrackFitter', **kwargs)
ITk configs #####.
Definition: CommonTrackFitterConfig.py:126
python.SiCombinatorialTrackFinderToolConfig.SiDetElementBoundaryLinksCondAlg_xk_ITkStrip_Cfg
def SiDetElementBoundaryLinksCondAlg_xk_ITkStrip_Cfg(flags, name="ITkSiDetElementBoundaryLinksStripCondAlg", **kwargs)
Definition: SiCombinatorialTrackFinderToolConfig.py:51
python.InDetGNNTrackingConfig.SeedFitterToolCfg
def SeedFitterToolCfg(flags, name="SeedFitterTool", **kwargs)
Definition: InDetGNNTrackingConfig.py:58
python.TrkExRungeKuttaPropagatorConfig.RungeKuttaPropagatorCfg
def RungeKuttaPropagatorCfg(flags, name='AtlasRungeKuttaPropagator', **kwargs)
Definition: TrkExRungeKuttaPropagatorConfig.py:9
python.InDetGNNTrackingConfig.GNNEndToEndTrackMaker
def GNNEndToEndTrackMaker(flags, name="GNNEndToEndTrackMaker", **kwargs)
Definition: InDetGNNTrackingConfig.py:85
python.InDetGNNTrackingConfig.GNNTrackFinderToolCfg
def GNNTrackFinderToolCfg(flags, name='GNNTrackFinderTool', **kwargs)
Definition: InDetGNNTrackingConfig.py:31
python.InDetGNNTrackingConfig.GNNTrackReaderToolCfg
def GNNTrackReaderToolCfg(flags, name='GNNTrackReaderTool', **kwargs)
Definition: InDetGNNTrackingConfig.py:66
python.OnnxRuntimeInferenceConfig.OnnxRuntimeInferenceToolCfg
def OnnxRuntimeInferenceToolCfg(flags, str model_fname=None, Optional[OnnxRuntimeType] execution_provider=None, name="OnnxRuntimeInferenceTool", **kwargs)
Definition: OnnxRuntimeInferenceConfig.py:9
python.ITkStripConditionsToolsConfig.ITkStripConditionsSummaryToolCfg
def ITkStripConditionsSummaryToolCfg(flags, name="ITkStripConditionsSummaryTool", **kwargs)
Definition: ITkStripConditionsToolsConfig.py:16
python.MagFieldServicesConfig.AtlasFieldCacheCondAlgCfg
def AtlasFieldCacheCondAlgCfg(flags, **kwargs)
Definition: MagFieldServicesConfig.py:8
python.InDetGNNTrackingConfig.DumpObjectsCfg
def DumpObjectsCfg(flags, name="DumpObjects", outfile="Dump_GNN4Itk.root", **kwargs)
Definition: InDetGNNTrackingConfig.py:11
python.SiCombinatorialTrackFinderToolConfig.SiDetElementBoundaryLinksCondAlg_xk_ITkPixel_Cfg
def SiDetElementBoundaryLinksCondAlg_xk_ITkPixel_Cfg(flags, name="ITkSiDetElementBoundaryLinksPixelCondAlg", **kwargs)
Definition: SiCombinatorialTrackFinderToolConfig.py:36
python.InDetGNNTrackingConfig.GNNTrackMakerCfg
def GNNTrackMakerCfg(flags, name="GNNTrackMaker", **kwargs)
Definition: InDetGNNTrackingConfig.py:77