23 const std::vector<int64_t>& shape,
24 const std::vector<T>&
data,
25 std::vector<std::shared_ptr<tc::InferInput>>& inputs)
28 tc::InferInput* rawInputPtr =
nullptr;
31 tc::Error err = tc::InferInput::Create(&rawInputPtr, name, shape, dtype);
33 std::cerr <<
"Unable to create input: " + name << std::endl;
42 std::shared_ptr<tc::InferInput> input(rawInputPtr);
43 err = input->AppendRaw(
reinterpret_cast<const uint8_t*
>(
data.data()),
44 data.size() *
sizeof(T));
46 std::cerr <<
"Unable to set input data for: " + name << std::endl;
50 inputs.push_back(std::move(input));
56 const std::shared_ptr<tc::InferResult>& result,
57 std::vector<T>& outputVec)
59 const uint8_t* rawData =
nullptr;
65 tc::Error err = result->RawData(name, &rawData, &
size);
67 std::cerr <<
"Unable to get raw output for: " + name << std::endl;
71 outputVec.resize(
size /
sizeof(T));
72 std::memcpy(outputVec.data(), rawData,
size);
80 ,
const std::string& model_name
81 ,
float client_timeout
83 ,
const std::string& url
85 ,
const std::string& bearer)
94 std::unique_ptr< Ort::Env > env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL,
"");
97 Ort::SessionOptions session_options;
98 session_options.SetIntraOpNumThreads(1);
102 session_options.SetLogSeverityLevel(4);
103 session_options.SetGraphOptimizationLevel(
104 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
112 session_options.DisableCpuMemArena();
115 Ort::AllocatorWithDefaultOptions allocator;
118 std::unique_ptr< Ort::Session > session = std::make_unique<Ort::Session>(
119 *env, path_to_onnx.c_str(), session_options);
126 if (
m_metadata.contains(
"onnx_model_version")) {
129 throw std::runtime_error(
"Unknown Onnx model version!");
135 throw std::runtime_error(
"Onnx model version not found in metadata");
144 const auto name = std::string(session->GetOutputNameAllocated(i, allocator).get());
145 const auto type = session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
146 const int rank = session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size();
169 Ort::AllocatorWithDefaultOptions allocator;
172 return std::string(
m_metadata[
"outputs"].begin().key());
177 std::set<std::string> model_types;
179 const auto name = std::string(session->GetOutputNameAllocated(i, allocator).get());
180 size_t underscore_pos = name.find(
'_');
181 if (underscore_pos != std::string::npos) {
182 std::string substring = name.substr(0, underscore_pos);
183 model_types.insert(std::move(substring));
186 return std::string(
"UnknownModelName");
189 if (model_types.size() != 1) {
190 throw std::runtime_error(
"SaltModelTriton: model names are not consistent between outputs");
192 return *model_types.begin();
216 std::vector<std::shared_ptr<tc::InferInput> > inputs_;
217 inputs_.reserve(gnn_inputs.size());
219 for (
auto& [inputName, inputInfo]: gnn_inputs) {
220 const std::vector<float>& inputData = inputInfo.first;
221 const std::vector<int64_t>& inputShape = inputInfo.second;
223 throw std::runtime_error(
"Failed to prepare input for inference");
228 std::vector<tc::InferInput*> rawInputs;
229 for(
auto& input : inputs_) {
230 rawInputs.push_back(input.get());
234 tc::InferResult* rawResultPtr =
nullptr;
235 tc::Headers http_headers;
237 http_headers[
"authorization"] =
"Bearer " +
m_bearer;
239 grpc_compression_algorithm compression_algorithm = grpc_compression_algorithm::GRPC_COMPRESS_NONE;
243 tc::Error err = client->Infer(&rawResultPtr
248 , compression_algorithm);
250 throw std::runtime_error(
"unable to run model "+
m_model_name +
" error: " + err.Message());
254 throw std::runtime_error(
"Failed to create Triton gRPC client");
259 std::shared_ptr<tc::InferResult> results(rawResultPtr);
260 for (
size_t node_idx = 0; node_idx <
m_output_nodes.size(); ++node_idx) {
262 switch(output_node.type) {
265 std::vector<float> outputVecFloat;
267 output.vecFloat[output_node.name] = std::move(outputVecFloat);
272 std::vector<float> outputFloat;
274 if(outputFloat.size()==1) {
275 output.singleFloat[output_node.name] = outputFloat[0];
278 throw std::runtime_error(
"Vector of floats returned instead of a single float for " + output_node.name);
284 std::vector<int8_t> outputVecInt;
287 std::vector<char> outputVecChar(outputVecInt.begin(), outputVecInt.end());
288 output.vecChar[output_node.name] = std::move(outputVecChar);
294 throw std::runtime_error(
"Unknown output type for the node " + output_node.name);
SaltModelTriton(const std::string &path_to_onnx, const std::string &model_name, float client_timeout, int port, const std::string &url, bool useSSL, const std::string &bearer="")