17 : m_env (std::
make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL,
""))
20 Ort::SessionOptions session_options;
21 session_options.SetIntraOpNumThreads(1);
25 session_options.SetLogSeverityLevel(4);
26 session_options.SetGraphOptimizationLevel(
27 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
35 session_options.DisableCpuMemArena();
38 Ort::AllocatorWithDefaultOptions allocator;
41 m_session = std::make_unique<Ort::Session>(
42 *
m_env, path_to_onnx.c_str(), session_options);
45 m_metadata = loadMetadata(
"gnn_config");
46 m_num_inputs =
m_session->GetInputCount();
47 m_num_outputs =
m_session->GetOutputCount();
50 if (m_metadata.contains(
"onnx_model_version")) {
51 m_onnx_model_version = m_metadata[
"onnx_model_version"].get<
SaltModelVersion>();
53 throw std::runtime_error(
"Unknown Onnx model version!");
56 if (m_metadata.contains(
"outputs")){
57 m_onnx_model_version = SaltModelVersion::V0;
59 throw std::runtime_error(
"Onnx model version not found in metadata");
64 m_model_name = determineModelName();
67 for (
size_t i = 0;
i < m_num_inputs;
i++) {
68 std::string input_name =
m_session->GetInputNameAllocated(
i, allocator).get();
73 for (
size_t i = 0;
i < m_num_outputs;
i++) {
74 const auto name = std::string(
m_session->GetOutputNameAllocated(
i, allocator).get());
75 const auto type =
m_session->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetElementType();
76 const int rank =
m_session->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetShape().size();
77 if (m_onnx_model_version == SaltModelVersion::V0) {
78 const SaltModelOutput saltModelOutput(
name,
type, m_model_name);
79 m_output_nodes.push_back(saltModelOutput);
81 const SaltModelOutput saltModelOutput(
name,
type, rank);
82 m_output_nodes.push_back(saltModelOutput);
88 Ort::AllocatorWithDefaultOptions allocator;
89 Ort::ModelMetadata modelMetadata =
m_session->GetModelMetadata();
90 std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(
key.c_str(), allocator).get());
94 const std::string SaltModel::determineModelName()
const {
95 Ort::AllocatorWithDefaultOptions allocator;
96 if (m_onnx_model_version == SaltModelVersion::V0) {
98 return std::string(m_metadata[
"outputs"].
begin().
key());
102 std::set<std::string> model_names;
103 for (
size_t i = 0;
i < m_num_outputs;
i++) {
104 const auto name = std::string(
m_session->GetOutputNameAllocated(
i, allocator).get());
105 size_t underscore_pos =
name.find(
'_');
106 if (underscore_pos != std::string::npos) {
107 std::string substring =
name.substr(0, underscore_pos);
108 model_names.insert(substring);
110 return std::string(
"UnknownModelName");
113 if (model_names.size() != 1) {
114 throw std::runtime_error(
"SaltModel: model names are not consistent between outputs");
116 return *model_names.begin();
121 const SaltModelGraphConfig::GraphConfig SaltModel::getGraphConfig()
const {
129 const SaltModel::OutputConfig& SaltModel::getOutputConfig()
const {
130 return m_output_nodes;
134 return m_onnx_model_version;
137 const std::string& SaltModel::getModelName()
const {
143 std::map<std::string, Inputs>& gnn_inputs)
const {
145 std::vector<float> input_tensor_values;
148 auto memory_info = Ort::MemoryInfo::CreateCpu(
149 OrtArenaAllocator, OrtMemTypeDefault
151 std::vector<Ort::Value> input_tensors;
153 input_tensors.push_back(Ort::Value::CreateTensor<float>(
154 memory_info, gnn_inputs.at(node_name).first.data(), gnn_inputs.at(node_name).first.size(),
155 gnn_inputs.at(node_name).second.data(), gnn_inputs.at(node_name).second.size())
160 std::vector<const char*> input_node_names;
163 input_node_names.push_back(
name.c_str());
165 std::vector<const char*> output_node_names;
166 output_node_names.reserve(m_output_nodes.size());
167 for (
const auto&
node : m_output_nodes) {
168 output_node_names.push_back(
node.name_in_model.c_str());
176 auto output_tensors = session.Run(Ort::RunOptions{
nullptr},
177 input_node_names.data(), input_tensors.data(), input_node_names.size(),
178 output_node_names.data(), output_node_names.size()
183 for (
size_t node_idx = 0; node_idx < m_output_nodes.size(); ++node_idx) {
184 const auto& output_node = m_output_nodes[node_idx];
185 const auto& tensor = output_tensors[node_idx];
186 auto tensor_type = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementType();
187 auto tensor_shape = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape();
188 int length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
189 if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
190 if (tensor_shape.size() == 0) {
191 output.singleFloat[output_node.name] = *tensor.GetTensorData<
float>();
192 }
else if (tensor_shape.size() == 1) {
193 const float*
data = tensor.GetTensorData<
float>();
196 throw std::runtime_error(
"Unsupported tensor shape for FLOAT type");
198 }
else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
199 if (tensor_shape.size() == 1) {
200 const char*
data = tensor.GetTensorData<
char>();
203 throw std::runtime_error(
"Unsupported tensor shape for INT8 type");
206 throw std::runtime_error(
"Unsupported tensor type");