ATLAS Offline Software
Classes | Public Member Functions | Public Attributes | Private Types | Private Member Functions | Private Attributes | List of all members
InDetGNNHardScatterSelection::GNN Class Reference

#include <GNN.h>

Collaboration diagram for InDetGNNHardScatterSelection::GNN:

Classes

struct  Decorators
 

Public Member Functions

 GNN (const std::string &nnFile)
 
 GNN (GNN &&)
 
 GNN (const GNN &)
 
virtual ~GNN ()
 
virtual void decorate (const xAOD::Vertex &verrtex) const
 

Public Attributes

std::shared_ptr< const FlavorTagInference::SaltModel > m_saltModel
 

Private Types

using TPC = xAOD::TrackParticleContainer
 
using TrackLinks = std::vector< ElementLink< TPC > >
 
template<typename T >
using Dec = SG::AuxElement::Decorator< T >
 
template<typename T >
using Decs = std::vector< std::pair< std::string, Dec< T > >>
 

Private Member Functions

std::set< std::string > createDecorators (const FlavorTagInference::SaltModel::OutputConfig &outConfig)
 

Private Attributes

std::string m_input_node_name
 
std::vector< internal::VarFromVertexm_varsFromVertex
 
std::vector< std::shared_ptr< IConstituentsLoader > > m_constituentsLoaders
 
Decorators m_decorators
 
float m_defaultValue {}
 

Detailed Description

Implementation of the GNN used by the InDetGNNHardScatterSelection::GNNTool

NOTE: The GNN relies on decorations added in InDetGNNHardScatterSelection::VertexDecoratorAlg and as such should only be called from within that algorithm

Author
Jackson Burzynski jacks.nosp@m.on.c.nosp@m.arl.b.nosp@m.urzy.nosp@m.nski@.nosp@m.cern.nosp@m..ch

Definition at line 42 of file InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/GNN.h.

Member Typedef Documentation

◆ Dec

template<typename T >
using InDetGNNHardScatterSelection::GNN::Dec = SG::AuxElement::Decorator<T>
private

◆ Decs

template<typename T >
using InDetGNNHardScatterSelection::GNN::Decs = std::vector<std::pair<std::string, Dec<T> >>
private

◆ TPC

◆ TrackLinks

Constructor & Destructor Documentation

◆ GNN() [1/3]

InDetGNNHardScatterSelection::GNN::GNN ( const std::string &  nnFile)

Definition at line 22 of file InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/GNN.cxx.

22  :
23  m_saltModel(nullptr)
24  {
25 
26  // Load and initialize the neural network model from the given file path.
27  std::string fullPathToOnnxFile = PathResolverFindCalibFile(nn_file);
28  m_saltModel = std::make_shared<FlavorTagInference::SaltModel>(fullPathToOnnxFile);
29 
30  // Extract metadata from the ONNX file, primarily about the model's inputs.
31  auto graph_config = m_saltModel->getGraphConfig();
32 
33  // Create configuration objects for data preprocessing.
34  auto [inputs, constituents_configs] = dataprep::createGetterConfig(graph_config);
35 
36  for (const auto& config : constituents_configs){
37  switch (config.type){
39  m_constituentsLoaders.push_back(std::make_shared<TracksLoader>(config));
40  break;
42  m_constituentsLoaders.push_back(std::make_shared<ElectronsLoader>(config));
43  break;
45  m_constituentsLoaders.push_back(std::make_shared<MuonsLoader>(config));
46  break;
48  m_constituentsLoaders.push_back(std::make_shared<JetsLoader>(config));
49  break;
51  m_constituentsLoaders.push_back(std::make_shared<PhotonsLoader>(config));
52  break;
54  m_constituentsLoaders.push_back(std::make_shared<IParticlesLoader>(config));
55  break;
56  }
57  }
58 
60 
61  // Retrieve the configuration for the model outputs.
62  FlavorTagInference::SaltModel::OutputConfig gnn_output_config = m_saltModel->getOutputConfig();
63 
64  for (const auto& outNode : gnn_output_config) {
65  // the node's output name will be used to define the decoration name
66  std::string dec_name = outNode.name;
67  m_decorators.vertexFloat.emplace_back(outNode.name, Dec<float>(dec_name));
68  }
69  }

◆ GNN() [2/3]

InDetGNNHardScatterSelection::GNN::GNN ( GNN &&  )
default

◆ GNN() [3/3]

InDetGNNHardScatterSelection::GNN::GNN ( const GNN )
default

◆ ~GNN()

InDetGNNHardScatterSelection::GNN::~GNN ( )
virtualdefault

Member Function Documentation

◆ createDecorators()

std::set<std::string> InDetGNNHardScatterSelection::GNN::createDecorators ( const FlavorTagInference::SaltModel::OutputConfig &  outConfig)
private

