ATLAS Offline Software
Loading...
Searching...
No Matches
InDetGNNHardScatterSelection::GNN Class Reference

Implementation of the GNN used by the InDetGNNHardScatterSelection::GNNTool. More...

#include <GNN.h>

Collaboration diagram for InDetGNNHardScatterSelection::GNN:

Classes

struct  Decorators

Public Member Functions

 GNN (const std::string &nnFile)
 GNN (GNN &&)=delete
 GNN (const GNN &)=delete
virtual ~GNN ()
virtual float decorate (const xAOD::Vertex &verrtex) const

Private Types

using TPC = xAOD::TrackParticleContainer
using TrackLinks = std::vector<ElementLink<TPC>>
template<typename T>
using Dec = SG::AuxElement::Decorator<T>
template<typename T>
using Decs = std::vector<std::pair<std::string, Dec<T>>>

Private Member Functions

std::set< std::string > createDecorators (const FlavorTagInference::OutputConfig &outConfig)

Private Attributes

std::shared_ptr< const FlavorTagInference::SaltModelm_saltModel
std::string m_modelPath
std::string m_input_node_name
std::vector< internal::VarFromVertexm_varsFromVertex
std::vector< std::shared_ptr< IConstituentsLoader > > m_constituentsLoaders
Decorators m_decorators
float m_defaultValue {}

Detailed Description

Member Typedef Documentation

◆ Dec

template<typename T>
using InDetGNNHardScatterSelection::GNN::Dec = SG::AuxElement::Decorator<T>
private

◆ Decs

template<typename T>
using InDetGNNHardScatterSelection::GNN::Decs = std::vector<std::pair<std::string, Dec<T>>>
private

◆ TPC

◆ TrackLinks

Constructor & Destructor Documentation

◆ GNN() [1/3]

InDetGNNHardScatterSelection::GNN::GNN ( const std::string & nnFile)

Definition at line 28 of file InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/GNN.cxx.

28 :
29 m_saltModel(nullptr),
32 {
33
34 // Load and initialize the neural network model from the given file path.
35 std::string fullPathToOnnxFile = PathResolverFindCalibFile(nn_file);
36 m_modelPath = fullPathToOnnxFile;
37 m_saltModel = std::make_shared<FlavorTagInference::SaltModel>(fullPathToOnnxFile);
38 auto graph_config = m_saltModel->getGraphConfig();
39
40 // Vertex tensor key for runInference: must match InputNodeConfig::name in embedded gnn_config
41 // (same metadata SaltModel uses for graph_config).
42 std::vector<std::string> input_node_names;
43 input_node_names.reserve(graph_config.inputs.size());
44 for (const auto& node : graph_config.inputs) {
45 input_node_names.push_back(node.name);
46 }
47
48 if (std::find(input_node_names.begin(), input_node_names.end(), "vertice_features") != input_node_names.end()) {
49 m_input_node_name = "vertice_features";
50 } else if (std::find(input_node_names.begin(), input_node_names.end(), "vertex_features") != input_node_names.end()) {
51 m_input_node_name = "vertex_features";
52 } else if (graph_config.inputs.size() == 1) {
53 m_input_node_name = graph_config.inputs.front().name;
54 } else {
55 std::ostringstream msg;
56 msg << "Unsupported scalar vertex input name in model '" << nn_file << "'. Graph config inputs: ";
57 for (const auto& name : input_node_names) {
58 msg << name << " ";
59 }
60 throw std::runtime_error(msg.str());
61 }
62
63 // Create configuration objects for data preprocessing.
64 auto [inputs, constituents_configs] = dataprep::createGetterConfig(graph_config);
65
66 for (const auto& config : constituents_configs){
67 switch (config.type){
69 m_constituentsLoaders.push_back(std::make_shared<TracksLoader>(config));
70 break;
72 m_constituentsLoaders.push_back(std::make_shared<ElectronsLoader>(config));
73 break;
75 m_constituentsLoaders.push_back(std::make_shared<MuonsLoader>(config));
76 break;
78 m_constituentsLoaders.push_back(std::make_shared<JetsLoader>(config));
79 break;
81 m_constituentsLoaders.push_back(std::make_shared<PhotonsLoader>(config));
82 break;
84 m_constituentsLoaders.push_back(std::make_shared<IParticlesLoader>(config));
85 break;
86 }
87 }
88
90
91 // Retrieve the configuration for the model outputs.
92 FlavorTagInference::OutputConfig gnn_output_config = m_saltModel->getOutputConfig();
93
94 for (const auto& outNode : gnn_output_config) {
95 // the node's output name will be used to define the decoration name
96 std::string dec_name = outNode.name;
97 m_decorators.vertexFloat.emplace_back(outNode.name, Dec<float>(dec_name));
98 if (outNode.name == "salt_phsvertex") {
99 m_decorators.vertexFloat.emplace_back("salt_phsvertex", Dec<float>("HSGN2_phsvertex"));
100 }
101 }
102 }
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
std::vector< SaltModelOutput > OutputConfig
Definition ISaltModel.h:38
std::vector< internal::VarFromVertex > createVertexVarGetters(const std::vector< HSGNNInputConfig > &inputs)
std::tuple< std::vector< HSGNNInputConfig >, std::vector< ConstituentsInputConfig > > createGetterConfig(FlavorTagInference::SaltModelGraphConfig::GraphConfig &graph_config)
MsgStream & msg
Definition testRead.cxx:32

