5 #include <onnxruntime_cxx_api.h>
8 #include <unordered_map>
12 const std::vector<std::string>&
names,
13 const std::vector<ONNXTensorElementDataType>& types,
14 const std::vector<std::vector<int64_t>>& shapes
16 size_t max_length_name = 0;
18 max_length_name =
std::max(max_length_name,
name.length());
22 static const std::unordered_map<ONNXTensorElementDataType, std::string> typeMap = {
23 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
"float"},
24 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
"double"},
25 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
"int8"},
26 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
"int16"},
27 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
"int32"},
28 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
"int64"},
29 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
"uint8"},
30 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
"uint16"},
31 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
"uint32"},
32 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
"uint64"},
33 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
"bool"},
37 size_t max_length_shape = 0;
38 std::vector<std::string> shape_strs;
39 for (
const auto& shape: shapes) {
40 std::string shape_str =
"[";
41 for (
size_t j = 0; j < shape.size(); ++j) {
43 if (j != shape.size() - 1) {
48 size_t l = shape_str.length();
49 shape_strs.push_back(std::move(shape_str));
50 max_length_shape =
std::max(max_length_shape,
l);
53 int line_length = max_length_name + 4 + 10 + 3 + max_length_shape;
54 std::string h_line(line_length,
'-');
55 std::cout << h_line << std::endl;
58 std::ios_base::fmtflags
f( std::cout.
flags() );
59 std::cout << std::left << std::setw(max_length_name + 4) <<
" name";
60 std::cout << std::setw(10) <<
"type";
61 std::cout <<
"shape" << std::endl;
63 std::cout << h_line << std::endl;
65 for (
size_t i = 0;
i <
names.size();
i++) {
66 std::cout << std::left << std::setw(max_length_name + 4) <<
" " +
names.at(
i);
67 std::cout << std::setw(10) << typeMap.at(types.at(
i));
68 std::cout << shape_strs.at(
i) << std::endl;
70 std::cout << h_line << std::endl;
77 if (narg != 3 && narg != 2) {
78 std::cout <<
"usage: " <<
argv[0] <<
" <onnx_file> [key]" << std::endl;
83 auto env = std::make_unique< Ort::Env >(ORT_LOGGING_LEVEL_ERROR,
"");
86 Ort::SessionOptions session_options;
87 session_options.SetIntraOpNumThreads(1);
88 session_options.SetGraphOptimizationLevel(
89 GraphOptimizationLevel::ORT_DISABLE_ALL);
92 auto session = std::make_unique< Ort::Session >(*
env,
argv[1],
96 size_t num_input_nodes = session->GetInputCount();
97 std::vector<std::string>
names;
98 std::vector<ONNXTensorElementDataType> types;
99 std::vector<std::vector<int64_t>> shapes;
101 for (std::size_t
i = 0;
i < num_input_nodes;
i++) {
102 char* input_name = session->GetInputNameAllocated(
i, Ort::AllocatorWithDefaultOptions()).release();
103 names.push_back(input_name);
106 Ort::TypeInfo type_info = session->GetInputTypeInfo(
i);
107 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
109 types.push_back(tensor_info.GetElementType());
110 shapes.push_back(tensor_info.GetShape());
113 std::cout << std::endl <<
"input nodes: " << std::endl;
118 size_t num_output_nodes = session->GetOutputCount();
119 names.clear(); types.clear(); shapes.clear();
121 for (std::size_t
i = 0;
i < num_output_nodes;
i++) {
122 char* output_name = session->GetOutputNameAllocated(
i, Ort::AllocatorWithDefaultOptions()).release();
123 names.push_back(output_name);
126 Ort::TypeInfo type_info = session->GetOutputTypeInfo(
i);
127 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
129 types.push_back(tensor_info.GetElementType());
130 shapes.push_back(tensor_info.GetShape());
133 std::cout << std::endl <<
"output nodes: " << std::endl;
135 std::cout << std::endl;
139 Ort::AllocatorWithDefaultOptions allocator;
140 Ort::ModelMetadata
metadata = session->GetModelMetadata();
142 std::cout <<
"available metadata keys: ";
143 auto keys =
metadata.GetCustomMetadataMapKeysAllocated(allocator);
145 std::cout <<
keys[
i].get();
146 if (
i+1 <
keys.size()) std::cout <<
", ";
148 std::cout << std::endl;
151 auto val =
metadata.LookupCustomMetadataMapAllocated(
argv[2], allocator);
152 std::cout <<
val.get() << std::endl;