◆ decorate()

void InDetGNNHardScatterSelection::GNN::decorate ( const xAOD::Vertex verrtex) const
virtual

Definition at line 75 of file InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/GNN.cxx.

75  {
76  /* Main function for decorating a vertex with GNN outputs. */
77  using namespace internal;
78 
79  // prepare input
80  // -------------
81  std::map<std::string, FlavorTagInference::Inputs> gnn_input;
82 
83  std::vector<float> vertex_feat;
84  vertex_feat.reserve(m_varsFromVertex.size());
85 for (const auto& getter: m_varsFromVertex) {
86  vertex_feat.push_back(getter(vertex).second);
87  }
88  std::vector<int64_t> vertexfeat_dim = {1, static_cast<int64_t>(vertex_feat.size())};
89 
90  FlavorTagInference::Inputs vertex_info (vertex_feat, vertexfeat_dim);
91  gnn_input.insert({"vertex_features", vertex_info});
92 
93  for (const auto& loader : m_constituentsLoaders){
94  auto [sequence_name, sequence_data, sequence_constituents] = loader->getData(vertex);
95  gnn_input.insert({sequence_name, sequence_data});
96  }
97 
98  // run inference
99  // -------------
100  auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input);
101 
102  // decorate outputs
103  // ----------------
104  for (const auto& dec: m_decorators.vertexFloat) {
105  dec.second(vertex) = out_f.at(dec.first);
106  }
107  } // end of decorate()

Member Data Documentation

◆ m_constituentsLoaders

std::vector<std::shared_ptr<IConstituentsLoader> > InDetGNNHardScatterSelection::GNN::m_constituentsLoaders
private

◆ m_decorators

Decorators InDetGNNHardScatterSelection::GNN::m_decorators
private

◆ m_defaultValue

float InDetGNNHardScatterSelection::GNN::m_defaultValue {}
private

◆ m_input_node_name

std::string InDetGNNHardScatterSelection::GNN::m_input_node_name
private

◆ m_saltModel

std::shared_ptr<const FlavorTagInference::SaltModel> InDetGNNHardScatterSelection::GNN::m_saltModel

◆ m_varsFromVertex

std::vector<internal::VarFromVertex> InDetGNNHardScatterSelection::GNN::m_varsFromVertex
private

The documentation for this class was generated from the following files:
InDetGNNHardScatterSelection::GNN::m_saltModel
std::shared_ptr< const FlavorTagInference::SaltModel > m_saltModel
Definition: InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/GNN.h:52
InDetGNNHardScatterSelection::GNN::m_decorators
Decorators m_decorators
Definition: InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/GNN.h:75
python.SystemOfUnits.second
float second
Definition: SystemOfUnits.py:135
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
InDetGNNHardScatterSelection::dataprep::createGetterConfig
std::tuple< std::vector< HSGNNInputConfig >, std::vector< ConstituentsInputConfig > > createGetterConfig(FlavorTagInference::SaltModelGraphConfig::GraphConfig &graph_config)
Definition: InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/DataPrepUtilities.cxx:118
InDetGNNHardScatterSelection::ConstituentsType::PHOTON
@ PHOTON
config
Definition: PhysicsAnalysis/AnalysisCommon/AssociationUtils/python/config.py:1
InDetGNNHardScatterSelection::ConstituentsType::TRACK
@ TRACK
InDetGNNHardScatterSelection::dataprep::createVertexVarGetters
std::vector< internal::VarFromVertex > createVertexVarGetters(const std::vector< HSGNNInputConfig > &inputs)
Definition: InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/DataPrepUtilities.cxx:171
InDetGNNHardScatterSelection::ConstituentsType::ELECTRON
@ ELECTRON
InDetGNNHardScatterSelection::ConstituentsType::JET
@ JET
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:283
InDetGNNHardScatterSelection::ConstituentsType::MUON
@ MUON
InDetGNNHardScatterSelection::ConstituentsType::IPARTICLE
@ IPARTICLE
Trk::vertex
@ vertex
Definition: MeasurementType.h:21
FlavorTagInference::Inputs
std::pair< std::vector< float >, std::vector< int64_t > > Inputs
Definition: GNNDataLoader.h:16
InDetGNNHardScatterSelection::GNN::Decorators::vertexFloat
Decs< float > vertexFloat
Definition: InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/GNN.h:65
InDetGNNHardScatterSelection::GNN::m_constituentsLoaders
std::vector< std::shared_ptr< IConstituentsLoader > > m_constituentsLoaders
Definition: InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/GNN.h:73
InDetGNNHardScatterSelection::GNN::m_varsFromVertex
std::vector< internal::VarFromVertex > m_varsFromVertex
Definition: InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/InDetGNNHardScatterSelection/GNN.h:72