ATLAS Offline Software
Public Member Functions | Private Member Functions | Private Attributes | List of all members
ONNXWrapper Class Reference

#include <ONNXWrapper.h>

Collaboration diagram for ONNXWrapper:

Public Member Functions

 ONNXWrapper (const std::string model_path)
 
std::map< std::string, std::vector< float > > Run (std::map< std::string, std::vector< float >> inputs, int n_batches=1)
 
const std::map< std::string, std::vector< int64_t > > GetModelInputs ()
 
const std::map< std::string, std::vector< int64_t > > GetModelOutputs ()
 
const std::map< std::string, std::string > GetMETAData ()
 
std::string GetMETADataByKey (const char *key)
 
const std::vector< int64_t > & getInputShape (int input_nr)
 
const std::vector< int64_t > & getOutputShape (int output_nr)
 
const std::vector< const char * > & getInputNames ()
 
const std::vector< const char * > & getOutputNames ()
 
int getNumInputs () const
 
int getNumOutputs () const
 

Private Member Functions

const std::vector< int64_t > getShape (Ort::TypeInfo model_info)
 

Private Attributes

std::string m_modelName
 
std::string m_modelPath
 
size_t m_nr_inputs
 
size_t m_nr_output
 
std::map< std::string, std::vector< int64_t > > m_input_dims
 
std::map< std::string, std::vector< int64_t > > m_output_dims
 
std::unique_ptr< Ort::Session > m_onnxSession
 
std::unique_ptr< Ort::Env > m_onnxEnv
 
Ort::SessionOptions m_session_options
 
Ort::AllocatorWithDefaultOptions m_allocator
 
std::vector< const char * > m_output_names
 
std::vector< const char * > m_input_names
 

Detailed Description

Definition at line 16 of file ONNXWrapper.h.

Constructor & Destructor Documentation

◆ ONNXWrapper()

ONNXWrapper::ONNXWrapper ( const std::string  model_path)

Definition at line 3 of file ONNXWrapper.cxx.

