Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
OnnxUtils.cxx
Go to the documentation of this file.
1 // Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
2 
4 #include <cassert>
5 #include <string>
6 
7 namespace AthOnnxUtils {
8 
10  const Ort::Session& session,
11  std::vector<std::vector<int64_t> >& dataShape,
12  std::vector<std::string>& nodeNames,
13  bool isInput
14 ){
15  dataShape.clear();
16  nodeNames.clear();
17 
18  size_t numNodes = isInput? session.GetInputCount(): session.GetOutputCount();
19  dataShape.reserve(numNodes);
20  nodeNames.reserve(numNodes);
21 
22  Ort::AllocatorWithDefaultOptions allocator;
23  for( std::size_t i = 0; i < numNodes; i++ ) {
24  Ort::TypeInfo typeInfo = isInput? session.GetInputTypeInfo(i): session.GetOutputTypeInfo(i);
25  auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo();
26  dataShape.emplace_back(tensorInfo.GetShape());
27 
28  auto nodeName = isInput? session.GetInputNameAllocated(i, allocator) : session.GetOutputNameAllocated(i, allocator);
29  nodeNames.emplace_back(nodeName.get());
30  }
31 }
32 
34  const Ort::Session& session,
35  std::vector<std::vector<int64_t> >& dataShape,
36  std::vector<std::string>& nodeNames
37 ){
38  getNodeInfo(session, dataShape, nodeNames, true);
39 }
40 
42  const Ort::Session& session,
43  std::vector<std::vector<int64_t> >& dataShape,
44  std::vector<std::string>& nodeNames
45 ) {
46  getNodeInfo(session, dataShape, nodeNames, false);
47 }
48 
50  const std::vector<std::string>& inputNames,
51  const std::vector<Ort::Value>& inputData,
52  const std::vector<std::string>& outputNames,
53  const std::vector<Ort::Value>& outputData){
54 
55  if (inputNames.empty()) {
56  throw std::runtime_error("Onnxruntime input data maping cannot be empty");
57  }
58  assert(inputNames.size() == inputData.size());
59 
60  Ort::IoBinding iobinding(session);
61  for(size_t idx = 0; idx < inputNames.size(); ++idx){
62  iobinding.BindInput(inputNames[idx].data(), inputData[idx]);
63  }
64 
65 
66  for(size_t idx = 0; idx < outputNames.size(); ++idx){
67  iobinding.BindOutput(outputNames[idx].data(), outputData[idx]);
68  }
69 
70  session.Run(Ort::RunOptions{nullptr}, iobinding);
71 }
72 
73 int64_t getTensorSize(const std::vector<int64_t>& dataShape){
74  int64_t size = 1;
75  for (const auto& dim : dataShape) {
76  size *= dim;
77  }
78  return size;
79 }
80 
81 
82 } // namespace AthOnnx
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
AthOnnxUtils::inferenceWithIOBinding
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
Definition: OnnxUtils.cxx:49
yodamerge_tmp.dim
dim
Definition: yodamerge_tmp.py:239
python.oracle.Session
Session
Definition: oracle.py:78
python.setupRTTAlg.size
int size
Definition: setupRTTAlg.py:39
lumiFormat.i
int i
Definition: lumiFormat.py:85
AthOnnxUtils::getTensorSize
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
Definition: OnnxUtils.cxx:73
AthOnnxUtils::getNodeInfo
void getNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames, bool isInput)
Definition: OnnxUtils.cxx:9
XMLtoHeader.outputNames
outputNames
Definition: XMLtoHeader.py:18
AthOnnxUtils::getInputNodeInfo
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:33
AthOnnxUtils::getOutputNodeInfo
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:41
AthOnnxUtils
Definition: OnnxUtils.h:12
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
OnnxUtils.h