ATLAS Offline Software
Loading...
Searching...
No Matches
ONNXWrapper.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
4#ifndef ONNXUtils_h
5#define ONNXUtils_h
6
7// STL includes
8#include <string>
9#include <vector>
10#include <map>
11
12// Asg tool includes
13
15
16// ONNX Library
17#include <onnxruntime_cxx_api.h>
18
19
21
22 private:
23
24 // Class properties
25 std::string m_modelName; // Path to the onnx file
26 std::string m_modelPath; // Output of the path resolver
27
28 // Features of the network structure
29
30 // input and output nodes
33
34 // dimensions of the input and output
35 std::map<std::string, std::vector<int64_t>> m_input_dims;
36 std::map<std::string, std::vector<int64_t>> m_output_dims;
37
38
39 // ONNX session objects
40 std::unique_ptr<Ort::Session> m_onnxSession;
41 std::unique_ptr< Ort::Env > m_onnxEnv;
42
43 // onnx session options
44 Ort::SessionOptions m_session_options;
45 Ort::AllocatorWithDefaultOptions m_allocator;
46
47 // allocate memory
48
49 // name of the outputs
50 std::vector<const char*> m_output_names;
51 std::vector<const char*> m_input_names;
52 const std::vector<int64_t> getShape(Ort::TypeInfo model_info);
53
54 public:
55 // Constructor with parameters
56
57 ONNXWrapper(const std::string & model_path);
58
59 std::map<std::string, std::vector<float>> Run(
60 std::map<std::string,
61 std::vector<float>> inputs,
62 int n_batches=1);
63
64 const std::map<std::string, std::vector<int64_t>> GetModelInputs();
65 const std::map<std::string, std::vector<int64_t>> GetModelOutputs();
66
67 const std::map<std::string, std::string> GetMETAData();
68 std::string GetMETADataByKey(const char * key);
69 const std::vector<int64_t>& getInputShape(int input_nr);
70 const std::vector<int64_t>& getOutputShape(int output_nr);
71 const std::vector<const char*>& getInputNames();
72 const std::vector<const char*>& getOutputNames();
73 int getNumInputs() const;
74 int getNumOutputs() const;
75};
76
77#endif
const std::vector< int64_t > getShape(Ort::TypeInfo model_info)
int getNumOutputs() const
int getNumInputs() const
size_t m_nr_output
Definition ONNXWrapper.h:32
const std::map< std::string, std::string > GetMETAData()
Ort::AllocatorWithDefaultOptions m_allocator
Definition ONNXWrapper.h:45
const std::vector< const char * > & getOutputNames()
std::string m_modelPath
Definition ONNXWrapper.h:26
ONNXWrapper(const std::string &model_path)
std::string GetMETADataByKey(const char *key)
std::map< std::string, std::vector< int64_t > > m_input_dims
Definition ONNXWrapper.h:35
std::map< std::string, std::vector< float > > Run(std::map< std::string, std::vector< float > > inputs, int n_batches=1)
const std::vector< const char * > & getInputNames()
const std::map< std::string, std::vector< int64_t > > GetModelInputs()
size_t m_nr_inputs
Definition ONNXWrapper.h:31
std::vector< const char * > m_output_names
Definition ONNXWrapper.h:50
Ort::SessionOptions m_session_options
Definition ONNXWrapper.h:44
std::unique_ptr< Ort::Session > m_onnxSession
Definition ONNXWrapper.h:40
std::vector< const char * > m_input_names
Definition ONNXWrapper.h:51
const std::vector< int64_t > & getInputShape(int input_nr)
std::map< std::string, std::vector< int64_t > > m_output_dims
Definition ONNXWrapper.h:36
const std::vector< int64_t > & getOutputShape(int output_nr)
const std::map< std::string, std::vector< int64_t > > GetModelOutputs()
std::unique_ptr< Ort::Env > m_onnxEnv
Definition ONNXWrapper.h:41
std::string m_modelName
Definition ONNXWrapper.h:25