ATLAS Offline Software
Loading...
Searching...
No Matches
MultifoldGNN.cxx
Go to the documentation of this file.
1/*
2+ Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
7
9
10using namespace FlavorTagInference;
11
12namespace {
13 const std::string jetLinkName = "jetLink";
14 auto getNNs(
15 const std::vector<std::string>& nn_files,
16 const GNNOptions& o)
17 {
18 std::vector<std::shared_ptr<const GNN>> nns;
19 for (const auto& nn_file: nn_files) {
20 nns.emplace_back(std::make_shared<const GNN>(nn_file, o));
21 }
22 return nns;
23 }
24}
25
26namespace FlavorTagInference {
27
29 const std::vector<std::string>& nn_files,
30 const std::string& fold_hash_name,
31 const GNNOptions& o):
32 MultifoldGNN(getNNs(nn_files, o), fold_hash_name)
33 {
34 }
36 const std::vector<std::shared_ptr<const GNN>>& nns,
37 const std::string& fold_hash_name):
38 m_folds(nns),
39 m_fold_hash(fold_hash_name),
40 m_jetLink(jetLinkName)
41 {
42 }
46
47 void MultifoldGNN::decorate(const xAOD::IParticle& i_jet) const {
48 getFold(i_jet).decorate(i_jet);
49 }
51 getFold(i_jet).decorateWithDefaults(i_jet);
52 }
53
54 // Dependencies
56 auto first = m_folds.at(0)->getDependencies();
57 for (size_t idx = 1; idx< m_folds.size(); idx++) {
58 if (m_folds.at(idx)->getDependencies() != first) {
59 throw std::runtime_error("inconsistent dependencies in folds");
60 }
61 }
62 // this algorithm also depends on the jet fold hash, make sure
63 // it's declared.
64 first.bTagInputs.insert(
66 return first;
67 }
68
69 const GNN& MultifoldGNN::getFold(const SG::AuxElement& element) const {
70 return *m_folds.at(m_fold_hash(element) % m_folds.size());
71 }
72
73
74}
virtual void decorateWithDefaults(const xAOD::IParticle &jet) const
virtual void decorate(const xAOD::IParticle &i_jet) const
SG::AuxElement::ConstAccessor< uint32_t > m_fold_hash
SG::AuxElement::ConstAccessor< ElementLink< xAOD::JetContainer > > m_jetLink
void decorate(const xAOD::IParticle &i_jet) const
FTagDataDependencyNames getDependencies() const
MultifoldGNN(const std::vector< std::string > &folds, const std::string &fold_hash_name, const FlavorTagInference::GNNOptions &opts)
void decorateWithDefaults(const xAOD::IParticle &i_jet) const
const FlavorTagInference::GNN & getFold(const SG::AuxElement &element) const
std::vector< std::shared_ptr< const FlavorTagInference::GNN > > m_folds
Base class for elements of a container that can have aux data.
Definition AuxElement.h:483
static AuxTypeRegistry & instance()
Return the singleton registry instance.
Class providing the definition of the 4-vector interface.
This file contains "getter" functions used for accessing tagger inputs from the EDM.