ATLAS Offline Software
OnnxUtils.h
Go to the documentation of this file.
1 // Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
2 
3 #ifndef Onnx_UTILS_H
4 #define Onnx_UTILS_H
5 
6 #include <memory>
7 #include <vector>
8 
9 // Onnx Runtime include(s).
10 #include <onnxruntime_cxx_api.h>
11 
12 namespace AthOnnx {
13 
14 // @author Xiangyang Ju <xiangyang.ju@cern.ch>
15 
16 // @brief Convert a vector of vectors to a single vector.
17 // @param features The vector of vectors to be flattened.
18 // @return A single vector containing all the elements of the input vector of vectors.
19 template<typename T>
20 inline std::vector<T> flattenNestedVectors( const std::vector<std::vector<T>>& features) {
21  // 1. Compute the total size required.
22  int total_size = 0;
23  for (const auto& feature : features) total_size += feature.size();
24 
25  std::vector<T> flatten1D;
26  flatten1D.reserve(total_size);
27 
28  for (const auto& feature : features)
29  for (const auto& elem : feature)
30  flatten1D.push_back(elem);
31 
32  return flatten1D;
33 }
34 
35 // @brief Get the input data shape and node names (in the computational graph) from the onnx model
36 // @param session The onnx session.
37 // @param dataShape The shape of the input data. Note that there may be multiple inputs.
38 // @param nodeNames The names of the input nodes in the computational graph.
39 // the dataShape and nodeNames will be updated.
40 void getInputNodeInfo(
41  const Ort::Session& session,
42  std::vector<std::vector<int64_t> >& dataShape,
43  std::vector<std::string>& nodeNames);
44 
45 // @brief Get the output data shape and node names (in the computational graph) from the onnx model
46 // @param session The onnx session.
47 // @param dataShape The shape of the output data.
48 // @param nodeNames The names of the output nodes in the computational graph.
49 // the dataShape and nodeNames will be updated.
51  const Ort::Session& session,
52  std::vector<std::vector<int64_t> >& dataShape,
53  std::vector<std::string>& nodeNames);
54 
55 // Heleper function to get node info
56 void getNodeInfo(
57  const Ort::Session& session,
58  std::vector<std::vector<int64_t> >& dataShape,
59  std::vector<std::string>& nodeNames,
60  bool isInput
61 );
62 
63 // @brief to count the total number of elements in a tensor
64 // They are useful for reserving spaces for the output data.
65 int64_t getTensorSize(const std::vector<int64_t>& dataShape);
66 
67 // Inference with IO binding. Better for performance, particularly for GPUs.
68 // See https://onnxruntime.ai/docs/performance/tune-performance/iobinding.html
70  const std::vector<std::string>& inputNames,
71  const std::vector<Ort::Value>& inputData,
72  const std::vector<std::string>& outputNames,
73  const std::vector<Ort::Value>& outputData
74 );
75 
76 // @brief Create a tensor from a vector of data and its shape.
77 Ort::Value createTensor(std::vector<float>& data, const std::vector<int64_t>& dataShape);
78 
79 
80 }
81 #endif
AthOnnx::flattenNestedVectors
std::vector< T > flattenNestedVectors(const std::vector< std::vector< T >> &features)
Definition: OnnxUtils.h:20
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
AthOnnx::createTensor
Ort::Value createTensor(std::vector< float > &data, const std::vector< int64_t > &dataShape)
Definition: OnnxUtils.cxx:81
python.oracle.Session
Session
Definition: oracle.py:78
AthOnnx::getTensorSize
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
Definition: OnnxUtils.cxx:73
AthOnnx::getNodeInfo
void getNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames, bool isInput)
Definition: OnnxUtils.cxx:9
XMLtoHeader.outputNames
outputNames
Definition: XMLtoHeader.py:18
AthOnnx::inferenceWithIOBinding
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
Definition: OnnxUtils.cxx:49
AthOnnx::getInputNodeInfo
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:33
AthOnnx::getOutputNodeInfo
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:41
AthOnnx
Namespace holding all of the Onnx Runtime example code.
Definition: EvaluateModel.cxx:13