17 :
m_env (std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL,
""))
20 Ort::SessionOptions session_options;
21 session_options.SetIntraOpNumThreads(1);
25 session_options.SetLogSeverityLevel(4);
26 session_options.SetGraphOptimizationLevel(
27 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
30 Ort::AllocatorWithDefaultOptions allocator;
33 m_session = std::make_unique<Ort::Session>(
34 *
m_env, path_to_onnx.c_str(), session_options);
37 m_metadata = loadMetadata(
"gnn_config");
38 m_num_inputs =
m_session->GetInputCount();
39 m_num_outputs =
m_session->GetOutputCount();
42 if (m_metadata.contains(
"onnx_model_version")) {
43 m_onnx_model_version = m_metadata[
"onnx_model_version"].get<
SaltModelVersion>();
45 throw std::runtime_error(
"Unknown Onnx model version!");
48 if (m_metadata.contains(
"outputs")){
49 m_onnx_model_version = SaltModelVersion::V0;
51 throw std::runtime_error(
"Onnx model version not found in metadata");
56 m_model_name = determineModelName();
59 for (
size_t i = 0;
i < m_num_inputs;
i++) {
60 std::string input_name =
m_session->GetInputNameAllocated(
i, allocator).get();
65 for (
size_t i = 0;
i < m_num_outputs;
i++) {
66 const auto name = std::string(
m_session->GetOutputNameAllocated(
i, allocator).get());
67 const auto type =
m_session->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetElementType();
68 const int rank =
m_session->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetShape().size();
69 if (m_onnx_model_version == SaltModelVersion::V0) {
70 const SaltModelOutput saltModelOutput(
name,
type, m_model_name);
71 m_output_nodes.push_back(saltModelOutput);
73 const SaltModelOutput saltModelOutput(
name,
type, rank);
74 m_output_nodes.push_back(saltModelOutput);