ATLAS Offline Software
Loading...
Searching...
No Matches
InDetGNNHardScatterSelection::GNN Class Reference

Implementation of the GNN used by the InDetGNNHardScatterSelection::GNNTool. More...

#include <GNN.h>

Collaboration diagram for InDetGNNHardScatterSelection::GNN:

Classes

struct  Decorators

Public Member Functions

 GNN (const std::string &nnFile)
 GNN (GNN &&)
 GNN (const GNN &)
virtual ~GNN ()
virtual void decorate (const xAOD::Vertex &verrtex) const

Private Types

using TPC = xAOD::TrackParticleContainer
using TrackLinks = std::vector<ElementLink<TPC>>
template<typename T>
using Dec = SG::AuxElement::Decorator<T>
template<typename T>
using Decs = std::vector<std::pair<std::string, Dec<T>>>

Private Member Functions

std::set< std::string > createDecorators (const FlavorTagInference::OutputConfig &outConfig)

Private Attributes

std::shared_ptr< const FlavorTagInference::SaltModelm_saltModel
std::string m_input_node_name
std::vector< internal::VarFromVertexm_varsFromVertex
std::vector< std::shared_ptr< IConstituentsLoader > > m_constituentsLoaders
Decorators m_decorators
float m_defaultValue {}

Detailed Description

Member Typedef Documentation

◆ Dec

◆ Decs

template<typename T>
using InDetGNNHardScatterSelection::GNN::Decs = std::vector<std::pair<std::string, Dec<T>>>
private

◆ TPC

◆ TrackLinks

Constructor & Destructor Documentation

◆ GNN() [1/3]

InDetGNNHardScatterSelection::GNN::GNN ( const std::string & nnFile)

Definition at line 22 of file InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/GNN.cxx.

22 :
23 m_saltModel(nullptr)
24 {
25
26 // Load and initialize the neural network model from the given file path.
27 std::string fullPathToOnnxFile = PathResolverFindCalibFile(nn_file);
28 m_saltModel = std::make_shared<FlavorTagInference::SaltModel>(fullPathToOnnxFile);
29
30 // Extract metadata from the ONNX file, primarily about the model's inputs.
31 auto graph_config = m_saltModel->getGraphConfig();
32
33 // Create configuration objects for data preprocessing.
34 auto [inputs, constituents_configs] = dataprep::createGetterConfig(graph_config);
35
36 for (const auto& config : constituents_configs){
37 switch (config.type){
39 m_constituentsLoaders.push_back(std::make_shared<TracksLoader>(config));
40 break;
42 m_constituentsLoaders.push_back(std::make_shared<ElectronsLoader>(config));
43 break;
45 m_constituentsLoaders.push_back(std::make_shared<MuonsLoader>(config));
46 break;
48 m_constituentsLoaders.push_back(std::make_shared<JetsLoader>(config));
49 break;
51 m_constituentsLoaders.push_back(std::make_shared<PhotonsLoader>(config));
52 break;
54 m_constituentsLoaders.push_back(std::make_shared<IParticlesLoader>(config));
55 break;
56 }
57 }
58
60
61 // Retrieve the configuration for the model outputs.
62 FlavorTagInference::OutputConfig gnn_output_config = m_saltModel->getOutputConfig();
63
64 for (const auto& outNode : gnn_output_config) {
65 // the node's output name will be used to define the decoration name
66 std::string dec_name = outNode.name;
67 m_decorators.vertexFloat.emplace_back(outNode.name, Dec<float>(dec_name));
68 }
69 }
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
std::vector< SaltModelOutput > OutputConfig
Definition ISaltModel.h:36
std::vector< internal::VarFromVertex > createVertexVarGetters(const std::vector< HSGNNInputConfig > &inputs)
std::tuple< std::vector< HSGNNInputConfig >, std::vector< ConstituentsInputConfig > > createGetterConfig(FlavorTagInference::SaltModelGraphConfig::GraphConfig &graph_config)

◆ GNN() [2/3]

InDetGNNHardScatterSelection::GNN::GNN ( GNN && )
default

◆ GNN() [3/3]

InDetGNNHardScatterSelection::GNN::GNN ( const GNN & )
default

◆ ~GNN()

InDetGNNHardScatterSelection::GNN::~GNN ( )
virtualdefault

Member Function Documentation

◆ createDecorators()

std::set< std::string > InDetGNNHardScatterSelection::GNN::createDecorators ( const FlavorTagInference::OutputConfig & outConfig)
private

◆ decorate()

void InDetGNNHardScatterSelection::GNN::decorate ( const xAOD::Vertex & verrtex) const
virtual

Definition at line 75 of file InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/GNN.cxx.

75 {
76 /* Main function for decorating a vertex with GNN outputs. */
77 using namespace internal;
78
79 // prepare input
80 // -------------
81 std::map<std::string, FlavorTagInference::Inputs> gnn_input;
82
83 std::vector<float> vertex_feat;
84 vertex_feat.reserve(m_varsFromVertex.size());
85for (const auto& getter: m_varsFromVertex) {
86 vertex_feat.push_back(getter(vertex).second);
87 }
88 std::vector<int64_t> vertexfeat_dim = {1, static_cast<int64_t>(vertex_feat.size())};
89
90 FlavorTagInference::Inputs vertex_info (vertex_feat, vertexfeat_dim);
91 gnn_input.insert({"vertex_features", vertex_info});
92
93 for (const auto& loader : m_constituentsLoaders){
94 auto [sequence_name, sequence_data, sequence_constituents] = loader->getData(vertex);
95 gnn_input.insert({sequence_name, sequence_data});
96 }
97
98 // run inference
99 // -------------
100 auto [out_f, out_vc, out_vf] = m_saltModel->runInference(gnn_input);
101
102 // decorate outputs
103 // ----------------
104 for (const auto& dec: m_decorators.vertexFloat) {
105 dec.second(vertex) = out_f.at(dec.first);
106 }
107 } // end of decorate()
std::pair< std::vector< float >, std::vector< int64_t > > Inputs

Member Data Documentation

◆ m_constituentsLoaders

std::vector<std::shared_ptr<IConstituentsLoader> > InDetGNNHardScatterSelection::GNN::m_constituentsLoaders
private

◆ m_decorators

Decorators InDetGNNHardScatterSelection::GNN::m_decorators
private

◆ m_defaultValue

float InDetGNNHardScatterSelection::GNN::m_defaultValue {}
private

◆ m_input_node_name

std::string InDetGNNHardScatterSelection::GNN::m_input_node_name
private

◆ m_saltModel

std::shared_ptr<const FlavorTagInference::SaltModel> InDetGNNHardScatterSelection::GNN::m_saltModel
private

◆ m_varsFromVertex

std::vector<internal::VarFromVertex> InDetGNNHardScatterSelection::GNN::m_varsFromVertex
private

The documentation for this class was generated from the following files: