ATLAS Offline Software
Loading...
Searching...
No Matches
GNNVertexConfig.py
Go to the documentation of this file.
1# Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2#!/usr/bin/env python
3
4from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
5from AthenaCommon.Logging import logging
6
7logPHYS = logging.getLogger('PHYS')
8
10 jetContextName = 'CustomVtxGNN'
11 def customVtxContext(prevflags):
12 context = prevflags.Jet.Context.default.clone(
13 Vertices = "PrimaryVertices_initial",
14 GhostTracks = "PseudoJetGhostTrack",
15 GhostTracksLabel = "GhostTrack",
16 TVA = "JetTrackVtxAssoc"+jetContextName,
17 JetTracks = "JetSelectedTracks"+jetContextName,
18 JetTracksQualityCuts = "JetSelectedTracks"+jetContextName+"_trackSelOpt"
19 )
20 return context
21 flags.addFlag(f"Jet.Context.{jetContextName}", customVtxContext)
22
23
24def CustomJetsCfg(flags):
25
26 acc = ComponentAccumulator()
27
28 CustomJetContainerName = "AntiKt4EMTopoCustomVtxGNNJets"
29
30 from JetRecConfig.StandardJetConstits import stdInputExtDic, JetInputExternal, JetInputConstitSeq, JetConstitModifier, xAODType
31 from JetRecConfig.JetDefinition import JetDefinition
32 from JetRecConfig.StandardSmallRJets import AntiKt4EMTopo
33 from JetRecTools import JetRecToolsConfig as jrtcfg
34 from JetMomentTools import JetMomentToolsConfig
35 from JetRecConfig.StandardJetConstits import stdConstitDic, stdContitModifDic
36 from JetRecConfig.StandardJetContext import inputsFromContext
37 from JetRecConfig.JetInputConfig import buildEventShapeAlg
38
39
40 # Get custom jet context
41 jetContextName = 'CustomVtxGNN'
42 context = flags.Jet.Context[jetContextName]
43
44 def replaceItems(tup,orgName,newName):
45 newList = list(tup)
46 for i, item in enumerate(newList):
47 if orgName in item:
48 newList[i] = item.replace(orgName,newName)
49 print( "Updated ", orgName, " to ", newName )
50 return tuple(newList)
51 print( "Failed to update ", orgName, " to ", newName )
52 return tuple(newList)
53
54 def updateCalibSequence(tup):
55 newList = list(tup)
56
57 rhoname = "Kt4EMTopoCustomVtxGNNEventShape"
58
59 for i, item in enumerate(newList):
60 if "Calib" in item:
61 calibspecs = item.split(":")
62 calib, calibcontext, data_type = calibspecs[:3]
63 calibseq=""
64 if len(calibspecs)>3:
65 calibseq = calibspecs[3]
66 finalCalibString = f"CalibCustomVtxGNN:{calibcontext}:{data_type}:{calibseq}:{rhoname}:PrimaryVertices_initial"
67 if len(calibspecs)>6: finalCalibString = f"{finalCalibString}:{calibspecs[6]}"
68 newList[i] = finalCalibString
69 print(finalCalibString)
70 return tuple(newList)
71 print( "Failed to update calib sequence" )
72 return tuple(newList)
73
74
75 # Create modifier list and JetDefinition
76 modsCustomVtxGNN = AntiKt4EMTopo.modifiers
77 modsCustomVtxGNN = updateCalibSequence(modsCustomVtxGNN)
78 modsCustomVtxGNN = replaceItems(modsCustomVtxGNN,"TrackMoments","TrackMomentsCustomVtxGNN")
79 modsCustomVtxGNN = replaceItems(modsCustomVtxGNN,"TrackSumMoments","TrackSumMomentsCustomVtxGNN")
80 modsCustomVtxGNN = replaceItems(modsCustomVtxGNN,"JVF","JVFCustomVtxGNN")
81 modsCustomVtxGNN = replaceItems(modsCustomVtxGNN,"JVT","JVTCustomVtxGNN")
82 modsCustomVtxGNN = replaceItems(modsCustomVtxGNN,"Charge","ChargeCustomVtxGNN")
83 modsCustomVtxGNN = replaceItems(modsCustomVtxGNN,"jetiso","jetisoCustomVtxGNN")
84
85 ghostCustomVtxGNN = AntiKt4EMTopo.ghostdefs
86
87 # TODO: check this
88 stdConstitDic["EMTopoOriginCustomVtxGNN"] = JetInputConstitSeq("EMTopoOriginCustomVtxGNN", xAODType.CaloCluster, ["EMCustomVtxGNN","OriginCustomVtxGNN"],
89 "CaloCalTopoClusters", "EMOriginCustomVtxGNNTopoClusters", label="EMTopo")
90 stdContitModifDic["OriginCustomVtxGNN"] = JetConstitModifier("OriginCustomVtxGNN", "CaloClusterConstituentsOrigin", prereqs=[inputsFromContext("Vertices")],
91 properties=dict(VertexContainer="PrimaryVertices_initial"))
92 stdContitModifDic["EMCustomVtxGNN"] = JetConstitModifier("EMCustomVtxGNN", "ClusterAtEMScaleTool", )
93
94 AntiKt4EMTopoCustomVtxGNN = JetDefinition("AntiKt",0.4,stdConstitDic.EMTopoOriginCustomVtxGNN,
95 infix = "CustomVtxGNN",
96 context = jetContextName,
97 ghostdefs = ghostCustomVtxGNN,
98 modifiers = modsCustomVtxGNN,
99 lock = True,
100 )
101
102 def getUsedInVertexFitTrackDecoratorAlgCustomVtxGNN(jetdef, jetmod):
103 """ Create the alg to decorate the used-in-fit information for AMVF """
104 context = jetdef._contextDic
105
106 from InDetUsedInFitTrackDecoratorTool.UsedInVertexFitTrackDecoratorConfig import getUsedInVertexFitTrackDecoratorAlg
107 alg = getUsedInVertexFitTrackDecoratorAlg(context['Tracks'], context['Vertices'],
108 vertexDeco='TTVA_AMVFVertices_forGNN',
109 weightDeco='TTVA_AMVFWeights_forGNN')
110 return alg
111
112 # Define new input variables for jet configuration
113 stdInputExtDic[context['Vertices']] = JetInputExternal( context['Vertices'], xAODType.Vertex )
114
115 stdInputExtDic["JetSelectedTracksCustomVtxGNN"] = JetInputExternal("JetSelectedTracksCustomVtxGNN", xAODType.TrackParticle,
116 prereqs= [ f"input:{context['Tracks']}" ], # in std context, this is InDetTrackParticles (see StandardJetContext)
117 algoBuilder = lambda jdef,_ : jrtcfg.getTrackSelAlg(jdef, trackSelOpt=False,
118 DecorDeps=["TTVA_AMVFWeights_forGNN", "TTVA_AMVFVertices_forGNN"] )
119 )
120
121 stdInputExtDic["JetTrackUsedInFitDecoCustomVtxGNN"] = JetInputExternal("JetTrackUsedInFitDecoCustomVtxGNN", xAODType.TrackParticle,
122 prereqs= [ f"input:{context['Tracks']}" , # in std context, this is InDetTrackParticles (see StandardJetContext)
123 f"input:{context['Vertices']}"],
124 algoBuilder = getUsedInVertexFitTrackDecoratorAlgCustomVtxGNN
125 )
126
127 stdInputExtDic["JetTrackVtxAssocCustomVtxGNN"] = JetInputExternal("JetTrackVtxAssocCustomVtxGNN", xAODType.TrackParticle,
128 algoBuilder = lambda jdef,_ : jrtcfg.getJetTrackVtxAlg(jdef._contextDic, algname="jetTVACustomVtxGNN",
129 WorkingPoint="Nonprompt_All_MaxWeight",
130 AMVFVerticesDeco='TTVA_AMVFVertices_forGNN',
131 AMVFWeightsDeco='TTVA_AMVFWeights_forGNN'),
132 prereqs = [ "input:JetTrackUsedInFitDecoCustomVtxGNN", f"input:{context['Vertices']}" ] )
133
134 stdInputExtDic["EventDensityCustomVtxGNN"] = JetInputExternal("EventDensity", "EventShape", algoBuilder = buildEventShapeAlg,
135 containername = lambda jetdef, _ : "Kt4"+jetdef.inputdef.label+"EventShape",
136 prereqs = lambda jetdef : ["input:"+jetdef.inputdef.name] )
137
138
139
140 from JetRecConfig.StandardJetMods import stdJetModifiers
141 from JetRecConfig.JetDefinition import JetModifier
142 from JetCalibTools import JetCalibToolsConfig
143
144 stdJetModifiers.update(
145
146 CalibCustomVtxGNN = JetModifier("JetCalibrationTool","jetcalib_jetcoll_calibseqCustomVtxGNN",
147 createfn=JetCalibToolsConfig.getJetCalibToolFromString,
148 prereqs=lambda mod,jetdef : JetCalibToolsConfig.getJetCalibToolPrereqs(mod,jetdef)+[f"input:{context['Vertices']}"]),
149
150
151 JVFCustomVtxGNN = JetModifier("JetVertexFractionTool", "jvfCustomVtxGNN",
152 createfn= lambda jdef,_ : JetMomentToolsConfig.getJVFTool(jdef,"CustomVtxGNN"),
153 modspec = "CustomVtxGNN",
154 prereqs = ["input:JetTrackVtxAssocCustomVtxGNN", "mod:TrackMomentsCustomVtxGNN", f"input:{context['Vertices']}"] ,
155 JetContainer = CustomJetContainerName),
156
157 JVTCustomVtxGNN = JetModifier("JetVertexTaggerTool", "jvtCustomVtxGNN",
158 createfn= lambda jdef,_ : JetMomentToolsConfig.getJVTTool(jdef,"CustomVtxGNN"),
159 modspec = "CustomVtxGNN",
160 prereqs = [ "mod:JVFCustomVtxGNN" ],JetContainer = CustomJetContainerName),
161
162 NNJVTCustomVtxGNN = JetModifier("JetVertexNNTagger", "nnjvtCustomVtxGNN",
163 createfn=lambda jdef,_ :JetMomentToolsConfig.getNNJvtTool(jdef,"CustomVtxGNN"),
164 prereqs = [ "mod:JVFCustomVtxGNN" ],JetContainer = CustomJetContainerName),
165
166 OriginSetPVCustomVtxGNN = JetModifier("JetOriginCorrectionTool", "origin_setpvCustomVtxGNN",
167 modspec = "CustomVtxGNN",
168 prereqs = [ "mod:JVFCustomVtxGNN" ],JetContainer = CustomJetContainerName, OnlyAssignPV=True),
169
170 TrackMomentsCustomVtxGNN = JetModifier("JetTrackMomentsTool", "trkmomsCustomVtxGNN",
171 createfn= lambda jdef,_ : JetMomentToolsConfig.getTrackMomentsTool(jdef,"CustomVtxGNN"),
172 modspec = "CustomVtxGNN",
173 prereqs = [ "input:JetTrackVtxAssocCustomVtxGNN","ghost:Track" ],JetContainer = CustomJetContainerName),
174
175 TrackSumMomentsCustomVtxGNN = JetModifier("JetTrackSumMomentsTool", "trksummomsCustomVtxGNN",
176 createfn=lambda jdef,_ :JetMomentToolsConfig.getTrackSumMomentsTool(jdef,"CustomVtxGNN"),
177 modspec = "CustomVtxGNN",
178 prereqs = [ "input:JetTrackVtxAssocCustomVtxGNN","ghost:Track" ],JetContainer = CustomJetContainerName),
179
180 ChargeCustomVtxGNN = JetModifier("JetChargeTool", "jetchargeCustomVtxGNN",
181 prereqs = [ "ghost:Track" ]),
182
183
184 QGTaggingCustomVtxGNN = JetModifier("JetQGTaggerVariableTool", "qgtaggingCustomVtxGNN",
185 createfn=lambda jdef,_ :JetMomentToolsConfig.getQGTaggingTool(jdef,"CustomVtxGNN"),
186 modspec = "CustomVtxGNN",
187 prereqs = lambda _,jdef :
188 ["input:JetTrackVtxAssocCustomVtxGNN","mod:TrackMomentsCustomVtxGNN"] +
189 (["mod:JetPtAssociation"] if not jdef._cflags.Input.isMC else []),
190 JetContainer = CustomJetContainerName),
191
192 fJVTCustomVtxGNN = JetModifier("JetForwardPFlowJvtTool", "fJVTCustomVtxGNN",
193 createfn=lambda jdef,_ :JetMomentToolsConfig.getPFlowfJVTTool(jdef,"CustomVtxGNN"),
194 modspec = "CustomVtxGNN",
195 prereqs = ["input:JetTrackVtxAssocCustomVtxGNN","input:EventDensityCustomVtxGNN",f"input:{context['Vertices']}","mod:NNJVTCustomVtxGNN"],
196 JetContainer = CustomJetContainerName),
197 jetisoCustomVtxGNN = JetModifier("JetIsolationTool","isoCustomVtxGNN",
198 JetContainer=CustomJetContainerName,
199 InputConstitContainer = "EMOriginCustomVtxGNNTopoClusters",
200 IsolationCalculations = ["IsoFixedCone:5:Pt", "IsoFixedCone:5:PtPUsub",],
201 RhoKey = lambda jetdef, specs : "Kt4"+jetdef.inputdef.label+"CustomVtxGNNEventShape" ,
202 prereqs= ["input:EventDensityCustomVtxGNN"], #lambda spec,jetdef : ["input:Kt4"+jetdef.inputdef.label+"EventShape",],
203 ),
204
205 )
206
207
208 from JetRecConfig.JetRecConfig import JetRecCfg
209
210 acc.merge(JetRecCfg(flags,AntiKt4EMTopoCustomVtxGNN))
211
212 return acc
213
214
215# Main algorithm config
216def GNNVertexCfg(flags, **kwargs):
217 acc = ComponentAccumulator()
218
219 from SGComps.AddressRemappingConfig import InputRenameCfg
220 acc.merge(InputRenameCfg("xAOD::VertexContainer", "PrimaryVertices", "PrimaryVertices_initial"))
221 acc.merge(InputRenameCfg("xAOD::VertexAuxContainer", "PrimaryVerticesAux.", "PrimaryVertices_initialAux."))
222
223 acc.merge(CustomJetsCfg(flags))
224
225
226 from InDetConfig.InDetGNNHardScatterSelectionConfig import GNNSequenceCfg
227 acc.merge(GNNSequenceCfg(flags))
228
229 from TrkConfig.TrkVertexToolsConfig import GNNVertexCollectionSortingToolCfg
230 vxsort_gnn = GNNVertexCollectionSortingToolCfg(flags)
231 from InDetPriVxFinder.ResortVerticesConfig import ResortVerticesCfg
232 acc.merge(ResortVerticesCfg(flags, "PrimaryVertices_initial", "PrimaryVertices", vxsort_gnn))
233
234 return acc
235
void print(char *figname, TCanvas *c1)
getJetCalibToolPrereqs(modspec, jetdef)
getTrackMomentsTool(jetdef, modspec)
getNNJvtTool(jetdef, modspec)
getPFlowfJVTTool(jetdef, modspec)
getTrackSumMomentsTool(jetdef, modspec)
getQGTaggingTool(jetdef, modspec)
getJVFTool(jetdef, modspec)
GNNVertexCfg(flags, **kwargs)