ATLAS Offline Software
get-onnx-model-info.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #include <onnxruntime_cxx_api.h>
6 #include <iostream>
7 #include <iomanip>
8 #include <unordered_map>
9 #include <cstdint>
10 
12  const std::vector<std::string>& names,
13  const std::vector<ONNXTensorElementDataType>& types,
14  const std::vector<std::vector<int64_t>>& shapes
15 ){
16  size_t max_length_name = 0;
17  for (const auto& name : names) {
18  max_length_name = std::max(max_length_name, name.length());
19  }
20 
21  // Define the mapping of enum values to string representations
22  static const std::unordered_map<ONNXTensorElementDataType, std::string> typeMap = {
23  {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, "float"},
24  {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, "double"},
25  {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, "int8"},
26  {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, "int16"},
27  {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, "int32"},
28  {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, "int64"},
29  {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, "uint8"},
30  {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, "uint16"},
31  {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, "uint32"},
32  {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, "uint64"},
33  {ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, "bool"},
34  };
35 
36  // format the shape as a string
37  size_t max_length_shape = 0;
38  std::vector<std::string> shape_strs;
39  for (const auto& shape: shapes) {
40  std::string shape_str = "[";
41  for (size_t j = 0; j < shape.size(); ++j) {
42  shape_str += std::to_string(shape[j]);
43  if (j != shape.size() - 1) {
44  shape_str += ", ";
45  }
46  }
47  shape_str += "]";
48  size_t l = shape_str.length();
49  shape_strs.push_back(std::move(shape_str));
50  max_length_shape = std::max(max_length_shape, l);
51  }
52 
53  int line_length = max_length_name + 4 + 10 + 3 + max_length_shape;
54  std::string h_line(line_length, '-');
55  std::cout << h_line << std::endl;
56 
57  // header
58  std::ios_base::fmtflags f( std::cout.flags() ); //save cout format flags
59  std::cout << std::left << std::setw(max_length_name + 4) << " name";
60  std::cout << std::setw(10) << "type";
61  std::cout << "shape" << std::endl;
62 
63  std::cout << h_line << std::endl;
64 
65  for (size_t i = 0; i < names.size(); i++) {
66  std::cout << std::left << std::setw(max_length_name + 4) << " " + names.at(i);
67  std::cout << std::setw(10) << typeMap.at(types.at(i));
68  std::cout << shape_strs.at(i) << std::endl;
69  }
70  std::cout << h_line << std::endl;
71  std::cout.flags( f );//restore format
72 }
73 
74 
75 
76 int main(int narg, char* argv[]) {
77  if (narg != 3 && narg != 2) {
78  std::cout << "usage: " << argv[0] << " <onnx_file> [key]" << std::endl;
79  return 1;
80  }
81 
82  //load the onnx model to memory using the path
83  auto env = std::make_unique< Ort::Env >(ORT_LOGGING_LEVEL_ERROR, "");
84 
85  // initialize session options if needed
86  Ort::SessionOptions session_options;
87  session_options.SetIntraOpNumThreads(1);
88  session_options.SetGraphOptimizationLevel(
89  GraphOptimizationLevel::ORT_DISABLE_ALL);
90 
91  // create session and load model into memory
92  auto session = std::make_unique< Ort::Session >(*env, argv[1],
93  session_options);
94 
95  // cout the input nodes
96  size_t num_input_nodes = session->GetInputCount();
97  std::vector<std::string> names;
98  std::vector<ONNXTensorElementDataType> types;
99  std::vector<std::vector<int64_t>> shapes;
100 
101  for (std::size_t i = 0; i < num_input_nodes; i++) {
102  char* input_name = session->GetInputNameAllocated(i, Ort::AllocatorWithDefaultOptions()).release();
103  names.push_back(input_name);
104 
105  // get input type and shape
106  Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
107  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
108 
109  types.push_back(tensor_info.GetElementType());
110  shapes.push_back(tensor_info.GetShape());
111  }
112 
113  std::cout << std::endl << "input nodes: " << std::endl;
114  pretty_print_table(names, types, shapes);
115 
116 
117  // cout the output nodes
118  size_t num_output_nodes = session->GetOutputCount();
119  names.clear(); types.clear(); shapes.clear();
120 
121  for (std::size_t i = 0; i < num_output_nodes; i++) {
122  char* output_name = session->GetOutputNameAllocated(i, Ort::AllocatorWithDefaultOptions()).release();
123  names.push_back(output_name);
124 
125  // get input type and shape
126  Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
127  auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
128 
129  types.push_back(tensor_info.GetElementType());
130  shapes.push_back(tensor_info.GetShape());
131  }
132 
133  std::cout << std::endl << "output nodes: " << std::endl;
134  pretty_print_table(names, types, shapes);
135  std::cout << std::endl;
136 
137 
138  // get metadata
139  Ort::AllocatorWithDefaultOptions allocator;
140  Ort::ModelMetadata metadata = session->GetModelMetadata();
141  if (narg == 2) {
142  std::cout << "available metadata keys: ";
143  auto keys = metadata.GetCustomMetadataMapKeysAllocated(allocator);
144  for (uint64_t i = 0; i < keys.size(); i++) {
145  std::cout << keys[i].get();
146  if (i+1 < keys.size()) std::cout << ", ";
147  }
148  std::cout << std::endl;
149  return 2;
150  }
151  auto val = metadata.LookupCustomMetadataMapAllocated(argv[2], allocator);
152  std::cout << val.get() << std::endl;
153 
154  return 0;
155 }
python.CaloRecoConfig.f
f
Definition: CaloRecoConfig.py:127
max
#define max(a, b)
Definition: cfImp.cxx:41
AthenaPoolTestRead.flags
flags
Definition: AthenaPoolTestRead.py:8
UploadAMITag.l
list l
Definition: UploadAMITag.larcaf.py:158
LArCellConditions.argv
argv
Definition: LArCellConditions.py:112
python.checkMetadata.metadata
metadata
Definition: checkMetadata.py:175
pretty_print_table
void pretty_print_table(const std::vector< std::string > &names, const std::vector< ONNXTensorElementDataType > &types, const std::vector< std::vector< int64_t >> &shapes)
Definition: get-onnx-model-info.cxx:11
main
int main(int narg, char *argv[])
Definition: get-onnx-model-info.cxx:76
lumiFormat.i
int i
Definition: lumiFormat.py:92
python.subdetectors.mmg.names
names
Definition: mmg.py:8
xAOD::uint64_t
uint64_t
Definition: EventInfo_v1.cxx:123
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
Pythia8_RapidityOrderMPI.val
val
Definition: Pythia8_RapidityOrderMPI.py:14
python.Bindings.keys
keys
Definition: Control/AthenaPython/python/Bindings.py:790
python.DataFormatRates.env
env
Definition: DataFormatRates.py:32