32def GNNVertexFitterAlgCfg(flags, jetcol="AntiKt4EMPFlowJets", inclusive=False, **kwargs):
33 acc = ComponentAccumulator()
34
35 trackCollection = 'InDetTrackParticles'
36 JetTrackAssociator = "TracksForBTagging"
37
38
39 acc.merge(BTagTrackAugmenterAlgCfg(flags))
40
41 acc.merge(JetParticleAssociationAlgCfg(
42 flags,
43 JetCollection=jetcol,
44 InputParticleCollection=trackCollection,
45 OutputParticleDecoration=JetTrackAssociator,
46 ))
47
48
49 for networks in flags.BTagging.NNs.get(jetcol, []):
50 assert len(networks['folds']) > 1
51
52 nnFilePaths=networks['folds']
53 common = commonpath(nnFilePaths)
54 nn_name = '_'.join(PurePath(common).with_suffix('').parts)
55 algname = f'{nn_name}_Jet'
56
57 remapping=networks.get('remapping', {})
58 remapping['BTagTrackToJetAssociator'] = JetTrackAssociator
59
60 args = dict(
61 flags=flags,
62 BTaggingCollection=None,
63 TrackCollection=trackCollection,
64 nnFilePaths=nnFilePaths,
65 remapping=remapping,
66 JetCollection=jetcol,
67 )
68
69 acc.merge(MultifoldGNNCfg(**args))
70
71 tool = acc.popToolsAndMerge(
72 GNNVertexFitterToolCfg(
73 flags,
74 name=f'{algname}_VertexFitterTool',
75 GNNModel=algname,
76 doInclusiveVertexing=inclusive,
77 removeNonHFVertices=inclusive,
78 )
79 )
80
81 name = f'{algname}_VertexFitterAlg{"Incl" if inclusive else ""}'
82 outcol = f'{"Inclusive" if inclusive else ""}GNNVertices'
83
84 acc.addEventAlgo(
85 CompFactory.Rec.GNNVertexFitterAlg(
86 name = name,
87 VtxTool=tool,
88 inputJetContainer=jetcol,
89 outputVertexContainer=outcol,
90 **kwargs
91 )
92 )
93
94 return acc