ATLAS Offline Software
Public Member Functions | Private Member Functions | Private Attributes | List of all members
ONNXWrapper Class Reference

#include <ONNXWrapper.h>

Collaboration diagram for ONNXWrapper:

Public Member Functions

 ONNXWrapper (const std::string &model_path)
 
std::map< std::string, std::vector< float > > Run (std::map< std::string, std::vector< float >> inputs, int n_batches=1)
 
const std::map< std::string, std::vector< int64_t > > GetModelInputs ()
 
const std::map< std::string, std::vector< int64_t > > GetModelOutputs ()
 
const std::map< std::string, std::string > GetMETAData ()
 
std::string GetMETADataByKey (const char *key)
 
const std::vector< int64_t > & getInputShape (int input_nr)
 
const std::vector< int64_t > & getOutputShape (int output_nr)
 
const std::vector< const char * > & getInputNames ()
 
const std::vector< const char * > & getOutputNames ()
 
int getNumInputs () const
 
int getNumOutputs () const
 

Private Member Functions

const std::vector< int64_t > getShape (Ort::TypeInfo model_info)
 

Private Attributes

std::string m_modelName
 
std::string m_modelPath
 
size_t m_nr_inputs
 
size_t m_nr_output
 
std::map< std::string, std::vector< int64_t > > m_input_dims
 
std::map< std::string, std::vector< int64_t > > m_output_dims
 
std::unique_ptr< Ort::Session > m_onnxSession
 
std::unique_ptr< Ort::Env > m_onnxEnv
 
Ort::SessionOptions m_session_options
 
Ort::AllocatorWithDefaultOptions m_allocator
 
std::vector< const char * > m_output_names
 
std::vector< const char * > m_input_names
 

Detailed Description

Definition at line 20 of file ONNXWrapper.h.

Constructor & Destructor Documentation

◆ ONNXWrapper()

ONNXWrapper::ONNXWrapper ( const std::string &  model_path)

Definition at line 9 of file ONNXWrapper.cxx.

