17 :
m_env (
std::make_unique<
Ort::Env>(ORT_LOGGING_LEVEL_FATAL,
""))
20 Ort::SessionOptions session_options;
21 session_options.SetIntraOpNumThreads(1);
25 session_options.SetLogSeverityLevel(4);
26 session_options.SetGraphOptimizationLevel(
27 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
35 session_options.DisableCpuMemArena();
38 Ort::AllocatorWithDefaultOptions allocator;
41 m_session = std::make_unique<Ort::Session>(
42 *
m_env, path_to_onnx.c_str(), session_options);
50 if (
m_metadata.contains(
"onnx_model_version")) {
53 throw std::runtime_error(
"Unknown Onnx model version!");
59 throw std::runtime_error(
"Onnx model version not found in metadata");
73 const auto name = std::string(
m_session->GetOutputNameAllocated(i, allocator).get());
74 const auto type =
m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetElementType();
75 const int rank =
m_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape().size();
92 Ort::AllocatorWithDefaultOptions allocator;
95 return std::string(
m_metadata[
"outputs"].begin().key());
99 std::set<std::string> model_names;
101 const auto name = std::string(
m_session->GetOutputNameAllocated(i, allocator).get());
102 size_t underscore_pos = name.find(
'_');
103 if (underscore_pos != std::string::npos) {
104 model_names.insert(name.substr(0, underscore_pos));
106 return std::string(
"");
109 if (model_names.size() != 1) {
110 throw std::runtime_error(
"SaltModel: model names are not consistent between outputs");
112 return *model_names.begin();
136 std::vector<float> input_tensor_values;
139 auto memory_info = Ort::MemoryInfo::CreateCpu(
140 OrtArenaAllocator, OrtMemTypeDefault
142 std::vector<Ort::Value> input_tensors;
144 input_tensors.push_back(Ort::Value::CreateTensor<float>(
145 memory_info, gnn_inputs.at(node_name).first.data(), gnn_inputs.at(node_name).first.size(),
146 gnn_inputs.at(node_name).second.data(), gnn_inputs.at(node_name).second.size())
151 std::vector<const char*> input_node_names;
154 input_node_names.push_back(name.c_str());
156 std::vector<const char*> output_node_names;
159 output_node_names.push_back(
node.name_in_model.c_str());
167 auto output_tensors = session.Run(Ort::RunOptions{
nullptr},
168 input_node_names.data(), input_tensors.data(), input_node_names.size(),
169 output_node_names.data(), output_node_names.size()
174 for (
size_t node_idx = 0; node_idx <
m_output_nodes.size(); ++node_idx) {
176 const auto& tensor = output_tensors[node_idx];
177 auto tensor_type = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementType();
178 auto tensor_shape = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape();
179 int length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
180 if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
181 if (tensor_shape.size() == 0) {
182 output.singleFloat[output_node.name] = *tensor.GetTensorData<
float>();
183 }
else if (tensor_shape.size() == 1) {
184 const float*
data = tensor.GetTensorData<
float>();
185 output.vecFloat[output_node.name] = std::vector<float>(
data,
data +
length);
187 throw std::runtime_error(
"Unsupported tensor shape for FLOAT type");
189 }
else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
190 if (tensor_shape.size() == 1) {
191 const char*
data = tensor.GetTensorData<
char>();
192 output.vecChar[output_node.name] = std::vector<char>(
data,
data +
length);
194 throw std::runtime_error(
"Unsupported tensor shape for INT8 type");
197 throw std::runtime_error(
"Unsupported tensor type");