23 const std::vector<int64_t>& shape,
24 const std::vector<T>&
data,
25 std::vector<std::shared_ptr<tc::InferInput>>& inputs)
28 tc::InferInput* rawInputPtr =
nullptr;
31 tc::Error err = tc::InferInput::Create(&rawInputPtr, name, shape, dtype);
33 std::cerr <<
"Unable to create input: " + name << std::endl;
42 std::shared_ptr<tc::InferInput> input(rawInputPtr);
43 err = input->AppendRaw(
reinterpret_cast<const uint8_t*
>(
data.data()),
44 data.size() *
sizeof(T));
46 std::cerr <<
"Unable to set input data for: " + name << std::endl;
50 inputs.push_back(std::move(input));
56 const std::shared_ptr<tc::InferResult>& result,
57 std::vector<T>& outputVec)
59 const uint8_t* rawData =
nullptr;
65 tc::Error err = result->RawData(name, &rawData, &size);
67 std::cerr <<
"Unable to get raw output for: " + name << std::endl;
71 outputVec.resize(size /
sizeof(T));
72 std::memcpy(outputVec.data(), rawData, size);
80 ,
const std::string& model_name
81 ,
float client_timeout
83 ,
const std::string& url
85 ,
const std::string& bearer)
94 std::unique_ptr< Ort::Env > env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL,
"");
97 Ort::SessionOptions session_options;
98 session_options.SetIntraOpNumThreads(1);
102 session_options.SetLogSeverityLevel(4);
103 session_options.SetGraphOptimizationLevel(
104 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
112 session_options.DisableCpuMemArena();
115 Ort::AllocatorWithDefaultOptions allocator;
118 std::unique_ptr< Ort::Session > session = std::make_unique<Ort::Session>(
119 *env, path_to_onnx.c_str(), session_options);
126 if (
m_metadata.contains(
"onnx_model_version")) {
129 throw std::runtime_error(
"Unknown Onnx model version!");
135 throw std::runtime_error(
"Onnx model version not found in metadata");
144 const auto name = std::string(session->GetOutputNameAllocated(i, allocator).get());
145 const auto type = session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
146 const int rank = session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size();
169 Ort::AllocatorWithDefaultOptions allocator;
172 return std::string(
m_metadata[
"outputs"].begin().key());
177 std::set<std::string> model_types;
179 const auto name = std::string(session->GetOutputNameAllocated(i, allocator).get());
180 size_t underscore_pos = name.find(
'_');
181 if (underscore_pos != std::string::npos) {
182 std::string substring = name.substr(0, underscore_pos);
183 model_types.insert(std::move(substring));
186 return std::string(
"UnknownModelName");
189 if (model_types.size() != 1) {
190 throw std::runtime_error(
"SaltModelTriton: model names are not consistent between outputs");
192 return *model_types.begin();
214 std::map<std::string, Inputs>& gnn_inputs)
const {
217 std::vector<std::shared_ptr<tc::InferInput> > inputs_;
218 inputs_.reserve(gnn_inputs.size());
220 for (
auto& [inputName, inputInfo]: gnn_inputs) {
221 const std::vector<float>& inputData = inputInfo.first;
222 const std::vector<int64_t>& inputShape = inputInfo.second;
224 throw std::runtime_error(
"Failed to prepare input for inference");
229 std::vector<tc::InferInput*> rawInputs;
230 for(
auto& input : inputs_) {
231 rawInputs.push_back(input.get());
235 tc::InferResult* rawResultPtr =
nullptr;
236 tc::Headers http_headers;
238 http_headers[
"authorization"] =
"Bearer " +
m_bearer;
240 grpc_compression_algorithm compression_algorithm = grpc_compression_algorithm::GRPC_COMPRESS_NONE;
244 tc::Error err = client->Infer(&rawResultPtr
249 , compression_algorithm);
251 throw std::runtime_error(
"unable to run model "+
m_model_name +
" error: " + err.Message());
255 throw std::runtime_error(
"Failed to create Triton gRPC client");
260 std::shared_ptr<tc::InferResult> results(rawResultPtr);
261 for (
size_t node_idx = 0; node_idx <
m_output_nodes.size(); ++node_idx) {
263 switch(output_node.type) {
266 std::vector<float> outputVecFloat;
268 output.vecFloat[output_node.name] = std::move(outputVecFloat);
273 std::vector<float> outputFloat;
275 if(outputFloat.size()==1) {
276 output.singleFloat[output_node.name] = outputFloat[0];
279 throw std::runtime_error(
"Vector of floats returned instead of a single float for " + output_node.name);
285 std::vector<int8_t> outputVecInt;
288 std::vector<char> outputVecChar(outputVecInt.begin(), outputVecInt.end());
289 output.vecChar[output_node.name] = std::move(outputVecChar);
295 throw std::runtime_error(
"Unknown output type for the node " + output_node.name);
SaltModelTriton(const std::string &path_to_onnx, const std::string &model_name, float client_timeout, int port, const std::string &url, bool useSSL, const std::string &bearer="")