7 #include "lwtnn/parse_json.hh"
17 : m_env (std::
make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL,
""))
20 Ort::SessionOptions session_options;
21 session_options.SetIntraOpNumThreads(1);
25 session_options.SetLogSeverityLevel(4);
26 session_options.SetGraphOptimizationLevel(
27 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
30 Ort::AllocatorWithDefaultOptions allocator;
33 m_session = std::make_unique<Ort::Session>(
34 *
m_env, path_to_onnx.c_str(), session_options);
37 m_metadata = loadMetadata(
"gnn_config");
38 m_num_inputs =
m_session->GetInputCount();
39 m_num_outputs =
m_session->GetOutputCount();
42 if (m_metadata.contains(
"onnx_model_version")) {
43 m_onnx_model_version = m_metadata[
"onnx_model_version"].get<
SaltModelVersion>();
45 throw std::runtime_error(
"Unknown Onnx model version!");
48 if (m_metadata.contains(
"outputs")){
49 m_onnx_model_version = SaltModelVersion::V0;
51 throw std::runtime_error(
"Onnx model version not found in metadata");
56 m_model_name = determineModelName();
59 for (
size_t i = 0;
i < m_num_inputs;
i++) {
60 std::string input_name =
m_session->GetInputNameAllocated(
i, allocator).get();
65 for (
size_t i = 0;
i < m_num_outputs;
i++) {
66 const auto name = std::string(
m_session->GetOutputNameAllocated(
i, allocator).get());
67 const auto type =
m_session->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetElementType();
68 const int rank =
m_session->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetShape().size();
69 if (m_onnx_model_version == SaltModelVersion::V0) {
70 const SaltModelOutput saltModelOutput(
name,
type, m_model_name);
71 m_output_nodes.push_back(saltModelOutput);
73 const SaltModelOutput saltModelOutput(
name,
type, rank);
74 m_output_nodes.push_back(saltModelOutput);
80 Ort::AllocatorWithDefaultOptions allocator;
81 Ort::ModelMetadata modelMetadata =
m_session->GetModelMetadata();
82 std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(
key.c_str(), allocator).get());
86 const std::string SaltModel::determineModelName()
const {
87 Ort::AllocatorWithDefaultOptions allocator;
88 if (m_onnx_model_version == SaltModelVersion::V0) {
90 return std::string(m_metadata[
"outputs"].
begin().
key());
94 std::set<std::string> model_names;
95 for (
size_t i = 0;
i < m_num_outputs;
i++) {
96 const auto name = std::string(
m_session->GetOutputNameAllocated(
i, allocator).get());
97 size_t underscore_pos =
name.find(
'_');
98 if (underscore_pos != std::string::npos) {
99 std::string substring =
name.substr(0, underscore_pos);
100 model_names.insert(substring);
102 return std::string(
"UnknownModelName");
105 if (model_names.size() != 1) {
106 throw std::runtime_error(
"SaltModel: model names are not consistent between outputs");
108 return *model_names.begin();
113 const lwt::GraphConfig SaltModel::getLwtConfig()
const {
121 if (getSaltModelVersion() != SaltModelVersion::V0){
124 std::stringstream metadataStream;
125 metadataStream << metadataCopy.dump();
133 const SaltModel::OutputConfig& SaltModel::getOutputConfig()
const {
134 return m_output_nodes;
138 return m_onnx_model_version;
141 const std::string& SaltModel::getModelName()
const {
147 std::map<std::string, Inputs>& gnn_inputs)
const {
149 std::vector<float> input_tensor_values;
152 auto memory_info = Ort::MemoryInfo::CreateCpu(
153 OrtArenaAllocator, OrtMemTypeDefault
155 std::vector<Ort::Value> input_tensors;
157 input_tensors.push_back(Ort::Value::CreateTensor<float>(
158 memory_info, gnn_inputs.at(node_name).first.data(), gnn_inputs.at(node_name).first.size(),
159 gnn_inputs.at(node_name).second.data(), gnn_inputs.at(node_name).second.size())
164 std::vector<const char*> input_node_names;
167 input_node_names.push_back(
name.c_str());
169 std::vector<const char*> output_node_names;
170 output_node_names.reserve(m_output_nodes.size());
171 for (
const auto&
node : m_output_nodes) {
172 output_node_names.push_back(
node.name_in_model.c_str());
180 auto output_tensors = session.Run(Ort::RunOptions{
nullptr},
181 input_node_names.data(), input_tensors.data(), input_node_names.size(),
182 output_node_names.data(), output_node_names.size()
187 for (
size_t node_idx = 0; node_idx < m_output_nodes.size(); ++node_idx) {
188 const auto& output_node = m_output_nodes[node_idx];
189 const auto& tensor = output_tensors[node_idx];
190 auto tensor_type = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementType();
191 auto tensor_shape = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape();
192 int length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
193 if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
194 if (tensor_shape.size() == 0) {
195 output.singleFloat[output_node.name] = *tensor.GetTensorData<
float>();
196 }
else if (tensor_shape.size() == 1) {
197 const float*
data = tensor.GetTensorData<
float>();
200 throw std::runtime_error(
"Unsupported tensor shape for FLOAT type");
202 }
else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
203 if (tensor_shape.size() == 1) {
204 const char*
data = tensor.GetTensorData<
char>();
207 throw std::runtime_error(
"Unsupported tensor shape for INT8 type");
210 throw std::runtime_error(
"Unsupported tensor type");