ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModelTriton.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
5#include "SaltModelTriton.h"
9#include <onnxruntime_cxx_api.h>
10
11#include <stdexcept>
12#include <tuple>
13#include <set>
14#include <memory>
15
16// DType traits for Triton
17template <typename T> struct TritonDType;
18template <> struct TritonDType<float> { static constexpr const char* value = "FP32"; };
19template <> struct TritonDType<int64_t> { static constexpr const char* value = "INT64"; };
20
21template <typename T>
22bool prepareInput(const std::string& name,
23 const std::vector<int64_t>& shape,
24 const std::vector<T>& data,
25 std::vector<std::shared_ptr<tc::InferInput>>& inputs)
26{
27 const char* dtype = TritonDType<T>::value;
28 tc::InferInput* rawInputPtr = nullptr;
29
30 // create the InferInput object with the predefined name, shape, and data type.
31 tc::Error err = tc::InferInput::Create(&rawInputPtr, name, shape, dtype);
32 if(!err.IsOk()) {
33 std::cerr << "Unable to create input: " + name << std::endl;
34 return false;
35 }
36
37 // Append tensor values for this input from a byte array.
38 // Note: The vector is not copied and so it must not be modified or destroyed
39 // until this input is no longer needed (that is until the Infer() call(s) that use the input have completed).
40 // Multiple calls can be made to this API to keep adding tensor data for this input.
41 // The data will be delivered in the order it was added.
42 std::shared_ptr<tc::InferInput> input(rawInputPtr);
43 err = input->AppendRaw(reinterpret_cast<const uint8_t*>(data.data()),
44 data.size() * sizeof(T));
45 if(!err.IsOk()) {
46 std::cerr << "Unable to set input data for: " + name << std::endl;
47 return false;
48 }
49
50 inputs.push_back(std::move(input));
51 return true;
52}
53
54template <typename T>
55bool extractOutput(const std::string& name,
56 const std::shared_ptr<tc::InferResult>& result,
57 std::vector<T>& outputVec)
58{
59 const uint8_t* rawData = nullptr;
60 size_t size = 0;
61
62 // Get access to the buffer holding raw results of specified output returned by the server.
63 // Note: the buffer is owned by InferResult instance.
64 // Users can copy out the data if required to extend the lifetime.
65 tc::Error err = result->RawData(name, &rawData, &size);
66 if(!err.IsOk()) {
67 std::cerr << "Unable to get raw output for: " + name << std::endl;
68 return false;
69 }
70
71 outputVec.resize(size / sizeof(T));
72 std::memcpy(outputVec.data(), rawData, size);
73 return true;
74}
75
76
77namespace FlavorTagInference {
78
79 SaltModelTriton::SaltModelTriton(const std::string& path_to_onnx
80 , const std::string& model_name
81 , float client_timeout
82 , int port
83 , const std::string& url
84 , bool useSSL)
85 : m_model_name(model_name)
86 , m_clientTimeout(client_timeout)
87 , m_port(port)
88 , m_url(url)
89 , m_useSSL(useSSL)
90 //load the onnx model to memory using the path m_path_to_onnx
91 {
92 std::unique_ptr< Ort::Env > env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL, "");
93
94 // initialize session options
95 Ort::SessionOptions session_options;
96 session_options.SetIntraOpNumThreads(1);
97
98 // Ignore all non-fatal errors. This isn't a good idea, but it's
99 // what we get for uploading semi-working graphs.
100 session_options.SetLogSeverityLevel(4);
101 session_options.SetGraphOptimizationLevel(
102 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
103 // this should reduce memory use while slowing things down slightly
104 // see
105 //
106 // https://github.com/microsoft/onnxruntime/issues/11627#issuecomment-1137668551
107 //
108 // and also https://its.cern.ch/jira/browse/AFT-818
109 //
110 session_options.DisableCpuMemArena();
111
112 // declare an allocator with default options
113 Ort::AllocatorWithDefaultOptions allocator;
114
115 // create session and load model into memory
116 std::unique_ptr< Ort::Session > session = std::make_unique<Ort::Session>(
117 *env, path_to_onnx.c_str(), session_options);
118
119 // get metadata from the onnx model
120 m_metadata = loadMetadata("gnn_config", session.get());
121 m_num_outputs = session->GetOutputCount();
122
123 // get the onnx model version
124 if (m_metadata.contains("onnx_model_version")) { // metadata version is explicitly set
125 m_onnx_model_version = m_metadata["onnx_model_version"].get<SaltModelVersion>();
127 throw std::runtime_error("Unknown Onnx model version!");
128 }
129 } else { // metadata version is not set, infer from the presence of "outputs" key
130 if (m_metadata.contains("outputs")){
132 } else {
133 throw std::runtime_error("Onnx model version not found in metadata");
134 }
135 }
136
137 // get the model name
138 m_model_type = determineModelType(session.get());
139
140 // iterate over output nodes and get their configuration
141 for (size_t i = 0; i < m_num_outputs; i++) {
142 const auto name = std::string(session->GetOutputNameAllocated(i, allocator).get());
143 const auto type = session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
144 const int rank = session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size();
146 const SaltModelOutput saltModelOutput(name, type, m_model_type);
147 m_output_nodes.push_back(std::move(saltModelOutput));
148 } else {
149 const SaltModelOutput saltModelOutput(name, type, rank);
150 m_output_nodes.push_back(std::move(saltModelOutput));
151 }
152 }
153
154 m_options = std::make_unique<tc::InferOptions>(m_model_name);
155 m_options->model_version_ = ""; // Fixme: find a proper solution for this
156 m_options->client_timeout_ = m_clientTimeout;
157 }
158
159 const nlohmann::json SaltModelTriton::loadMetadata(const std::string& key, const Ort::Session* session) const {
160 Ort::AllocatorWithDefaultOptions allocator;
161 Ort::ModelMetadata modelMetadata = session->GetModelMetadata();
162 std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get());
163 return nlohmann::json::parse(metadataString);
164 }
165
166 const std::string SaltModelTriton::determineModelType(const Ort::Session* session) const {
167 Ort::AllocatorWithDefaultOptions allocator;
169 // get the model name directly from the metadata
170 return std::string(m_metadata["outputs"].begin().key());
171 }
172 else {
173 // get the model name from the output node names
174 // each output node name is of the form "<model_type>_<output_name>"
175 std::set<std::string> model_types;
176 for (size_t i = 0; i < m_num_outputs; i++) {
177 const auto name = std::string(session->GetOutputNameAllocated(i, allocator).get());
178 size_t underscore_pos = name.find('_');
179 if (underscore_pos != std::string::npos) {
180 std::string substring = name.substr(0, underscore_pos);
181 model_types.insert(std::move(substring));
182 }
183 else {
184 return std::string("UnknownModelName");
185 }
186 }
187 if (model_types.size() != 1) {
188 throw std::runtime_error("SaltModelTriton: model names are not consistent between outputs");
189 }
190 return *model_types.begin();
191 }
192 }
193
197
201
205
206 const std::string& SaltModelTriton::getModelName() const {
207 return m_model_type;
208 }
209
210
212 std::map<std::string, Inputs>& gnn_inputs) const {
213
214 // Create tensor for the input data
215 std::vector<std::shared_ptr<tc::InferInput> > inputs_;
216 inputs_.reserve(gnn_inputs.size());
217
218 for (auto& [inputName, inputInfo]: gnn_inputs) {
219 const std::vector<float>& inputData = inputInfo.first; // ? good name ?
220 const std::vector<int64_t>& inputShape = inputInfo.second;
221 if(!prepareInput<float>(inputName, inputShape, inputData, inputs_)) {
222 throw std::runtime_error("Failed to prepare input for inference"); // ? more informative error message ?
223 }
224 }
225
226 // construct raw points for inference
227 std::vector<tc::InferInput*> rawInputs;
228 for(auto& input : inputs_) {
229 rawInputs.push_back(input.get());
230 }
231
232 // perform the inference
233 tc::InferResult* rawResultPtr = nullptr;
234 tc::Headers http_headers;
235 grpc_compression_algorithm compression_algorithm = grpc_compression_algorithm::GRPC_COMPRESS_NONE;
236
237 auto client = getClient();
238 if(client) {
239 tc::Error err = client->Infer(&rawResultPtr
240 , *m_options
241 , rawInputs
242 , {}
243 , http_headers
244 , compression_algorithm);
245 if(!err.IsOk()) {
246 throw std::runtime_error("unable to run model "+ m_model_name + " error: " + err.Message());
247 }
248 }
249 else {
250 throw std::runtime_error("Failed to create Triton gRPC client");
251 }
252
253 // Get the result of the inference
254 InferenceOutput output;
255 std::shared_ptr<tc::InferResult> results(rawResultPtr);
256 for (size_t node_idx = 0; node_idx < m_output_nodes.size(); ++node_idx) {
257 const auto& output_node = m_output_nodes[node_idx];
258 switch(output_node.type) {
260 {
261 std::vector<float> outputVecFloat;
262 extractOutput<float>(output_node.name, results, outputVecFloat);
263 output.vecFloat[output_node.name] = std::move(outputVecFloat);
264 }
265 break;
267 {
268 std::vector<float> outputFloat;
269 extractOutput<float>(output_node.name, results, outputFloat);
270 if(outputFloat.size()==1) {
271 output.singleFloat[output_node.name] = outputFloat[0];
272 }
273 else {
274 throw std::runtime_error("Vector of floats returned instead of a single float for " + output_node.name);
275 }
276 }
277 break;
279 {
280 std::vector<int8_t> outputVecInt;
281 extractOutput<int8_t>(output_node.name, results, outputVecInt);
282 // convert int8_t vector to char vector
283 std::vector<char> outputVecChar(outputVecInt.begin(), outputVecInt.end());
284 output.vecChar[output_node.name] = std::move(outputVecChar);
285 }
286 break;
288 [[fallthrough]];
289 default:
290 throw std::runtime_error("Unknown output type for the node " + output_node.name);
291 }
292 }
293
294 return output;
295 }
296
297 tc::InferenceServerGrpcClient* SaltModelTriton::getClient() const
298 {
299 thread_local std::unique_ptr<tc::InferenceServerGrpcClient> threadClient;
300 if(!threadClient) {
301 std::string url = m_url + ":" + std::to_string(m_port);
302 tc::Error err = tc::InferenceServerGrpcClient::Create(&threadClient, url, false, m_useSSL);
303 if (!err.IsOk()) {
304 std::cerr << "SaltModelTriton ERROR: Failed to create Triton gRPC client for model: " << m_model_name
305 << " at URL: " << url << std::endl;
306 std::cerr << err.Message() << std::endl;
307 return nullptr;
308 }
309 }
310 return threadClient.get();
311 }
312} // end of FlavorTagInference namespace
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
bool extractOutput(const std::string &name, const std::shared_ptr< tc::InferResult > &result, std::vector< T > &outputVec)
bool prepareInput(const std::string &name, const std::vector< int64_t > &shape, const std::vector< T > &data, std::vector< std::shared_ptr< tc::InferInput > > &inputs)
Define macros for attributes used to control the static checker.
tc::InferenceServerGrpcClient * getClient() const
SaltModelTriton(const std::string &path_to_onnx, const std::string &model_name, float client_timeout, int port, const std::string &url, bool useSSL)
virtual InferenceOutput runInference(std::map< std::string, Inputs > &gnn_inputs) const override
const nlohmann::json loadMetadata(const std::string &key, const Ort::Session *session) const
virtual const OutputConfig & getOutputConfig() const override
virtual const SaltModelGraphConfig::GraphConfig getGraphConfig() const override
std::unique_ptr< tc::InferOptions > m_options
const std::string determineModelType(const Ort::Session *session) const
virtual SaltModelVersion getSaltModelVersion() const override
virtual const std::string & getModelName() const override
GraphConfig parse_json_graph(const nlohmann::json &metadata)
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::vector< SaltModelOutput > OutputConfig
Definition ISaltModel.h:36
static constexpr const char * value
static constexpr const char * value