ATLAS Offline Software
ONNXWrapper.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
5 
6 // Constructor:
7 // Use the path resolver to find the location of the network .onnx file
8 // initialise onnx environment
9 ONNXWrapper::ONNXWrapper(const std::string & model_path):
10  m_modelPath(PathResolverFindCalibFile(model_path)),
11  m_onnxEnv(std::make_unique< Ort::Env >(ORT_LOGGING_LEVEL_WARNING, "")) {
12 
13  // initialize session options if needed
14  m_session_options.SetIntraOpNumThreads(1);
15  m_session_options.SetGraphOptimizationLevel(
16  GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
17 
18 
19  // Initialise the ONNX environment and session using the above options and the model name
20  m_onnxSession = std::make_unique<Ort::Session>(*m_onnxEnv,
21  m_modelPath.c_str(),
23 
24  // get the input nodes
25  m_nr_inputs = m_onnxSession->GetInputCount();
26 
27  // get the output nodes
28  m_nr_output = m_onnxSession->GetOutputCount();
29 
30 
31  // iterate over all input nodes
32  for (std::size_t i = 0; i < m_nr_inputs; i++) {
33  const char* input_name = m_onnxSession->GetInputNameAllocated(i, m_allocator).release();
34 
35  m_input_names.push_back(input_name);
36 
37  m_input_dims[input_name] = getShape(m_onnxSession->GetInputTypeInfo(i));
38  }
39 
40  // iterate over all output nodes
41  for(std::size_t i = 0; i < m_nr_output; i++ ) {
42  const char* output_name = m_onnxSession->GetOutputNameAllocated(i, m_allocator).release();
43 
44  m_output_names.push_back(output_name);
45 
46  m_output_dims[output_name] = getShape(m_onnxSession->GetOutputTypeInfo(i));
47 
48  }
49 }
50 
51 std::map<std::string, std::vector<float>> ONNXWrapper::Run(
52  std::map<std::string, std::vector<float>> inputs, int n_batches) {
53  for ( const auto &p : inputs ) // check for valid dimensions between batches and inputs
54  {
55  uint64_t n=1;
56  for (uint64_t i:m_input_dims[p.first])
57  {
58  n*=i;
59  }
60  if ( (p.second.size() % n) != 0){
61 
62  throw std::invalid_argument("For input '"+p.first+"' length not compatible with model. Expect a multiple of "+std::to_string(n)+", got "+std::to_string(p.second.size()));
63  }
64  if ( p.second.size()!=(n_batches*n)){
65  throw std::invalid_argument("Number of batches not compatible with length of vector");
66  }
67  }
68  // Create a CPU tensor to be used as input
69  Ort::MemoryInfo memory_info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
70 
71  // define input tensor
72  std::vector<Ort::Value> output_tensor;
73  std::vector<Ort::Value> input_tensor;
74 
75  // add the inputs to vector
76  for ( const auto &p : m_input_dims )
77  {
78  std::vector<int64_t> in_dims = p.second;
79  in_dims.at(0) = n_batches;
80  input_tensor.push_back(Ort::Value::CreateTensor<float>(memory_info,
81  inputs[p.first].data(),
82  inputs[p.first].size(),
83  in_dims.data(),
84  in_dims.size()));
85  }
86 
87  // init output tensor and fill with zeros
88  std::map<std::string, std::vector<float>> outputs;
89  for ( const auto &p : m_output_dims ) {
90  std::vector<int64_t> out_dims = p.second;
91  out_dims.at(0) = n_batches;
92  // init output
93  int length = 1;
94  for(auto i : out_dims){ length*=i; }
95  std::vector<float> output(length,0);
96  // std::vector<float> output(m_output_dims[i][1], 0.0);
97  outputs[p.first] = output;
98  output_tensor.push_back(Ort::Value::CreateTensor<float>(memory_info,
99  outputs[p.first].data(),
100  outputs[p.first].size(),
101  out_dims.data(),
102  out_dims.size()));
103  }
104 
105  Ort::Session& session = *m_onnxSession;
106 
107  // run the model
108  session.Run(Ort::RunOptions{nullptr},
109  m_input_names.data(),
110  input_tensor.data(),
111  2,
112  m_output_names.data(),
113  output_tensor.data(),
114  2);
115 
116  return outputs;
117  }
118 
119 
120 const std::map<std::string, std::vector<int64_t>> ONNXWrapper::GetModelInputs() {
121  std::map<std::string, std::vector<int64_t>> ModelInputINFO_map;
122 
123  for(std::size_t i = 0; i < m_nr_inputs; i++ ) {
124  ModelInputINFO_map[m_input_names.at(i)] = m_input_dims[m_input_names.at(i)];
125  }
126  return ModelInputINFO_map;
127 }
128 
129 const std::map<std::string, std::vector<int64_t>> ONNXWrapper::GetModelOutputs() {
130  std::map<std::string, std::vector<int64_t>> ModelOutputINFO_map;
131 
132  for(std::size_t i = 0; i < m_nr_output; i++ ) {
133  ModelOutputINFO_map[m_output_names.at(i)] = m_output_dims[m_output_names.at(i)];
134  }
135  return ModelOutputINFO_map;
136 }
137 
138 const std::map<std::string, std::string> ONNXWrapper::GetMETAData() {
139  std::map<std::string, std::string> METAData_map;
140  auto metadata = m_onnxSession->GetModelMetadata();
141 
142  auto keys = metadata.GetCustomMetadataMapKeysAllocated(m_allocator);
143  // auto status = OrtApi::ModelMetadataGetCustomMetadataMapKeys(metadata, m_allocator, keys, nkeys)
144  for (size_t i = 0; i < keys.size(); i++) {
145  METAData_map[keys[i].get()]=this->GetMETADataByKey(keys[i].get());
146  }
147 
148  return METAData_map;
149 }
150 
151 std::string ONNXWrapper::GetMETADataByKey(const char * key){
152  auto metadata = m_onnxSession->GetModelMetadata();
153  return metadata.LookupCustomMetadataMapAllocated(key, m_allocator).get();
154 }
155 
156 const std::vector<const char*>& ONNXWrapper::getInputNames(){
157  //put the model access for input here
158  return m_input_names;
159 }
160 
161 const std::vector<const char*>& ONNXWrapper::getOutputNames(){
162  //put the model access for outputs here
163  return m_output_names;
164 }
165 
166 const std::vector<int64_t>& ONNXWrapper::getInputShape(int input_nr=0){
167  //put the model access for input here
168  std::vector<const char*> names = getInputNames();
169  return m_input_dims[names.at(input_nr)];
170 }
171 
172 const std::vector<int64_t>& ONNXWrapper::getOutputShape(int output_nr=0){
173  //put the model access for outputs here
174  std::vector<const char*> names = getOutputNames();
175  return m_output_dims[names.at(output_nr)];
176 }
177 
178 int ONNXWrapper::getNumInputs() const { return m_input_names.size(); }
179 int ONNXWrapper::getNumOutputs() const { return m_output_names.size(); }
180 
181 const std::vector<int64_t> ONNXWrapper::getShape(Ort::TypeInfo model_info) {
182  auto tensor_info = model_info.GetTensorTypeAndShapeInfo();
183  std::vector<int64_t> dims = tensor_info.GetShape();
184  dims[0]=1;
185  return dims;
186  }
ONNXWrapper::m_onnxSession
std::unique_ptr< Ort::Session > m_onnxSession
Definition: ONNXWrapper.h:40
ONNXWrapper::m_session_options
Ort::SessionOptions m_session_options
Definition: ONNXWrapper.h:44
ONNXWrapper::getInputNames
const std::vector< const char * > & getInputNames()
Definition: ONNXWrapper.cxx:156
ONNXWrapper::m_output_names
std::vector< const char * > m_output_names
Definition: ONNXWrapper.h:50
make_unique
std::unique_ptr< T > make_unique(Args &&... args)
Definition: SkimmingToolEXOT5.cxx:23
ONNXWrapper::m_input_names
std::vector< const char * > m_input_names
Definition: ONNXWrapper.h:51
ONNXWrapper::m_onnxEnv
std::unique_ptr< Ort::Env > m_onnxEnv
Definition: ONNXWrapper.h:41
ONNXWrapper::GetMETAData
const std::map< std::string, std::string > GetMETAData()
Definition: ONNXWrapper.cxx:138
ONNXWrapper::getOutputNames
const std::vector< const char * > & getOutputNames()
Definition: ONNXWrapper.cxx:161
ONNXWrapper::m_modelPath
std::string m_modelPath
Definition: ONNXWrapper.h:26
ONNXWrapper::m_nr_inputs
size_t m_nr_inputs
Definition: ONNXWrapper.h:31
python.oracle.Session
Session
Definition: oracle.py:78
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
ONNXWrapper::GetMETADataByKey
std::string GetMETADataByKey(const char *key)
Definition: ONNXWrapper.cxx:151
python.checkMetadata.metadata
metadata
Definition: checkMetadata.py:175
ONNXWrapper::getShape
const std::vector< int64_t > getShape(Ort::TypeInfo model_info)
Definition: ONNXWrapper.cxx:181
python.utils.AtlRunQueryDQUtils.p
p
Definition: AtlRunQueryDQUtils.py:210
lumiFormat.i
int i
Definition: lumiFormat.py:85
beamspotman.n
n
Definition: beamspotman.py:731
python.subdetectors.mmg.names
names
Definition: mmg.py:8
ONNXWrapper::m_input_dims
std::map< std::string, std::vector< int64_t > > m_input_dims
Definition: ONNXWrapper.h:35
ONNXWrapper::ONNXWrapper
ONNXWrapper(const std::string &model_path)
Definition: ONNXWrapper.cxx:9
xAOD::uint64_t
uint64_t
Definition: EventInfo_v1.cxx:123
ONNXWrapper.h
ONNXWrapper::Run
std::map< std::string, std::vector< float > > Run(std::map< std::string, std::vector< float >> inputs, int n_batches=1)
Definition: ONNXWrapper.cxx:51
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
merge.output
output
Definition: merge.py:17
ONNXWrapper::GetModelOutputs
const std::map< std::string, std::vector< int64_t > > GetModelOutputs()
Definition: ONNXWrapper.cxx:129
ONNXWrapper::getNumInputs
int getNumInputs() const
Definition: ONNXWrapper.cxx:178
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
ONNXWrapper::m_output_dims
std::map< std::string, std::vector< int64_t > > m_output_dims
Definition: ONNXWrapper.h:36
get
T * get(TKey *tobj)
get a TObject* from a TKey* (why can't a TObject be a TKey?)
Definition: hcg.cxx:127
ONNXWrapper::GetModelInputs
const std::map< std::string, std::vector< int64_t > > GetModelInputs()
Definition: ONNXWrapper.cxx:120
python.Bindings.keys
keys
Definition: Control/AthenaPython/python/Bindings.py:798
ONNXWrapper::getInputShape
const std::vector< int64_t > & getInputShape(int input_nr)
Definition: ONNXWrapper.cxx:166
ONNXWrapper::m_nr_output
size_t m_nr_output
Definition: ONNXWrapper.h:32
length
double length(const pvec &v)
Definition: FPGATrackSimLLPDoubletHoughTransformTool.cxx:26
ONNXWrapper::getNumOutputs
int getNumOutputs() const
Definition: ONNXWrapper.cxx:179
ONNXWrapper::getOutputShape
const std::vector< int64_t > & getOutputShape(int output_nr)
Definition: ONNXWrapper.cxx:172
ONNXWrapper::m_allocator
Ort::AllocatorWithDefaultOptions m_allocator
Definition: ONNXWrapper.h:45
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37