13#include "GaudiKernel/IAlgTool.h"
31 #ifndef XAOD_STANDALONE
35 ATH_MSG_INFO(
"Session does not support asynchronous inference");
37 return StatusCode::SUCCESS;
40 const IAlgTool* p =
dynamic_cast<const IAlgTool*
>(
this);
42 const IInterface* myParent =
nullptr;
43 while (p !=
nullptr) {
44 myParent = p->parent();
45 p =
dynamic_cast<const IAlgTool*
>(myParent);
50 ATH_MSG_INFO(
"Owned by an AthAsynchronousAlgorithm, using asynchronous inference");
53 ATH_MSG_INFO(
"Not owned by an AthAsynchronousAlgorithm, not using asynchronous inference");
57 return StatusCode::SUCCESS;
70 return StatusCode::SUCCESS;
98 return inputDataSize / abs(tensorSize);
111 #ifndef XAOD_STANDALONE
120 #ifndef XAOD_STANDALONE
127 if (!errorMsg.empty()) {
129 return StatusCode::FAILURE;
134 return StatusCode::SUCCESS;
154 std::string shapeStr =
"\t";
155 for (
const auto& dim : shape) {
156 shapeStr += std::to_string(dim) +
" ";
163 std::string shapeStr =
"\t";
164 for (
const auto& dim : shape) {
165 shapeStr += std::to_string(dim) +
" ";
174 std::vector<Ort::Value> inputTensors;
175 for (
auto& [inputName, inputInfo] : inputData) {
176 const std::vector<int64_t>& shape = inputInfo.first;
177 if (std::holds_alternative<std::vector<float>>(inputInfo.second)) {
178 auto&
data = std::get<std::vector<float>>(inputInfo.second);
180 }
else if (std::holds_alternative<std::vector<int64_t>>(inputInfo.second)) {
181 auto&
data = std::get<std::vector<int64_t>>(inputInfo.second);
185 return StatusCode::FAILURE;
190 std::vector<Ort::Value> outputTensors;
191 outputTensors.reserve(inputData.size());
193 if (outputData.find(outName) == outputData.end()) {
194 ATH_MSG_ERROR(
"Output name " << outName <<
" not found in output data map");
195 return StatusCode::FAILURE;
197 auto& outputInfo = outputData.at(outName);
198 auto& shape = outputInfo.first;
199 auto tensorSize = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
201 if (std::holds_alternative<std::vector<float>>(outputInfo.second)) {
202 auto&
data = std::get<std::vector<float>>(outputInfo.second);
203 data.resize(tensorSize);
205 }
else if (std::holds_alternative<std::vector<int64_t>>(outputInfo.second)) {
206 auto&
data = std::get<std::vector<int64_t>>(outputInfo.second);
207 data.resize(tensorSize);
211 return StatusCode::FAILURE;
217 return StatusCode::SUCCESS;
#define ATH_CHECK
Evaluate an expression and check for errors.
char data[hepevt_bytes_allocation_ATLAS]
An algorithm that can be suspended while work is offloaded to an accelerator.
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
std::string asyncInference(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, std::vector< Ort::Value > &outputData, const AthAsynchronousAlgorithm *parentAlg)
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Ort::Value createTensor(std::vector< T > &data, const std::vector< int64_t > &dataShape)