ATLAS Offline Software
Trigger
EFTracking
FPGATrackSim
FPGATrackSimAlgorithms
src
OnnxRuntimeBase.h
Go to the documentation of this file.
1
2
#include <vector>
3
#include <map>
4
#include <Eigen/Dense>
5
#include <onnxruntime_cxx_api.h>
6
#include <TString.h>
7
8
using
NetworkBatchInput
=
9
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
10
11
// General class that sets up the ONNX runtime framework for loading a ML model
12
// and using it for inference.
13
class
OnnxRuntimeBase
{
14
public
:
15
// int m_totalInputs;
16
17
TString
m_fileName
;
18
// int n_hits;
19
// std::string m_trackType;
20
// std::string m_order;
21
// bool m_scaled;
22
23
OnnxRuntimeBase
(TString
fileName
);
24
OnnxRuntimeBase
();
25
~OnnxRuntimeBase
(){}
26
void
initialize
(TString);
27
28
std::vector<float>
runONNXInference
(std::vector<float>& inputTensorValues)
const
;
29
std::vector<std::vector<float>>
runONNXInference
(
NetworkBatchInput
& inputTensorValues)
const
;
30
std::map<int, Eigen::MatrixXf>
runONNXInferenceMultilayerOutput
(
NetworkBatchInput
& inputTensorValues)
const
;
31
32
const
std::vector<int64_t>&
getInputNodesDims
(){
return
m_inputNodeDims
;};
33
const
std::vector<int64_t>&
getOutputNodesDims
(){
return
m_outputNodeDims
;};
34
35
private
:
37
std::unique_ptr<Ort::Session>
m_session
;
38
39
std::vector<const char*>
m_inputNodeNames
;
40
std::vector<int64_t>
m_inputNodeDims
;
41
std::vector<const char*>
m_outputNodeNames
;
42
std::vector<int64_t>
m_outputNodeDims
;
43
44
std::unique_ptr< Ort::Env >
m_env
;
45
46
};
47
OnnxRuntimeBase::runONNXInference
std::vector< float > runONNXInference(std::vector< float > &inputTensorValues) const
Definition:
OnnxRuntimeBase.cxx:65
OnnxRuntimeBase::m_env
std::unique_ptr< Ort::Env > m_env
Definition:
OnnxRuntimeBase.h:44
OnnxRuntimeBase::initialize
void initialize(TString)
Definition:
OnnxRuntimeBase.cxx:16
OnnxRuntimeBase::m_outputNodeDims
std::vector< int64_t > m_outputNodeDims
Definition:
OnnxRuntimeBase.h:42
OnnxRuntimeBase::getOutputNodesDims
const std::vector< int64_t > & getOutputNodesDims()
Definition:
OnnxRuntimeBase.h:33
OnnxRuntimeBase::m_inputNodeDims
std::vector< int64_t > m_inputNodeDims
Definition:
OnnxRuntimeBase.h:40
OnnxRuntimeBase::m_session
std::unique_ptr< Ort::Session > m_session
ONNX runtime session / model properties.
Definition:
OnnxRuntimeBase.h:33
OnnxRuntimeBase::m_inputNodeNames
std::vector< const char * > m_inputNodeNames
Definition:
OnnxRuntimeBase.h:39
FortranAlgorithmOptions.fileName
fileName
Definition:
FortranAlgorithmOptions.py:13
OnnxRuntimeBase::runONNXInferenceMultilayerOutput
std::map< int, Eigen::MatrixXf > runONNXInferenceMultilayerOutput(NetworkBatchInput &inputTensorValues) const
Definition:
OnnxRuntimeBase.cxx:143
OnnxRuntimeBase::m_fileName
TString m_fileName
Definition:
OnnxRuntimeBase.h:17
OnnxRuntimeBase::getInputNodesDims
const std::vector< int64_t > & getInputNodesDims()
Definition:
OnnxRuntimeBase.h:32
OnnxRuntimeBase
Definition:
OnnxRuntimeBase.h:13
OnnxRuntimeBase::m_outputNodeNames
std::vector< const char * > m_outputNodeNames
Definition:
OnnxRuntimeBase.h:41
OnnxRuntimeBase::OnnxRuntimeBase
OnnxRuntimeBase()
Definition:
OnnxRuntimeBase.cxx:14
OnnxRuntimeBase::~OnnxRuntimeBase
~OnnxRuntimeBase()
Definition:
OnnxRuntimeBase.h:25
NetworkBatchInput
Eigen::Matrix< float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor > NetworkBatchInput
Definition:
OnnxRuntimeBase.h:9
Generated on Sun Dec 22 2024 21:15:51 for ATLAS Offline Software by
1.8.18