ATLAS Offline Software
Loading...
Searching...
No Matches
OnnxUtils.h
Go to the documentation of this file.
1// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
2
3#ifndef Onnx_UTILS_H
4#define Onnx_UTILS_H
5
6#include <memory>
7#include <vector>
8
9// Onnx Runtime include(s).
10#include <onnxruntime_cxx_api.h>
11
12namespace AthOnnxUtils {
13
14// @author Xiangyang Ju <xiangyang.ju@cern.ch>
15
16// @brief Convert a vector of vectors to a single vector.
17// @param features The vector of vectors to be flattened.
18// @return A single vector containing all the elements of the input vector of vectors.
19template<typename T>
20inline std::vector<T> flattenNestedVectors( const std::vector<std::vector<T>>& features) {
21 // 1. Compute the total size required.
22 int total_size = 0;
23 for (const auto& feature : features) total_size += feature.size();
24
25 std::vector<T> flatten1D;
26 flatten1D.reserve(total_size);
27
28 for (const auto& feature : features)
29 for (const auto& elem : feature)
30 flatten1D.push_back(elem);
31
32 return flatten1D;
33}
34
35// @brief Get the input data shape and node names (in the computational graph) from the onnx model
36// @param session The onnx session.
37// @param dataShape The shape of the input data. Note that there may be multiple inputs.
38// @param nodeNames The names of the input nodes in the computational graph.
39// the dataShape and nodeNames will be updated.
41 const Ort::Session& session,
42 std::vector<std::vector<int64_t> >& dataShape,
43 std::vector<std::string>& nodeNames);
44
45// @brief Get the output data shape and node names (in the computational graph) from the onnx model
46// @param session The onnx session.
47// @param dataShape The shape of the output data.
48// @param nodeNames The names of the output nodes in the computational graph.
49// the dataShape and nodeNames will be updated.
51 const Ort::Session& session,
52 std::vector<std::vector<int64_t> >& dataShape,
53 std::vector<std::string>& nodeNames);
54
55// Heleper function to get node info
56void getNodeInfo(
57 const Ort::Session& session,
58 std::vector<std::vector<int64_t> >& dataShape,
59 std::vector<std::string>& nodeNames,
60 bool isInput
61);
62
63// @brief to count the total number of elements in a tensor
64// They are useful for reserving spaces for the output data.
65int64_t getTensorSize(const std::vector<int64_t>& dataShape);
66
67// Inference with IO binding. Better for performance, particularly for GPUs.
68// See https://onnxruntime.ai/docs/performance/tune-performance/iobinding.html
69void inferenceWithIOBinding(Ort::Session& session,
70 const std::vector<std::string>& inputNames,
71 const std::vector<Ort::Value>& inputData,
72 const std::vector<std::string>& outputNames,
73 const std::vector<Ort::Value>& outputData
74);
75
76// @brief Create a tensor from a vector of data and its shape.
77template<typename T>
78Ort::Value createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape) {
79 // Create a tensor from the data.
80 Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
81 return Ort::Value::CreateTensor<T>(memoryInfo, data.data(), data.size(), dataShape.data(), dataShape.size());
82}
83
84} // namespace AthOnnx
85#endif
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
Definition OnnxUtils.cxx:49
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
Definition OnnxUtils.cxx:73
void getNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames, bool isInput)
Definition OnnxUtils.cxx:9
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition OnnxUtils.cxx:41
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition OnnxUtils.cxx:33
std::vector< T > flattenNestedVectors(const std::vector< std::vector< T > > &features)
Definition OnnxUtils.h:20
Ort::Value createTensor(std::vector< T > &data, const std::vector< int64_t > &dataShape)
Definition OnnxUtils.h:78