ATLAS Offline Software
OnnxRuntimeBase.h
Go to the documentation of this file.
1 
2 #include <vector>
3 #include <map>
4 #include <Eigen/Dense>
5 #include <onnxruntime_cxx_api.h>
6 #include <TString.h>
7 
9  Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
10 
11 // General class that sets up the ONNX runtime framework for loading a ML model
12 // and using it for inference.
14  public:
15  // int m_totalInputs;
16 
17  TString m_fileName;
18  // int n_hits;
19  // std::string m_trackType;
20  // std::string m_order;
21  // bool m_scaled;
22 
23  OnnxRuntimeBase(TString fileName);
26  void initialize(TString);
27 
28  std::vector<float> runONNXInference(std::vector<float>& inputTensorValues) const;
29  std::vector<std::vector<float>> runONNXInference(NetworkBatchInput& inputTensorValues) const;
30  std::map<int, Eigen::MatrixXf> runONNXInferenceMultilayerOutput(NetworkBatchInput& inputTensorValues) const;
31 
32  const std::vector<int64_t>& getInputNodesDims(){return m_inputNodeDims;};
33  const std::vector<int64_t>& getOutputNodesDims(){return m_outputNodeDims;};
34 
35  private:
37  std::unique_ptr<Ort::Session> m_session;
38 
39  std::vector<const char*> m_inputNodeNames;
40  std::vector<int64_t> m_inputNodeDims;
41  std::vector<const char*> m_outputNodeNames;
42  std::vector<int64_t> m_outputNodeDims;
43 
44  std::unique_ptr< Ort::Env > m_env;
45 
46 };
47 
OnnxRuntimeBase::runONNXInference
std::vector< float > runONNXInference(std::vector< float > &inputTensorValues) const
Definition: OnnxRuntimeBase.cxx:65
OnnxRuntimeBase::m_env
std::unique_ptr< Ort::Env > m_env
Definition: OnnxRuntimeBase.h:44
OnnxRuntimeBase::initialize
void initialize(TString)
Definition: OnnxRuntimeBase.cxx:16
OnnxRuntimeBase::m_outputNodeDims
std::vector< int64_t > m_outputNodeDims
Definition: OnnxRuntimeBase.h:42
OnnxRuntimeBase::getOutputNodesDims
const std::vector< int64_t > & getOutputNodesDims()
Definition: OnnxRuntimeBase.h:33
OnnxRuntimeBase::m_inputNodeDims
std::vector< int64_t > m_inputNodeDims
Definition: OnnxRuntimeBase.h:40
OnnxRuntimeBase::m_session
std::unique_ptr< Ort::Session > m_session
ONNX runtime session / model properties.
Definition: OnnxRuntimeBase.h:33
OnnxRuntimeBase::m_inputNodeNames
std::vector< const char * > m_inputNodeNames
Definition: OnnxRuntimeBase.h:39
FortranAlgorithmOptions.fileName
fileName
Definition: FortranAlgorithmOptions.py:13
OnnxRuntimeBase::runONNXInferenceMultilayerOutput
std::map< int, Eigen::MatrixXf > runONNXInferenceMultilayerOutput(NetworkBatchInput &inputTensorValues) const
Definition: OnnxRuntimeBase.cxx:143
OnnxRuntimeBase::m_fileName
TString m_fileName
Definition: OnnxRuntimeBase.h:17
OnnxRuntimeBase::getInputNodesDims
const std::vector< int64_t > & getInputNodesDims()
Definition: OnnxRuntimeBase.h:32
OnnxRuntimeBase
Definition: OnnxRuntimeBase.h:13
OnnxRuntimeBase::m_outputNodeNames
std::vector< const char * > m_outputNodeNames
Definition: OnnxRuntimeBase.h:41
OnnxRuntimeBase::OnnxRuntimeBase
OnnxRuntimeBase()
Definition: OnnxRuntimeBase.cxx:14
OnnxRuntimeBase::~OnnxRuntimeBase
~OnnxRuntimeBase()
Definition: OnnxRuntimeBase.h:25
NetworkBatchInput
Eigen::Matrix< float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > NetworkBatchInput
Definition: OnnxRuntimeBase.h:9