ATLAS Offline Software
ONNXWrapper.cxx
Go to the documentation of this file.
2 
3 ONNXWrapper::ONNXWrapper(const std::string model_path) {
4 
5  // Use the path resolver to find the location of the network .onnx file
7 
8  // init onnx envi
9  m_onnxEnv = std::make_unique< Ort::Env >(ORT_LOGGING_LEVEL_WARNING, "");
10 
11  // initialize session options if needed
12  m_session_options.SetIntraOpNumThreads(1);
13  m_session_options.SetGraphOptimizationLevel(
14  GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
15 
16 
17  // Initialise the ONNX environment and session using the above options and the model name
18  m_onnxSession = std::make_unique<Ort::Session>(*m_onnxEnv,
19  m_modelPath.c_str(),
21 
22  // get the input nodes
23  m_nr_inputs = m_onnxSession->GetInputCount();
24 
25  // get the output nodes
26  m_nr_output = m_onnxSession->GetOutputCount();
27 
28 
29  // iterate over all input nodes
30  for (std::size_t i = 0; i < m_nr_inputs; i++) {
31  const char* input_name = m_onnxSession->GetInputNameAllocated(i, m_allocator).release();
32 
33  m_input_names.push_back(input_name);
34 
35  m_input_dims[input_name] = getShape(m_onnxSession->GetInputTypeInfo(i));
36  }
37 
38  // iterate over all output nodes
39  for(std::size_t i = 0; i < m_nr_output; i++ ) {
40  const char* output_name = m_onnxSession->GetOutputNameAllocated(i, m_allocator).release();
41 
42  m_output_names.push_back(output_name);
43 
44  m_output_dims[output_name] = getShape(m_onnxSession->GetOutputTypeInfo(i));
45 
46  }
47 }
48 
49 std::map<std::string, std::vector<float>> ONNXWrapper::Run(
50  std::map<std::string, std::vector<float>> inputs, int n_batches) {
51  for ( const auto &p : inputs ) // check for valid dimensions between batches and inputs
52  {
53  uint64_t n=1;
54  for (uint64_t i:m_input_dims[p.first])
55  {
56  n*=i;
57  }
58  if ( (p.second.size() % n) != 0){
59 
60  throw std::invalid_argument("For input '"+p.first+"' length not compatible with model. Expect a multiple of "+std::to_string(n)+", got "+std::to_string(p.second.size()));
61  }
62  if ( p.second.size()!=(n_batches*n)){
63  throw std::invalid_argument("Number of batches not compatible with length of vector");
64  }
65  }
66  // Create a CPU tensor to be used as input
67  Ort::MemoryInfo memory_info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
68 
69  // define input tensor
70  std::vector<Ort::Value> output_tensor;
71  std::vector<Ort::Value> input_tensor;
72 
73  // add the inputs to vector
74  for ( const auto &p : m_input_dims )
75  {
76  std::vector<int64_t> in_dims = p.second;
77  in_dims.at(0) = n_batches;
78  input_tensor.push_back(Ort::Value::CreateTensor<float>(memory_info,
79  inputs[p.first].data(),
80  inputs[p.first].size(),
81  in_dims.data(),
82  in_dims.size()));
83  }
84 
85  // init output tensor and fill with zeros
86  std::map<std::string, std::vector<float>> outputs;
87  for ( const auto &p : m_output_dims ) {
88  std::vector<int64_t> out_dims = p.second;
89  out_dims.at(0) = n_batches;
90  // init output
91  int length = 1;
92  for(auto i : out_dims){ length*=i; }
93  std::vector<float> output(length,0);
94  // std::vector<float> output(m_output_dims[i][1], 0.0);
95  outputs[p.first] = output;
96  output_tensor.push_back(Ort::Value::CreateTensor<float>(memory_info,
97  outputs[p.first].data(),
98  outputs[p.first].size(),
99  out_dims.data(),
100  out_dims.size()));
101  }
102 
103  Ort::Session& session = *m_onnxSession;
104 
105  // run the model
106  session.Run(Ort::RunOptions{nullptr},
107  m_input_names.data(),
108  input_tensor.data(),
109  2,
110  m_output_names.data(),
111  output_tensor.data(),
112  2);
113 
114  return outputs;
115  }
116 
117 
118 const std::map<std::string, std::vector<int64_t>> ONNXWrapper::GetModelInputs() {
119  std::map<std::string, std::vector<int64_t>> ModelInputINFO_map;
120 
121  for(std::size_t i = 0; i < m_nr_inputs; i++ ) {
122  ModelInputINFO_map[m_input_names.at(i)] = m_input_dims[m_input_names.at(i)];
123  }
124  return ModelInputINFO_map;
125 }
126 
127 const std::map<std::string, std::vector<int64_t>> ONNXWrapper::GetModelOutputs() {
128  std::map<std::string, std::vector<int64_t>> ModelOutputINFO_map;
129 
130  for(std::size_t i = 0; i < m_nr_output; i++ ) {
131  ModelOutputINFO_map[m_output_names.at(i)] = m_output_dims[m_output_names.at(i)];
132  }
133  return ModelOutputINFO_map;
134 }
135 
136 const std::map<std::string, std::string> ONNXWrapper::GetMETAData() {
137  std::map<std::string, std::string> METAData_map;
138  auto metadata = m_onnxSession->GetModelMetadata();
139 
140  auto keys = metadata.GetCustomMetadataMapKeysAllocated(m_allocator);
141  // auto status = OrtApi::ModelMetadataGetCustomMetadataMapKeys(metadata, m_allocator, keys, nkeys)
142  for (size_t i = 0; i < keys.size(); i++) {
143  METAData_map[keys[i].get()]=this->GetMETADataByKey(keys[i].get());
144  }
145 
146  return METAData_map;
147 }
148 
149 std::string ONNXWrapper::GetMETADataByKey(const char * key){
150  auto metadata = m_onnxSession->GetModelMetadata();
151  return metadata.LookupCustomMetadataMapAllocated(key, m_allocator).get();
152 }
153 
154 const std::vector<const char*>& ONNXWrapper::getInputNames(){
155  //put the model access for input here
156  return m_input_names;
157 }
158 
159 const std::vector<const char*>& ONNXWrapper::getOutputNames(){
160  //put the model access for outputs here
161  return m_output_names;
162 }
163 
164 const std::vector<int64_t>& ONNXWrapper::getInputShape(int input_nr=0){
165  //put the model access for input here
166  std::vector<const char*> names = getInputNames();
167  return m_input_dims[names.at(input_nr)];
168 }
169 
170 const std::vector<int64_t>& ONNXWrapper::getOutputShape(int output_nr=0){
171  //put the model access for outputs here
172  std::vector<const char*> names = getOutputNames();
173  return m_output_dims[names.at(output_nr)];
174 }
175 
176 int ONNXWrapper::getNumInputs() const { return m_input_names.size(); }
177 int ONNXWrapper::getNumOutputs() const { return m_output_names.size(); }
178 
179 const std::vector<int64_t> ONNXWrapper::getShape(Ort::TypeInfo model_info) {
180  auto tensor_info = model_info.GetTensorTypeAndShapeInfo();
181  std::vector<int64_t> dims = tensor_info.GetShape();
182  dims[0]=1;
183  return dims;
184  }
ONNXWrapper::m_onnxSession
std::unique_ptr< Ort::Session > m_onnxSession
Definition: ONNXWrapper.h:36
ONNXWrapper::m_session_options
Ort::SessionOptions m_session_options
Definition: ONNXWrapper.h:40
ONNXWrapper::getInputNames
const std::vector< const char * > & getInputNames()
Definition: ONNXWrapper.cxx:154
ONNXWrapper::m_output_names
std::vector< const char * > m_output_names
Definition: ONNXWrapper.h:46
ONNXWrapper::m_input_names
std::vector< const char * > m_input_names
Definition: ONNXWrapper.h:47
ONNXWrapper::m_onnxEnv
std::unique_ptr< Ort::Env > m_onnxEnv
Definition: ONNXWrapper.h:37
ONNXWrapper::GetMETAData
const std::map< std::string, std::string > GetMETAData()
Definition: ONNXWrapper.cxx:136
ONNXWrapper::getOutputNames
const std::vector< const char * > & getOutputNames()
Definition: ONNXWrapper.cxx:159
ONNXWrapper::m_modelPath
std::string m_modelPath
Definition: ONNXWrapper.h:22
ONNXWrapper::ONNXWrapper
ONNXWrapper(const std::string model_path)
Definition: ONNXWrapper.cxx:3
ONNXWrapper::m_nr_inputs
size_t m_nr_inputs
Definition: ONNXWrapper.h:27
python.oracle.Session
Session
Definition: oracle.py:78
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
ONNXWrapper::GetMETADataByKey
std::string GetMETADataByKey(const char *key)
Definition: ONNXWrapper.cxx:149
python.checkMetadata.metadata
metadata
Definition: checkMetadata.py:175
ONNXWrapper::getShape
const std::vector< int64_t > getShape(Ort::TypeInfo model_info)
Definition: ONNXWrapper.cxx:179
python.utils.AtlRunQueryDQUtils.p
p
Definition: AtlRunQueryDQUtils.py:210
lumiFormat.i
int i
Definition: lumiFormat.py:85
beamspotman.n
n
Definition: beamspotman.py:731
python.subdetectors.mmg.names
names
Definition: mmg.py:8
ONNXWrapper::m_input_dims
std::map< std::string, std::vector< int64_t > > m_input_dims
Definition: ONNXWrapper.h:31
xAOD::uint64_t
uint64_t
Definition: EventInfo_v1.cxx:123
ONNXWrapper.h
ONNXWrapper::Run
std::map< std::string, std::vector< float > > Run(std::map< std::string, std::vector< float >> inputs, int n_batches=1)
Definition: ONNXWrapper.cxx:49
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
merge.output
output
Definition: merge.py:17
ONNXWrapper::GetModelOutputs
const std::map< std::string, std::vector< int64_t > > GetModelOutputs()
Definition: ONNXWrapper.cxx:127
ONNXWrapper::getNumInputs
int getNumInputs() const
Definition: ONNXWrapper.cxx:176
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
ONNXWrapper::m_output_dims
std::map< std::string, std::vector< int64_t > > m_output_dims
Definition: ONNXWrapper.h:32
get
T * get(TKey *tobj)
get a TObject* from a TKey* (why can't a TObject be a TKey?)
Definition: hcg.cxx:127
ONNXWrapper::GetModelInputs
const std::map< std::string, std::vector< int64_t > > GetModelInputs()
Definition: ONNXWrapper.cxx:118
python.Bindings.keys
keys
Definition: Control/AthenaPython/python/Bindings.py:798
ONNXWrapper::getInputShape
const std::vector< int64_t > & getInputShape(int input_nr)
Definition: ONNXWrapper.cxx:164
ONNXWrapper::m_nr_output
size_t m_nr_output
Definition: ONNXWrapper.h:28
length
double length(const pvec &v)
Definition: FPGATrackSimLLPDoubletHoughTransformTool.cxx:26
ONNXWrapper::getNumOutputs
int getNumOutputs() const
Definition: ONNXWrapper.cxx:177
ONNXWrapper::getOutputShape
const std::vector< int64_t > & getOutputShape(int output_nr)
Definition: ONNXWrapper.cxx:170
ONNXWrapper::m_allocator
Ort::AllocatorWithDefaultOptions m_allocator
Definition: ONNXWrapper.h:41
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37