ATLAS Offline Software
MultifoldGNN.cxx
Go to the documentation of this file.
1 /*
2 + Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
7 
9 #include "xAODJet/JetContainer.h"
10 
11 using namespace FlavorTagInference;
12 
13 namespace {
14  const std::string jetLinkName = "jetLink";
15  template<typename T, typename C>
16  std::set<std::string> merged(T get, const C& c) {
17  auto first = get(*c.at(0));
18  for (size_t idx = 1; idx < c.size(); idx++) {
19  if (get(*c.at(idx)) != first) {
20  throw std::runtime_error("inconsistent dependencies in folds");
21  }
22  }
23  return first;
24  }
25  auto getNNs(
26  const std::vector<std::string>& nn_files,
27  const GNNOptions& o)
28  {
29  std::vector<std::shared_ptr<const GNN>> nns;
30  for (const auto& nn_file: nn_files) {
31  nns.emplace_back(std::make_shared<const GNN>(nn_file, o));
32  }
33  return nns;
34  }
35 }
36 
37 namespace FlavorTagInference {
38 
40  const std::vector<std::string>& nn_files,
41  const std::string& fold_hash_name,
42  const GNNOptions& o):
43  MultifoldGNN(getNNs(nn_files, o), fold_hash_name)
44  {
45  }
47  const std::vector<std::shared_ptr<const GNN>>& nns,
48  const std::string& fold_hash_name):
49  m_folds(nns),
50  m_fold_hash(fold_hash_name),
51  m_jetLink(jetLinkName)
52  {
53  }
55  MultifoldGNN::MultifoldGNN(const MultifoldGNN&) = default;
56  MultifoldGNN::~MultifoldGNN() = default;
57 
58  void MultifoldGNN::decorate(const xAOD::IParticle& i_jet) const {
59  getFold(i_jet).decorate(i_jet);
60  }
62  getFold(i_jet).decorateWithDefaults(i_jet);
63  }
64 
65  // Dependencies
66  std::set<std::string> MultifoldGNN::getDecoratorKeys() const {
67  return merged([](const auto& f){ return f.getDecoratorKeys(); }, m_folds);
68  }
69  std::set<std::string> MultifoldGNN::getAuxInputKeys() const {
70  auto out = merged([](const auto& f){ return f.getAuxInputKeys(); }, m_folds);
72  return out;
73  }
74  std::set<std::string> MultifoldGNN::getConstituentAuxInputKeys() const {
75  return merged([](const auto& f){ return f.getConstituentAuxInputKeys(); }, m_folds);
76  }
77 
78  const GNN& MultifoldGNN::getFold(const SG::AuxElement& element) const {
79  return *m_folds.at(m_fold_hash(element) % m_folds.size());
80  }
81 
82 
83 }
FlavorTagInference::GNN::decorateWithDefaults
virtual void decorateWithDefaults(const xAOD::IParticle &jet) const
Definition: PhysicsAnalysis/JetTagging/FlavorTagInference/Root/GNN.cxx:98
GNN.h
FlavorTagInference::MultifoldGNN
Definition: MultifoldGNN.h:25
FlavorTagInference
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: AssociationEnums.h:11
FlavorTagInference::MultifoldGNN::decorate
void decorate(const xAOD::IParticle &i_jet) const
Definition: MultifoldGNN.cxx:58
BTagging.h
SG::AuxTypeRegistry::instance
static AuxTypeRegistry & instance()
Return the singleton registry instance.
Definition: AuxTypeRegistry.cxx:639
FlavorTagInference::MultifoldGNN::decorateWithDefaults
void decorateWithDefaults(const xAOD::IParticle &i_jet) const
Definition: MultifoldGNN.cxx:61
SG::AuxElement
Base class for elements of a container that can have aux data.
Definition: AuxElement.h:483
FlavorTagInference::MultifoldGNN::getDecoratorKeys
std::set< std::string > getDecoratorKeys() const
Definition: MultifoldGNN.cxx:66
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:70
xAOD::IParticle
Class providing the definition of the 4-vector interface.
Definition: Event/xAOD/xAODBase/xAODBase/IParticle.h:41
BTaggingConfigFlags.getNNs
def getNNs(flags)
Definition: BTaggingConfigFlags.py:78
dumpTruth.getName
getName
Definition: dumpTruth.py:34
FlavorTagInference::MultifoldGNN::getConstituentAuxInputKeys
std::set< std::string > getConstituentAuxInputKeys() const
Definition: MultifoldGNN.cxx:74
FlavorTagInference::MultifoldGNN::~MultifoldGNN
~MultifoldGNN()
hist_file_dump.f
f
Definition: hist_file_dump.py:140
FlavorTagInference::MultifoldGNN::MultifoldGNN
MultifoldGNN(const std::vector< std::string > &folds, const std::string &fold_hash_name, const FlavorTagInference::GNNOptions &opts)
Definition: MultifoldGNN.cxx:39
FlavorTagInference::GNNOptions
Definition: GNNOptions.h:16
MultifoldGNN.h
FlavorTagInference::GNN::decorate
virtual void decorate(const xAOD::IParticle &i_jet) const
Definition: PhysicsAnalysis/JetTagging/FlavorTagInference/Root/GNN.cxx:117
JetContainer.h
DeMoScan.first
bool first
Definition: DeMoScan.py:534
get
T * get(TKey *tobj)
get a TObject* from a TKey* (why can't a TObject be a TKey?)
Definition: hcg.cxx:127
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
FlavorTagInference::MultifoldGNN::getAuxInputKeys
std::set< std::string > getAuxInputKeys() const
Definition: MultifoldGNN.cxx:69
FlavorTagInference::MultifoldGNN::getFold
const FlavorTagInference::GNN & getFold(const SG::AuxElement &element) const
Definition: MultifoldGNN.cxx:78
python.compressB64.c
def c
Definition: compressB64.py:93
FlavorTagInference::GNN
Definition: PhysicsAnalysis/JetTagging/FlavorTagInference/FlavorTagInference/GNN.h:40
FlavorTagInference::MultifoldGNN::m_folds
std::vector< std::shared_ptr< const FlavorTagInference::GNN > > m_folds
Definition: MultifoldGNN.h:43
FlavorTagInference::MultifoldGNN::m_fold_hash
SG::AuxElement::ConstAccessor< uint32_t > m_fold_hash
Definition: MultifoldGNN.h:44