3  {
4 
5  // Use the path resolver to find the location of the network .onnx file
7 
8  // init onnx envi
9  m_onnxEnv = std::make_unique< Ort::Env >(ORT_LOGGING_LEVEL_WARNING, "");
10 
11  // initialize session options if needed
12  m_session_options.SetIntraOpNumThreads(1);
13  m_session_options.SetGraphOptimizationLevel(
14  GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
15 
16 
17  // Initialise the ONNX environment and session using the above options and the model name
18  m_onnxSession = std::make_unique<Ort::Session>(*m_onnxEnv,
19  m_modelPath.c_str(),
21 
22  // get the input nodes
23  m_nr_inputs = m_onnxSession->GetInputCount();
24 
25  // get the output nodes
26  m_nr_output = m_onnxSession->GetOutputCount();
27 
28 
29  // iterate over all input nodes
30  for (std::size_t i = 0; i < m_nr_inputs; i++) {
31  const char* input_name = m_onnxSession->GetInputNameAllocated(i, m_allocator).release();
32 
33  m_input_names.push_back(input_name);
34 
35  m_input_dims[input_name] = getShape(m_onnxSession->GetInputTypeInfo(i));
36  }
37 
38  // iterate over all output nodes
39  for(std::size_t i = 0; i < m_nr_output; i++ ) {
40  const char* output_name = m_onnxSession->GetOutputNameAllocated(i, m_allocator).release();
41 
42  m_output_names.push_back(output_name);
43 
44  m_output_dims[output_name] = getShape(m_onnxSession->GetOutputTypeInfo(i));
45 
46  }
47 }

Member Function Documentation

◆ getInputNames()

const std::vector< const char * > & ONNXWrapper::getInputNames ( )

Definition at line 154 of file ONNXWrapper.cxx.

154  {
155  //put the model access for input here
156  return m_input_names;
157 }

◆ getInputShape()

const std::vector< int64_t > & ONNXWrapper::getInputShape ( int  input_nr = 0)

Definition at line 164 of file ONNXWrapper.cxx.

164  {
165  //put the model access for input here
166  std::vector<const char*> names = getInputNames();
167  return m_input_dims[names.at(input_nr)];
168 }

◆ GetMETAData()

const std::map< std::string, std::string > ONNXWrapper::GetMETAData ( )

Definition at line 136 of file ONNXWrapper.cxx.

136  {
137  std::map<std::string, std::string> METAData_map;
138  auto metadata = m_onnxSession->GetModelMetadata();
139 
140  auto keys = metadata.GetCustomMetadataMapKeysAllocated(m_allocator);
141  // auto status = OrtApi::ModelMetadataGetCustomMetadataMapKeys(metadata, m_allocator, keys, nkeys)
142  for (size_t i = 0; i < keys.size(); i++) {
143  METAData_map[keys[i].get()]=this->GetMETADataByKey(keys[i].get());
144  }
145 
146  return METAData_map;
147 }

◆ GetMETADataByKey()

std::string ONNXWrapper::GetMETADataByKey ( const char *  key)

Definition at line 149 of file ONNXWrapper.cxx.

149  {
150  auto metadata = m_onnxSession->GetModelMetadata();
151  return metadata.LookupCustomMetadataMapAllocated(key, m_allocator).get();
152 }

◆ GetModelInputs()

const std::map< std::string, std::vector< int64_t > > ONNXWrapper::GetModelInputs ( )

Definition at line 118 of file ONNXWrapper.cxx.

118  {
119  std::map<std::string, std::vector<int64_t>> ModelInputINFO_map;
120 
121  for(std::size_t i = 0; i < m_nr_inputs; i++ ) {
122  ModelInputINFO_map[m_input_names.at(i)] = m_input_dims[m_input_names.at(i)];
123  }
124  return ModelInputINFO_map;
125 }

◆ GetModelOutputs()

const std::map< std::string, std::vector< int64_t > > ONNXWrapper::GetModelOutputs ( )

Definition at line 127 of file ONNXWrapper.cxx.

127  {
128  std::map<std::string, std::vector<int64_t>> ModelOutputINFO_map;
129 
130  for(std::size_t i = 0; i < m_nr_output; i++ ) {
131  ModelOutputINFO_map[m_output_names.at(i)] = m_output_dims[m_output_names.at(i)];
132  }
133  return ModelOutputINFO_map;
134 }

◆ getNumInputs()

int ONNXWrapper::getNumInputs ( ) const

Definition at line 176 of file ONNXWrapper.cxx.

176 { return m_input_names.size(); }

◆ getNumOutputs()

int ONNXWrapper::getNumOutputs ( ) const

Definition at line 177 of file ONNXWrapper.cxx.

177 { return m_output_names.size(); }

◆ getOutputNames()

const std::vector< const char * > & ONNXWrapper::getOutputNames ( )

Definition at line 159 of file ONNXWrapper.cxx.

159  {
160  //put the model access for outputs here
161  return m_output_names;
162 }

◆ getOutputShape()

const std::vector< int64_t > & ONNXWrapper::getOutputShape ( int  output_nr = 0)

Definition at line 170 of file ONNXWrapper.cxx.

170  {
171  //put the model access for outputs here
172  std::vector<const char*> names = getOutputNames();
173  return m_output_dims[names.at(output_nr)];
174 }

◆ getShape()

const std::vector< int64_t > ONNXWrapper::getShape ( Ort::TypeInfo  model_info)
private

Definition at line 179 of file ONNXWrapper.cxx.

179  {
180  auto tensor_info = model_info.GetTensorTypeAndShapeInfo();
181  std::vector<int64_t> dims = tensor_info.GetShape();
182  dims[0]=1;
183  return dims;
184  }

◆ Run()

std::map< std::string, std::vector< float > > ONNXWrapper::Run ( std::map< std::string, std::vector< float >>  inputs,
int  n_batches = 1 
)

Definition at line 49 of file ONNXWrapper.cxx.

50  {
51  for ( const auto &p : inputs ) // check for valid dimensions between batches and inputs
52  {
53  uint64_t n=1;
54  for (uint64_t i:m_input_dims[p.first])
55  {
56  n*=i;
57  }
58  if ( (p.second.size() % n) != 0){
59 
60  throw std::invalid_argument("For input '"+p.first+"' length not compatible with model. Expect a multiple of "+std::to_string(n)+", got "+std::to_string(p.second.size()));
61  }
62  if ( p.second.size()!=(n_batches*n)){
63  throw std::invalid_argument("Number of batches not compatible with length of vector");
64  }
65  }
66  // Create a CPU tensor to be used as input
67  Ort::MemoryInfo memory_info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
68 
69  // define input tensor
70  std::vector<Ort::Value> output_tensor;
71  std::vector<Ort::Value> input_tensor;
72 
73  // add the inputs to vector
74  for ( const auto &p : m_input_dims )
75  {
76  std::vector<int64_t> in_dims = p.second;
77  in_dims.at(0) = n_batches;
78  input_tensor.push_back(Ort::Value::CreateTensor<float>(memory_info,
79  inputs[p.first].data(),
80  inputs[p.first].size(),
81  in_dims.data(),
82  in_dims.size()));
83  }
84 
85  // init output tensor and fill with zeros
86  std::map<std::string, std::vector<float>> outputs;
87  for ( const auto &p : m_output_dims ) {
88  std::vector<int64_t> out_dims = p.second;
89  out_dims.at(0) = n_batches;
90  // init output
91  int length = 1;
92  for(auto i : out_dims){ length*=i; }
93  std::vector<float> output(length,0);
94  // std::vector<float> output(m_output_dims[i][1], 0.0);
95  outputs[p.first] = output;
96  output_tensor.push_back(Ort::Value::CreateTensor<float>(memory_info,
97  outputs[p.first].data(),
98  outputs[p.first].size(),
99  out_dims.data(),
100  out_dims.size()));
101  }
102 
103  Ort::Session& session = *m_onnxSession;
104 
105  // run the model
106  session.Run(Ort::RunOptions{nullptr},
107  m_input_names.data(),
108  input_tensor.data(),
109  2,
110  m_output_names.data(),
111  output_tensor.data(),
112  2);
113 
114  return outputs;
115  }

Member Data Documentation

◆ m_allocator

Ort::AllocatorWithDefaultOptions ONNXWrapper::m_allocator
private

Definition at line 41 of file ONNXWrapper.h.

◆ m_input_dims

std::map<std::string, std::vector<int64_t> > ONNXWrapper::m_input_dims
private

Definition at line 31 of file ONNXWrapper.h.

◆ m_input_names

std::vector<const char*> ONNXWrapper::m_input_names
private

Definition at line 47 of file ONNXWrapper.h.

◆ m_modelName

std::string ONNXWrapper::m_modelName
private

Definition at line 21 of file ONNXWrapper.h.

◆ m_modelPath

std::string ONNXWrapper::m_modelPath
private

Definition at line 22 of file ONNXWrapper.h.

◆ m_nr_inputs

size_t ONNXWrapper::m_nr_inputs
private

Definition at line 27 of file ONNXWrapper.h.

◆ m_nr_output

size_t ONNXWrapper::m_nr_output
private

Definition at line 28 of file ONNXWrapper.h.

◆ m_onnxEnv

std::unique_ptr< Ort::Env > ONNXWrapper::m_onnxEnv
private

Definition at line 37 of file ONNXWrapper.h.

◆ m_onnxSession

std::unique_ptr<Ort::Session> ONNXWrapper::m_onnxSession
private

Definition at line 36 of file ONNXWrapper.h.

◆ m_output_dims

std::map<std::string, std::vector<int64_t> > ONNXWrapper::m_output_dims
private

Definition at line 32 of file ONNXWrapper.h.

◆ m_output_names

std::vector<const char*> ONNXWrapper::m_output_names
private

Definition at line 46 of file ONNXWrapper.h.

◆ m_session_options

Ort::SessionOptions ONNXWrapper::m_session_options
private

Definition at line 40 of file ONNXWrapper.h.


The documentation for this class was generated from the following files:
ONNXWrapper::m_onnxSession
std::unique_ptr< Ort::Session > m_onnxSession
Definition: ONNXWrapper.h:36
ONNXWrapper::m_session_options
Ort::SessionOptions m_session_options
Definition: ONNXWrapper.h:40
ONNXWrapper::getInputNames
const std::vector< const char * > & getInputNames()
Definition: ONNXWrapper.cxx:154
ONNXWrapper::m_output_names
std::vector< const char * > m_output_names
Definition: ONNXWrapper.h:46
ONNXWrapper::m_input_names
std::vector< const char * > m_input_names
Definition: ONNXWrapper.h:47
ONNXWrapper::m_onnxEnv
std::unique_ptr< Ort::Env > m_onnxEnv
Definition: ONNXWrapper.h:37
ONNXWrapper::getOutputNames
const std::vector< const char * > & getOutputNames()
Definition: ONNXWrapper.cxx:159
ONNXWrapper::m_modelPath
std::string m_modelPath
Definition: ONNXWrapper.h:22
ONNXWrapper::m_nr_inputs
size_t m_nr_inputs
Definition: ONNXWrapper.h:27
python.oracle.Session
Session
Definition: oracle.py:78
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
ONNXWrapper::GetMETADataByKey
std::string GetMETADataByKey(const char *key)
Definition: ONNXWrapper.cxx:149
python.checkMetadata.metadata
metadata
Definition: checkMetadata.py:175
ONNXWrapper::getShape
const std::vector< int64_t > getShape(Ort::TypeInfo model_info)
Definition: ONNXWrapper.cxx:179
python.utils.AtlRunQueryDQUtils.p
p
Definition: AtlRunQueryDQUtils.py:210
lumiFormat.i
int i
Definition: lumiFormat.py:85
beamspotman.n
n
Definition: beamspotman.py:731
python.subdetectors.mmg.names
names
Definition: mmg.py:8
ONNXWrapper::m_input_dims
std::map< std::string, std::vector< int64_t > > m_input_dims
Definition: ONNXWrapper.h:31
xAOD::uint64_t
uint64_t
Definition: EventInfo_v1.cxx:123
python.CreateTierZeroArgdict.outputs
outputs
Definition: CreateTierZeroArgdict.py:189
merge.output
output
Definition: merge.py:17
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
PathResolverFindCalibFile
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
Definition: PathResolver.cxx:431
ONNXWrapper::m_output_dims
std::map< std::string, std::vector< int64_t > > m_output_dims
Definition: ONNXWrapper.h:32
get
T * get(TKey *tobj)
get a TObject* from a TKey* (why can't a TObject be a TKey?)
Definition: hcg.cxx:127
python.Bindings.keys
keys
Definition: Control/AthenaPython/python/Bindings.py:798
ONNXWrapper::m_nr_output
size_t m_nr_output
Definition: ONNXWrapper.h:28
length
double length(const pvec &v)
Definition: FPGATrackSimLLPDoubletHoughTransformTool.cxx:26
ONNXWrapper::m_allocator
Ort::AllocatorWithDefaultOptions m_allocator
Definition: ONNXWrapper.h:41
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37