ATLAS Offline Software
OnnxUtils.cxx
Go to the documentation of this file.
1 // Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
2 
4 #include <cassert>
5 #include <string>
6 
7 namespace AthOnnx {
8 
10  const Ort::Session& session,
11  std::vector<std::vector<int64_t> >& dataShape,
12  std::vector<std::string>& nodeNames,
13  bool isInput
14 ){
15  dataShape.clear();
16  nodeNames.clear();
17 
18  size_t numNodes = isInput? session.GetInputCount(): session.GetOutputCount();
19  dataShape.reserve(numNodes);
20  nodeNames.reserve(numNodes);
21 
22  Ort::AllocatorWithDefaultOptions allocator;
23  for( std::size_t i = 0; i < numNodes; i++ ) {
24  Ort::TypeInfo typeInfo = isInput? session.GetInputTypeInfo(i): session.GetOutputTypeInfo(i);
25  auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo();
26  dataShape.emplace_back(tensorInfo.GetShape());
27 
28  auto nodeName = isInput? session.GetInputNameAllocated(i, allocator) : session.GetOutputNameAllocated(i, allocator);
29  nodeNames.emplace_back(nodeName.get());
30  }
31 }
32 
34  const Ort::Session& session,
35  std::vector<std::vector<int64_t> >& dataShape,
36  std::vector<std::string>& nodeNames
37 ){
38  getNodeInfo(session, dataShape, nodeNames, true);
39 }
40 
42  const Ort::Session& session,
43  std::vector<std::vector<int64_t> >& dataShape,
44  std::vector<std::string>& nodeNames
45 ) {
46  getNodeInfo(session, dataShape, nodeNames, false);
47 }
48 
50  const std::vector<std::string>& inputNames,
51  const std::vector<Ort::Value>& inputData,
52  const std::vector<std::string>& outputNames,
53  const std::vector<Ort::Value>& outputData){
54 
55  if (inputNames.empty()) {
56  throw std::runtime_error("Onnxruntime input data maping cannot be empty");
57  }
58  assert(inputNames.size() == inputData.size());
59 
60  Ort::IoBinding iobinding(session);
61  for(size_t idx = 0; idx < inputNames.size(); ++idx){
62  iobinding.BindInput(inputNames[idx].data(), inputData[idx]);
63  }
64 
65 
66  for(size_t idx = 0; idx < outputNames.size(); ++idx){
67  iobinding.BindOutput(outputNames[idx].data(), outputData[idx]);
68  }
69 
70  session.Run(Ort::RunOptions{nullptr}, iobinding);
71 }
72 
73 int64_t getTensorSize(const std::vector<int64_t>& dataShape){
74  int64_t size = 1;
75  for (const auto& dim : dataShape) {
76  size *= dim;
77  }
78  return size;
79 }
80 
81 Ort::Value createTensor(std::vector<float>& data, const std::vector<int64_t>& dataShape)
82 {
83  auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
84 
85  return Ort::Value::CreateTensor<float>(
86  memoryInfo,
87  data.data(),
88  data.size(),
89  dataShape.data(),
90  dataShape.size());
91 };
92 
93 
94 } // namespace AthOnnx
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
AthOnnx::createTensor
Ort::Value createTensor(std::vector< float > &data, const std::vector< int64_t > &dataShape)
Definition: OnnxUtils.cxx:81
yodamerge_tmp.dim
dim
Definition: yodamerge_tmp.py:239
python.oracle.Session
Session
Definition: oracle.py:78
AthOnnx::getTensorSize
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
Definition: OnnxUtils.cxx:73
python.setupRTTAlg.size
int size
Definition: setupRTTAlg.py:39
lumiFormat.i
int i
Definition: lumiFormat.py:92
AthOnnx::getNodeInfo
void getNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames, bool isInput)
Definition: OnnxUtils.cxx:9
XMLtoHeader.outputNames
outputNames
Definition: XMLtoHeader.py:18
AthOnnx::inferenceWithIOBinding
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
Definition: OnnxUtils.cxx:49
AthOnnx::getInputNodeInfo
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:33
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
OnnxUtils.h
AthOnnx::getOutputNodeInfo
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:41
AthOnnx
Namespace holding all of the Onnx Runtime example code.
Definition: EvaluateModel.cxx:13