ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModel.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
8
9#include <stdexcept>
10#include <tuple>
11#include <set>
12
13namespace FlavorTagInference {
14
15 SaltModel::SaltModel(const std::string& path_to_onnx)
16 //load the onnx model to memory using the path m_path_to_onnx
17 : m_env (std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL, ""))
18 {
19 // initialize session options
20 Ort::SessionOptions session_options;
21 session_options.SetIntraOpNumThreads(1);
22
23 // Ignore all non-fatal errors. This isn't a good idea, but it's
24 // what we get for uploading semi-working graphs.
25 session_options.SetLogSeverityLevel(4);
26 session_options.SetGraphOptimizationLevel(
27 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
28 // this should reduce memory use while slowing things down slightly
29 // see
30 //
31 // https://github.com/microsoft/onnxruntime/issues/11627#issuecomment-1137668551
32 //
33 // and also https://its.cern.ch/jira/browse/AFT-818
34 //
35 session_options.DisableCpuMemArena();
36
37 // declare an allocator with default options
38 Ort::AllocatorWithDefaultOptions allocator;
39
40 // create session and load model into memory
41 m_session = std::make_unique<Ort::Session>(
42 *m_env, path_to_onnx.c_str(), session_options);
43
44 // get metadata from the onnx model
45 m_metadata = loadMetadata("gnn_config");
46 m_num_inputs = m_session->GetInputCount();
47 m_num_outputs = m_session->GetOutputCount();
48
49 // get the onnx model version
50 if (m_metadata.contains("onnx_model_version")) { // metadata version is explicitly set
51 m_onnx_model_version = m_metadata["onnx_model_version"].get<SaltModelVersion>();
53 throw std::runtime_error("Unknown Onnx model version!");
54 }
55 } else { // metadata version is not set, infer from the presence of "outputs" key
56 if (m_metadata.contains("outputs")){
58 } else {
59 throw std::runtime_error("Onnx model version not found in metadata");
60 }
61 }
62
63 // get the model name
65
66 // iterate over input nodes and get their names
67 for (size_t i = 0; i < m_num_inputs; i++) {
68 m_input_node_names.push_back(m_session->GetInputNameAllocated(i, allocator).get());
69 }
70
71 // iterate over output nodes and get their configuration
72 for (size_t i = 0; i < m_num_outputs; i++) {
73 const auto name = std::string(m_session->GetOutputNameAllocated(i, allocator).get());
74 const auto type = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
75 const int rank = m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size();
77 m_output_nodes.emplace_back(name, type, m_model_name);
78 } else {
79 m_output_nodes.emplace_back(name, type, rank);
80 }
81 }
82 }
83
84 const nlohmann::json SaltModel::loadMetadata(const std::string& key) const {
85 Ort::AllocatorWithDefaultOptions allocator;
86 Ort::ModelMetadata modelMetadata = m_session->GetModelMetadata();
87 std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get());
88 return nlohmann::json::parse(metadataString);
89 }
90
91 const std::string SaltModel::determineModelName() const {
92 Ort::AllocatorWithDefaultOptions allocator;
94 // get the model name directly from the metadata
95 return std::string(m_metadata["outputs"].begin().key());
96 } else {
97 // get the model name from the output node names
98 // each output node name is of the form "<model_name>_<output_name>"
99 std::set<std::string> model_names;
100 for (size_t i = 0; i < m_num_outputs; i++) {
101 const auto name = std::string(m_session->GetOutputNameAllocated(i, allocator).get());
102 size_t underscore_pos = name.find('_');
103 if (underscore_pos != std::string::npos) {
104 model_names.insert(name.substr(0, underscore_pos));
105 } else {
106 return std::string("UnknownModelName");
107 }
108 }
109 if (model_names.size() != 1) {
110 throw std::runtime_error("SaltModel: model names are not consistent between outputs");
111 }
112 return *model_names.begin();
113 }
114
115 }
116
120
122 return m_output_nodes;
123 }
124
128
129 const std::string& SaltModel::getModelName() const {
130 return m_model_name;
131 }
132
133
135 std::map<std::string, Inputs>& gnn_inputs) const {
136
137 std::vector<float> input_tensor_values;
138
139 // create input tensor object from data values
140 auto memory_info = Ort::MemoryInfo::CreateCpu(
141 OrtArenaAllocator, OrtMemTypeDefault
142 );
143 std::vector<Ort::Value> input_tensors;
144 for (auto& node_name : m_input_node_names) {
145 input_tensors.push_back(Ort::Value::CreateTensor<float>(
146 memory_info, gnn_inputs.at(node_name).first.data(), gnn_inputs.at(node_name).first.size(),
147 gnn_inputs.at(node_name).second.data(), gnn_inputs.at(node_name).second.size())
148 );
149 }
150
151 // casting vector<string> to vector<const char*>. this is what ORT expects
152 std::vector<const char*> input_node_names;
153 input_node_names.reserve(m_input_node_names.size());
154 for (const auto& name : m_input_node_names) {
155 input_node_names.push_back(name.c_str());
156 }
157 std::vector<const char*> output_node_names;
158 output_node_names.reserve(m_output_nodes.size());
159 for (const auto& node : m_output_nodes) {
160 output_node_names.push_back(node.name_in_model.c_str());
161 }
162
163 // score model & input tensor, get back output tensor
164 // Although Session::Run is non-const, the onnx authors say
165 // it is safe to call from multiple threads:
166 // https://github.com/microsoft/onnxruntime/discussions/10107
167 Ort::Session& session ATLAS_THREAD_SAFE = *m_session;
168 auto output_tensors = session.Run(Ort::RunOptions{nullptr},
169 input_node_names.data(), input_tensors.data(), input_node_names.size(),
170 output_node_names.data(), output_node_names.size()
171 );
172
173 // Extract outputs with improved clarity and structure
174 InferenceOutput output;
175 for (size_t node_idx = 0; node_idx < m_output_nodes.size(); ++node_idx) {
176 const auto& output_node = m_output_nodes[node_idx];
177 const auto& tensor = output_tensors[node_idx];
178 auto tensor_type = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementType();
179 auto tensor_shape = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape();
180 int length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
181 if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
182 if (tensor_shape.size() == 0) {
183 output.singleFloat[output_node.name] = *tensor.GetTensorData<float>();
184 } else if (tensor_shape.size() == 1) {
185 const float* data = tensor.GetTensorData<float>();
186 output.vecFloat[output_node.name] = std::vector<float>(data, data + length);
187 } else {
188 throw std::runtime_error("Unsupported tensor shape for FLOAT type");
189 }
190 } else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
191 if (tensor_shape.size() == 1) {
192 const char* data = tensor.GetTensorData<char>();
193 output.vecChar[output_node.name] = std::vector<char>(data, data + length);
194 } else {
195 throw std::runtime_error("Unsupported tensor shape for INT8 type");
196 }
197 } else {
198 throw std::runtime_error("Unsupported tensor type");
199 }
200 }
201
202 return output;
203 }
204
205} // end of FlavorTagInference namespace
double length(const pvec &v)
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
Define macros for attributes used to control the static checker.
#define ATLAS_THREAD_SAFE
virtual const std::string & getModelName() const override
const std::string determineModelName() const
Definition SaltModel.cxx:91
virtual InferenceOutput runInference(std::map< std::string, Inputs > &gnn_inputs) const override
virtual const OutputConfig & getOutputConfig() const override
SaltModel(const std::string &path_to_onnx)
Definition SaltModel.cxx:15
const nlohmann::json loadMetadata(const std::string &key) const
Definition SaltModel.cxx:84
virtual const SaltModelGraphConfig::GraphConfig getGraphConfig() const override
virtual SaltModelVersion getSaltModelVersion() const override
Definition node.h:24
GraphConfig parse_json_graph(const nlohmann::json &metadata)
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::vector< SaltModelOutput > OutputConfig
Definition ISaltModel.h:36
STL namespace.