9  :
11  m_onnxEnv(std::make_unique< Ort::Env >(ORT_LOGGING_LEVEL_WARNING, "")) {
12 
13  // initialize session options if needed
14  m_session_options.SetIntraOpNumThreads(1);
15  m_session_options.SetGraphOptimizationLevel(
16  GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
17 
18 
19  // Initialise the ONNX environment and session using the above options and the model name
20  m_onnxSession = std::make_unique<Ort::Session>(*m_onnxEnv,
21  m_modelPath.c_str(),
23 
24  // get the input nodes
25  m_nr_inputs = m_onnxSession->GetInputCount();
26 
27  // get the output nodes
28  m_nr_output = m_onnxSession->GetOutputCount();
29 
30 
31  // iterate over all input nodes
32  for (std::size_t i = 0; i < m_nr_inputs; i++) {
33  const char* input_name = m_onnxSession->GetInputNameAllocated(i, m_allocator).release();
34 
35  m_input_names.push_back(input_name);
36 
37  m_input_dims[input_name] = getShape(m_onnxSession->GetInputTypeInfo(i));
38  }
39 
40  // iterate over all output nodes
41  for(std::size_t i = 0; i < m_nr_output; i++ ) {
42  const char* output_name = m_onnxSession->GetOutputNameAllocated(i, m_allocator).release();
43 
44  m_output_names.push_back(output_name);
45 
46  m_output_dims[output_name] = getShape(m_onnxSession->GetOutputTypeInfo(i));
47 
48  }
49 }

Member Function Documentation

◆ getInputNames()

const std::vector< const char * > & ONNXWrapper::getInputNames ( )

Definition at line 156 of file ONNXWrapper.cxx.

156  {
157  //put the model access for input here
158  return m_input_names;
159 }

◆ getInputShape()

const std::vector< int64_t > & ONNXWrapper::getInputShape ( int  input_nr = 0)

Definition at line 166 of file ONNXWrapper.cxx.

166  {
167  //put the model access for input here
168  std::vector<const char*> names = getInputNames();
169  return m_input_dims[names.at(input_nr)];
170 }

◆ GetMETAData()

const std::map< std::string, std::string > ONNXWrapper::GetMETAData ( )

Definition at line 138 of file ONNXWrapper.cxx.

138  {
139  std::map<std::string, std::string> METAData_map;
140  auto metadata = m_onnxSession->GetModelMetadata();
141 
142  auto keys = metadata.GetCustomMetadataMapKeysAllocated(m_allocator);
143  // auto status = OrtApi::ModelMetadataGetCustomMetadataMapKeys(metadata, m_allocator, keys, nkeys)
144  for (size_t i = 0; i < keys.size(); i++) {
145  METAData_map[keys[i].get()]=this->GetMETADataByKey(keys[i].get());
146  }
147 
148  return METAData_map;
149 }

◆ GetMETADataByKey()

std::string ONNXWrapper::GetMETADataByKey ( const char *  key)

Definition at line 151 of file ONNXWrapper.cxx.

151  {
152  auto metadata = m_onnxSession->GetModelMetadata();
153  return metadata.LookupCustomMetadataMapAllocated(key, m_allocator).get();
154 }

◆ GetModelInputs()

const std::map< std::string, std::vector< int64_t > > ONNXWrapper::GetModelInputs ( )

Definition at line 120 of file ONNXWrapper.cxx.

120  {
121  std::map<std::string, std::vector<int64_t>> ModelInputINFO_map;
122 
123  for(std::size_t i = 0; i < m_nr_inputs; i++ ) {
124  ModelInputINFO_map[m_input_names.at(i)] = m_input_dims[m_input_names.at(i)];
125  }
126  return ModelInputINFO_map;
127 }

◆ GetModelOutputs()

const std::map< std::string, std::vector< int64_t > > ONNXWrapper::GetModelOutputs ( )

Definition at line 129 of file ONNXWrapper.cxx.

129  {
130  std::map<std::string, std::vector<int64_t>> ModelOutputINFO_map;
131 
132  for(std::size_t i = 0; i < m_nr_output; i++ ) {
133  ModelOutputINFO_map[m_output_names.at(i)] = m_output_dims[m_output_names.at(i)];
134  }
135  return ModelOutputINFO_map;
136 }

◆ getNumInputs()

int ONNXWrapper::getNumInputs ( ) const

Definition at line 178 of file ONNXWrapper.cxx.

178 { return m_input_names.size(); }

◆ getNumOutputs()

int ONNXWrapper::getNumOutputs ( ) const

Definition at line 179 of file ONNXWrapper.cxx.

179 { return m_output_names.size(); }

◆ getOutputNames()

const std::vector< const char * > & ONNXWrapper::getOutputNames ( )

Definition at line 161 of file ONNXWrapper.cxx.

161  {
162  //put the model access for outputs here
163  return m_output_names;
164 }

◆ getOutputShape()

const std::vector< int64_t > & ONNXWrapper::getOutputShape ( int  output_nr = 0)

Definition at line 172 of file ONNXWrapper.cxx.

172  {
173  //put the model access for outputs here
174  std::vector<const char*> names = getOutputNames();
175  return m_output_dims[names.at(output_nr)];
176 }

◆ getShape()

const std::vector< int64_t > ONNXWrapper::getShape ( Ort::TypeInfo  model_info)
private

Definition at line 181 of file ONNXWrapper.cxx.

181  {
182  auto tensor_info = model_info.GetTensorTypeAndShapeInfo();
183  std::vector<int64_t> dims = tensor_info.GetShape();
184  dims[0]=1;
185  return dims;
186  }

◆ Run()

std::map< std::string, std::vector< float > > ONNXWrapper::Run ( std::map< std::string, std::vector< float >>  inputs,
int  n_batches = 1 
)

Definition at line 51 of file ONNXWrapper.cxx.

52  {
53  for ( const auto &p : inputs ) // check for valid dimensions between batches and inputs
54  {
55  uint64_t n=1;
56  for (uint64_t i:m_input_dims[p.first])
57  {
58  n*=i;
59  }
60  if ( (p.second.size() % n) != 0){
61 
62  throw std::invalid_argument("For input '"+p.first+"' length not compatible with model. Expect a multiple of "+std::to_string(n)+", got "+std::to_string(p.second.size()));
63  }
64  if ( p.second.size()!=(n_batches*n)){
65  throw std::invalid_argument("Number of batches not compatible with length of vector");
66  }
67  }
68  // Create a CPU tensor to be used as input
69  Ort::MemoryInfo memory_info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
70 
71  // define input tensor
72  std::vector<Ort::Value> output_tensor;
73  std::vector<Ort::Value> input_tensor;
74 
75  // add the inputs to vector
76  for ( const auto &p : m_input_dims )
77  {
78  std::vector<int64_t> in_dims = p.second;
79  in_dims.at(0) = n_batches;
80  input_tensor.push_back(Ort::Value::CreateTensor<float>(memory_info,
81  inputs[p.first].data(),
82  inputs[p.first].size(),
83  in_dims.data(),
84  in_dims.size()));
85  }
86 
87  // init output tensor and fill with zeros
88  std::map<std::string, std::vector<float>> outputs;
89  for ( const auto &p : m_output_dims ) {
90  std::vector<int64_t> out_dims = p.second;
91  out_dims.at(0) = n_batches;
92  // init output
93  int length = 1;
94  for(auto i : out_dims){ length*=i; }
95  std::vector<float> output(length,0);
96  // std::vector<float> output(m_output_dims[i][1], 0.0);
97  outputs[p.first] = output;
98  output_tensor.push_back(Ort::Value::CreateTensor<float>(memory_info,
99  outputs[p.first].data(),
100  outputs[p.first].size(),
101  out_dims.data(),
102  out_dims.size()));
103  }
104 
105  Ort::Session& session = *m_onnxSession;
106 
107  // run the model
108  session.Run(Ort::RunOptions{nullptr},
109  m_input_names.data(),
110  input_tensor.data(),
111  2,
112  m_output_names.data(),
113  output_tensor.data(),
114  2);
115 
116  return outputs;
117  }

Member Data Documentation

◆ m_allocator

Ort::AllocatorWithDefaultOptions ONNXWrapper::m_allocator
private

Definition at line 45 of file ONNXWrapper.h.

◆ m_input_dims

std::map<std::string, std::vector<int64_t> > ONNXWrapper::m_input_dims
private

Definition at line 35 of file ONNXWrapper.h.

◆ m_input_names

std::vector<const char*> ONNXWrapper::m_input_names
private

Definition at line 51 of file ONNXWrapper.h.

◆ m_modelName

std::string ONNXWrapper::m_modelName
private

Definition at line 25 of file ONNXWrapper.h.

◆ m_modelPath

std::string ONNXWrapper::m_modelPath
private

Definition at line 26 of file ONNXWrapper.h.

◆ m_nr_inputs

size_t ONNXWrapper::m_nr_inputs
private

Definition at line 31 of file ONNXWrapper.h.

◆ m_nr_output

size_t ONNXWrapper::m_nr_output
private

Definition at line 32 of file ONNXWrapper.h.

◆ m_onnxEnv

std::unique_ptr< Ort::Env > ONNXWrapper::m_onnxEnv
private

Definition at line 41 of file ONNXWrapper.h.

◆ m_onnxSession

std::unique_ptr<Ort::Session> ONNXWrapper::m_onnxSession
private

Definition at line 40 of file ONNXWrapper.h.

◆ m_output_dims

std::map<std::string, std::vector<int64_t> > ONNXWrapper::m_output_dims
private

Definition at line 36 of file ONNXWrapper.h.

◆ m_output_names

std::vector<const char*> ONNXWrapper::m_output_names
private

Definition at line 50 of file ONNXWrapper.h.

◆ m_session_options

Ort::SessionOptions ONNXWrapper::m_session_options
private

Definition at line 44 of file ONNXWrapper.h.


The documentation for this class was generated from the following files:
ONNXWrapper::m_onnxSession
std::unique_ptr< Ort::Session > m_onnxSession
Definition: ONNXWrapper.h:40
ONNXWrapper::m_session_options
Ort::SessionOptions m_session_options
Definition: ONNXWrapper.h:44
ONNXWrapper::getInputNames
const std::vector< const char * > & getInputNames()
Definition: ONNXWrapper.cxx:156
ONNXWrapper::m_output_names
std::vector< const char * > m_output_names
Definition: ONNXWrapper.h:50
ONNXWrapper::m_input_names
std::vector< const char * > m_input_names
Definition: ONNXWrapper.h:51
ONNXWrapper::m_onnxEnv
std::unique_ptr< Ort::Env > m_onnxEnv
Definition: ONNXWrapper.h:41
ONNXWrapper::getOutputNames
const std::vector< const char * > & getOutputNames()
Definition: ONNXWrapper.cxx:161
ONNXWrapper::m_modelPath
std::string m_modelPath
Definition: ONNXWrapper.h:26
ONNXWrapper::m_nr_inputs
size_t m_nr_inputs
Definition: ONNXWrapper.h:31
python.oracle.Session
Session
Definition: oracle.py:78
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
ONNXWrapper::GetMETADataByKey
std::string GetMETADataByKey(const char *key)
Definition: ONNXWrapper.cxx:151
python.checkMetadata.metadata
metadata
Definition: checkMetadata.py:175
ONNXWrapper::getShape
const std::vector< int64_t > getShape(Ort::TypeInfo model_info)
Definition: ONNXWrapper.cxx:181
python.utils.AtlRunQueryDQUtils.p
p
Definition: AtlRunQueryDQUtils.py:210
lumiFormat.i
int i
Definition: lumiFormat.py:85
beamspotman.n
n
Definition: beamspotman.py:731
python.subdetectors.mmg.names
names
Definition: mmg.py:8
ONNXWrapper::m_input_dims
std::map< std::string, std::vector< int64_t > > m_input_dims
Definition: ONNXWrapper.h:35
xAOD::uint64_t
uint64_t
Definition: EventInfo_v1.cxx:123
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
merge.output
output
Definition: merge.py:17
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
ONNXWrapper::m_output_dims
std::map< std::string, std::vector< int64_t > > m_output_dims
Definition: ONNXWrapper.h:36
get
T * get(TKey *tobj)
get a TObject* from a TKey* (why can't a TObject be a TKey?)
Definition: hcg.cxx:127
python.Bindings.keys
keys
Definition: Control/AthenaPython/python/Bindings.py:798
ONNXWrapper::m_nr_output
size_t m_nr_output
Definition: ONNXWrapper.h:32
length
double length(const pvec &v)
Definition: FPGATrackSimLLPDoubletHoughTransformTool.cxx:26
ONNXWrapper::m_allocator
Ort::AllocatorWithDefaultOptions m_allocator
Definition: ONNXWrapper.h:45
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37