ATLAS Offline Software
IOnnxRuntimeInferenceTool.h
Go to the documentation of this file.
1 // Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
2 #ifndef AthOnnx_IOnnxRuntimeInferenceTool_H
3 #define AthOnnx_IOnnxRuntimeInferenceTool_H
4 
5 #include "AsgTools/IAsgTool.h"
6 
7 #include <memory>
8 #include <numeric>
9 #include <utility>
10 
11 #include <onnxruntime_cxx_api.h>
12 
13 
14 namespace AthOnnx {
48  {
50 
51  public:
52 
58  virtual void setBatchSize(int64_t batchSize) = 0;
59 
66  virtual int64_t getBatchSize(int64_t dataSize, int idx = 0) const = 0;
67 
76  template <typename T>
77  StatusCode addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, unsigned idx = 0, int64_t batchSize = -1) const;
78 
87  template <typename T>
88  StatusCode addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, unsigned idx = 0, int64_t batchSize = -1) const;
89 
90 
97  virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const = 0;
98 
99  virtual void printModelInfo() const = 0;
100 
101  protected:
102  unsigned m_numInputs;
103  unsigned m_numOutputs;
104  std::vector<std::vector<int64_t> > m_inputShapes;
105  std::vector<std::vector<int64_t> > m_outputShapes;
106 
107  private:
108  template <typename T>
109  Ort::Value createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const;
110 
111  };
112 
114 } // namespace AthOnnx
115 
116 #endif
AthOnnx::IOnnxRuntimeInferenceTool::m_inputShapes
std::vector< std::vector< int64_t > > m_inputShapes
Definition: IOnnxRuntimeInferenceTool.h:104
AthOnnx::IOnnxRuntimeInferenceTool
Interface class for creating Onnx Runtime sessions.
Definition: IOnnxRuntimeInferenceTool.h:48
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
AthOnnx::IOnnxRuntimeInferenceTool::createTensor
Ort::Value createTensor(std::vector< T > &data, const std::vector< int64_t > &dataShape, int64_t batchSize) const
asg::IAsgTool
Base class for the dual-use tool interface classes.
Definition: IAsgTool.h:41
AthOnnx::IOnnxRuntimeInferenceTool::addInput
StatusCode addInput(std::vector< Ort::Value > &inputTensors, std::vector< T > &data, unsigned idx=0, int64_t batchSize=-1) const
add the input data to the input tensors
AthOnnx::IOnnxRuntimeInferenceTool::addOutput
StatusCode addOutput(std::vector< Ort::Value > &outputTensors, std::vector< T > &data, unsigned idx=0, int64_t batchSize=-1) const
add the output data to the output tensors
IOnnxRuntimeInferenceTool.icc
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ASG_TOOL_INTERFACE
#define ASG_TOOL_INTERFACE(CLASSNAME)
Definition: AsgToolMacros.h:40
IAsgTool.h
PayloadHelpers::dataSize
size_t dataSize(TDA::PayloadIterator start)
Size in bytes of the buffer that is needed to decode next fragment data content.
Definition: TriggerEDMDeserialiserAlg.cxx:188
AthOnnx::IOnnxRuntimeInferenceTool::m_numOutputs
unsigned m_numOutputs
Definition: IOnnxRuntimeInferenceTool.h:103
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
AthOnnx::IOnnxRuntimeInferenceTool::inference
virtual StatusCode inference(std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const =0
perform inference
AthOnnx::IOnnxRuntimeInferenceTool::m_outputShapes
std::vector< std::vector< int64_t > > m_outputShapes
Definition: IOnnxRuntimeInferenceTool.h:105
AthOnnx::IOnnxRuntimeInferenceTool::printModelInfo
virtual void printModelInfo() const =0
AthOnnx::IOnnxRuntimeInferenceTool::setBatchSize
virtual void setBatchSize(int64_t batchSize)=0
set batch size.
AthOnnx::IOnnxRuntimeInferenceTool::getBatchSize
virtual int64_t getBatchSize(int64_t dataSize, int idx=0) const =0
methods for determining batch size from the data size
AthOnnx
Namespace holding all of the Onnx Runtime example code.
Definition: EvaluateModel.cxx:13
AthOnnx::IOnnxRuntimeInferenceTool::m_numInputs
unsigned m_numInputs
Definition: IOnnxRuntimeInferenceTool.h:102