ATLAS Offline Software
Loading...
Searching...
No Matches
IOnnxRuntimeInferenceTool.h
Go to the documentation of this file.
1// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
2#ifndef AthOnnx_IOnnxRuntimeInferenceTool_H
3#define AthOnnx_IOnnxRuntimeInferenceTool_H
5#include "AsgTools/IAsgTool.h"
6
7#include <memory>
8#include <numeric>
9#include <utility>
10
11#include <onnxruntime_cxx_api.h>
12
13
14namespace AthOnnx {
22
24 * std::vector<Ort::Value> inputTensors;
25 * std::vector<float> inputData_1; // The input data is filled by users, possibly from the event information.
26 * int64_t batchSize = m_onnxTool->getBatchSize(inputData_1.size(), 0); // The batch size is determined by the input data size to support dynamic batch size.
27 * m_onnxTool->addInput(inputTensors, inputData_1, 0, batchSize);
28 * std::vector<int64_t> inputData_2; // Some models may have multiple inputs. Add inputs one by one.
29 * int64_t batchSize_2 = m_onnxTool->getBatchSize(inputData_2.size(), 1);
30 * m_onnxTool->addInput(inputTensors, inputData_2, 1, batchSize_2);
31 * ```
32 * 2. create output tensors:
33 * ```c++
34 * std::vector<Ort::Value> outputTensors;
35 * std::vector<float> outputData; // The output data will be filled by the onnx session.
36 * m_onnxTool->addOutput(outputTensors, outputData, 0, batchSize);
37 * ```
38 * 3. perform inference:
39 * ```c++
40 * m_onnxTool->inference(inputTensors, outputTensors);
41 * ```
42 * 4. Model outputs will be automatically filled to outputData.
43 *
44 *
45 * @author Xiangyang Ju <xju@cern.ch>
46 */
48 {
50
51 public:
52
58 virtual void setBatchSize(int64_t batchSize) = 0;
59
66 virtual int64_t getBatchSize(int64_t dataSize, int idx = 0) const = 0;
67
76 template <typename T>
77 StatusCode addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, unsigned idx = 0, int64_t batchSize = -1) const;
78
87 template <typename T>
88 StatusCode addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, unsigned idx = 0, int64_t batchSize = -1) const;
89
90
97 virtual StatusCode inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const = 0;
98
99 virtual void printModelInfo() const = 0;
100
101 protected:
102 unsigned m_numInputs;
103 unsigned m_numOutputs;
104 std::vector<std::vector<int64_t> > m_inputShapes;
105 std::vector<std::vector<int64_t> > m_outputShapes;
106
107 private:
108 template <typename T>
109 Ort::Value createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const;
110
111 };
112
114} // namespace AthOnnx
115
116#endif
#define ASG_TOOL_INTERFACE(CLASSNAME)
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
Interface class for creating Onnx Runtime sessions.
Ort::Value createTensor(std::vector< T > &data, const std::vector< int64_t > &dataShape, int64_t batchSize) const
std::vector< std::vector< int64_t > > m_outputShapes
virtual int64_t getBatchSize(int64_t dataSize, int idx=0) const =0
methods for determining batch size from the data size
virtual void setBatchSize(int64_t batchSize)=0
set batch size.
virtual StatusCode inference(std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const =0
perform inference
std::vector< std::vector< int64_t > > m_inputShapes
StatusCode addInput(std::vector< Ort::Value > &inputTensors, std::vector< T > &data, unsigned idx=0, int64_t batchSize=-1) const
add the input data to the input tensors
StatusCode addOutput(std::vector< Ort::Value > &outputTensors, std::vector< T > &data, unsigned idx=0, int64_t batchSize=-1) const
add the output data to the output tensors
virtual void printModelInfo() const =0
Base class for the dual-use tool interface classes.
Definition IAsgTool.h:41
Namespace holding all of the Onnx Runtime example code.