8 #include "lwtnn/parse_json.hh"
18 : m_env (std::
make_unique<Ort::Env>(ORT_LOGGING_LEVEL_FATAL,
""))
21 Ort::SessionOptions session_options;
22 session_options.SetIntraOpNumThreads(1);
26 session_options.SetLogSeverityLevel(4);
27 session_options.SetGraphOptimizationLevel(
28 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
31 Ort::AllocatorWithDefaultOptions allocator;
34 m_session = std::make_unique<Ort::Session>(
35 *
m_env, path_to_onnx.c_str(), session_options);
38 m_metadata = loadMetadata(
"gnn_config");
39 m_num_inputs =
m_session->GetInputCount();
40 m_num_outputs =
m_session->GetOutputCount();
43 if (m_metadata.contains(
"onnx_model_version")) {
44 m_onnx_model_version = m_metadata[
"onnx_model_version"].get<
OnnxModelVersion>();
46 throw std::runtime_error(
"Unknown Onnx model version!");
49 if (m_metadata.contains(
"outputs")){
50 m_onnx_model_version = OnnxModelVersion::V0;
52 throw std::runtime_error(
"Onnx model version not found in metadata");
57 m_model_name = determineModelName();
60 for (
size_t i = 0;
i < m_num_inputs;
i++) {
61 std::string input_name =
m_session->GetInputNameAllocated(
i, allocator).get();
66 for (
size_t i = 0;
i < m_num_outputs;
i++) {
67 const auto name = std::string(
m_session->GetOutputNameAllocated(
i, allocator).get());
68 const auto type =
m_session->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetElementType();
69 const int rank =
m_session->GetOutputTypeInfo(
i).GetTensorTypeAndShapeInfo().GetShape().size();
70 if (m_onnx_model_version == OnnxModelVersion::V0) {
71 const OnnxOutput onnxOutput(
name,
type, m_model_name);
72 m_output_nodes.push_back(onnxOutput);
74 const OnnxOutput onnxOutput(
name,
type, rank);
75 m_output_nodes.push_back(onnxOutput);
81 Ort::AllocatorWithDefaultOptions allocator;
82 Ort::ModelMetadata modelMetadata =
m_session->GetModelMetadata();
83 std::string metadataString(modelMetadata.LookupCustomMetadataMapAllocated(
key.c_str(), allocator).get());
87 const std::string OnnxUtil::determineModelName()
const {
88 Ort::AllocatorWithDefaultOptions allocator;
89 if (m_onnx_model_version == OnnxModelVersion::V0) {
91 return std::string(m_metadata[
"outputs"].
begin().
key());
95 std::set<std::string> model_names;
96 for (
size_t i = 0;
i < m_num_outputs;
i++) {
97 const auto name = std::string(
m_session->GetOutputNameAllocated(
i, allocator).get());
98 size_t underscore_pos =
name.find(
'_');
99 if (underscore_pos != std::string::npos) {
103 return std::string(
"UnknownModelName");
106 if (model_names.size() != 1) {
107 throw std::runtime_error(
"OnnxUtil: model names are not consistent between outputs");
109 return *model_names.begin();
114 const lwt::GraphConfig OnnxUtil::getLwtConfig()
const {
122 if (getOnnxModelVersion() != OnnxModelVersion::V0){
125 std::stringstream metadataStream;
126 metadataStream << metadataCopy.dump();
134 const OnnxUtil::OutputConfig& OnnxUtil::getOutputConfig()
const {
135 return m_output_nodes;
139 return m_onnx_model_version;
142 const std::string& OnnxUtil::getModelName()
const {
148 std::map<std::string, Inputs>& gnn_inputs)
const {
150 std::vector<float> input_tensor_values;
153 auto memory_info = Ort::MemoryInfo::CreateCpu(
154 OrtArenaAllocator, OrtMemTypeDefault
156 std::vector<Ort::Value> input_tensors;
158 input_tensors.push_back(Ort::Value::CreateTensor<float>(
159 memory_info, gnn_inputs.at(node_name).first.data(), gnn_inputs.at(node_name).first.size(),
160 gnn_inputs.at(node_name).second.data(), gnn_inputs.at(node_name).second.size())
165 std::vector<const char*> input_node_names;
168 input_node_names.push_back(
name.c_str());
170 std::vector<const char*> output_node_names;
171 output_node_names.reserve(m_output_nodes.size());
172 for (
const auto&
node : m_output_nodes) {
173 output_node_names.push_back(
node.name_in_model.c_str());
181 auto output_tensors = session.Run(Ort::RunOptions{
nullptr},
182 input_node_names.data(), input_tensors.data(), input_node_names.size(),
183 output_node_names.data(), output_node_names.size()
188 for (
size_t node_idx = 0; node_idx < m_output_nodes.size(); ++node_idx) {
189 const auto& output_node = m_output_nodes[node_idx];
190 const auto& tensor = output_tensors[node_idx];
191 auto tensor_type = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementType();
192 auto tensor_shape = tensor.GetTypeInfo().GetTensorTypeAndShapeInfo().GetShape();
193 int length = tensor.GetTensorTypeAndShapeInfo().GetElementCount();
194 if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
195 if (tensor_shape.size() == 0) {
196 output.singleFloat[output_node.name] = *tensor.GetTensorData<
float>();
197 }
else if (tensor_shape.size() == 1) {
198 const float*
data = tensor.GetTensorData<
float>();
201 throw std::runtime_error(
"Unsupported tensor shape for FLOAT type");
203 }
else if (tensor_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) {
204 if (tensor_shape.size() == 1) {
205 const char*
data = tensor.GetTensorData<
char>();
208 throw std::runtime_error(
"Unsupported tensor shape for INT8 type");
211 throw std::runtime_error(
"Unsupported tensor type");