ATLAS Offline Software
Loading...
Searching...
No Matches
ONNXWrapper Class Reference

#include <ONNXWrapper.h>

Collaboration diagram for ONNXWrapper:

Public Member Functions

 ONNXWrapper (const std::string &model_path)
std::map< std::string, std::vector< float > > Run (std::map< std::string, std::vector< float > > inputs, int n_batches=1)
const std::map< std::string, std::vector< int64_t > > GetModelInputs ()
const std::map< std::string, std::vector< int64_t > > GetModelOutputs ()
const std::map< std::string, std::string > GetMETAData ()
std::string GetMETADataByKey (const char *key)
const std::vector< int64_t > & getInputShape (int input_nr)
const std::vector< int64_t > & getOutputShape (int output_nr)
const std::vector< const char * > & getInputNames ()
const std::vector< const char * > & getOutputNames ()
int getNumInputs () const
int getNumOutputs () const

Private Member Functions

const std::vector< int64_t > getShape (Ort::TypeInfo model_info)

Private Attributes

std::string m_modelName
std::string m_modelPath
size_t m_nr_inputs
size_t m_nr_output
std::map< std::string, std::vector< int64_t > > m_input_dims
std::map< std::string, std::vector< int64_t > > m_output_dims
std::unique_ptr< Ort::Session > m_onnxSession
std::unique_ptr< Ort::Env > m_onnxEnv
Ort::SessionOptions m_session_options
Ort::AllocatorWithDefaultOptions m_allocator
std::vector< const char * > m_output_names
std::vector< const char * > m_input_names

Detailed Description

Definition at line 20 of file ONNXWrapper.h.

Constructor & Destructor Documentation

◆ ONNXWrapper()

ONNXWrapper::ONNXWrapper ( const std::string & model_path)

Definition at line 9 of file ONNXWrapper.cxx.

