ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModelTriton.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
9#include <onnxruntime_cxx_api.h>
10
11#include <stdexcept>
12#include <tuple>
13#include <set>
14#include <memory>
15
16// DType traits for Triton
17template <typename T> struct TritonDType;
18template <> struct TritonDType<float> { static constexpr const char* value = "FP32"; };
19template <> struct TritonDType<int64_t> { static constexpr const char* value = "INT64"; };
20
21template <typename T>
22bool prepareInput(const std::string& name,
23 const std::vector<int64_t>& shape,
24 const std::vector<T>& data,
25 std::vector<std::shared_ptr<tc::InferInput>>& inputs)
26{
27 const char* dtype = TritonDType<T>::value;
28 tc::InferInput* rawInputPtr = nullptr;
29
30 // create the InferInput object with the predefined name, shape, and data type.
31 tc::Error err = tc::InferInput::Create(&rawInputPtr, name, shape, dtype);
32 if(!err.IsOk()) {
33 std::cerr << "Unable to create input: " + name << std::endl;
34 return false;
35 }
36
37 // Append tensor values for this input from a byte array.
38 // Note: The vector is not copied and so it must not be modified or destroyed
39 // until this input is no longer needed (that is until the Infer() call(s) that use the input have completed).
40 // Multiple calls can be made to this API to keep adding tensor data for this input.
41 // The data will be delivered in the order it was added.
42 std::shared_ptr<tc::InferInput> input(rawInputPtr);
43 err = input->AppendRaw(reinterpret_cast<const uint8_t*>(data.data()),
44 data.size() * sizeof(T));
45 if(!err.IsOk()) {
46 std::cerr << "Unable to set input data for: " + name << std::endl;
47 return false;
48 }
49
50 inputs.push_back(std::move(input));
51 return true;
52}
53
54template <typename T>
55bool extractOutput(const std::string& name,
56 const std::shared_ptr<tc::InferResult>& result,
57 std::vector<T>& outputVec)
58{
59 const uint8_t* rawData = nullptr;
60 size_t size = 0;
61
62 // Get access to the buffer holding raw results of specified output returned by the server.
63 // Note: the buffer is owned by InferResult instance.
64 // Users can copy out the data if required to extend the lifetime.
65 tc::Error err = result->RawData(name, &rawData, &size);
66 if(!err.IsOk()) {
67 std::cerr << "Unable to get raw output for: " + name << std::endl;
68 return false;
69 }
70
71 outputVec.resize(size / sizeof(T));
72 std::memcpy(outputVec.data(), rawData, size);
73 return true;
74}
75
76
77namespace FlavorTagInference {
78
79 SaltModelTriton::SaltModelTriton(const std::string& path_to_onnx
80 , const std::string& model_name
81 , float client_timeout
82 , int port
83 , const std::string& url
84 , bool useSSL
85 , const std::string& bearer)
86 : m_model_name(model_name)
87 , m_clientTimeout(client_timeout)
88 , m_port(port)
89 , m_url(url)
90 , m_useSSL(useSSL)
91 , m_bearer(bearer)
92 //load the onnx model to memory using the path m_path_to_onnx
93 {
94 std::unique_ptr< Ort::Env > env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL, "");
95
96 // initialize session options
97 Ort::SessionOptions session_options;
98 session_options.SetIntraOpNumThreads(1);
99
100 // Ignore all non-fatal errors. This isn't a good idea, but it's
101 // what we get for uploading semi-working graphs.
102 session_options.SetLogSeverityLevel(4);
103 session_options.SetGraphOptimizationLevel(
104 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
105 // this should reduce memory use while slowing things down slightly
106 // see
107 //
108 // https://github.com/microsoft/onnxruntime/issues/11627#issuecomment-1137668551
109 //
110 // and also https://its.cern.ch/jira/browse/AFT-818
111 //
112 session_options.DisableCpuMemArena();
113
114 // declare an allocator with default options
115 Ort::AllocatorWithDefaultOptions allocator;
116
117 // create session and load model into memory
118 std::unique_ptr< Ort::Session > session = std::make_unique<Ort::Session>(
119 *env, path_to_onnx.c_str(), session_options);
120
121 // get metadata from the onnx model
122 m_metadata = loadMetadata("gnn_config", session.get());
123 m_num_outputs = session->GetOutputCount();
124
125 // get the onnx model version
126 if (m_metadata.contains("onnx_model_version")) { // metadata version is explicitly set
127 m_onnx_model_version = m_metadata["onnx_model_version"].get<SaltModelVersion>();
129 throw std::runtime_error("Unknown Onnx model version!");
130 }
131 } else { // metadata version is not set, infer from the presence of "outputs" key
132 if (m_metadata.contains("outputs")){
134 } else {
135 throw std::runtime_error("Onnx model version not found in metadata");
136 }
137 }
138
139 // get the model name
140 m_model_type = determineModelType(session.get());
141
142 // iterate over output nodes and get their configuration
143 for (size_t i = 0; i < m_num_outputs; i++) {
144 const auto name = std::string(session->GetOutputNameAllocated(i, allocator).get());
145 const auto type = session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
146 const int rank = session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size();
148 const SaltModelOutput saltModelOutput(name, type, m_model_type);
149 m_output_nodes.push_back(std::move(saltModelOutput));
150 } else {
151 const SaltModelOutput saltModelOutput(name, type, rank);
152 m_output_nodes.push_back(std::move(saltModelOutput));
153 }
154 }
155
156 m_options = std::make_unique<tc::InferOptions>(m_model_name);
157 m_options->model_version_ = ""; // Fixme: find a proper solution for this
158 m_options->client_timeout_ = m_clientTimeout;
159 }
160
161 const nlohmann::json SaltModelTriton::loadMetadata(const std::string& key, const Ort::Session* session) const {
162 Ort::AllocatorWithDefaultOptions allocator;
163 Ort::ModelMetadata modelMetadata = session->GetModelMetadata();
164 std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get());
165 return nlohmann::json::parse(metadataString);
166 }
167
168 const std::string SaltModelTriton::determineModelType(const Ort::Session* session) const {
169 Ort::AllocatorWithDefaultOptions allocator;
171 // get the model name directly from the metadata
172 return std::string(m_metadata["outputs"].begin().key());
173 }
174 else {
175 // get the model name from the output node names
176 // each output node name is of the form "<model_type>_<output_name>"
177 std::set<std::string> model_types;
178 for (size_t i = 0; i < m_num_outputs; i++) {
179 const auto name = std::string(session->GetOutputNameAllocated(i, allocator).get());
180 size_t underscore_pos = name.find('_');
181 if (underscore_pos != std::string::npos) {
182 std::string substring = name.substr(0, underscore_pos);
183 model_types.insert(std::move(substring));
184 }
185 else {
186 return std::string("UnknownModelName");
187 }
188 }
189 if (model_types.size() != 1) {
190 throw std::runtime_error("SaltModelTriton: model names are not consistent between outputs");
191 }
192 return *model_types.begin();
193 }
194 }
195
199
203
207
208 const std::string& SaltModelTriton::getModelName() const {
209 return m_model_type;
210 }
211
212
214
215 // Create tensor for the input data
216 std::vector<std::shared_ptr<tc::InferInput> > inputs_;
217 inputs_.reserve(gnn_inputs.size());
218
219 for (auto& [inputName, inputInfo]: gnn_inputs) {
220 const std::vector<float>& inputData = inputInfo.first; // ? good name ?
221 const std::vector<int64_t>& inputShape = inputInfo.second;
222 if(!prepareInput<float>(inputName, inputShape, inputData, inputs_)) {
223 throw std::runtime_error("Failed to prepare input for inference"); // ? more informative error message ?
224 }
225 }
226
227 // construct raw points for inference
228 std::vector<tc::InferInput*> rawInputs;
229 for(auto& input : inputs_) {
230 rawInputs.push_back(input.get());
231 }
232
233 // perform the inference
234 tc::InferResult* rawResultPtr = nullptr;
235 tc::Headers http_headers;
236 if (!m_bearer.empty()) {
237 http_headers["authorization"] = "Bearer " + m_bearer;
238 }
239 grpc_compression_algorithm compression_algorithm = grpc_compression_algorithm::GRPC_COMPRESS_NONE;
240
241 auto client = getClient();
242 if(client) {
243 tc::Error err = client->Infer(&rawResultPtr
244 , *m_options
245 , rawInputs
246 , {}
247 , http_headers
248 , compression_algorithm);
249 if(!err.IsOk()) {
250 throw std::runtime_error("unable to run model "+ m_model_name + " error: " + err.Message());
251 }
252 }
253 else {
254 throw std::runtime_error("Failed to create Triton gRPC client");
255 }
256
257 // Get the result of the inference
258 InferenceOutput output;
259 std::shared_ptr<tc::InferResult> results(rawResultPtr);
260 for (size_t node_idx = 0; node_idx < m_output_nodes.size(); ++node_idx) {
261 const auto& output_node = m_output_nodes[node_idx];
262 switch(output_node.type) {
264 {
265 std::vector<float> outputVecFloat;
266 extractOutput<float>(output_node.name, results, outputVecFloat);
267 output.vecFloat[output_node.name] = std::move(outputVecFloat);
268 }
269 break;
271 {
272 std::vector<float> outputFloat;
273 extractOutput<float>(output_node.name, results, outputFloat);
274 if(outputFloat.size()==1) {
275 output.singleFloat[output_node.name] = outputFloat[0];
276 }
277 else {
278 throw std::runtime_error("Vector of floats returned instead of a single float for " + output_node.name);
279 }
280 }
281 break;
283 {
284 std::vector<int8_t> outputVecInt;
285 extractOutput<int8_t>(output_node.name, results, outputVecInt);
286 // convert int8_t vector to char vector
287 std::vector<char> outputVecChar(outputVecInt.begin(), outputVecInt.end());
288 output.vecChar[output_node.name] = std::move(outputVecChar);
289 }
290 break;
292 [[fallthrough]];
293 default:
294 throw std::runtime_error("Unknown output type for the node " + output_node.name);
295 }
296 }
297
298 return output;
299 }
300
301 tc::InferenceServerGrpcClient* SaltModelTriton::getClient() const
302 {
303 thread_local std::unique_ptr<tc::InferenceServerGrpcClient> threadClient;
304 if(!threadClient) {
305 std::string url = m_url + ":" + std::to_string(m_port);
306 tc::Error err = tc::InferenceServerGrpcClient::Create(&threadClient, url, false, m_useSSL);
307 if (!err.IsOk()) {
308 std::cerr << "SaltModelTriton ERROR: Failed to create Triton gRPC client for model: " << m_model_name
309 << " at URL: " << url << std::endl;
310 std::cerr << err.Message() << std::endl;
311 return nullptr;
312 }
313 }
314 return threadClient.get();
315 }
316} // end of FlavorTagInference namespace
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
bool extractOutput(const std::string &name, const std::shared_ptr< tc::InferResult > &result, std::vector< T > &outputVec)
bool prepareInput(const std::string &name, const std::vector< int64_t > &shape, const std::vector< T > &data, std::vector< std::shared_ptr< tc::InferInput > > &inputs)
size_t size() const
Number of registered mappings.
Define macros for attributes used to control the static checker.
tc::InferenceServerGrpcClient * getClient() const
const nlohmann::json loadMetadata(const std::string &key, const Ort::Session *session) const
virtual const OutputConfig & getOutputConfig() const override
virtual const SaltModelGraphConfig::GraphConfig getGraphConfig() const override
virtual InferenceOutput runInference(InputMap &gnn_inputs) const override
std::unique_ptr< tc::InferOptions > m_options
const std::string determineModelType(const Ort::Session *session) const
virtual SaltModelVersion getSaltModelVersion() const override
SaltModelTriton(const std::string &path_to_onnx, const std::string &model_name, float client_timeout, int port, const std::string &url, bool useSSL, const std::string &bearer="")
virtual const std::string & getModelName() const override
GraphConfig parse_json_graph(const nlohmann::json &metadata)
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::vector< SaltModelOutput > OutputConfig
Definition ISaltModel.h:38
std::map< std::string, Inputs, std::less<> > InputMap
Definition ISaltModel.h:37
static constexpr const char * value
static constexpr const char * value