ATLAS Offline Software
Loading...
Searching...
No Matches
OnnxRuntimeInferenceTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
7
8#ifndef XAOD_STANDALONE
9// AthAsynchronousAlgorithm for pointer
11
12// Gaudi include to figure out if something is an algorithm, a service, or a tool
13#include "GaudiKernel/IAlgTool.h"
14#endif // !XAOD_STANDALONE
15
17 : asg::AsgTool ( name )
18{
19}
20
22{
23 // Get the Onnx Runtime service.
24 ATH_CHECK(m_onnxRuntimeSvc.retrieve());
25
26 // Create the session.
27 ATH_CHECK(m_onnxSessionTool.retrieve());
28
30
31 #ifndef XAOD_STANDALONE
32 // If session doesn't support asynchronous inference we don't need to find our parent
33 if (!m_onnxSessionTool->supportsAsync())
34 {
35 ATH_MSG_INFO("Session does not support asynchronous inference");
36 m_parentAsyncAlg = nullptr;
37 return StatusCode::SUCCESS;
38 }
39 // Figure out if parent is an AthAsynchronousAlgorithm, and set pointer if it is
40 const IAlgTool* p = dynamic_cast<const IAlgTool*>(this);
41 // Follow chain of parents up until we hit one that can't be converted to an IAlgTool
42 const IInterface* myParent = nullptr;
43 while (p != nullptr) {
44 myParent = p->parent();
45 p = dynamic_cast<const IAlgTool*>(myParent);
46 }
47 // If this ultimate ancestor can be converted to an AthAsynchronousAlgorithm, set the member variable
48 m_parentAsyncAlg = dynamic_cast<const AthAsynchronousAlgorithm*>(myParent);
49 if (m_parentAsyncAlg != nullptr) {
50 ATH_MSG_INFO("Owned by an AthAsynchronousAlgorithm, using asynchronous inference");
51 }
52 else {
53 ATH_MSG_INFO("Not owned by an AthAsynchronousAlgorithm, not using asynchronous inference");
54 }
55 #endif // !XAOD_STANDALONE
56
57 return StatusCode::SUCCESS;
58}
59
61{
62 auto& session = m_onnxSessionTool->session();
63 // obtain the model information
64 m_numInputs = session.GetInputCount();
65 m_numOutputs = session.GetOutputCount();
66
69
70 return StatusCode::SUCCESS;
71}
72
73
75{
76 if (batchSize <= 0) {
77 ATH_MSG_ERROR("Batch size should be positive");
78 return;
79 }
80
81 for (auto& shape : m_inputShapes) {
82 if (shape[0] == -1) {
83 shape[0] = batchSize;
84 }
85 }
86
87 for (auto& shape : m_outputShapes) {
88 if (shape[0] == -1) {
89 shape[0] = batchSize;
90 }
91 }
92}
93
94int64_t AthOnnx::OnnxRuntimeInferenceTool::getBatchSize(int64_t inputDataSize, int idx) const
95{
96 auto tensorSize = AthOnnxUtils::getTensorSize(m_inputShapes[idx]);
97 if (tensorSize < 0) {
98 return inputDataSize / abs(tensorSize);
99 } else {
100 return -1;
101 }
102}
103
104StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const
105{
106 assert (inputTensors.size() == m_numInputs);
107 assert (outputTensors.size() == m_numOutputs);
108
109 // Run the model.
110 // If we're in Athena and the parent is an asynchronous algorithm we do the inference asynchronously
111 #ifndef XAOD_STANDALONE
112 if (m_parentAsyncAlg == nullptr) {
113 #endif // !XAOD_STANDALONE
114
116 m_onnxSessionTool->session(),
117 m_inputNodeNames, inputTensors,
118 m_outputNodeNames, outputTensors);
119
120 #ifndef XAOD_STANDALONE
121 }
122 else {
123 // Asynchronous version
124 std::string errorMsg = AthOnnxUtils::asyncInference(
125 m_onnxSessionTool->session(), m_inputNodeNames, inputTensors,
126 m_outputNodeNames, outputTensors, m_parentAsyncAlg);
127 if (!errorMsg.empty()) {
128 ATH_MSG_ERROR("ONNX Runtime Error: " << errorMsg);
129 return StatusCode::FAILURE;
130 }
131 }
132 #endif // !XAOD_STANDALONE
133
134 return StatusCode::SUCCESS;
135}
136
138{
139 ATH_MSG_INFO("Number of inputs: " << m_numInputs);
140 ATH_MSG_INFO("Number of outputs: " << m_numOutputs);
141
142 ATH_MSG_INFO("Input node names: ");
143 for (const auto& name : m_inputNodeNames) {
144 ATH_MSG_INFO("\t" << name);
145 }
146
147 ATH_MSG_INFO("Output node names: ");
148 for (const auto& name : m_outputNodeNames) {
149 ATH_MSG_INFO("\t" << name);
150 }
151
152 ATH_MSG_INFO("Input shapes: ");
153 for (const auto& shape : m_inputShapes) {
154 std::string shapeStr = "\t";
155 for (const auto& dim : shape) {
156 shapeStr += std::to_string(dim) + " ";
157 }
158 ATH_MSG_INFO(shapeStr);
159 }
160
161 ATH_MSG_INFO("Output shapes: ");
162 for (const auto& shape : m_outputShapes) {
163 std::string shapeStr = "\t";
164 for (const auto& dim : shape) {
165 shapeStr += std::to_string(dim) + " ";
166 }
167 ATH_MSG_INFO(shapeStr);
168 }
169}
170
172{
173 // Create input tensors.
174 std::vector<Ort::Value> inputTensors;
175 for (auto& [inputName, inputInfo] : inputData) {
176 const std::vector<int64_t>& shape = inputInfo.first;
177 if (std::holds_alternative<std::vector<float>>(inputInfo.second)) {
178 auto& data = std::get<std::vector<float>>(inputInfo.second);
179 inputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
180 } else if (std::holds_alternative<std::vector<int64_t>>(inputInfo.second)) {
181 auto& data = std::get<std::vector<int64_t>>(inputInfo.second);
182 inputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
183 } else {
184 ATH_MSG_ERROR("Unsupported data type");
185 return StatusCode::FAILURE;
186 }
187 }
188
189 // Create output tensors.
190 std::vector<Ort::Value> outputTensors;
191 outputTensors.reserve(inputData.size());
192 for (auto& outName : m_outputNodeNames) {
193 if (outputData.find(outName) == outputData.end()) {
194 ATH_MSG_ERROR("Output name " << outName << " not found in output data map");
195 return StatusCode::FAILURE;
196 }
197 auto& outputInfo = outputData.at(outName);
198 auto& shape = outputInfo.first;
199 auto tensorSize = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
200
201 if (std::holds_alternative<std::vector<float>>(outputInfo.second)) {
202 auto& data = std::get<std::vector<float>>(outputInfo.second);
203 data.resize(tensorSize);
204 outputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
205 } else if (std::holds_alternative<std::vector<int64_t>>(outputInfo.second)) {
206 auto& data = std::get<std::vector<int64_t>>(outputInfo.second);
207 data.resize(tensorSize);
208 outputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
209 } else {
210 ATH_MSG_ERROR("Unsupported data type");
211 return StatusCode::FAILURE;
212 }
213 }
214
215 ATH_CHECK(inference(inputTensors, outputTensors));
216
217 return StatusCode::SUCCESS;
218}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
An algorithm that can be suspended while work is offloaded to an accelerator.
std::vector< std::vector< int64_t > > m_outputShapes
std::vector< std::vector< int64_t > > m_inputShapes
ToolHandle< IOnnxRuntimeSessionTool > m_onnxSessionTool
virtual StatusCode inference(std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
perform inference
std::vector< std::string > m_inputNodeNames
virtual int64_t getBatchSize(int64_t inputDataSize, int idx=0) const override final
methods for determining batch size from the data size
virtual StatusCode initialize() override
Initialize the tool.
virtual void printModelInfo() const override final
const AthAsynchronousAlgorithm * m_parentAsyncAlg
ServiceHandle< IOnnxRuntimeSvc > m_onnxRuntimeSvc
std::vector< std::string > m_outputNodeNames
virtual void setBatchSize(int64_t batchSize) override final
set batch size.
AsgTool(const std::string &name)
Constructor specifying the tool instance's name.
Definition AsgTool.cxx:58
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
Definition OnnxUtils.cxx:57
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
std::string asyncInference(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, std::vector< Ort::Value > &outputData, const AthAsynchronousAlgorithm *parentAlg)
Definition OnnxUtils.cxx:82
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition OnnxUtils.cxx:49
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition OnnxUtils.cxx:41
Ort::Value createTensor(std::vector< T > &data, const std::vector< int64_t > &dataShape)
Definition OnnxUtils.h:92