◆ GNN() [2/3]

InDetGNNHardScatterSelection::GNN::GNN ( GNN && )
delete

◆ GNN() [3/3]

InDetGNNHardScatterSelection::GNN::GNN ( const GNN & )
delete

◆ ~GNN()

InDetGNNHardScatterSelection::GNN::~GNN ( )
virtualdefault

Member Function Documentation

◆ createDecorators()

std::set< std::string > InDetGNNHardScatterSelection::GNN::createDecorators ( const FlavorTagInference::OutputConfig & outConfig)
private

◆ decorate()

float InDetGNNHardScatterSelection::GNN::decorate ( const xAOD::Vertex & verrtex) const
virtual

Definition at line 106 of file InnerDetector/InDetRecTools/InDetGNNHardScatterSelection/Root/GNN.cxx.

106 {
107 /* Main function for decorating a vertex with GNN outputs. */
108 using namespace internal;
109
110 // prepare input
111 // -------------
113
114 std::vector<float> vertex_feat;
115 vertex_feat.reserve(m_varsFromVertex.size());
116 for (const auto& getter: m_varsFromVertex) {
117 vertex_feat.push_back(getter(vertex).second);
118 }
119 std::vector<int64_t> vertexfeat_dim = {1, static_cast<int64_t>(vertex_feat.size())};
120
121 for (const auto& value : vertex_feat) {
122 if (!std::isfinite(value)) {
123 throw std::runtime_error("Non-finite scalar vertex input before runInference");
124 }
125 }
126
127 FlavorTagInference::Inputs vertex_info (vertex_feat, vertexfeat_dim);
128 gnn_input.insert({m_input_node_name, vertex_info});
129 // Provide common scalar aliases to absorb metadata/model naming mismatches.
130 constexpr std::array<const char*, 4> scalar_input_aliases = {
131 "vertice_features", "vertex_features", "vertice_var", "vertex_var"
132 };
133 for (const char* alias : scalar_input_aliases) {
134 gnn_input.insert({alias, vertex_info});
135 }
136
137 for (const auto& loader : m_constituentsLoaders){
138 auto [sequence_name, sequence_data, sequence_constituents] = loader->getData(vertex);
139 gnn_input.insert({sequence_name, sequence_data});
140 const std::string legacy_sequence_name = loader->getName();
141 if (legacy_sequence_name != sequence_name) {
142 gnn_input.insert({legacy_sequence_name, sequence_data});
143 }
144 }
145
146 // run inference
147 // -------------
148 // FPE warning are hidden but should be resolved.
149 // Related JIRA Ticket : https://its.cern.ch/jira/browse/ATLASRECTS-8386
150 std::feclearexcept(FE_ALL_EXCEPT);
151
152 const FlavorTagInference::InferenceOutput inference_output =
153 m_saltModel->runInference(gnn_input);
154 const auto& out_f = inference_output.singleFloat;
155
156 // Get HSGNN score
157 // ----------------
158 float score = -999;
159 FlavorTagInference::OutputConfig gnn_output_config = m_saltModel->getOutputConfig();
160
161 for (const auto& outNode : gnn_output_config) {
162 std::string score_name = outNode.name;
163
164 if (outNode.name.find("_phsvertex") != std::string::npos) {
165 score = out_f.at(score_name);
166 }
167 }
168 std::feclearexcept(FE_ALL_EXCEPT);
169
170 return score;
171
172 } // end of decorate()
std::map< std::string, Inputs, std::less<> > InputMap
Definition ISaltModel.h:37
std::map< std::string, float > singleFloat
Definition ISaltModel.h:41

Member Data Documentation

◆ m_constituentsLoaders

std::vector<std::shared_ptr<IConstituentsLoader> > InDetGNNHardScatterSelection::GNN::m_constituentsLoaders
private

◆ m_decorators

Decorators InDetGNNHardScatterSelection::GNN::m_decorators
private

◆ m_defaultValue

float InDetGNNHardScatterSelection::GNN::m_defaultValue {}
private

◆ m_input_node_name

std::string InDetGNNHardScatterSelection::GNN::m_input_node_name
private

◆ m_modelPath

std::string InDetGNNHardScatterSelection::GNN::m_modelPath
private

◆ m_saltModel

std::shared_ptr<const FlavorTagInference::SaltModel> InDetGNNHardScatterSelection::GNN::m_saltModel
private

◆ m_varsFromVertex

std::vector<internal::VarFromVertex> InDetGNNHardScatterSelection::GNN::m_varsFromVertex
private

The documentation for this class was generated from the following files: