ATLAS Offline Software
Loading...
Searching...
No Matches
OnnxUtils.h
Go to the documentation of this file.
1// Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2
3#ifndef Onnx_UTILS_H
4#define Onnx_UTILS_H
5
6#include <cassert>
7#include <memory>
8#include <string>
9#include <vector>
10
11// Onnx Runtime include(s).
12#include <onnxruntime_cxx_api.h>
13
15
16namespace AthOnnxUtils {
17
18// @author Xiangyang Ju <xiangyang.ju@cern.ch>
19
20// @brief Convert a vector of vectors to a single vector.
21// @param features The vector of vectors to be flattened.
22// @return A single vector containing all the elements of the input vector of vectors.
23template<typename T>
24inline std::vector<T> flattenNestedVectors( const std::vector<std::vector<T>>& features) {
25 // 1. Compute the total size required.
26 int total_size = 0;
27 for (const auto& feature : features) total_size += feature.size();
28
29 std::vector<T> flatten1D;
30 flatten1D.reserve(total_size);
31
32 for (const auto& feature : features)
33 for (const auto& elem : feature)
34 flatten1D.push_back(elem);
35
36 return flatten1D;
37}
38
39// @brief Get the input data shape and node names (in the computational graph) from the onnx model
40// @param session The onnx session.
41// @param dataShape The shape of the input data. Note that there may be multiple inputs.
42// @param nodeNames The names of the input nodes in the computational graph.
43// the dataShape and nodeNames will be updated.
45 const Ort::Session& session,
46 std::vector<std::vector<int64_t> >& dataShape,
47 std::vector<std::string>& nodeNames);
48
49// @brief Get the output data shape and node names (in the computational graph) from the onnx model
50// @param session The onnx session.
51// @param dataShape The shape of the output data.
52// @param nodeNames The names of the output nodes in the computational graph.
53// the dataShape and nodeNames will be updated.
55 const Ort::Session& session,
56 std::vector<std::vector<int64_t> >& dataShape,
57 std::vector<std::string>& nodeNames);
58
59// Heleper function to get node info
60void getNodeInfo(
61 const Ort::Session& session,
62 std::vector<std::vector<int64_t> >& dataShape,
63 std::vector<std::string>& nodeNames,
64 bool isInput
65);
66
67// @brief to count the total number of elements in a tensor
68// They are useful for reserving spaces for the output data.
69int64_t getTensorSize(const std::vector<int64_t>& dataShape);
70
71// Inference with IO binding. Better for performance, particularly for GPUs.
72// See https://onnxruntime.ai/docs/performance/tune-performance/iobinding.html
73void inferenceWithIOBinding(Ort::Session& session,
74 const std::vector<std::string>& inputNames,
75 const std::vector<Ort::Value>& inputData,
76 const std::vector<std::string>& outputNames,
77 const std::vector<Ort::Value>& outputData
78);
79
80#ifndef XAOD_STANDALONE
81// Asynchronous inference
82std::string asyncInference(Ort::Session& session,
83 const std::vector<std::string>& inputNames,
84 const std::vector<Ort::Value>& inputData,
85 const std::vector<std::string>& outputNames,
86 std::vector<Ort::Value>& outputData,
87 const AthAsynchronousAlgorithm* parentAlg);
88#endif
89
90// @brief Create a tensor from a vector of data and its shape.
91template<typename T>
92Ort::Value createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape) {
93 // Create a tensor from the data.
94 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
95 return Ort::Value::CreateTensor<T>(memoryInfo, data.data(), data.size(), dataShape.data(), dataShape.size());
96}
97
98} // namespace AthOnnx
99#endif
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
An algorithm that can be suspended while work is offloaded to an accelerator.
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
Definition OnnxUtils.cxx:57
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
void getNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames, bool isInput)
Definition OnnxUtils.cxx:17
std::string asyncInference(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, std::vector< Ort::Value > &outputData, const AthAsynchronousAlgorithm *parentAlg)
Definition OnnxUtils.cxx:82
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition OnnxUtils.cxx:49
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition OnnxUtils.cxx:41
std::vector< T > flattenNestedVectors(const std::vector< std::vector< T > > &features)
Definition OnnxUtils.h:24
Ort::Value createTensor(std::vector< T > &data, const std::vector< int64_t > &dataShape)
Definition OnnxUtils.h:92