ATLAS Offline Software
Loading...
Searching...
No Matches
SaltModelTriton.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
9#include <onnxruntime_cxx_api.h>
10
11#include <stdexcept>
12#include <tuple>
13#include <set>
14#include <memory>
15
16// DType traits for Triton
17template <typename T> struct TritonDType;
18template <> struct TritonDType<float> { static constexpr const char* value = "FP32"; };
19template <> struct TritonDType<int64_t> { static constexpr const char* value = "INT64"; };
20
21template <typename T>
22bool prepareInput(const std::string& name,
23 const std::vector<int64_t>& shape,
24 const std::vector<T>& data,
25 std::vector<std::shared_ptr<tc::InferInput>>& inputs)
26{
27 const char* dtype = TritonDType<T>::value;
28 tc::InferInput* rawInputPtr = nullptr;
29
30 // create the InferInput object with the predefined name, shape, and data type.
31 tc::Error err = tc::InferInput::Create(&rawInputPtr, name, shape, dtype);
32 if(!err.IsOk()) {
33 std::cerr << "Unable to create input: " + name << std::endl;
34 return false;
35 }
36
37 // Append tensor values for this input from a byte array.
38 // Note: The vector is not copied and so it must not be modified or destroyed
39 // until this input is no longer needed (that is until the Infer() call(s) that use the input have completed).
40 // Multiple calls can be made to this API to keep adding tensor data for this input.
41 // The data will be delivered in the order it was added.
42 std::shared_ptr<tc::InferInput> input(rawInputPtr);
43 err = input->AppendRaw(reinterpret_cast<const uint8_t*>(data.data()),
44 data.size() * sizeof(T));
45 if(!err.IsOk()) {
46 std::cerr << "Unable to set input data for: " + name << std::endl;
47 return false;
48 }
49
50 inputs.push_back(std::move(input));
51 return true;
52}
53
54template <typename T>
55bool extractOutput(const std::string& name,
56 const std::shared_ptr<tc::InferResult>& result,
57 std::vector<T>& outputVec)
58{
59 const uint8_t* rawData = nullptr;
60 size_t size = 0;
61
62 // Get access to the buffer holding raw results of specified output returned by the server.
63 // Note: the buffer is owned by InferResult instance.
64 // Users can copy out the data if required to extend the lifetime.
65 tc::Error err = result->RawData(name, &rawData, &size);
66 if(!err.IsOk()) {
67 std::cerr << "Unable to get raw output for: " + name << std::endl;
68 return false;
69 }
70
71 outputVec.resize(size / sizeof(T));
72 std::memcpy(outputVec.data(), rawData, size);
73 return true;
74}
75
76
77namespace FlavorTagInference {
78
79 SaltModelTriton::SaltModelTriton(const std::string& path_to_onnx
80 , const std::string& model_name
81 , float client_timeout
82 , int port
83 , const std::string& url
84 , bool useSSL
85 , const std::string& bearer)
86 : m_model_name(model_name)
87 , m_clientTimeout(client_timeout)
88 , m_port(port)
89 , m_url(url)
90 , m_useSSL(useSSL)
91 , m_bearer(bearer)
92 //load the onnx model to memory using the path m_path_to_onnx
93 {
94 std::unique_ptr< Ort::Env > env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL, "");
95
96 // initialize session options
97 Ort::SessionOptions session_options;
98 session_options.SetIntraOpNumThreads(1);
99
100 // Ignore all non-fatal errors. This isn't a good idea, but it's
101 // what we get for uploading semi-working graphs.
102 session_options.SetLogSeverityLevel(4);
103 session_options.SetGraphOptimizationLevel(
104 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
105 // this should reduce memory use while slowing things down slightly
106 // see
107 //
108 // https://github.com/microsoft/onnxruntime/issues/11627#issuecomment-1137668551
109 //
110 // and also https://its.cern.ch/jira/browse/AFT-818
111 //
112 session_options.DisableCpuMemArena();
113
114 // declare an allocator with default options
115 Ort::AllocatorWithDefaultOptions allocator;
116
117 // create session and load model into memory
118 std::unique_ptr< Ort::Session > session = std::make_unique<Ort::Session>(
119 *env, path_to_onnx.c_str(), session_options);
120
121 // get metadata from the onnx model
122 m_metadata = loadMetadata("gnn_config", session.get());
123 m_num_outputs = session->GetOutputCount();
124
125 // get the onnx model version
126 if (m_metadata.contains("onnx_model_version")) { // metadata version is explicitly set
127 m_onnx_model_version = m_metadata["onnx_model_version"].get<SaltModelVersion>();
129 throw std::runtime_error("Unknown Onnx model version!");
130 }
131 } else { // metadata version is not set, infer from the presence of "outputs" key
132 if (m_metadata.contains("outputs")){
134 } else {
135 throw std::runtime_error("Onnx model version not found in metadata");
136 }
137 }
138
139 // get the model name
140 m_model_type = determineModelType(session.get());
141
142 // iterate over output nodes and get their configuration
143 for (size_t i = 0; i < m_num_outputs; i++) {
144 const auto name = std::string(session->GetOutputNameAllocated(i, allocator).get());
145 const auto type = session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
146 const int rank = session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size();
148 const SaltModelOutput saltModelOutput(name, type, m_model_type);
149 m_output_nodes.push_back(std::move(saltModelOutput));
150 } else {
151 const SaltModelOutput saltModelOutput(name, type, rank);
152 m_output_nodes.push_back(std::move(saltModelOutput));
153 }
154 }
155
156 m_options = std::make_unique<tc::InferOptions>(m_model_name);
157 m_options->model_version_ = ""; // Fixme: find a proper solution for this
158 m_options->client_timeout_ = m_clientTimeout;
159 }
160
161 const nlohmann::json SaltModelTriton::loadMetadata(const std::string& key, const Ort::Session* session) const {
162 Ort::AllocatorWithDefaultOptions allocator;
163 Ort::ModelMetadata modelMetadata = session->GetModelMetadata();
164 std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(key.c_str(), allocator).get());
165 return nlohmann::json::parse(metadataString);
166 }
167
168 const std::string SaltModelTriton::determineModelType(const Ort::Session* session) const {
169 Ort::AllocatorWithDefaultOptions allocator;
171 // get the model name directly from the metadata
172 return std::string(m_metadata["outputs"].begin().key());
173 }
174 else {
175 // get the model name from the output node names
176 // each output node name is of the form "<model_type>_<output_name>"
177 std::set<std::string> model_types;
178 for (size_t i = 0; i < m_num_outputs; i++) {
179 const auto name = std::string(session->GetOutputNameAllocated(i, allocator).get());
180 size_t underscore_pos = name.find('_');
181 if (underscore_pos != std::string::npos) {
182 std::string substring = name.substr(0, underscore_pos);
183 model_types.insert(std::move(substring));
184 }
185 else {
186 return std::string("UnknownModelName");
187 }
188 }
189 if (model_types.size() != 1) {
190 throw std::runtime_error("SaltModelTriton: model names are not consistent between outputs");
191 }
192 return *model_types.begin();
193 }
194 }
195
199
203
207
208 const std::string& SaltModelTriton::getModelName() const {
209 return m_model_type;
210 }
211
212
214 std::map<std::string, Inputs>& gnn_inputs) const {
215
216 // Create tensor for the input data
217 std::vector<std::shared_ptr<tc::InferInput> > inputs_;
218 inputs_.reserve(gnn_inputs.size());
219
220 for (auto& [inputName, inputInfo]: gnn_inputs) {
221 const std::vector<float>& inputData = inputInfo.first; // ? good name ?
222 const std::vector<int64_t>& inputShape = inputInfo.second;
223 if(!prepareInput<float>(inputName, inputShape, inputData, inputs_)) {
224 throw std::runtime_error("Failed to prepare input for inference"); // ? more informative error message ?
225 }
226 }
227
228 // construct raw points for inference
229 std::vector<tc::InferInput*> rawInputs;
230 for(auto& input : inputs_) {
231 rawInputs.push_back(input.get());
232 }
233
234 // perform the inference
235 tc::InferResult* rawResultPtr = nullptr;
236 tc::Headers http_headers;
237 if (!m_bearer.empty()) {
238 http_headers["authorization"] = "Bearer " + m_bearer;
239 }
240 grpc_compression_algorithm compression_algorithm = grpc_compression_algorithm::GRPC_COMPRESS_NONE;
241
242 auto client = getClient();
243 if(client) {
244 tc::Error err = client->Infer(&rawResultPtr
245 , *m_options
246 , rawInputs
247 , {}
248 , http_headers
249 , compression_algorithm);
250 if(!err.IsOk()) {
251 throw std::runtime_error("unable to run model "+ m_model_name + " error: " + err.Message());
252 }
253 }
254 else {
255 throw std::runtime_error("Failed to create Triton gRPC client");
256 }
257
258 // Get the result of the inference
259 InferenceOutput output;
260 std::shared_ptr<tc::InferResult> results(rawResultPtr);
261 for (size_t node_idx = 0; node_idx < m_output_nodes.size(); ++node_idx) {
262 const auto& output_node = m_output_nodes[node_idx];
263 switch(output_node.type) {
265 {
266 std::vector<float> outputVecFloat;
267 extractOutput<float>(output_node.name, results, outputVecFloat);
268 output.vecFloat[output_node.name] = std::move(outputVecFloat);
269 }
270 break;
272 {
273 std::vector<float> outputFloat;
274 extractOutput<float>(output_node.name, results, outputFloat);
275 if(outputFloat.size()==1) {
276 output.singleFloat[output_node.name] = outputFloat[0];
277 }
278 else {
279 throw std::runtime_error("Vector of floats returned instead of a single float for " + output_node.name);
280 }
281 }
282 break;
284 {
285 std::vector<int8_t> outputVecInt;
286 extractOutput<int8_t>(output_node.name, results, outputVecInt);
287 // convert int8_t vector to char vector
288 std::vector<char> outputVecChar(outputVecInt.begin(), outputVecInt.end());
289 output.vecChar[output_node.name] = std::move(outputVecChar);
290 }
291 break;
293 [[fallthrough]];
294 default:
295 throw std::runtime_error("Unknown output type for the node " + output_node.name);
296 }
297 }
298
299 return output;
300 }
301
302 tc::InferenceServerGrpcClient* SaltModelTriton::getClient() const
303 {
304 thread_local std::unique_ptr<tc::InferenceServerGrpcClient> threadClient;
305 if(!threadClient) {
306 std::string url = m_url + ":" + std::to_string(m_port);
307 tc::Error err = tc::InferenceServerGrpcClient::Create(&threadClient, url, false, m_useSSL);
308 if (!err.IsOk()) {
309 std::cerr << "SaltModelTriton ERROR: Failed to create Triton gRPC client for model: " << m_model_name
310 << " at URL: " << url << std::endl;
311 std::cerr << err.Message() << std::endl;
312 return nullptr;
313 }
314 }
315 return threadClient.get();
316 }
317} // end of FlavorTagInference namespace
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
bool extractOutput(const std::string &name, const std::shared_ptr< tc::InferResult > &result, std::vector< T > &outputVec)
bool prepareInput(const std::string &name, const std::vector< int64_t > &shape, const std::vector< T > &data, std::vector< std::shared_ptr< tc::InferInput > > &inputs)
Define macros for attributes used to control the static checker.
tc::InferenceServerGrpcClient * getClient() const
virtual InferenceOutput runInference(std::map< std::string, Inputs > &gnn_inputs) const override
const nlohmann::json loadMetadata(const std::string &key, const Ort::Session *session) const
virtual const OutputConfig & getOutputConfig() const override
virtual const SaltModelGraphConfig::GraphConfig getGraphConfig() const override
std::unique_ptr< tc::InferOptions > m_options
const std::string determineModelType(const Ort::Session *session) const
virtual SaltModelVersion getSaltModelVersion() const override
SaltModelTriton(const std::string &path_to_onnx, const std::string &model_name, float client_timeout, int port, const std::string &url, bool useSSL, const std::string &bearer="")
virtual const std::string & getModelName() const override
GraphConfig parse_json_graph(const nlohmann::json &metadata)
This file contains "getter" functions used for accessing tagger inputs from the EDM.
std::vector< SaltModelOutput > OutputConfig
Definition ISaltModel.h:36
static constexpr const char * value
static constexpr const char * value