![]() |
ATLAS Offline Software
|
Interface class for creating Onnx Runtime sessions. More...
#include <IOnnxRuntimeInferenceTool.h>
Public Member Functions | |
virtual void | setBatchSize (int64_t batchSize)=0 |
set batch size. More... | |
virtual int64_t | getBatchSize (int64_t dataSize, int idx=0) const =0 |
methods for determining batch size from the data size More... | |
template<typename T > | |
StatusCode | addInput (std::vector< Ort::Value > &inputTensors, std::vector< T > &data, unsigned idx=0, int64_t batchSize=-1) const |
add the input data to the input tensors More... | |
template<typename T > | |
StatusCode | addOutput (std::vector< Ort::Value > &outputTensors, std::vector< T > &data, unsigned idx=0, int64_t batchSize=-1) const |
add the output data to the output tensors More... | |
virtual StatusCode | inference (std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const =0 |
perform inference More... | |
virtual void | printModelInfo () const =0 |
virtual void | print () const =0 |
Print the state of the tool. More... | |
Protected Attributes | |
unsigned | m_numInputs |
unsigned | m_numOutputs |
std::vector< std::vector< int64_t > > | m_inputShapes |
std::vector< std::vector< int64_t > > | m_outputShapes |
Private Member Functions | |
template<typename T > | |
Ort::Value | createTensor (std::vector< T > &data, const std::vector< int64_t > &dataShape, int64_t batchSize) const |
Interface class for creating Onnx Runtime sessions.
Interface class for creating Onnx Runtime sessions. It is thread safe, supports models with various number of inputs and outputs, supports models with dynamic batch size, and usess . It defines a standardized procedure to perform Onnx Runtime inference. The procedure is as follows, assuming the tool m_onnxTool
is created and initialized:
perform inference: ```c++ m_onnxTool->inference(inputTensors, outputTensors); ```
Definition at line 47 of file IOnnxRuntimeInferenceTool.h.
StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addInput | ( | std::vector< Ort::Value > & | inputTensors, |
std::vector< T > & | data, | ||
unsigned | idx = 0 , |
||
int64_t | batchSize = -1 |
||
) | const |
add the input data to the input tensors
inputTensors | the input tensor container |
data | the input data |
idx | the index of the input node |
batchSize | the batch size |
StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addOutput | ( | std::vector< Ort::Value > & | outputTensors, |
std::vector< T > & | data, | ||
unsigned | idx = 0 , |
||
int64_t | batchSize = -1 |
||
) | const |
add the output data to the output tensors
outputTensors | the output tensor container |
data | the output data |
idx | the index of the output node |
batchSize | the batch size |
|
private |
|
pure virtual |
methods for determining batch size from the data size
dataSize | the size of the input data, like std::vector<T>::size() |
idx | the index of the input node |
Implemented in AthOnnx::OnnxRuntimeInferenceTool.
|
pure virtual |
perform inference
inputTensors | the input tensor container |
outputTensors | the output tensor container |
Implemented in AthOnnx::OnnxRuntimeInferenceTool.
|
pure virtualinherited |
Print the state of the tool.
Implemented in JetRecTool, JetFinder, JetModifiedMassDrop, JetFromPseudojet, JetReclusterer, JetReclusteringTool, JetTruthLabelingTool, JetPileupLabelingTool, HI::HIPileupTool, asg::AsgTool, JetDumper, JetBottomUpSoftDrop, JetRecursiveSoftDrop, JetSoftDrop, JetConstituentsRetriever, JetSubStructureMomentToolsBase, JetSplitter, JetToolRunner, JetPruner, JetPseudojetRetriever, JetTrimmer, AsgHelloTool, and KtDeltaRTool.
|
pure virtual |
Implemented in AthOnnx::OnnxRuntimeInferenceTool.
|
pure virtual |
set batch size.
If the model has dynamic batch size, the batchSize value will be set to both input shapes and output shapes
Implemented in AthOnnx::OnnxRuntimeInferenceTool.
|
protected |
Definition at line 104 of file IOnnxRuntimeInferenceTool.h.
|
protected |
Definition at line 102 of file IOnnxRuntimeInferenceTool.h.
|
protected |
Definition at line 103 of file IOnnxRuntimeInferenceTool.h.
|
protected |
Definition at line 105 of file IOnnxRuntimeInferenceTool.h.