23 const std::vector<int64_t>& shape,
24 const std::vector<T>&
data,
25 std::vector<std::shared_ptr<tc::InferInput>>& inputs)
28 tc::InferInput* rawInputPtr =
nullptr;
31 tc::Error err = tc::InferInput::Create(&rawInputPtr, name, shape, dtype);
33 std::cerr <<
"Unable to create input: " + name << std::endl;
42 std::shared_ptr<tc::InferInput> input(rawInputPtr);
43 err = input->AppendRaw(
reinterpret_cast<const uint8_t*
>(
data.data()),
44 data.size() *
sizeof(T));
46 std::cerr <<
"Unable to set input data for: " + name << std::endl;
50 inputs.push_back(std::move(input));
56 const std::shared_ptr<tc::InferResult>&
result,
57 std::vector<T>& outputVec)
59 const uint8_t* rawData =
nullptr;
65 tc::Error err =
result->RawData(name, &rawData, &size);
67 std::cerr <<
"Unable to get raw output for: " + name << std::endl;
71 outputVec.resize(size /
sizeof(T));
72 std::memcpy(outputVec.data(), rawData, size);
80 ,
const std::string& model_name
81 ,
float client_timeout
83 ,
const std::string& url
92 std::unique_ptr< Ort::Env > env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL,
"");
95 Ort::SessionOptions session_options;
96 session_options.SetIntraOpNumThreads(1);
100 session_options.SetLogSeverityLevel(4);
101 session_options.SetGraphOptimizationLevel(
102 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
110 session_options.DisableCpuMemArena();
113 Ort::AllocatorWithDefaultOptions allocator;
116 std::unique_ptr< Ort::Session > session = std::make_unique<Ort::Session>(
117 *env, path_to_onnx.c_str(), session_options);
124 if (
m_metadata.contains(
"onnx_model_version")) {
127 throw std::runtime_error(
"Unknown Onnx model version!");
133 throw std::runtime_error(
"Onnx model version not found in metadata");
142 const auto name = std::string(session->GetOutputNameAllocated(i, allocator).get());
143 const auto type = session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
144 const int rank = session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size();
167 Ort::AllocatorWithDefaultOptions allocator;
170 return std::string(
m_metadata[
"outputs"].begin().key());
175 std::set<std::string> model_types;
177 const auto name = std::string(session->GetOutputNameAllocated(i, allocator).get());
178 size_t underscore_pos = name.find(
'_');
179 if (underscore_pos != std::string::npos) {
180 std::string substring = name.substr(0, underscore_pos);
181 model_types.insert(std::move(substring));
184 return std::string(
"UnknownModelName");
187 if (model_types.size() != 1) {
188 throw std::runtime_error(
"SaltModelTriton: model names are not consistent between outputs");
190 return *model_types.begin();
212 std::map<std::string, Inputs>& gnn_inputs)
const {
215 std::vector<std::shared_ptr<tc::InferInput> > inputs_;
216 inputs_.reserve(gnn_inputs.size());
218 for (
auto& [inputName, inputInfo]: gnn_inputs) {
219 const std::vector<float>& inputData = inputInfo.first;
220 const std::vector<int64_t>& inputShape = inputInfo.second;
222 throw std::runtime_error(
"Failed to prepare input for inference");
227 std::vector<tc::InferInput*> rawInputs;
228 for(
auto& input : inputs_) {
229 rawInputs.push_back(input.get());
233 tc::InferResult* rawResultPtr =
nullptr;
234 tc::Headers http_headers;
235 grpc_compression_algorithm compression_algorithm = grpc_compression_algorithm::GRPC_COMPRESS_NONE;
239 tc::Error err = client->Infer(&rawResultPtr
244 , compression_algorithm);
246 throw std::runtime_error(
"unable to run model "+
m_model_name +
" error: " + err.Message());
250 throw std::runtime_error(
"Failed to create Triton gRPC client");
255 std::shared_ptr<tc::InferResult> results(rawResultPtr);
256 for (
size_t node_idx = 0; node_idx <
m_output_nodes.size(); ++node_idx) {
258 switch(output_node.type) {
261 std::vector<float> outputVecFloat;
263 output.vecFloat[output_node.name] = std::move(outputVecFloat);
268 std::vector<float> outputFloat;
270 if(outputFloat.size()==1) {
271 output.singleFloat[output_node.name] = outputFloat[0];
274 throw std::runtime_error(
"Vector of floats returned instead of a single float for " + output_node.name);
280 std::vector<int8_t> outputVecInt;
283 std::vector<char> outputVecChar(outputVecInt.begin(), outputVecInt.end());
284 output.vecChar[output_node.name] = std::move(outputVecChar);
290 throw std::runtime_error(
"Unknown output type for the node " + output_node.name);