ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModel.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
8
9#include <stdexcept>
10#include <tuple>
11#include <set>
12
13namespace FlavorTagInference {
14
15 SaltModel::SaltModel(const std::string& path_to_onnx)
16 //load the onnx model to memory using the path m_path_to_onnx
17 : m_env (std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL, ""))
18 {
19 // initialize session options
20 Ort::SessionOptions session_options;
21 session_options.SetIntraOpNumThreads(1);
22
23 // Ignore all non-fatal errors. This isn't a good idea, but it's
24 // what we get for uploading semi-working graphs.
25 session_options.SetLogSeverityLevel(4);
26 session_options.SetGraphOptimizationLevel(
27 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
28 // this should reduce memory use while slowing things down slightly
29 // see
30 //
31 // https://github.com/microsoft/onnxruntime/issues/11627#issuecomment-1137668551
32 //
33 // and also https://its.cern.ch/jira/browse/AFT-818
34 //
35 session_options.DisableCpuMemArena();
36
37 // declare an allocator with default options
38 Ort::AllocatorWithDefaultOptions allocator;
39
40 // create session and load model into memory
41 m_session = std::make_unique<Ort::Session>(
42 *m_env, path_to_onnx.c_str(), session_options);
43
44 // get metadata from the onnx model
45 m_metadata = loadMetadata("gnn_config");
46 m_num_inputs = m_session->GetInputCount();
47 m_num_outputs = m_session->GetOutputCount();
48
49 // get the onnx model version
50 if (m_metadata.contains("onnx_model_version")) { // metadata version is explicitly set
51 m_onnx_model_version = m_metadata["onnx_model_version"].get<SaltModelVersion>();
53 throw std::runtime_error("Unknown Onnx model version!");
54 }
55 } else { // metadata version is not set, infer from the presence of "outputs" key
56 if (m_metadata.contains("outputs")){
58 } else {
59 throw std::runtime_error("Onnx model version not found in metadata");
60 }
61 }
62
63 // get the model name
65
66 // iterate over input nodes and get their names
67 for (size_t i = 0; i < m_num_inputs; i++) {
68 m_input_node_names.push_back(m_session->GetInputNameAllocated(i, allocator).get());
69 }
70
71 // iterate over output nodes and get their configuration
72 for (size_t i = 0; i < m_num_outputs; i++) {
73 const auto name = std::string(m_session->GetOutputNameAllocated(i, allocator).get());
74 const auto type = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
75 const int rank = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size();
77 m_output_nodes.emplace_back(name, type, m_model_name);
78 } else {
79 m_output_nodes.emplace_back(name, type, rank);
80 }
81 }
82 }
83
84 const nlohmann::json SaltModel::loadMetadata(const std::string& key) const {
85 Ort::AllocatorWithDefaultOptions allocator;
86 Ort::ModelMetadata modelMetadata = m_session->GetModelMetadata();
87 std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get());
88 return nlohmann::json::parse(metadataString);
89 }
90
91 const std::string SaltModel::determineModelName() const {
92 Ort::AllocatorWithDefaultOptions allocator;
94 // get the model name directly from the metadata
95 return std::string(m_metadata["outputs"].begin().key());
96 } else {
97 // get the model name from the output node names
98 // each output node name is of the form "<model_name>_<output_name>"
99 std::set<std::string> model_names;
100 for (size_t i = 0; i < m_num_outputs; i++) {
101 const auto name = std::string(m_session->GetOutputNameAllocated(i, allocator).get());
102 size_t underscore_pos = name.find('_');
103 if (underscore_pos != std::string::npos) {
104 model_names.insert(name.substr(0, underscore_pos));
105 } else {
106 return std::string("");
107 }
108 }
109 if (model_names.size() != 1) {
110 throw std::runtime_error("SaltModel: model names are not consistent between outputs");
111 }
112 return *model_names.begin();
113 }
114
115 }
116
120
122 return m_output_nodes;
123 }
124
128
129 const std::string& SaltModel::getModelName() const {
130 return m_model_name;
131 }
132
133
135
136 std::vector<float> input_tensor_values;
137
138 // create input tensor object from data values
139 auto memory_info = Ort::MemoryInfo::CreateCpu(
140 OrtArenaAllocator, OrtMemTypeDefault
141 );
142 std::vector<Ort::Value> input_tensors;
143 for (auto& node_name : m_input_node_names) {
144 input_tensors.push_back(Ort::Value::CreateTensor<float>(
145 memory_info, gnn_inputs.at(node_name).first.data(), gnn_inputs.at(node_name).first.size(),
146 gnn_inputs.at(node_name).second.data(), gnn_inputs.at(node_name).second.size())
147 );
148 }
149
150 // casting vector<string> to vector<const char*>. this is what ORT expects
151 std::vector<const char*> input_node_names;
152 input_node_names.reserve(m_input_node_names.size());
153 for (const auto& name : m_input_node_names) {
154 input_node_names.push_back(name.c_str());
155 }
156 std::vector<const char*> output_node_names;
157 output_node_names.reserve(m_output_nodes.size());
158 for (const auto& node : m_output_nodes) {
159 output_node_names.push_back(node.name_in_model.c_str());
160 }
161
162 // score model & input tensor, get back output tensor
163 // Although Session::Run is non-const, the onnx authors say
164 // it is safe to call from multiple threads:
165 // https://github.com/microsoft/onnxruntime/discussions/10107
166 Ort::Session& session ATLAS_THREAD_SAFE = *m_session;
167 auto output_tensors = session.Run(Ort::RunOptions{nullptr},
168 input_node_names.data(), input_tensors.data(), input_node_names.size(),
169 output_node_names.data(), output_node_names.size()
170 );
171
172 // Extract outputs with improved clarity and structure
173 InferenceOutput output;
174 for (size_t node_idx = 0; node_idx < m_output_nodes.size(); ++node_idx) {
175 const auto& output_node = m_output_nodes[node_idx];
176 const auto& tensor = output_tensors[node_idx];
177 auto tensor_type = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementType();
178 auto tensor_shape = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape();
179 int length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
180 if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
181 if (tensor_shape.size() == 0) {
182 output.singleFloat[output_node.name] = *tensor.GetTensorData<float>();
183 } else if (tensor_shape.size() == 1) {
184 const float* data = tensor.GetTensorData<float>();
185 output.vecFloat[output_node.name] = std::vector<float>(data, data + length);
186 } else {
187 throw std::runtime_error("Unsupported tensor shape for FLOAT type");
188 }
189 } else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
190 if (tensor_shape.size() == 1) {
191 const char* data = tensor.GetTensorData<char>();
192 output.vecChar[output_node.name] = std::vector<char>(data, data + length);
193 } else {
194 throw std::runtime_error("Unsupported tensor shape for INT8 type");
195 }
196 } else {
197 throw std::runtime_error("Unsupported tensor type");
198 }
199 }
200
201 return output;
202 }
203
204} // end of FlavorTagInference namespace
double length(const pvec &v)
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
Define macros for attributes used to control the static checker.
#define ATLAS_THREAD_SAFE
virtual const std::string & getModelName() const override
const std::string determineModelName() const
Definition SaltModel.cxx:91
virtual const OutputConfig & getOutputConfig() const override
SaltModel(const std::string &path_to_onnx)
Definition SaltModel.cxx:15
const nlohmann::json loadMetadata(const std::string &key) const
Definition SaltModel.cxx:84
virtual const SaltModelGraphConfig::GraphConfig getGraphConfig() const override
virtual InferenceOutput runInference(InputMap &gnn_inputs) const override
virtual SaltModelVersion getSaltModelVersion() const override
Definition node.h:24
GraphConfig parse_json_graph(const nlohmann::json &metadata)
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::vector< SaltModelOutput > OutputConfig
Definition ISaltModel.h:38
std::map< std::string, Inputs, std::less<> > InputMap
Definition ISaltModel.h:37
STL namespace.