9 #include <onnxruntime_cxx_api.h>
19 template <>
struct TritonDType<int64_t> {
static constexpr
const char*
value =
"INT64"; };
23 const std::vector<int64_t>& shape,
24 const std::vector<T>&
data,
25 std::vector<std::shared_ptr<tc::InferInput>>&
inputs)
28 tc::InferInput* rawInputPtr =
nullptr;
33 std::cerr <<
"Unable to create input: " +
name << std::endl;
42 std::shared_ptr<tc::InferInput>
input(rawInputPtr);
44 data.size() *
sizeof(
T));
46 std::cerr <<
"Unable to set input data for: " +
name << std::endl;
56 const std::shared_ptr<tc::InferResult>&
result,
57 std::vector<T>& outputVec)
59 const uint8_t* rawData =
nullptr;
67 std::cerr <<
"Unable to get raw output for: " +
name << std::endl;
71 outputVec.resize(
size /
sizeof(
T));
72 std::memcpy(outputVec.data(), rawData,
size);
80 ,
const std::string& model_name
81 ,
float client_timeout
83 ,
const std::string&
url
85 : m_model_name(model_name)
86 , m_clientTimeout(client_timeout)
92 std::unique_ptr< Ort::Env >
env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL,
"");
95 Ort::SessionOptions session_options;
96 session_options.SetIntraOpNumThreads(1);
100 session_options.SetLogSeverityLevel(4);
101 session_options.SetGraphOptimizationLevel(
102 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
110 session_options.DisableCpuMemArena();
113 Ort::AllocatorWithDefaultOptions allocator;
116 std::unique_ptr< Ort::Session > session = std::make_unique<Ort::Session>(
117 *
env, path_to_onnx.c_str(), session_options);
124 if (
m_metadata.contains(
"onnx_model_version")) {
127 throw std::runtime_error(
"Unknown Onnx model version!");
133 throw std::runtime_error(
"Onnx model version not found in metadata");
142 const auto name = std::string(session->GetOutputNameAllocated(
i, allocator).get());
143 const auto type = session->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetElementType();
144 const int rank = session->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetShape().size();
160 Ort::AllocatorWithDefaultOptions allocator;
161 Ort::ModelMetadata modelMetadata = session->GetModelMetadata();
162 std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(
key.c_str(), allocator).get());
167 Ort::AllocatorWithDefaultOptions allocator;
175 std::set<std::string> model_types;
177 const auto name = std::string(session->GetOutputNameAllocated(
i, allocator).get());
178 size_t underscore_pos =
name.find(
'_');
179 if (underscore_pos != std::string::npos) {
180 std::string substring =
name.substr(0, underscore_pos);
181 model_types.insert(std::move(substring));
184 return std::string(
"UnknownModelName");
187 if (model_types.size() != 1) {
188 throw std::runtime_error(
"SaltModelTriton: model names are not consistent between outputs");
190 return *model_types.begin();
212 std::map<std::string, Inputs>& gnn_inputs)
const {
215 std::vector<std::shared_ptr<tc::InferInput> > inputs_;
216 inputs_.reserve(gnn_inputs.size());
218 for (
auto& [inputName, inputInfo]: gnn_inputs) {
219 const std::vector<float>& inputData = inputInfo.first;
220 const std::vector<int64_t>& inputShape = inputInfo.second;
221 if(!prepareInput<float>(inputName, inputShape, inputData, inputs_)) {
222 throw std::runtime_error(
"Failed to prepare input for inference");
227 std::vector<tc::InferInput*> rawInputs;
228 for(
auto& input : inputs_) {
229 rawInputs.push_back(input.get());
233 tc::InferResult* rawResultPtr =
nullptr;
234 tc::Headers http_headers;
235 grpc_compression_algorithm compression_algorithm = grpc_compression_algorithm::GRPC_COMPRESS_NONE;
244 , compression_algorithm);
246 throw std::runtime_error(
"unable to run model "+
m_model_name +
" error: " +
err.Message());
250 throw std::runtime_error(
"Failed to create Triton gRPC client");
255 std::shared_ptr<tc::InferResult>
results(rawResultPtr);
256 for (
size_t node_idx = 0; node_idx <
m_output_nodes.size(); ++node_idx) {
258 switch(output_node.type) {
261 std::vector<float> outputVecFloat;
262 extractOutput<float>(output_node.name,
results, outputVecFloat);
263 output.vecFloat[output_node.name] = std::move(outputVecFloat);
268 std::vector<float> outputFloat;
269 extractOutput<float>(output_node.name,
results, outputFloat);
270 if(outputFloat.size()==1) {
271 output.singleFloat[output_node.name] = outputFloat[0];
274 throw std::runtime_error(
"Vector of floats returned instead of a single float for " + output_node.name);
280 std::vector<int8_t> outputVecInt;
281 extractOutput<int8_t>(output_node.name,
results, outputVecInt);
283 std::vector<char> outputVecChar(outputVecInt.begin(), outputVecInt.end());
284 output.vecChar[output_node.name] = std::move(outputVecChar);
290 throw std::runtime_error(
"Unknown output type for the node " + output_node.name);
299 thread_local std::unique_ptr<tc::InferenceServerGrpcClient> threadClient;
304 std::cerr <<
"SaltModelTriton ERROR: Failed to create Triton gRPC client for model: " <<
m_model_name
305 <<
" at URL: " <<
url << std::endl;
306 std::cerr <<
err.Message() << std::endl;
310 return threadClient.get();