ATLAS Offline Software
ONNXWrapper.h
Go to the documentation of this file.
1 #ifndef ONNXUtils_h
2 #define ONNXUtils_h
3 
4 // STL includes
5 #include <string>
6 #include <vector>
7 #include <map>
8 
9 // Asg tool includes
11 
12 // ONNX Library
13 #include <onnxruntime_cxx_api.h>
14 
15 
16 class ONNXWrapper {
17 
18  private:
19 
20  // Class properties
21  std::string m_modelName; // Path to the onnx file
22  std::string m_modelPath; // Output of the path resolver
23 
24  // Features of the network structure
25 
26  // input and output nodes
27  size_t m_nr_inputs;
28  size_t m_nr_output;
29 
30  // dimensions of the input and output
31  std::map<std::string, std::vector<int64_t>> m_input_dims;
32  std::map<std::string, std::vector<int64_t>> m_output_dims;
33 
34 
35  // ONNX session objects
36  std::unique_ptr<Ort::Session> m_onnxSession;
37  std::unique_ptr< Ort::Env > m_onnxEnv;
38 
39  // onnx session options
40  Ort::SessionOptions m_session_options;
41  Ort::AllocatorWithDefaultOptions m_allocator;
42 
43  // allocate memory
44 
45  // name of the outputs
46  std::vector<const char*> m_output_names;
47  std::vector<const char*> m_input_names;
48  const std::vector<int64_t> getShape(Ort::TypeInfo model_info);
49 
50  public:
51  // Constructor with parameters
52 
53  ONNXWrapper(const std::string model_path);
54 
55  std::map<std::string, std::vector<float>> Run(
56  std::map<std::string,
57  std::vector<float>> inputs,
58  int n_batches=1);
59 
60  const std::map<std::string, std::vector<int64_t>> GetModelInputs();
61  const std::map<std::string, std::vector<int64_t>> GetModelOutputs();
62 
63  const std::map<std::string, std::string> GetMETAData();
64  std::string GetMETADataByKey(const char * key);
65  const std::vector<int64_t>& getInputShape(int input_nr);
66  const std::vector<int64_t>& getOutputShape(int output_nr);
67  const std::vector<const char*>& getInputNames();
68  const std::vector<const char*>& getOutputNames();
69  int getNumInputs() const;
70  int getNumOutputs() const;
71 };
72 
73 #endif
ONNXWrapper::m_onnxSession
std::unique_ptr< Ort::Session > m_onnxSession
Definition: ONNXWrapper.h:36
ONNXWrapper::m_session_options
Ort::SessionOptions m_session_options
Definition: ONNXWrapper.h:40
ONNXWrapper::getInputNames
const std::vector< const char * > & getInputNames()
Definition: ONNXWrapper.cxx:154
ONNXWrapper::m_output_names
std::vector< const char * > m_output_names
Definition: ONNXWrapper.h:46
ONNXWrapper::m_input_names
std::vector< const char * > m_input_names
Definition: ONNXWrapper.h:47
ONNXWrapper::m_onnxEnv
std::unique_ptr< Ort::Env > m_onnxEnv
Definition: ONNXWrapper.h:37
ONNXWrapper::GetMETAData
const std::map< std::string, std::string > GetMETAData()
Definition: ONNXWrapper.cxx:136
ONNXWrapper::getOutputNames
const std::vector< const char * > & getOutputNames()
Definition: ONNXWrapper.cxx:159
ONNXWrapper::m_modelPath
std::string m_modelPath
Definition: ONNXWrapper.h:22
ONNXWrapper::ONNXWrapper
ONNXWrapper(const std::string model_path)
Definition: ONNXWrapper.cxx:3
ONNXWrapper
Definition: ONNXWrapper.h:16
ONNXWrapper::m_nr_inputs
size_t m_nr_inputs
Definition: ONNXWrapper.h:27
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
ONNXWrapper::GetMETADataByKey
std::string GetMETADataByKey(const char *key)
Definition: ONNXWrapper.cxx:149
ONNXWrapper::m_modelName
std::string m_modelName
Definition: ONNXWrapper.h:21
ONNXWrapper::getShape
const std::vector< int64_t > getShape(Ort::TypeInfo model_info)
Definition: ONNXWrapper.cxx:179
ONNXWrapper::m_input_dims
std::map< std::string, std::vector< int64_t > > m_input_dims
Definition: ONNXWrapper.h:31
ONNXWrapper::Run
std::map< std::string, std::vector< float > > Run(std::map< std::string, std::vector< float >> inputs, int n_batches=1)
Definition: ONNXWrapper.cxx:49
ONNXWrapper::GetModelOutputs
const std::map< std::string, std::vector< int64_t > > GetModelOutputs()
Definition: ONNXWrapper.cxx:127
PathResolver.h
ONNXWrapper::getNumInputs
int getNumInputs() const
Definition: ONNXWrapper.cxx:176
ONNXWrapper::m_output_dims
std::map< std::string, std::vector< int64_t > > m_output_dims
Definition: ONNXWrapper.h:32
ONNXWrapper::GetModelInputs
const std::map< std::string, std::vector< int64_t > > GetModelInputs()
Definition: ONNXWrapper.cxx:118
ONNXWrapper::getInputShape
const std::vector< int64_t > & getInputShape(int input_nr)
Definition: ONNXWrapper.cxx:164
ONNXWrapper::m_nr_output
size_t m_nr_output
Definition: ONNXWrapper.h:28
ONNXWrapper::getNumOutputs
int getNumOutputs() const
Definition: ONNXWrapper.cxx:177
ONNXWrapper::getOutputShape
const std::vector< int64_t > & getOutputShape(int output_nr)
Definition: ONNXWrapper.cxx:170
ONNXWrapper::m_allocator
Ort::AllocatorWithDefaultOptions m_allocator
Definition: ONNXWrapper.h:41
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37