ATLAS Offline Software
MultifoldGNN.cxx
Go to the documentation of this file.
1 /*
2 + Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
7 
8 #include "xAODJet/JetContainer.h"
9 
10 using namespace FlavorTagInference;
11 
12 namespace {
13  const std::string jetLinkName = "jetLink";
14  template<typename T, typename C>
15  std::set<std::string> merged(T get, const C& c) {
16  auto first = get(*c.at(0));
17  for (size_t idx = 1; idx < c.size(); idx++) {
18  if (get(*c.at(idx)) != first) {
19  throw std::runtime_error("inconsistent dependencies in folds");
20  }
21  }
22  return first;
23  }
24  auto getNNs(
25  const std::vector<std::string>& nn_files,
26  const GNNOptions& o)
27  {
28  std::vector<std::shared_ptr<const GNN>> nns;
29  for (const auto& nn_file: nn_files) {
30  nns.emplace_back(std::make_shared<const GNN>(nn_file, o));
31  }
32  return nns;
33  }
34 }
35 
36 namespace FlavorTagInference {
37 
39  const std::vector<std::string>& nn_files,
40  const std::string& fold_hash_name,
41  const GNNOptions& o):
42  MultifoldGNN(getNNs(nn_files, o), fold_hash_name)
43  {
44  }
46  const std::vector<std::shared_ptr<const GNN>>& nns,
47  const std::string& fold_hash_name):
48  m_folds(nns),
49  m_fold_hash(fold_hash_name),
50  m_jetLink(jetLinkName)
51  {
52  }
54  MultifoldGNN::MultifoldGNN(const MultifoldGNN&) = default;
55  MultifoldGNN::~MultifoldGNN() = default;
56 
57  void MultifoldGNN::decorate(const xAOD::IParticle& i_jet) const {
58  getFold(i_jet).decorate(i_jet);
59  }
61  getFold(i_jet).decorateWithDefaults(i_jet);
62  }
63 
64  // Dependencies
65  std::set<std::string> MultifoldGNN::getDecoratorKeys() const {
66  return merged([](const auto& f){ return f.getDecoratorKeys(); }, m_folds);
67  }
68  std::set<std::string> MultifoldGNN::getAuxInputKeys() const {
69  auto out = merged([](const auto& f){ return f.getAuxInputKeys(); }, m_folds);
71  return out;
72  }
73  std::set<std::string> MultifoldGNN::getConstituentAuxInputKeys() const {
74  return merged([](const auto& f){ return f.getConstituentAuxInputKeys(); }, m_folds);
75  }
76 
77  const GNN& MultifoldGNN::getFold(const SG::AuxElement& element) const {
78  return *m_folds.at(m_fold_hash(element) % m_folds.size());
79  }
80 
81 
82 }
FlavorTagInference::GNN::decorateWithDefaults
virtual void decorateWithDefaults(const xAOD::IParticle &jet) const
Definition: PhysicsAnalysis/JetTagging/FlavorTagInference/Root/GNN.cxx:98
GNN.h
FlavorTagInference::MultifoldGNN
Definition: MultifoldGNN.h:25
FlavorTagInference
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: PhysicsAnalysis/JetTagging/FlavorTagInference/FlavorTagInference/ConstituentsLoader.h:27
FlavorTagInference::MultifoldGNN::decorate
void decorate(const xAOD::IParticle &i_jet) const
Definition: MultifoldGNN.cxx:57
SG::AuxTypeRegistry::instance
static AuxTypeRegistry & instance()
Return the singleton registry instance.
Definition: AuxTypeRegistry.cxx:639
FlavorTagInference::MultifoldGNN::decorateWithDefaults
void decorateWithDefaults(const xAOD::IParticle &i_jet) const
Definition: MultifoldGNN.cxx:60
SG::AuxElement
Base class for elements of a container that can have aux data.
Definition: AuxElement.h:483
FlavorTagInference::MultifoldGNN::getDecoratorKeys
std::set< std::string > getDecoratorKeys() const
Definition: MultifoldGNN.cxx:65
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:70
xAOD::IParticle
Class providing the definition of the 4-vector interface.
Definition: Event/xAOD/xAODBase/xAODBase/IParticle.h:41
BTaggingConfigFlags.getNNs
def getNNs(flags)
Definition: BTaggingConfigFlags.py:78
dumpTruth.getName
getName
Definition: dumpTruth.py:34
FlavorTagInference::MultifoldGNN::getConstituentAuxInputKeys
std::set< std::string > getConstituentAuxInputKeys() const
Definition: MultifoldGNN.cxx:73
FlavorTagInference::MultifoldGNN::~MultifoldGNN
~MultifoldGNN()
hist_file_dump.f
f
Definition: hist_file_dump.py:140
FlavorTagInference::MultifoldGNN::MultifoldGNN
MultifoldGNN(const std::vector< std::string > &folds, const std::string &fold_hash_name, const FlavorTagInference::GNNOptions &opts)
Definition: MultifoldGNN.cxx:38
FlavorTagInference::GNNOptions
Definition: GNNOptions.h:15
MultifoldGNN.h
FlavorTagInference::GNN::decorate
virtual void decorate(const xAOD::IParticle &i_jet) const
Definition: PhysicsAnalysis/JetTagging/FlavorTagInference/Root/GNN.cxx:117
JetContainer.h
DeMoScan.first
bool first
Definition: DeMoScan.py:534
get
T * get(TKey *tobj)
get a TObject* from a TKey* (why can't a TObject be a TKey?)
Definition: hcg.cxx:127
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
FlavorTagInference::MultifoldGNN::getAuxInputKeys
std::set< std::string > getAuxInputKeys() const
Definition: MultifoldGNN.cxx:68
FlavorTagInference::MultifoldGNN::getFold
const FlavorTagInference::GNN & getFold(const SG::AuxElement &element) const
Definition: MultifoldGNN.cxx:77
python.compressB64.c
def c
Definition: compressB64.py:93
FlavorTagInference::GNN
Definition: PhysicsAnalysis/JetTagging/FlavorTagInference/FlavorTagInference/GNN.h:38
FlavorTagInference::MultifoldGNN::m_folds
std::vector< std::shared_ptr< const FlavorTagInference::GNN > > m_folds
Definition: MultifoldGNN.h:43
FlavorTagInference::MultifoldGNN::m_fold_hash
SG::AuxElement::ConstAccessor< uint32_t > m_fold_hash
Definition: MultifoldGNN.h:44