ATLAS Offline Software
Loading...
Searching...
No Matches
InDetGNNHardScatterSelectionConfig.py
Go to the documentation of this file.
1# Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2# A GNN based algorithm for selecting the Hard Scatter vertex (truth primary vertex
3# that is simulated for the hard process, i.e. the PV in the TruthEvent)
4
5from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
6from AthenaConfiguration.ComponentFactory import CompFactory
7
8from TrkConfig.VertexFindingFlags import VertexSortingSetup
9# Overlap removal configurations
10
11def AsgPtEtaSelectionToolCfg(flags, name="AsgPtEtaSelectionTool", **kwargs):
12 cfg = ComponentAccumulator()
13 kwargs.setdefault("maxEta", 2.5)
14 cfg.setPrivateTools(CompFactory.CP.AsgPtEtaSelectionTool(name, **kwargs))
15 return cfg
16
17def AsgPtEtaSelectionToolGapCfg(flags, name="AsgPtEtaSelectionTool", **kwargs):
18 kwargs.setdefault("etaGapLow", 1.37)
19 kwargs.setdefault("etaGapHigh", 1.52)
20 return AsgPtEtaSelectionToolCfg(flags, name, **kwargs)
21
22def AsgViewFromSelectionAlgCfg(flags, name="AsgViewFromSelectionAlg", **kwargs):
23 cfg = ComponentAccumulator()
24 kwargs.setdefault("selection", "selectPtEta")
25 kwargs.setdefault("deepCopy", False)
26 cfg.addEventAlgo(CompFactory.CP.AsgViewFromSelectionAlg(name, **kwargs))
27 return cfg
28
29def GNNHSSelectionAlgCfg(flags, input, minPt):
30 cfg = ComponentAccumulator()
31
32 selectionTool = None
33 if input in ["Electrons", "Photons", "AntiKt4EMTopoJets"]:
34 selectionTool = cfg.popToolsAndMerge(
35 AsgPtEtaSelectionToolGapCfg(flags, minPt = minPt))
36 elif input in ["Muons"]:
37 selectionTool = cfg.popToolsAndMerge(
38 AsgPtEtaSelectionToolCfg(flags, minPt = minPt))
39
40 selection_outputs = [
41 ("SG::AuxVectorBase", f"StoreGateSvc+{input}.selectPtEta"),
42 ("xAOD::IParticleContainer", f"StoreGateSvc+{input}.selectPtEta"),
43 ]
44
45 cfg.addEventAlgo(CompFactory.CP.AsgSelectionAlg(
46 name = "GNNHS_"+input+"_SelectionAlg",
47 selectionTool = selectionTool,
48 selectionDecoration = "selectPtEta,as_char",
49 particles = input,
50 ExtraOutputs = selection_outputs))
51
52 return cfg
53
54
55def GNNHSOverlapRemovalToolCfg(flags, name="GNNHS_OverlapRemovalToolCfg", **kwargs):
56 cfg = ComponentAccumulator()
57
58 kwargs.setdefault("InputLabel", "selectPtEta")
59 kwargs.setdefault("OutputLabel", "passesOR")
60 kwargs.setdefault("OutputPassValue", True)
61
62 subtool_kwargs={}
63 for prop in ["InputLabel", "OutputLabel", "OutputPassValue"]:
64 subtool_kwargs[prop] = kwargs[prop]
65
66 kwargs.setdefault("EleEleORT", CompFactory.ORUtils.EleEleOverlapTool(**subtool_kwargs))
67 kwargs.setdefault("EleMuORT", CompFactory.ORUtils.EleMuSharedTrkOverlapTool(**subtool_kwargs))
68 kwargs.setdefault("EleJetORT", CompFactory.ORUtils.EleJetOverlapTool(**subtool_kwargs))
69 kwargs.setdefault(
70 "MuJetORT",
71 CompFactory.ORUtils.MuJetOverlapTool(
72 PVContainerName=(
73 "PrimaryVertices_initial"
74 if flags.Tracking.PriVertex.sortingSetup is VertexSortingSetup.GNNSorting
75 else "PrimaryVertices"
76 ),
77 AllowNoPV=True,
78 **subtool_kwargs
79 )
80 )
81 kwargs.setdefault("PhoEleORT", CompFactory.ORUtils.DeltaROverlapTool(**subtool_kwargs))
82 kwargs.setdefault("PhoMuORT", CompFactory.ORUtils.DeltaROverlapTool(**subtool_kwargs))
83 kwargs.setdefault("PhoJetORT", CompFactory.ORUtils.DeltaROverlapTool(**subtool_kwargs))
84
85 cfg.setPrivateTools(CompFactory.ORUtils.OverlapRemovalTool(name, **kwargs))
86 return cfg
87
88def GNNHSOverlapRemovalAlgCfg(flags, name="GNNHS_OverlapRemovalAlg",
89 overlapInputNames = None, overlapOutputNames = None, **kwargs):
90 cfg = ComponentAccumulator()
91
92 kwargs.setdefault("OutputLabel", "passesOR")
93 kwargs.setdefault("affectingSystematicsFilter", ".*")
94
95 for obj in overlapInputNames:
96 kwargs.setdefault(obj, overlapInputNames[obj])
97 kwargs.setdefault(obj+"Decoration", kwargs["OutputLabel"] + ",as_char")
98
99 extraInputs = set(kwargs.get("ExtraInputs", set()))
100 for inputContainer in overlapInputNames.values():
101 extraInputs.add(("xAOD::IParticleContainer", f"StoreGateSvc+{inputContainer}.selectPtEta"))
102 kwargs["ExtraInputs"] = list(extraInputs)
103
104 overlap_outputs = [
105 ("SG::AuxVectorBase", f"StoreGateSvc+{overlapInputNames['jets']}.passesOR"),
106 ("xAOD::IParticleContainer", f"StoreGateSvc+{overlapInputNames['jets']}.passesOR"),
107 ("SG::AuxVectorBase", f"StoreGateSvc+{overlapInputNames['electrons']}.passesOR"),
108 ("xAOD::IParticleContainer", f"StoreGateSvc+{overlapInputNames['electrons']}.passesOR"),
109 ("SG::AuxVectorBase", f"StoreGateSvc+{overlapInputNames['muons']}.passesOR"),
110 ("xAOD::IParticleContainer", f"StoreGateSvc+{overlapInputNames['muons']}.passesOR"),
111 ("SG::AuxVectorBase", f"StoreGateSvc+{overlapInputNames['photons']}.passesOR"),
112 ("xAOD::IParticleContainer", f"StoreGateSvc+{overlapInputNames['photons']}.passesOR"),
113 ]
114
115 kwargs.setdefault("overlapTool", cfg.popToolsAndMerge(GNNHSOverlapRemovalToolCfg(flags)))
116
117 cfg.addEventAlgo(CompFactory.CP.OverlapRemovalAlg(name, ExtraOutputs=overlap_outputs, **kwargs))
118
119 or_container_types = {
120 "jets": ["xAOD::JetContainer", "xAOD::IParticleContainer"],
121 "electrons": ["xAOD::ElectronContainer", "xAOD::IParticleContainer"],
122 "muons": ["xAOD::MuonContainer", "xAOD::IParticleContainer"],
123 "photons": ["xAOD::PhotonContainer", "xAOD::EgammaContainer", "xAOD::IParticleContainer"],
124 }
125
126 for obj in overlapInputNames:
127 output_name = overlapOutputNames[obj]
128 view_outputs = [("xAOD::AuxContainerBase", f"StoreGateSvc+{output_name}Aux.")]
129 for out_type in or_container_types[obj]:
130 view_outputs.append((out_type, f"StoreGateSvc+{output_name}"))
131 view_outputs.extend([
132 ("SG::AuxVectorBase", f"StoreGateSvc+{output_name}.passesOR"),
133 ("xAOD::IParticleContainer", f"StoreGateSvc+{output_name}.passesOR"),
134 ("SG::AuxVectorBase", f"StoreGateSvc+{output_name}.selectPtEta"),
135 ("xAOD::IParticleContainer", f"StoreGateSvc+{output_name}.selectPtEta"),
136 ])
137
138 cfg.addEventAlgo(CompFactory.CP.AsgViewFromSelectionAlg(
139 name = "GNNHS_"+obj+"_ORSelectionAlg",
140 input = overlapInputNames[obj],
141 output = output_name,
142 selection = [kwargs["OutputLabel"]+",as_char"],
143 deepCopy = True,
144 ExtraOutputs = view_outputs))
145
146 return cfg
147
148
149# GNNTool + VertexDecoratorAlg configs
150
151def GNNToolCfg(flags, name="HardScatterSelectionGNNTool", **kwargs):
152 acc = ComponentAccumulator()
153 acc.setPrivateTools(CompFactory.InDetGNNHardScatterSelection.GNNTool(name, **kwargs))
154 return acc
155
156def GNNHSVertexDecoratorAlgCfg(flags, name="GNNHS_VertexDecoratorAlg", **kwargs):
157 cfg = ComponentAccumulator()
158
159 kwargs.setdefault("photonsIn", "Photons")
160
161 if "gnnTool" not in kwargs:
162
163 kwargs.setdefault("gnnTool", cfg.popToolsAndMerge(
164 GNNToolCfg(flags,
165 nnFile="InDetGNNHardScatterSelection/v1.2/HSGNN_baseline_v1.2.onnx")))
166
167 if "TrackVertexAssociationTool" not in kwargs:
168 from TrackVertexAssociationTool.TrackVertexAssociationToolConfig import GNNHS_TTVAToolCfg
169 kwargs.setdefault("TrackVertexAssociationTool", cfg.popToolsAndMerge(
170 GNNHS_TTVAToolCfg(flags)))
171
172 vertex_extra_inputs = set(kwargs.get("ExtraInputs", set()))
173 for cont in ["electronsIn", "muonsIn", "photonsIn", "jetsIn"]:
174 if cont in kwargs:
175 vertex_extra_inputs.add(("xAOD::IParticleContainer", f"StoreGateSvc+{kwargs[cont]}"))
176 if "photonsIn" in kwargs:
177 photons_key = kwargs["photonsIn"]
178 vertex_extra_inputs.update({
179 ("xAOD::IParticleContainer", f"StoreGateSvc+{photons_key}.zCommon"),
180 ("xAOD::IParticleContainer", f"StoreGateSvc+{photons_key}.caloPointingZ"),
181 ("xAOD::IParticleContainer", f"StoreGateSvc+{photons_key}.zCommonError"),
182 })
183 kwargs["ExtraInputs"] = list(vertex_extra_inputs)
184
185 cfg.addEventAlgo(CompFactory.InDetGNNHardScatterSelection.VertexDecoratorAlg(name, **kwargs))
186 return cfg
187
188
189# Global Sequence
190def GNNSequenceCfg(flags, doOverlapRemoval=True):
191 cfg = ComponentAccumulator()
192
193 selectionSvc = CompFactory.CP.SelectionNameSvc("SelectionNameSvc")
194 cfg.addService(selectionSvc)
195
196 inputCollections = {
197 "jets": (
198 "AntiKt4EMTopoCustomVtxGNNJets"
199 if flags.Tracking.PriVertex.sortingSetup is VertexSortingSetup.GNNSorting
200 else "AntiKt4EMTopoJets"
201 ),
202 "electrons": "Electrons",
203 "muons": "Muons",
204 "photons": "Photons",
205 }
206
207 ptThresholds = {
208 "jets": 15000,
209 "electrons": 4500,
210 "muons": 3000,
211 "photons": 10000,
212 }
213
214 for obj in inputCollections:
215 cfg.merge(GNNHSSelectionAlgCfg(flags, input = inputCollections[obj],
216 minPt = ptThresholds[obj]))
217
218 if doOverlapRemoval:
219 # OverlapRemovalAlg is configured with affectingSystematicsFilter,
220 # so the CP systematics service is only needed in this branch.
221 cfg.addService(CompFactory.CP.SystematicsSvc("SystematicsSvc"))
222 overlapOutputNames = {
223 "muons": f'{inputCollections["muons"]}_OR',
224 "electrons": f'{inputCollections["electrons"]}_OR',
225 "photons": f'{inputCollections["photons"]}_OR',
226 "jets": f'{inputCollections["jets"]}_OR',
227 }
228 cfg.merge(GNNHSOverlapRemovalAlgCfg(flags, overlapInputNames = inputCollections,
229 overlapOutputNames = overlapOutputNames))
230 else:
231 # If overlap removal is disabled, use original containers
232 overlapOutputNames = inputCollections
233
234 from PhotonVertexSelection.PhotonVertexSelectionConfig import DecoratePhotonPointingAlgCfg
235 photon_key = overlapOutputNames["photons"]
236 cfg.merge(
237 DecoratePhotonPointingAlgCfg(
238 flags,
239 PhotonContainerKey=photon_key,
240 ExtraInputs=[
241 ("xAOD::EgammaContainer", f"StoreGateSvc+{photon_key}"),
242 ],
243 ExtraOutputs=[
244 ("xAOD::IParticleContainer", f"StoreGateSvc+{photon_key}.zCommon"),
245 ("xAOD::IParticleContainer", f"StoreGateSvc+{photon_key}.caloPointingZ"),
246 ("xAOD::IParticleContainer", f"StoreGateSvc+{photon_key}.zCommonError"),
247 ],
248 )
249 )
250
251 cfg.merge(
253 flags,
254 vertexIn=(
255 "PrimaryVertices_initial"
256 if flags.Tracking.PriVertex.sortingSetup is VertexSortingSetup.GNNSorting
257 else "PrimaryVertices"
258 ),
259 electronsIn=overlapOutputNames["electrons"],
260 muonsIn=overlapOutputNames["muons"],
261 photonsIn=overlapOutputNames["photons"],
262 jetsIn=overlapOutputNames["jets"],
263 )
264 )
265
266 return cfg
267
268
269if __name__ == "__main__":
270 from AthenaConfiguration.AllConfigFlags import initConfigFlags
271 flags = initConfigFlags()
272
273 # AOD input
274 from AthenaConfiguration.TestDefaults import defaultTestFiles
275 flags.Input.Files = defaultTestFiles.AOD_RUN3_MC
276 flags.Exec.MaxEvents = 100
277 flags.lock()
278
279 from AthenaConfiguration.MainServicesConfig import MainServicesCfg
280 top_acc = MainServicesCfg(flags)
281
282 from AthenaPoolCnvSvc.PoolReadConfig import PoolReadCfg
283 top_acc.merge(PoolReadCfg(flags))
284
285 from InDetPhysValMonitoring.addRecoJetsConfig import (
286 AddRecoJetsIfNotExistingCfg)
287 top_acc.merge(AddRecoJetsIfNotExistingCfg(
288 flags, "AntiKt4EMTopoJets"))
289
290 top_acc.merge(GNNSequenceCfg(flags))
291 from AthenaCommon.Constants import DEBUG
292 top_acc.foreach_component("AthEventSeq/*").OutputLevel = DEBUG
293 top_acc.printConfig(withDetails=True, summariseProps=True)
294 top_acc.store(open("GNNSequenceConfig.pkl", "wb"))
295
296 import sys
297 if "--norun" not in sys.argv:
298 sc = top_acc.run(1)
299 if sc.isFailure():
300 sys.exit(-1)
STL class.
GNNHSVertexDecoratorAlgCfg(flags, name="GNNHS_VertexDecoratorAlg", **kwargs)
AsgPtEtaSelectionToolCfg(flags, name="AsgPtEtaSelectionTool", **kwargs)
GNNToolCfg(flags, name="HardScatterSelectionGNNTool", **kwargs)
GNNHSOverlapRemovalAlgCfg(flags, name="GNNHS_OverlapRemovalAlg", overlapInputNames=None, overlapOutputNames=None, **kwargs)
GNNHSOverlapRemovalToolCfg(flags, name="GNNHS_OverlapRemovalToolCfg", **kwargs)
AsgViewFromSelectionAlgCfg(flags, name="AsgViewFromSelectionAlg", **kwargs)
AsgPtEtaSelectionToolGapCfg(flags, name="AsgPtEtaSelectionTool", **kwargs)