17 : m_env (std::make_unique<
Ort::Env>(ORT_LOGGING_LEVEL_FATAL,
""))
20 Ort::SessionOptions session_options;
21 session_options.SetIntraOpNumThreads(1);
25 session_options.SetLogSeverityLevel(4);
26 session_options.SetGraphOptimizationLevel(
27 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
35 session_options.DisableCpuMemArena();
38 Ort::AllocatorWithDefaultOptions allocator;
41 m_session = std::make_unique<Ort::Session>(
42 *
m_env, path_to_onnx.c_str(), session_options);
50 if (
m_metadata.contains(
"onnx_model_version")) {
53 throw std::runtime_error(
"Unknown Onnx model version!");
59 throw std::runtime_error(
"Onnx model version not found in metadata");
73 const auto name = std::string(
m_session->GetOutputNameAllocated(
i, allocator).get());
74 const auto type =
m_session->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetElementType();
75 const int rank =
m_session->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetShape().size();
85 Ort::AllocatorWithDefaultOptions allocator;
86 Ort::ModelMetadata modelMetadata =
m_session->GetModelMetadata();
87 std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(
key.c_str(), allocator).get());
92 Ort::AllocatorWithDefaultOptions allocator;
99 std::set<std::string> model_names;
101 const auto name = std::string(
m_session->GetOutputNameAllocated(
i, allocator).get());
102 size_t underscore_pos =
name.find(
'_');
103 if (underscore_pos != std::string::npos) {
104 model_names.insert(
name.substr(0, underscore_pos));
106 return std::string(
"UnknownModelName");
109 if (model_names.size() != 1) {
110 throw std::runtime_error(
"SaltModel: model names are not consistent between outputs");
112 return *model_names.begin();
135 std::map<std::string, Inputs>& gnn_inputs)
const {
137 std::vector<float> input_tensor_values;
140 auto memory_info = Ort::MemoryInfo::CreateCpu(
141 OrtArenaAllocator, OrtMemTypeDefault
143 std::vector<Ort::Value> input_tensors;
145 input_tensors.push_back(Ort::Value::CreateTensor<float>(
146 memory_info, gnn_inputs.at(node_name).first.data(), gnn_inputs.at(node_name).first.size(),
147 gnn_inputs.at(node_name).second.data(), gnn_inputs.at(node_name).second.size())
152 std::vector<const char*> input_node_names;
155 input_node_names.push_back(
name.c_str());
157 std::vector<const char*> output_node_names;
160 output_node_names.push_back(
node.name_in_model.c_str());
168 auto output_tensors = session.Run(Ort::RunOptions{
nullptr},
169 input_node_names.data(), input_tensors.data(), input_node_names.size(),
170 output_node_names.data(), output_node_names.size()
175 for (
size_t node_idx = 0; node_idx <
m_output_nodes.size(); ++node_idx) {
177 const auto& tensor = output_tensors[node_idx];
178 auto tensor_type = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementType();
179 auto tensor_shape = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape();
180 int length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
181 if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
182 if (tensor_shape.size() == 0) {
183 output.singleFloat[output_node.name] = *tensor.GetTensorData<
float>();
184 }
else if (tensor_shape.size() == 1) {
185 const float*
data = tensor.GetTensorData<
float>();
186 output.vecFloat[output_node.name] = std::vector<float>(
data,
data +
length);
188 throw std::runtime_error(
"Unsupported tensor shape for FLOAT type");
190 }
else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
191 if (tensor_shape.size() == 1) {
192 const char*
data = tensor.GetTensorData<
char>();
193 output.vecChar[output_node.name] = std::vector<char>(
data,
data +
length);
195 throw std::runtime_error(
"Unsupported tensor shape for INT8 type");
198 throw std::runtime_error(
"Unsupported tensor type");