ATLAS Offline Software
ONNXWrapper.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 #ifndef ONNXUtils_h
5 #define ONNXUtils_h
6 
7 // STL includes
8 #include <string>
9 #include <vector>
10 #include <map>
11 
12 // Asg tool includes
13 
15 
16 // ONNX Library
17 #include <onnxruntime_cxx_api.h>
18 
19 
20 class ONNXWrapper {
21 
22  private:
23 
24  // Class properties
25  std::string m_modelName; // Path to the onnx file
26  std::string m_modelPath; // Output of the path resolver
27 
28  // Features of the network structure
29 
30  // input and output nodes
31  size_t m_nr_inputs;
32  size_t m_nr_output;
33 
34  // dimensions of the input and output
35  std::map<std::string, std::vector<int64_t>> m_input_dims;
36  std::map<std::string, std::vector<int64_t>> m_output_dims;
37 
38 
39  // ONNX session objects
40  std::unique_ptr<Ort::Session> m_onnxSession;
41  std::unique_ptr< Ort::Env > m_onnxEnv;
42 
43  // onnx session options
44  Ort::SessionOptions m_session_options;
45  Ort::AllocatorWithDefaultOptions m_allocator;
46 
47  // allocate memory
48 
49  // name of the outputs
50  std::vector<const char*> m_output_names;
51  std::vector<const char*> m_input_names;
52  const std::vector<int64_t> getShape(Ort::TypeInfo model_info);
53 
54  public:
55  // Constructor with parameters
56 
57  ONNXWrapper(const std::string & model_path);
58 
59  std::map<std::string, std::vector<float>> Run(
60  std::map<std::string,
61  std::vector<float>> inputs,
62  int n_batches=1);
63 
64  const std::map<std::string, std::vector<int64_t>> GetModelInputs();
65  const std::map<std::string, std::vector<int64_t>> GetModelOutputs();
66 
67  const std::map<std::string, std::string> GetMETAData();
68  std::string GetMETADataByKey(const char * key);
69  const std::vector<int64_t>& getInputShape(int input_nr);
70  const std::vector<int64_t>& getOutputShape(int output_nr);
71  const std::vector<const char*>& getInputNames();
72  const std::vector<const char*>& getOutputNames();
73  int getNumInputs() const;
74  int getNumOutputs() const;
75 };
76 
77 #endif
ONNXWrapper::m_onnxSession
std::unique_ptr< Ort::Session > m_onnxSession
Definition: ONNXWrapper.h:40
ONNXWrapper::m_session_options
Ort::SessionOptions m_session_options
Definition: ONNXWrapper.h:44
ONNXWrapper::getInputNames
const std::vector< const char * > & getInputNames()
Definition: ONNXWrapper.cxx:156
ONNXWrapper::m_output_names
std::vector< const char * > m_output_names
Definition: ONNXWrapper.h:50
ONNXWrapper::m_input_names
std::vector< const char * > m_input_names
Definition: ONNXWrapper.h:51
ONNXWrapper::m_onnxEnv
std::unique_ptr< Ort::Env > m_onnxEnv
Definition: ONNXWrapper.h:41
ONNXWrapper::GetMETAData
const std::map< std::string, std::string > GetMETAData()
Definition: ONNXWrapper.cxx:138
ONNXWrapper::getOutputNames
const std::vector< const char * > & getOutputNames()
Definition: ONNXWrapper.cxx:161
ONNXWrapper::m_modelPath
std::string m_modelPath
Definition: ONNXWrapper.h:26
ONNXWrapper
Definition: ONNXWrapper.h:20
ONNXWrapper::m_nr_inputs
size_t m_nr_inputs
Definition: ONNXWrapper.h:31
postInclude.inputs
inputs
Definition: postInclude.SortInput.py:15
ONNXWrapper::GetMETADataByKey
std::string GetMETADataByKey(const char *key)
Definition: ONNXWrapper.cxx:151
ONNXWrapper::m_modelName
std::string m_modelName
Definition: ONNXWrapper.h:25
ONNXWrapper::getShape
const std::vector< int64_t > getShape(Ort::TypeInfo model_info)
Definition: ONNXWrapper.cxx:181
ONNXWrapper::m_input_dims
std::map< std::string, std::vector< int64_t > > m_input_dims
Definition: ONNXWrapper.h:35
ONNXWrapper::ONNXWrapper
ONNXWrapper(const std::string &model_path)
Definition: ONNXWrapper.cxx:9
ONNXWrapper::Run
std::map< std::string, std::vector< float > > Run(std::map< std::string, std::vector< float >> inputs, int n_batches=1)
Definition: ONNXWrapper.cxx:51
ONNXWrapper::GetModelOutputs
const std::map< std::string, std::vector< int64_t > > GetModelOutputs()
Definition: ONNXWrapper.cxx:129
PathResolver.h
ONNXWrapper::getNumInputs
int getNumInputs() const
Definition: ONNXWrapper.cxx:178
ONNXWrapper::m_output_dims
std::map< std::string, std::vector< int64_t > > m_output_dims
Definition: ONNXWrapper.h:36
ONNXWrapper::GetModelInputs
const std::map< std::string, std::vector< int64_t > > GetModelInputs()
Definition: ONNXWrapper.cxx:120
ONNXWrapper::getInputShape
const std::vector< int64_t > & getInputShape(int input_nr)
Definition: ONNXWrapper.cxx:166
ONNXWrapper::m_nr_output
size_t m_nr_output
Definition: ONNXWrapper.h:32
ONNXWrapper::getNumOutputs
int getNumOutputs() const
Definition: ONNXWrapper.cxx:179
ONNXWrapper::getOutputShape
const std::vector< int64_t > & getOutputShape(int output_nr)
Definition: ONNXWrapper.cxx:172
ONNXWrapper::m_allocator
Ort::AllocatorWithDefaultOptions m_allocator
Definition: ONNXWrapper.h:45
mapkey::key
key
Definition: TElectronEfficiencyCorrectionTool.cxx:37