ATLAS Offline Software
Loading...
Searching...
No Matches
OnnxUtils.cxx
Go to the documentation of this file.
1// Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2
4#include <cassert>
5#include <string>
6
7#ifndef XAOD_STANDALONE
8// AthAsynchronousAlgorithm to resume in asyncInference
10
11// Explicit include of boost fiber
12#include <boost/fiber/all.hpp>
13#endif // !XAOD_STANDALONE
14
15namespace AthOnnxUtils {
16
18 const Ort::Session& session,
19 std::vector<std::vector<int64_t> >& dataShape,
20 std::vector<std::string>& nodeNames,
21 bool isInput
22){
23 dataShape.clear();
24 nodeNames.clear();
25
26 size_t numNodes = isInput? session.GetInputCount(): session.GetOutputCount();
27 dataShape.reserve(numNodes);
28 nodeNames.reserve(numNodes);
29
30 Ort::AllocatorWithDefaultOptions allocator;
31 for( std::size_t i = 0; i < numNodes; i++ ) {
32 Ort::TypeInfo typeInfo = isInput? session.GetInputTypeInfo(i): session.GetOutputTypeInfo(i);
33 auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo();
34 dataShape.emplace_back(tensorInfo.GetShape());
35
36 auto nodeName = isInput? session.GetInputNameAllocated(i, allocator) : session.GetOutputNameAllocated(i, allocator);
37 nodeNames.emplace_back(nodeName.get());
38 }
39}
40
42 const Ort::Session& session,
43 std::vector<std::vector<int64_t> >& dataShape,
44 std::vector<std::string>& nodeNames
45){
46 getNodeInfo(session, dataShape, nodeNames, true);
47}
48
50 const Ort::Session& session,
51 std::vector<std::vector<int64_t> >& dataShape,
52 std::vector<std::string>& nodeNames
53) {
54 getNodeInfo(session, dataShape, nodeNames, false);
55}
56
57void inferenceWithIOBinding(Ort::Session& session,
58 const std::vector<std::string>& inputNames,
59 const std::vector<Ort::Value>& inputData,
60 const std::vector<std::string>& outputNames,
61 const std::vector<Ort::Value>& outputData){
62
63 if (inputNames.empty()) {
64 throw std::runtime_error("Onnxruntime input data maping cannot be empty");
65 }
66 assert(inputNames.size() == inputData.size());
67
68 Ort::IoBinding iobinding(session);
69 for(size_t idx = 0; idx < inputNames.size(); ++idx){
70 iobinding.BindInput(inputNames[idx].data(), inputData[idx]);
71 }
72
73
74 for(size_t idx = 0; idx < outputNames.size(); ++idx){
75 iobinding.BindOutput(outputNames[idx].data(), outputData[idx]);
76 }
77
78 session.Run(Ort::RunOptions{nullptr}, iobinding);
79}
80
81#ifndef XAOD_STANDALONE
82std::string asyncInference(Ort::Session& session,
83 const std::vector<std::string>& inputNames,
84 const std::vector<Ort::Value>& inputData,
85 const std::vector<std::string>& outputNames,
86 std::vector<Ort::Value>& outputData,
87 const AthAsynchronousAlgorithm* parentAlg) {
88 if (inputNames.empty()) {
89 throw std::runtime_error("Onnxruntime input data mapping cannot be empty");
90 }
91 assert(inputNames.size() == inputData.size());
92 assert(outputNames.size() == outputData.size());
93
94 Ort::RunOptions runOptions{};
95
96 // Transform names into formats required by ORT
97 std::vector<const char*> inputNamesArray{};
98 std::vector<const char*> outputNamesArray{};
99 inputNamesArray.reserve(inputNames.size());
100 outputNamesArray.reserve(outputNames.size());
101 for (const auto& name : inputNames) {
102 inputNamesArray.push_back(name.c_str());
103 }
104 for (const auto& name : outputNames) {
105 outputNamesArray.push_back(name.c_str());
106 }
107
108 // Setup for async
109 using Promise_t = boost::fibers::promise<std::string>;
110 Promise_t promise{};
111 boost::fibers::future<std::string> future{promise.get_future()};
112
113 // callback in format required by ORT
114 const auto callback = [](void* promise, OrtValue**, std::size_t,
115 OrtStatusPtr statusPtr) mutable {
116 std::string errorMsg{};
117 if (statusPtr != nullptr) {
118 Ort::Status status{statusPtr};
119 if (!status.IsOK()) {
120 errorMsg = status.GetErrorMessage();
121 }
122 }
123 static_cast<Promise_t*>(promise)->set_value(errorMsg);
124 };
125
126 // Run inference
127 session.RunAsync(runOptions, inputNamesArray.data(), inputData.data(),
128 inputData.size(), outputNamesArray.data(), outputData.data(),
129 outputData.size(), callback, static_cast<void*>(&promise));
130 // Suspends fiber while waiting
131 std::string errorMsg = future.get();
132 parentAlg->restoreAfterSuspend().orThrow("Failed to restore after suspension", "AsyncAlg");
133 return errorMsg;
134}
135#endif
136
137int64_t getTensorSize(const std::vector<int64_t>& dataShape){
138 int64_t size = 1;
139 for (const auto& dim : dataShape) {
140 size *= dim;
141 }
142 return size;
143}
144
145
146} // namespace AthOnnx
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
size_t size() const
Number of registered mappings.
An algorithm that can be suspended while work is offloaded to an accelerator.
virtual StatusCode restoreAfterSuspend() const override
Restore after suspend.
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
Definition OnnxUtils.cxx:57
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
void getNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames, bool isInput)
Definition OnnxUtils.cxx:17
std::string asyncInference(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, std::vector< Ort::Value > &outputData, const AthAsynchronousAlgorithm *parentAlg)
Definition OnnxUtils.cxx:82
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition OnnxUtils.cxx:49
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition OnnxUtils.cxx:41