ATLAS Offline Software
Loading...
Searching...
No Matches
OnnxRuntimeBase.h
Go to the documentation of this file.
1
2#include <vector>
3#include <map>
4#include <Eigen/Dense>
5#include <onnxruntime_cxx_api.h>
6#include <TString.h>
7
9 Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
10
11// General class that sets up the ONNX runtime framework for loading a ML model
12// and using it for inference.
14 public:
15 // int m_totalInputs;
16
17 TString m_fileName;
18 // int n_hits;
19 // std::string m_trackType;
20 // std::string m_order;
21 // bool m_scaled;
22
23 OnnxRuntimeBase(TString fileName);
26 void initialize(TString);
27
28 std::vector<float> runONNXInference(std::vector<float>& inputTensorValues) const;
29 std::vector<std::vector<float>> runONNXInference(std::vector<std::vector<float> >& inputTensorValues) const;
30 std::vector<std::vector<float>> runONNXInference(NetworkBatchInput& inputTensorValues) const;
31 std::map<int, Eigen::MatrixXf> runONNXInferenceMultilayerOutput(NetworkBatchInput& inputTensorValues) const;
32
33 const std::vector<int64_t>& getInputNodesDims(){return m_inputNodeDims;};
34 const std::vector<int64_t>& getOutputNodesDims(){return m_outputNodeDims;};
35
36 private:
38 std::unique_ptr<Ort::Session> m_session;
39
40 std::vector<const char*> m_inputNodeNames;
41 std::vector<int64_t> m_inputNodeDims;
42 std::vector<const char*> m_outputNodeNames;
43 std::vector<int64_t> m_outputNodeDims;
44
45 std::unique_ptr< Ort::Env > m_env;
46
47};
48
Eigen::Matrix< float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > NetworkBatchInput
const std::vector< int64_t > & getInputNodesDims()
std::unique_ptr< Ort::Env > m_env
OnnxRuntimeBase(TString fileName)
std::vector< const char * > m_inputNodeNames
const std::vector< int64_t > & getOutputNodesDims()
std::vector< int64_t > m_outputNodeDims
std::map< int, Eigen::MatrixXf > runONNXInferenceMultilayerOutput(NetworkBatchInput &inputTensorValues) const
std::vector< int64_t > m_inputNodeDims
std::vector< float > runONNXInference(std::vector< float > &inputTensorValues) const
std::vector< const char * > m_outputNodeNames
std::unique_ptr< Ort::Session > m_session
ONNX runtime session / model properties.
void initialize()