ATLAS Offline Software
GNN.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 
4  This class is used in conjunction with OnnxUtil to run inference on a GNN model.
5  Whereas OnnxUtil handles the interfacing with the ONNX runtime, this class handles
6  the interfacing with the ATLAS EDM. It is responsible for collecting all the inputs
7  needed for inference, running inference (via OnnxUtil), and decorating the results
8  back to ATLAS EDM.
9 */
10 
11 #ifndef GNN_H
12 #define GNN_H
13 
14 // Tool includes
19 
23 
24 // EDM includes
26 #include "xAODJet/JetContainer.h"
27 
28 #include <memory>
29 #include <string>
30 #include <map>
31 
32 namespace FlavorTagDiscriminants {
33 
34  struct GNNOptions;
35  class OnnxUtil;
36  //
37  // Tool to to flavor tag jet/btagging object
38  // using GNN based taggers
39  class GNN
40  {
41  public:
42  // recommended constructor, file path + options
43  GNN(const std::string& nnFile, const GNNOptions& opts);
44  // redefined options constructor, will share underlying network
45  GNN(const GNN&, const GNNOptions& opts);
46  // legacy constructor
47  GNN(const std::string& nnFile,
48  const FlipTagConfig& flip_config = FlipTagConfig::STANDARD,
49  const std::map<std::string, std::string>& variableRemapping = {},
50  const TrackLinkType trackLinkType = TrackLinkType::TRACK_PARTICLE,
51  float defaultOutputValue = NAN);
52  GNN(GNN&&);
53  GNN(const GNN&);
54  virtual ~GNN();
55 
56  virtual void decorate(const xAOD::BTagging& btag) const;
57  virtual void decorate(const xAOD::Jet& jet) const;
58  virtual void decorateWithDefaults(const SG::AuxElement& jet) const;
59  void decorate(const xAOD::Jet& jet, const SG::AuxElement& decorated) const;
60 
61  virtual std::set<std::string> getDecoratorKeys() const;
62  virtual std::set<std::string> getAuxInputKeys() const;
63  virtual std::set<std::string> getConstituentAuxInputKeys() const;
64 
65  std::shared_ptr<const OnnxUtil> m_onnxUtil;
66  private:
67  // private constructor, delegate of the above public ones
68  GNN(std::shared_ptr<const OnnxUtil>, const GNNOptions& opts);
69  // type definitions for ONNX output decorators
71  using TrackLinks = std::vector<ElementLink<TPC>>;
72 
73  template<typename T>
75 
76  template<typename T>
77  using Decs = std::vector<std::pair<std::string, Dec<T>>>;
78 
79  struct Decorators {
86  };
87 
88  /* create all decorators */
89  std::tuple<FTagDataDependencyNames, std::set<std::string>>
90  createDecorators(const OnnxUtil::OutputConfig& outConfig, const FTagOptions& options);
91 
93  std::string m_input_node_name;
94  std::vector<internal::VarFromBTag> m_varsFromBTag;
95  std::vector<internal::VarFromJet> m_varsFromJet;
96  std::vector<std::shared_ptr<IConstituentsLoader>> m_constituentsLoaders;
97 
101  };
102 } // end namespace FlavorTagDiscriminants
103 #endif //GNN_H
FlavorTagDiscriminants::GNN::TrackLinks
std::vector< ElementLink< TPC > > TrackLinks
Definition: GNN.h:71
FlavorTagDiscriminants::GNN::GNN
GNN(const std::string &nnFile, const GNNOptions &opts)
Definition: GNN.cxx:28
FlavorTagDiscriminants::GNN::Decorators::jetFloat
Decs< float > jetFloat
Definition: GNN.h:80
FlavorTagDiscriminants::GNN::m_constituentsLoaders
std::vector< std::shared_ptr< IConstituentsLoader > > m_constituentsLoaders
Definition: GNN.h:96
FlavorTagDiscriminants::FlipTagConfig::STANDARD
@ STANDARD
FlavorTagDiscriminants
This file contains "getter" functions used for accessing tagger inputs from the EDM.
Definition: AssociationEnums.h:11
FlavorTagDiscriminants::GNNOptions
Definition: GNNOptions.h:16
FlavorTagDiscriminants::GNN::m_onnxUtil
std::shared_ptr< const OnnxUtil > m_onnxUtil
Definition: GNN.h:65
SG::AuxElement
Base class for elements of a container that can have aux data.
Definition: AuxElement.h:446
FlavorTagDiscriminants::GNN::m_varsFromJet
std::vector< internal::VarFromJet > m_varsFromJet
Definition: GNN.h:95
xAOD::TrackParticleContainer
TrackParticleContainer_v1 TrackParticleContainer
Definition of the current "TrackParticle container version".
Definition: Event/xAOD/xAODTracking/xAODTracking/TrackParticleContainer.h:14
FlavorTagDiscriminants::GNN::Decorators::trackChar
Decs< char > trackChar
Definition: GNN.h:84
SG::ConstAccessor
Helper class to provide constant type-safe access to aux data.
Definition: ConstAccessor.h:54
FlipTagEnums.h
FlavorTagDiscriminants::GNN::m_dataDependencyNames
FTagDataDependencyNames m_dataDependencyNames
Definition: GNN.h:100
GNNOptions.h
FlavorTagDiscriminants::GNN::getDecoratorKeys
virtual std::set< std::string > getDecoratorKeys() const
Definition: GNN.cxx:220
FlavorTagDiscriminants::GNN::m_varsFromBTag
std::vector< internal::VarFromBTag > m_varsFromBTag
Definition: GNN.h:94
FlavorTagDiscriminants::GNN::getConstituentAuxInputKeys
virtual std::set< std::string > getConstituentAuxInputKeys() const
Definition: GNN.cxx:226
jet
Definition: JetCalibTools_PlotJESFactors.cxx:23
FlavorTagDiscriminants::FTagDataDependencyNames
Definition: FTagDataDependencyNames.h:12
OnnxUtil
Definition: JetTagPerformanceCalibration/xAODBTaggingEfficiency/xAODBTaggingEfficiency/OnnxUtil.h:14
DataPrepUtilities.h
SG::Decorator
Helper class to provide type-safe access to aux data.
Definition: Decorator.h:58
FlavorTagDiscriminants::GNN::GNN
GNN(GNN &&)
FlavorTagDiscriminants::GNN::getAuxInputKeys
virtual std::set< std::string > getAuxInputKeys() const
Definition: GNN.cxx:223
FlavorTagDiscriminants::TrackLinkType
TrackLinkType
Definition: AssociationEnums.h:12
FlavorTagDiscriminants::GNN::Decorators::jetVecFloat
Decs< std::vector< float > > jetVecFloat
Definition: GNN.h:82
python.AtlRunQueryLib.options
options
Definition: AtlRunQueryLib.py:379
FlavorTagDiscriminants::GNN::m_input_node_name
std::string m_input_node_name
Definition: GNN.h:93
xAOD::BTagging_v1
Definition: BTagging_v1.h:39
DataVector< xAOD::TrackParticle_v1 >
TracksLoader.h
FlavorTagDiscriminants::GNN::Decorators::jetTrackLinks
Decs< TrackLinks > jetTrackLinks
Definition: GNN.h:83
FTagDataDependencyNames.h
FlavorTagDiscriminants::GNN::Decs
std::vector< std::pair< std::string, Dec< T > >> Decs
Definition: GNN.h:77
FlavorTagDiscriminants::FTagOptions
Definition: DataPrepUtilities.h:45
xAOD::Jet_v1
Class describing a jet.
Definition: Jet_v1.h:57
FlavorTagDiscriminants::GNN::m_jetLink
SG::AuxElement::ConstAccessor< ElementLink< xAOD::JetContainer > > m_jetLink
Definition: GNN.h:92
FlavorTagDiscriminants::GNN
Definition: GNN.h:40
FlavorTagDiscriminants::GNN::decorateWithDefaults
virtual void decorateWithDefaults(const SG::AuxElement &jet) const
Definition: GNN.cxx:109
JetContainer.h
FlavorTagDiscriminants::GNN::GNN
GNN(const GNN &)
FlavorTagDiscriminants::GNN::Decorators::trackFloat
Decs< float > trackFloat
Definition: GNN.h:85
FlavorTagDiscriminants::GNN::~GNN
virtual ~GNN()
FlavorTagDiscriminants::GNN::m_defaultValue
float m_defaultValue
Definition: GNN.h:99
FlavorTagDiscriminants::FlipTagConfig
FlipTagConfig
Definition: FlipTagEnums.h:14
athena.opts
opts
Definition: athena.py:86
FlavorTagDiscriminants::GNN::m_decorators
Decorators m_decorators
Definition: GNN.h:98
IParticlesLoader.h
FlavorTagDiscriminants::GNN::Decorators
Definition: GNN.h:79
AssociationEnums.h
FlavorTagDiscriminants::GNN::Decorators::jetVecChar
Decs< std::vector< char > > jetVecChar
Definition: GNN.h:81
BTaggingFwd.h
FlavorTagDiscriminants::GNN::createDecorators
std::tuple< FTagDataDependencyNames, std::set< std::string > > createDecorators(const OnnxUtil::OutputConfig &outConfig, const FTagOptions &options)
Definition: GNN.cxx:231
FlavorTagDiscriminants::TrackLinkType::TRACK_PARTICLE
@ TRACK_PARTICLE
FlavorTagDiscriminants::GNN::decorate
virtual void decorate(const xAOD::BTagging &btag) const
Definition: GNN.cxx:94