ATLAS Offline Software
Loading...
Searching...
No Matches
get-onnx-model-info.cxx File Reference
#include <onnxruntime_cxx_api.h>
#include <iostream>
#include <iomanip>
#include <unordered_map>
#include <cstdint>
Include dependency graph for get-onnx-model-info.cxx:

Go to the source code of this file.

Functions

void pretty_print_table (const std::vector< std::string > &names, const std::vector< ONNXTensorElementDataType > &types, const std::vector< std::vector< int64_t > > &shapes)
int main (int narg, char *argv[])

Function Documentation

◆ main()

int main ( int narg,
char * argv[] )

Definition at line 76 of file get-onnx-model-info.cxx.

76 {
77 if (narg != 3 && narg != 2) {
78 std::cout << "usage: " << argv[0] << " <onnx_file> [key]" << std::endl;
79 return 1;
80 }
81
82 //load the onnx model to memory using the path
83 auto env = std::make_unique< Ort::Env >(ORT_LOGGING_LEVEL_ERROR, "");
84
85 // initialize session options if needed
86 Ort::SessionOptions session_options;
87 session_options.SetIntraOpNumThreads(1);
88 session_options.SetGraphOptimizationLevel(
89 GraphOptimizationLevel::ORT_DISABLE_ALL);
90
91 // create session and load model into memory
92 auto session = std::make_unique< Ort::Session >(*env, argv[1],
93 session_options);
94
95 // cout the input nodes
96 size_t num_input_nodes = session->GetInputCount();
97 std::vector<std::string> names;
98 std::vector<ONNXTensorElementDataType> types;
99 std::vector<std::vector<int64_t>> shapes;
100
101 for (std::size_t i = 0; i < num_input_nodes; i++) {
102 char* input_name = session->GetInputNameAllocated(i, Ort::AllocatorWithDefaultOptions()).release();
103 names.push_back(input_name);
104
105 // get input type and shape
106 Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
107 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
108
109 types.push_back(tensor_info.GetElementType());
110 shapes.push_back(tensor_info.GetShape());
111 }
112
113 std::cout << std::endl << "input nodes: " << std::endl;
114 pretty_print_table(names, types, shapes);
115
116
117 // cout the output nodes
118 size_t num_output_nodes = session->GetOutputCount();
119 names.clear(); types.clear(); shapes.clear();
120
121 for (std::size_t i = 0; i < num_output_nodes; i++) {
122 char* output_name = session->GetOutputNameAllocated(i, Ort::AllocatorWithDefaultOptions()).release();
123 names.push_back(output_name);
124
125 // get input type and shape
126 Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
127 auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
128
129 types.push_back(tensor_info.GetElementType());
130 shapes.push_back(tensor_info.GetShape());
131 }
132
133 std::cout << std::endl << "output nodes: " << std::endl;
134 pretty_print_table(names, types, shapes);
135 std::cout << std::endl;
136
137
138 // get metadata
139 Ort::AllocatorWithDefaultOptions allocator;
140 Ort::ModelMetadata metadata = session->GetModelMetadata();
141 if (narg == 2) {
142 std::cout << "available metadata keys: ";
143 auto keys = metadata.GetCustomMetadataMapKeysAllocated(allocator);
144 for (uint64_t i = 0; i < keys.size(); i++) {
145 std::cout << keys[i].get();
146 if (i+1 < keys.size()) std::cout << ", ";
147 }
148 std::cout << std::endl;
149 return 2;
150 }
151 auto val = metadata.LookupCustomMetadataMapAllocated(argv[2], allocator);
152 std::cout << val.get() << std::endl;
153
154 return 0;
155}
static const std::vector< std::string > types
void pretty_print_table(const std::vector< std::string > &names, const std::vector< ONNXTensorElementDataType > &types, const std::vector< std::vector< int64_t > > &shapes)

◆ pretty_print_table()

void pretty_print_table ( const std::vector< std::string > & names,
const std::vector< ONNXTensorElementDataType > & types,
const std::vector< std::vector< int64_t > > & shapes )

Definition at line 11 of file get-onnx-model-info.cxx.

15 {
16 size_t max_length_name = 0;
17 for (const auto& name : names) {
18 max_length_name = std::max(max_length_name, name.length());
19 }
20
21 // Define the mapping of enum values to string representations
22 static const std::unordered_map<ONNXTensorElementDataType, std::string> typeMap = {
23 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, "float"},
24 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, "double"},
25 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, "int8"},
26 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, "int16"},
27 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, "int32"},
28 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, "int64"},
29 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, "uint8"},
30 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, "uint16"},
31 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, "uint32"},
32 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, "uint64"},
33 {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, "bool"},
34 };
35
36 // format the shape as a string
37 size_t max_length_shape = 0;
38 std::vector<std::string> shape_strs;
39 for (const auto& shape: shapes) {
40 std::string shape_str = "[";
41 for (size_t j = 0; j < shape.size(); ++j) {
42 shape_str += std::to_string(shape[j]);
43 if (j != shape.size() - 1) {
44 shape_str += ", ";
45 }
46 }
47 shape_str += "]";
48 size_t l = shape_str.length();
49 shape_strs.push_back(std::move(shape_str));
50 max_length_shape = std::max(max_length_shape, l);
51 }
52
53 int line_length = max_length_name + 4 + 10 + 3 + max_length_shape;
54 std::string h_line(line_length, '-');
55 std::cout << h_line << std::endl;
56
57 // header
58 std::ios_base::fmtflags f( std::cout.flags() ); //save cout format flags
59 std::cout << std::left << std::setw(max_length_name + 4) << " name";
60 std::cout << std::setw(10) << "type";
61 std::cout << "shape" << std::endl;
62
63 std::cout << h_line << std::endl;
64
65 for (size_t i = 0; i < names.size(); i++) {
66 std::cout << std::left << std::setw(max_length_name + 4) << " " + names.at(i);
67 std::cout << std::setw(10) << typeMap.at(types.at(i));
68 std::cout << shape_strs.at(i) << std::endl;
69 }
70 std::cout << h_line << std::endl;
71 std::cout.flags( f );//restore format
72}
l
Printing final latex table to .tex output file.