#include <ONNXWrapper.h>
Definition at line 16 of file ONNXWrapper.h.
◆ ONNXWrapper()
ONNXWrapper::ONNXWrapper |
( |
const std::string |
model_path | ) |
|
Definition at line 3 of file ONNXWrapper.cxx.
9 m_onnxEnv = std::make_unique< Ort::Env >(ORT_LOGGING_LEVEL_WARNING,
"");
14 GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
◆ getInputNames()
const std::vector< const char * > & ONNXWrapper::getInputNames |
( |
| ) |
|
◆ getInputShape()
const std::vector< int64_t > & ONNXWrapper::getInputShape |
( |
int |
input_nr = 0 | ) |
|
◆ GetMETAData()
const std::map< std::string, std::string > ONNXWrapper::GetMETAData |
( |
| ) |
|
Definition at line 136 of file ONNXWrapper.cxx.
137 std::map<std::string, std::string> METAData_map;
142 for (
size_t i = 0;
i <
keys.size();
i++) {
◆ GetMETADataByKey()
std::string ONNXWrapper::GetMETADataByKey |
( |
const char * |
key | ) |
|
◆ GetModelInputs()
const std::map< std::string, std::vector< int64_t > > ONNXWrapper::GetModelInputs |
( |
| ) |
|
Definition at line 118 of file ONNXWrapper.cxx.
119 std::map<std::string, std::vector<int64_t>> ModelInputINFO_map;
124 return ModelInputINFO_map;
◆ GetModelOutputs()
const std::map< std::string, std::vector< int64_t > > ONNXWrapper::GetModelOutputs |
( |
| ) |
|
Definition at line 127 of file ONNXWrapper.cxx.
128 std::map<std::string, std::vector<int64_t>> ModelOutputINFO_map;
133 return ModelOutputINFO_map;
◆ getNumInputs()
int ONNXWrapper::getNumInputs |
( |
| ) |
const |
◆ getNumOutputs()
int ONNXWrapper::getNumOutputs |
( |
| ) |
const |
◆ getOutputNames()
const std::vector< const char * > & ONNXWrapper::getOutputNames |
( |
| ) |
|
◆ getOutputShape()
const std::vector< int64_t > & ONNXWrapper::getOutputShape |
( |
int |
output_nr = 0 | ) |
|
◆ getShape()
const std::vector< int64_t > ONNXWrapper::getShape |
( |
Ort::TypeInfo |
model_info | ) |
|
|
private |
Definition at line 179 of file ONNXWrapper.cxx.
180 auto tensor_info = model_info.GetTensorTypeAndShapeInfo();
181 std::vector<int64_t> dims = tensor_info.GetShape();
◆ Run()
std::map< std::string, std::vector< float > > ONNXWrapper::Run |
( |
std::map< std::string, std::vector< float >> |
inputs, |
|
|
int |
n_batches = 1 |
|
) |
| |
Definition at line 49 of file ONNXWrapper.cxx.
58 if ( (
p.second.size() %
n) != 0){
60 throw std::invalid_argument(
"For input '"+
p.first+
"' length not compatible with model. Expect a multiple of "+
std::to_string(
n)+
", got "+
std::to_string(
p.second.size()));
62 if (
p.second.size()!=(n_batches*
n)){
63 throw std::invalid_argument(
"Number of batches not compatible with length of vector");
67 Ort::MemoryInfo memory_info(
"Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
70 std::vector<Ort::Value> output_tensor;
71 std::vector<Ort::Value> input_tensor;
76 std::vector<int64_t> in_dims =
p.second;
77 in_dims.at(0) = n_batches;
78 input_tensor.push_back(Ort::Value::CreateTensor<float>(memory_info,
86 std::map<std::string, std::vector<float>>
outputs;
88 std::vector<int64_t> out_dims =
p.second;
89 out_dims.at(0) = n_batches;
96 output_tensor.push_back(Ort::Value::CreateTensor<float>(memory_info,
106 session.Run(Ort::RunOptions{
nullptr},
111 output_tensor.data(),
◆ m_allocator
Ort::AllocatorWithDefaultOptions ONNXWrapper::m_allocator |
|
private |
◆ m_input_dims
std::map<std::string, std::vector<int64_t> > ONNXWrapper::m_input_dims |
|
private |
◆ m_input_names
std::vector<const char*> ONNXWrapper::m_input_names |
|
private |
◆ m_modelName
std::string ONNXWrapper::m_modelName |
|
private |
◆ m_modelPath
std::string ONNXWrapper::m_modelPath |
|
private |
◆ m_nr_inputs
size_t ONNXWrapper::m_nr_inputs |
|
private |
◆ m_nr_output
size_t ONNXWrapper::m_nr_output |
|
private |
◆ m_onnxEnv
std::unique_ptr< Ort::Env > ONNXWrapper::m_onnxEnv |
|
private |
◆ m_onnxSession
std::unique_ptr<Ort::Session> ONNXWrapper::m_onnxSession |
|
private |
◆ m_output_dims
std::map<std::string, std::vector<int64_t> > ONNXWrapper::m_output_dims |
|
private |
◆ m_output_names
std::vector<const char*> ONNXWrapper::m_output_names |
|
private |
◆ m_session_options
Ort::SessionOptions ONNXWrapper::m_session_options |
|
private |
The documentation for this class was generated from the following files: