ATLAS Offline Software
Loading...
Searching...
No Matches
GNNVertexFitterConfig.py
Go to the documentation of this file.
1# Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
2
3from AthenaConfiguration.ComponentAccumulator import ComponentAccumulator
4from AthenaConfiguration.ComponentFactory import CompFactory
5from TrkConfig.TrkVKalVrtFitterConfig import TrkVKalVrtFitterCfg
6from BeamSpotConditions.BeamSpotConditionsConfig import BeamSpotCondAlgCfg
7from FlavorTagDiscriminants.FlavorTagNNConfig import MultifoldGNNCfg
8from os.path import commonpath
9from pathlib import PurePath
10from ParticleJetTools.JetParticleAssociationAlgConfig import (
11 JetParticleAssociationAlgCfg,
12)
13from BTagging.BTagTrackAugmenterAlgConfig import BTagTrackAugmenterAlgCfg
14
15def GNNVertexFitterToolCfg(flags, name="GNNVertexFitterTool", **kwargs):
16 acc = ComponentAccumulator()
17
18 acc.merge(BeamSpotCondAlgCfg(flags))
19 kwargs.setdefault("VertexFitterTool", acc.popToolsAndMerge(TrkVKalVrtFitterCfg(flags)))
20 kwargs.setdefault("GNNModel", "GN2v01")
21 kwargs.setdefault("JetCollection", "AntiKt4EMPFlowJets")
22 kwargs.setdefault("includePrimaryVertex", False)
23 kwargs.setdefault("removeNonHFVertices", False)
24 kwargs.setdefault("doInclusiveVertexing", False)
25 kwargs.setdefault("maxChi2", 20)
26 kwargs.setdefault("HFRatio", 0)
27 kwargs.setdefault("applyCuts", False)
28 acc.setPrivateTools(CompFactory.Rec.GNNVertexFitterTool(**kwargs))
29
30 return acc
31
32def GNNVertexFitterAlgCfg(flags, jetcol="AntiKt4EMPFlowJets", inclusive=False, **kwargs):
33 acc = ComponentAccumulator()
34
35 trackCollection = 'InDetTrackParticles'
36 JetTrackAssociator = "TracksForBTagging"
37
38 #Track Augmenter
39 acc.merge(BTagTrackAugmenterAlgCfg(flags))
40
41 acc.merge(JetParticleAssociationAlgCfg(
42 flags,
43 JetCollection=jetcol,
44 InputParticleCollection=trackCollection,
45 OutputParticleDecoration=JetTrackAssociator,
46 ))
47
48 # decorate b-tagging directly to the jets
49 for networks in flags.BTagging.NNs.get(jetcol, []):
50 assert len(networks['folds']) > 1
51
52 nnFilePaths=networks['folds']
53 common = commonpath(nnFilePaths)
54 nn_name = '_'.join(PurePath(common).with_suffix('').parts)
55 algname = f'{nn_name}_Jet'
56
57 remapping=networks.get('remapping', {})
58 remapping['BTagTrackToJetAssociator'] = JetTrackAssociator
59
60 args = dict(
61 flags=flags,
62 BTaggingCollection=None,
63 TrackCollection=trackCollection,
64 nnFilePaths=nnFilePaths,
65 remapping=remapping,
66 JetCollection=jetcol,
67 )
68
69 acc.merge(MultifoldGNNCfg(**args))
70
71 tool = acc.popToolsAndMerge(
73 flags,
74 name=f'{algname}_VertexFitterTool',
75 GNNModel=algname,
76 doInclusiveVertexing=inclusive,
77 removeNonHFVertices=inclusive, # if running inclusive vertexing, remove vertices with no HF tracks
78 )
79 )
80
81 name = f'{algname}_VertexFitterAlg{"Incl" if inclusive else ""}'
82 outcol = f'{"Inclusive" if inclusive else ""}GNNVertices'
83
84 acc.addEventAlgo(
85 CompFactory.Rec.GNNVertexFitterAlg(
86 name = name,
87 VtxTool=tool,
88 inputJetContainer=jetcol,
89 outputVertexContainer=outcol,
90 **kwargs
91 )
92 )
93
94 return acc
GNNVertexFitterToolCfg(flags, name="GNNVertexFitterTool", **kwargs)
GNNVertexFitterAlgCfg(flags, jetcol="AntiKt4EMPFlowJets", inclusive=False, **kwargs)