9 :
11 m_onnxEnv(std::make_unique< Ort::Env >(ORT_LOGGING_LEVEL_WARNING, "")) {
12
13 // initialize session options if needed
14 m_session_options.SetIntraOpNumThreads(1);
15 m_session_options.SetGraphOptimizationLevel(
16 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
17
18
19 // Initialise the ONNX environment and session using the above options and the model name
20 m_onnxSession = std::make_unique<Ort::Session>(*m_onnxEnv,
21 m_modelPath.c_str(),
23
24 // get the input nodes
25 m_nr_inputs = m_onnxSession->GetInputCount();
26
27 // get the output nodes
28 m_nr_output = m_onnxSession->GetOutputCount();
29
30
31 // iterate over all input nodes
32 for (std::size_t i = 0; i < m_nr_inputs; i++) {
33 const char* input_name = m_onnxSession->GetInputNameAllocated(i, m_allocator).release();
34
35 m_input_names.push_back(input_name);
36
37 m_input_dims[input_name] = getShape(m_onnxSession->GetInputTypeInfo(i));
38 }
39
40 // iterate over all output nodes
41 for(std::size_t i = 0; i < m_nr_output; i++ ) {
42 const char* output_name = m_onnxSession->GetOutputNameAllocated(i, m_allocator).release();
43
44 m_output_names.push_back(output_name);
45
46 m_output_dims[output_name] = getShape(m_onnxSession->GetOutputTypeInfo(i));
47
48 }
49}
std::string PathResolverFindCalibFile(const std::string &logical_file_name)
const std::vector< int64_t > getShape(Ort::TypeInfo model_info)
size_t m_nr_output
Definition ONNXWrapper.h:32
Ort::AllocatorWithDefaultOptions m_allocator
Definition ONNXWrapper.h:45
std::string m_modelPath
Definition ONNXWrapper.h:26
std::map< std::string, std::vector< int64_t > > m_input_dims
Definition ONNXWrapper.h:35
size_t m_nr_inputs
Definition ONNXWrapper.h:31
std::vector< const char * > m_output_names
Definition ONNXWrapper.h:50
Ort::SessionOptions m_session_options
Definition ONNXWrapper.h:44
std::unique_ptr< Ort::Session > m_onnxSession
Definition ONNXWrapper.h:40
std::vector< const char * > m_input_names
Definition ONNXWrapper.h:51
std::map< std::string, std::vector< int64_t > > m_output_dims
Definition ONNXWrapper.h:36
std::unique_ptr< Ort::Env > m_onnxEnv
Definition ONNXWrapper.h:41

Member Function Documentation

◆ getInputNames()

const std::vector< const char * > & ONNXWrapper::getInputNames ( )

Definition at line 156 of file ONNXWrapper.cxx.

156 {
157 //put the model access for input here
158 return m_input_names;
159}

◆ getInputShape()

const std::vector< int64_t > & ONNXWrapper::getInputShape ( int input_nr = 0)

Definition at line 166 of file ONNXWrapper.cxx.

166 {
167 //put the model access for input here
168 std::vector<const char*> names = getInputNames();
169 return m_input_dims[names.at(input_nr)];
170}
const std::vector< const char * > & getInputNames()

◆ GetMETAData()

const std::map< std::string, std::string > ONNXWrapper::GetMETAData ( )

Definition at line 138 of file ONNXWrapper.cxx.

138 {
139 std::map<std::string, std::string> METAData_map;
140 auto metadata = m_onnxSession->GetModelMetadata();
141
142 auto keys = metadata.GetCustomMetadataMapKeysAllocated(m_allocator);
143 // auto status = OrtApi::ModelMetadataGetCustomMetadataMapKeys(metadata, m_allocator, keys, nkeys)
144 for (size_t i = 0; i < keys.size(); i++) {
145 METAData_map[keys[i].get()]=this->GetMETADataByKey(keys[i].get());
146 }
147
148 return METAData_map;
149}
std::string GetMETADataByKey(const char *key)
T * get(TKey *tobj)
get a TObject* from a TKey* (why can't a TObject be a TKey?)
Definition hcg.cxx:130

◆ GetMETADataByKey()

std::string ONNXWrapper::GetMETADataByKey ( const char * key)

Definition at line 151 of file ONNXWrapper.cxx.

151 {
152 auto metadata = m_onnxSession->GetModelMetadata();
153 return metadata.LookupCustomMetadataMapAllocated(key, m_allocator).get();
154}

◆ GetModelInputs()

const std::map< std::string, std::vector< int64_t > > ONNXWrapper::GetModelInputs ( )

Definition at line 120 of file ONNXWrapper.cxx.

120 {
121 std::map<std::string, std::vector<int64_t>> ModelInputINFO_map;
122
123 for(std::size_t i = 0; i < m_nr_inputs; i++ ) {
124 ModelInputINFO_map[m_input_names.at(i)] = m_input_dims[m_input_names.at(i)];
125 }
126 return ModelInputINFO_map;
127}

◆ GetModelOutputs()

const std::map< std::string, std::vector< int64_t > > ONNXWrapper::GetModelOutputs ( )

Definition at line 129 of file ONNXWrapper.cxx.

129 {
130 std::map<std::string, std::vector<int64_t>> ModelOutputINFO_map;
131
132 for(std::size_t i = 0; i < m_nr_output; i++ ) {
133 ModelOutputINFO_map[m_output_names.at(i)] = m_output_dims[m_output_names.at(i)];
134 }
135 return ModelOutputINFO_map;
136}

◆ getNumInputs()

int ONNXWrapper::getNumInputs ( ) const

Definition at line 178 of file ONNXWrapper.cxx.

178{ return m_input_names.size(); }

◆ getNumOutputs()

int ONNXWrapper::getNumOutputs ( ) const

Definition at line 179 of file ONNXWrapper.cxx.

179{ return m_output_names.size(); }

◆ getOutputNames()

const std::vector< const char * > & ONNXWrapper::getOutputNames ( )

Definition at line 161 of file ONNXWrapper.cxx.

161 {
162 //put the model access for outputs here
163 return m_output_names;
164}

◆ getOutputShape()

const std::vector< int64_t > & ONNXWrapper::getOutputShape ( int output_nr = 0)

Definition at line 172 of file ONNXWrapper.cxx.

172 {
173 //put the model access for outputs here
174 std::vector<const char*> names = getOutputNames();
175 return m_output_dims[names.at(output_nr)];
176}
const std::vector< const char * > & getOutputNames()

◆ getShape()

const std::vector< int64_t > ONNXWrapper::getShape ( Ort::TypeInfo model_info)
private

Definition at line 181 of file ONNXWrapper.cxx.

181 {
182 auto tensor_info = model_info.GetTensorTypeAndShapeInfo();
183 std::vector<int64_t> dims = tensor_info.GetShape();
184 dims[0]=1;
185 return dims;
186 }

◆ Run()

std::map< std::string, std::vector< float > > ONNXWrapper::Run ( std::map< std::string, std::vector< float > > inputs,
int n_batches = 1 )

Definition at line 51 of file ONNXWrapper.cxx.

52 {
53 for ( const auto &p : inputs ) // check for valid dimensions between batches and inputs
54 {
55 uint64_t n=1;
56 for (uint64_t i:m_input_dims[p.first])
57 {
58 n*=i;
59 }
60 if ( (p.second.size() % n) != 0){
61
62 throw std::invalid_argument("For input '"+p.first+"' length not compatible with model. Expect a multiple of "+std::to_string(n)+", got "+std::to_string(p.second.size()));
63 }
64 if ( p.second.size()!=(n_batches*n)){
65 throw std::invalid_argument("Number of batches not compatible with length of vector");
66 }
67 }
68 // Create a CPU tensor to be used as input
69 Ort::MemoryInfo memory_info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
70
71 // define input tensor
72 std::vector<Ort::Value> output_tensor;
73 std::vector<Ort::Value> input_tensor;
74
75 // add the inputs to vector
76 for ( const auto &p : m_input_dims )
77 {
78 std::vector<int64_t> in_dims = p.second;
79 in_dims.at(0) = n_batches;
80 input_tensor.push_back(Ort::Value::CreateTensor<float>(memory_info,
81 inputs[p.first].data(),
82 inputs[p.first].size(),
83 in_dims.data(),
84 in_dims.size()));
85 }
86
87 // init output tensor and fill with zeros
88 std::map<std::string, std::vector<float>> outputs;
89 for ( const auto &p : m_output_dims ) {
90 std::vector<int64_t> out_dims = p.second;
91 out_dims.at(0) = n_batches;
92 // init output
93 int length = 1;
94 for(auto i : out_dims){ length*=i; }
95 std::vector<float> output(length,0);
96 // std::vector<float> output(m_output_dims[i][1], 0.0);
97 outputs[p.first] = std::move(output);
98 output_tensor.push_back(Ort::Value::CreateTensor<float>(memory_info,
99 outputs[p.first].data(),
100 outputs[p.first].size(),
101 out_dims.data(),
102 out_dims.size()));
103 }
104
105 Ort::Session& session = *m_onnxSession;
106
107 // run the model
108 session.Run(Ort::RunOptions{nullptr},
109 m_input_names.data(),
110 input_tensor.data(),
111 2,
112 m_output_names.data(),
113 output_tensor.data(),
114 2);
115
116 return outputs;
117 }
double length(const pvec &v)
output
Definition merge.py:16

Member Data Documentation

◆ m_allocator

Ort::AllocatorWithDefaultOptions ONNXWrapper::m_allocator
private

Definition at line 45 of file ONNXWrapper.h.

◆ m_input_dims

std::map<std::string, std::vector<int64_t> > ONNXWrapper::m_input_dims
private

Definition at line 35 of file ONNXWrapper.h.

◆ m_input_names

std::vector<const char*> ONNXWrapper::m_input_names
private

Definition at line 51 of file ONNXWrapper.h.

◆ m_modelName

std::string ONNXWrapper::m_modelName
private

Definition at line 25 of file ONNXWrapper.h.

◆ m_modelPath

std::string ONNXWrapper::m_modelPath
private

Definition at line 26 of file ONNXWrapper.h.

◆ m_nr_inputs

size_t ONNXWrapper::m_nr_inputs
private

Definition at line 31 of file ONNXWrapper.h.

◆ m_nr_output

size_t ONNXWrapper::m_nr_output
private

Definition at line 32 of file ONNXWrapper.h.

◆ m_onnxEnv

std::unique_ptr< Ort::Env > ONNXWrapper::m_onnxEnv
private

Definition at line 41 of file ONNXWrapper.h.

◆ m_onnxSession

std::unique_ptr<Ort::Session> ONNXWrapper::m_onnxSession
private

Definition at line 40 of file ONNXWrapper.h.

◆ m_output_dims

std::map<std::string, std::vector<int64_t> > ONNXWrapper::m_output_dims
private

Definition at line 36 of file ONNXWrapper.h.

◆ m_output_names

std::vector<const char*> ONNXWrapper::m_output_names
private

Definition at line 50 of file ONNXWrapper.h.

◆ m_session_options

Ort::SessionOptions ONNXWrapper::m_session_options
private

Definition at line 44 of file ONNXWrapper.h.


The documentation for this class was generated from the following files: