Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
GNNVertexConfig.py
Go to the documentation of this file.
1 # Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
2 #!/usr/bin/env python
3 
4 from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
5 from AthenaCommon.Logging import logging
6 
7 logPHYS = logging.getLogger('PHYS')
8 
9 def addJetContextFlags(flags):
10  jetContextName = 'CustomVtxGNN'
11  def customVtxContext(prevflags):
12  context = prevflags.Jet.Context.default.clone(
13  Vertices = "PrimaryVertices_initial",
14  GhostTracks = "PseudoJetGhostTrack",
15  GhostTracksLabel = "GhostTrack",
16  TVA = "JetTrackVtxAssoc"+jetContextName,
17  JetTracks = "JetSelectedTracks"+jetContextName,
18  JetTracksQualityCuts = "JetSelectedTracks"+jetContextName+"_trackSelOpt"
19  )
20  return context
21  flags.addFlag(f"Jet.Context.{jetContextName}", customVtxContext)
22 
23 
24 def CustomJetsCfg(flags):
25 
26  acc = ComponentAccumulator()
27 
28  CustomJetContainerName = "AntiKt4EMTopoCustomVtxGNNJets"
29 
30  from JetRecConfig.StandardJetConstits import stdInputExtDic, JetInputExternal, JetInputConstitSeq, JetConstitModifier, xAODType
31  from JetRecConfig.JetDefinition import JetDefinition
32  from JetRecConfig.StandardSmallRJets import AntiKt4EMTopo
33  from JetRecTools import JetRecToolsConfig as jrtcfg
34  from JetMomentTools import JetMomentToolsConfig
35  from JetRecConfig.StandardJetConstits import stdConstitDic, stdContitModifDic
36  from JetRecConfig.StandardJetContext import inputsFromContext
37  from JetRecConfig.JetInputConfig import buildEventShapeAlg
38 
39 
40  # Get custom jet context
41  jetContextName = 'CustomVtxGNN'
42  context = flags.Jet.Context[jetContextName]
43 
44  def replaceItems(tup,orgName,newName):
45  newList = list(tup)
46  for i, item in enumerate(newList):
47  if orgName in item:
48  newList[i] = item.replace(orgName,newName)
49  print( "Updated ", orgName, " to ", newName )
50  return tuple(newList)
51  print( "Failed to update ", orgName, " to ", newName )
52  return tuple(newList)
53 
54  def updateCalibSequence(tup):
55  newList = list(tup)
56 
57  rhoname = "Kt4EMTopoCustomVtxGNNEventShape"
58 
59  for i, item in enumerate(newList):
60  if "Calib" in item:
61  calibspecs = item.split(":")
62  calib, calibcontext, data_type = calibspecs[:3]
63  calibseq=""
64  if len(calibspecs)>3:
65  calibseq = calibspecs[3]
66  finalCalibString = f"CalibCustomVtxGNN:{calibcontext}:{data_type}:{calibseq}:{rhoname}:PrimaryVertices_initial"
67  if len(calibspecs)>6: finalCalibString = f"{finalCalibString}:{calibspecs[6]}"
68  newList[i] = finalCalibString
69  print(finalCalibString)
70  return tuple(newList)
71  print( "Failed to update calib sequence" )
72  return tuple(newList)
73 
74 
75  # Create modifier list and JetDefinition
76  modsCustomVtxGNN = AntiKt4EMTopo.modifiers
77  modsCustomVtxGNN = updateCalibSequence(modsCustomVtxGNN)
78  modsCustomVtxGNN = replaceItems(modsCustomVtxGNN,"TrackMoments","TrackMomentsCustomVtxGNN")
79  modsCustomVtxGNN = replaceItems(modsCustomVtxGNN,"TrackSumMoments","TrackSumMomentsCustomVtxGNN")
80  modsCustomVtxGNN = replaceItems(modsCustomVtxGNN,"JVF","JVFCustomVtxGNN")
81  modsCustomVtxGNN = replaceItems(modsCustomVtxGNN,"JVT","JVTCustomVtxGNN")
82  modsCustomVtxGNN = replaceItems(modsCustomVtxGNN,"Charge","ChargeCustomVtxGNN")
83  modsCustomVtxGNN = replaceItems(modsCustomVtxGNN,"jetiso","jetisoCustomVtxGNN")
84 
85  ghostCustomVtxGNN = AntiKt4EMTopo.ghostdefs
86 
87  # TODO: check this
88  stdConstitDic["EMTopoOriginCustomVtxGNN"] = JetInputConstitSeq("EMTopoOriginCustomVtxGNN", xAODType.CaloCluster, ["EMCustomVtxGNN","OriginCustomVtxGNN"],
89  "CaloCalTopoClusters", "EMOriginCustomVtxGNNTopoClusters", label="EMTopo")
90  stdContitModifDic["OriginCustomVtxGNN"] = JetConstitModifier("OriginCustomVtxGNN", "CaloClusterConstituentsOrigin", prereqs=[inputsFromContext("Vertices")],
91  properties=dict(VertexContainer="PrimaryVertices_initial"))
92  stdContitModifDic["EMCustomVtxGNN"] = JetConstitModifier("EMCustomVtxGNN", "ClusterAtEMScaleTool", )
93 
94  AntiKt4EMTopoCustomVtxGNN = JetDefinition("AntiKt",0.4,stdConstitDic.EMTopoOriginCustomVtxGNN,
95  infix = "CustomVtxGNN",
96  context = jetContextName,
97  ghostdefs = ghostCustomVtxGNN,
98  modifiers = modsCustomVtxGNN,
99  lock = True,
100  )
101 
102  def getUsedInVertexFitTrackDecoratorAlgCustomVtxGNN(jetdef, jetmod):
103  """ Create the alg to decorate the used-in-fit information for AMVF """
104  context = jetdef._contextDic
105 
106  from InDetUsedInFitTrackDecoratorTool.UsedInVertexFitTrackDecoratorConfig import getUsedInVertexFitTrackDecoratorAlg
107  alg = getUsedInVertexFitTrackDecoratorAlg(context['Tracks'], context['Vertices'],
108  vertexDeco='TTVA_AMVFVertices_forGNN',
109  weightDeco='TTVA_AMVFWeights_forGNN')
110  return alg
111 
112  # Define new input variables for jet configuration
113  stdInputExtDic[context['Vertices']] = JetInputExternal( context['Vertices'], xAODType.Vertex )
114 
115  stdInputExtDic["JetSelectedTracksCustomVtxGNN"] = JetInputExternal("JetSelectedTracksCustomVtxGNN", xAODType.TrackParticle,
116  prereqs= [ f"input:{context['Tracks']}" ], # in std context, this is InDetTrackParticles (see StandardJetContext)
117  algoBuilder = lambda jdef,_ : jrtcfg.getTrackSelAlg(jdef, trackSelOpt=False,
118  DecorDeps=["TTVA_AMVFWeights_forGNN", "TTVA_AMVFVertices_forGNN"] )
119  )
120 
121  stdInputExtDic["JetTrackUsedInFitDecoCustomVtxGNN"] = JetInputExternal("JetTrackUsedInFitDecoCustomVtxGNN", xAODType.TrackParticle,
122  prereqs= [ f"input:{context['Tracks']}" , # in std context, this is InDetTrackParticles (see StandardJetContext)
123  f"input:{context['Vertices']}"],
124  algoBuilder = getUsedInVertexFitTrackDecoratorAlgCustomVtxGNN
125  )
126 
127  stdInputExtDic["JetTrackVtxAssocCustomVtxGNN"] = JetInputExternal("JetTrackVtxAssocCustomVtxGNN", xAODType.TrackParticle,
128  algoBuilder = lambda jdef,_ : jrtcfg.getJetTrackVtxAlg(jdef._contextDic, algname="jetTVACustomVtxGNN",
129  WorkingPoint="Nonprompt_All_MaxWeight",
130  AMVFVerticesDeco='TTVA_AMVFVertices_forGNN',
131  AMVFWeightsDeco='TTVA_AMVFWeights_forGNN'),
132  prereqs = [ "input:JetTrackUsedInFitDecoCustomVtxGNN", f"input:{context['Vertices']}" ] )
133 
134  stdInputExtDic["EventDensityCustomVtxGNN"] = JetInputExternal("EventDensity", "EventShape", algoBuilder = buildEventShapeAlg,
135  containername = lambda jetdef, _ : "Kt4"+jetdef.inputdef.label+"EventShape",
136  prereqs = lambda jetdef : ["input:"+jetdef.inputdef.name] )
137 
138 
139 
140  from JetRecConfig.StandardJetMods import stdJetModifiers
141  from JetRecConfig.JetDefinition import JetModifier
142  from JetCalibTools import JetCalibToolsConfig
143 
144  stdJetModifiers.update(
145 
146  CalibCustomVtxGNN = JetModifier("JetCalibrationTool","jetcalib_jetcoll_calibseqCustomVtxGNN",
147  createfn=JetCalibToolsConfig.getJetCalibToolFromString,
148  prereqs=lambda mod,jetdef : JetCalibToolsConfig.getJetCalibToolPrereqs(mod,jetdef)+[f"input:{context['Vertices']}"]),
149 
150 
151  JVFCustomVtxGNN = JetModifier("JetVertexFractionTool", "jvfCustomVtxGNN",
152  createfn= lambda jdef,_ : JetMomentToolsConfig.getJVFTool(jdef,"CustomVtxGNN"),
153  modspec = "CustomVtxGNN",
154  prereqs = ["input:JetTrackVtxAssocCustomVtxGNN", "mod:TrackMomentsCustomVtxGNN", f"input:{context['Vertices']}"] ,
155  JetContainer = CustomJetContainerName),
156 
157  JVTCustomVtxGNN = JetModifier("JetVertexTaggerTool", "jvtCustomVtxGNN",
158  createfn= lambda jdef,_ : JetMomentToolsConfig.getJVTTool(jdef,"CustomVtxGNN"),
159  modspec = "CustomVtxGNN",
160  prereqs = [ "mod:JVFCustomVtxGNN" ],JetContainer = CustomJetContainerName),
161 
162  NNJVTCustomVtxGNN = JetModifier("JetVertexNNTagger", "nnjvtCustomVtxGNN",
163  createfn=lambda jdef,_ :JetMomentToolsConfig.getNNJvtTool(jdef,"CustomVtxGNN"),
164  prereqs = [ "mod:JVFCustomVtxGNN" ],JetContainer = CustomJetContainerName),
165 
166  OriginSetPVCustomVtxGNN = JetModifier("JetOriginCorrectionTool", "origin_setpvCustomVtxGNN",
167  modspec = "CustomVtxGNN",
168  prereqs = [ "mod:JVFCustomVtxGNN" ],JetContainer = CustomJetContainerName, OnlyAssignPV=True),
169 
170  TrackMomentsCustomVtxGNN = JetModifier("JetTrackMomentsTool", "trkmomsCustomVtxGNN",
171  createfn= lambda jdef,_ : JetMomentToolsConfig.getTrackMomentsTool(jdef,"CustomVtxGNN"),
172  modspec = "CustomVtxGNN",
173  prereqs = [ "input:JetTrackVtxAssocCustomVtxGNN","ghost:Track" ],JetContainer = CustomJetContainerName),
174 
175  TrackSumMomentsCustomVtxGNN = JetModifier("JetTrackSumMomentsTool", "trksummomsCustomVtxGNN",
176  createfn=lambda jdef,_ :JetMomentToolsConfig.getTrackSumMomentsTool(jdef,"CustomVtxGNN"),
177  modspec = "CustomVtxGNN",
178  prereqs = [ "input:JetTrackVtxAssocCustomVtxGNN","ghost:Track" ],JetContainer = CustomJetContainerName),
179 
180  ChargeCustomVtxGNN = JetModifier("JetChargeTool", "jetchargeCustomVtxGNN",
181  prereqs = [ "ghost:Track" ]),
182 
183 
184  QGTaggingCustomVtxGNN = JetModifier("JetQGTaggerVariableTool", "qgtaggingCustomVtxGNN",
185  createfn=lambda jdef,_ :JetMomentToolsConfig.getQGTaggingTool(jdef,"CustomVtxGNN"),
186  modspec = "CustomVtxGNN",
187  prereqs = lambda _,jdef :
188  ["input:JetTrackVtxAssocCustomVtxGNN","mod:TrackMomentsCustomVtxGNN"] +
189  (["mod:JetPtAssociation"] if not jdef._cflags.Input.isMC else []),
190  JetContainer = CustomJetContainerName),
191 
192  fJVTCustomVtxGNN = JetModifier("JetForwardPFlowJvtTool", "fJVTCustomVtxGNN",
193  createfn=lambda jdef,_ :JetMomentToolsConfig.getPFlowfJVTTool(jdef,"CustomVtxGNN"),
194  modspec = "CustomVtxGNN",
195  prereqs = ["input:JetTrackVtxAssocCustomVtxGNN","input:EventDensityCustomVtxGNN",f"input:{context['Vertices']}","mod:NNJVTCustomVtxGNN"],
196  JetContainer = CustomJetContainerName),
197  jetisoCustomVtxGNN = JetModifier("JetIsolationTool","isoCustomVtxGNN",
198  JetContainer=CustomJetContainerName,
199  InputConstitContainer = "EMOriginCustomVtxGNNTopoClusters",
200  IsolationCalculations = ["IsoFixedCone:5:Pt", "IsoFixedCone:5:PtPUsub",],
201  RhoKey = lambda jetdef, specs : "Kt4"+jetdef.inputdef.label+"CustomVtxGNNEventShape" ,
202  prereqs= ["input:EventDensityCustomVtxGNN"], #lambda spec,jetdef : ["input:Kt4"+jetdef.inputdef.label+"EventShape",],
203  ),
204 
205  )
206 
207 
208  from JetRecConfig.JetRecConfig import JetRecCfg
209 
210  acc.merge(JetRecCfg(flags,AntiKt4EMTopoCustomVtxGNN))
211 
212  return acc
213 
214 
215 # Main algorithm config
216 def GNNVertexCfg(flags, **kwargs):
217  acc = ComponentAccumulator()
218 
219  from SGComps.AddressRemappingConfig import InputRenameCfg
220  acc.merge(InputRenameCfg("xAOD::VertexContainer", "PrimaryVertices", "PrimaryVertices_initial"))
221  acc.merge(InputRenameCfg("xAOD::VertexAuxContainer", "PrimaryVerticesAux.", "PrimaryVertices_initialAux."))
222 
223  acc.merge(CustomJetsCfg(flags))
224 
225 
226  from InDetConfig.InDetGNNHardScatterSelectionConfig import GNNSequenceCfg
227  acc.merge(GNNSequenceCfg(flags))
228 
229  from TrkConfig.TrkVertexToolsConfig import GNNVertexCollectionSortingToolCfg
230  vxsort_gnn = GNNVertexCollectionSortingToolCfg(flags)
231  from InDetPriVxFinder.ResortVerticesConfig import ResortVerticesCfg
232  acc.merge(ResortVerticesCfg(flags, "PrimaryVertices_initial", "PrimaryVertices", vxsort_gnn))
233 
234  return acc
235 
python.JetAnalysisCommon.ComponentAccumulator
ComponentAccumulator
Definition: JetAnalysisCommon.py:302
JetCalibToolsConfig.getJetCalibToolPrereqs
def getJetCalibToolPrereqs(modspec, jetdef)
Definition: JetCalibToolsConfig.py:202
python.TrkVertexToolsConfig.GNNVertexCollectionSortingToolCfg
def GNNVertexCollectionSortingToolCfg(flags, name="GNNVertexCollectionSortingTool", **kwargs)
Definition: TrkVertexToolsConfig.py:53
python.StandardJetContext.inputsFromContext
def inputsFromContext(inputKey, prefix="", suffix="")
Definition: StandardJetContext.py:112
python.InDetGNNHardScatterSelectionConfig.GNNSequenceCfg
def GNNSequenceCfg(flags, **kwargs)
Definition: InDetGNNHardScatterSelectionConfig.py:148
python.JetRecConfig.JetRecCfg
def JetRecCfg(flags, jetdef, returnConfiguredDef=False)
Top level functions returning ComponentAccumulator out of JetDefinition.
Definition: JetRecConfig.py:36
histSizes.list
def list(name, path='/')
Definition: histSizes.py:38
python.GNNVertexConfig.addJetContextFlags
def addJetContextFlags(flags)
Definition: GNNVertexConfig.py:9
print
void print(char *figname, TCanvas *c1)
Definition: TRTCalib_StrawStatusPlots.cxx:25
UsedInVertexFitTrackDecoratorConfig.getUsedInVertexFitTrackDecoratorAlg
def getUsedInVertexFitTrackDecoratorAlg(trackCont="InDetTrackParticles", vtxCont="PrimaryVertices", vertexDeco="TTVA_AMVFVertices_forReco", weightDeco="TTVA_AMVFWeights_forReco")
Definition: UsedInVertexFitTrackDecoratorConfig.py:16
JetMomentToolsConfig.getNNJvtTool
def getNNJvtTool(jetdef, modspec)
Definition: JetMomentToolsConfig.py:117
python.GNNVertexConfig.GNNVertexCfg
def GNNVertexCfg(flags, **kwargs)
Definition: GNNVertexConfig.py:216
JetMomentToolsConfig.getPFlowfJVTTool
def getPFlowfJVTTool(jetdef, modspec)
Definition: JetMomentToolsConfig.py:195
python.GNNVertexConfig.CustomJetsCfg
def CustomJetsCfg(flags)
Definition: GNNVertexConfig.py:24
ResortVerticesConfig.ResortVerticesCfg
def ResortVerticesCfg(flags, vxin, vxout, vxsortercfg, algname="ResortVx")
Definition: ResortVerticesConfig.py:8
JetMomentToolsConfig.getJVFTool
def getJVFTool(jetdef, modspec)
Definition: JetMomentToolsConfig.py:90
JetMomentToolsConfig.getTrackMomentsTool
def getTrackMomentsTool(jetdef, modspec)
Definition: JetMomentToolsConfig.py:126
JetMomentToolsConfig.getJVTTool
def getJVTTool(jetdef, modspec)
Definition: JetMomentToolsConfig.py:107
JetMomentToolsConfig.getQGTaggingTool
def getQGTaggingTool(jetdef, modspec)
Definition: JetMomentToolsConfig.py:182
JetMomentToolsConfig.getTrackSumMomentsTool
def getTrackSumMomentsTool(jetdef, modspec)
Definition: JetMomentToolsConfig.py:142
AddressRemappingConfig.InputRenameCfg
def InputRenameCfg(type, from_name, to_name)
Definition: AddressRemappingConfig.py:28