ATLAS Offline Software
Loading...
Searching...
No Matches
OnnxUtils.cxx
Go to the documentation of this file.
1// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
2
4#include <cassert>
5#include <string>
6
7namespace AthOnnxUtils {
8
10 const Ort::Session& session,
11 std::vector<std::vector<int64_t> >& dataShape,
12 std::vector<std::string>& nodeNames,
13 bool isInput
14){
15 dataShape.clear();
16 nodeNames.clear();
17
18 size_t numNodes = isInput? session.GetInputCount(): session.GetOutputCount();
19 dataShape.reserve(numNodes);
20 nodeNames.reserve(numNodes);
21
22 Ort::AllocatorWithDefaultOptions allocator;
23 for( std::size_t i = 0; i < numNodes; i++ ) {
24 Ort::TypeInfo typeInfo = isInput? session.GetInputTypeInfo(i): session.GetOutputTypeInfo(i);
25 auto tensorInfo = typeInfo.GetTensorTypeAndShapeInfo();
26 dataShape.emplace_back(tensorInfo.GetShape());
27
28 auto nodeName = isInput? session.GetInputNameAllocated(i, allocator) : session.GetOutputNameAllocated(i, allocator);
29 nodeNames.emplace_back(nodeName.get());
30 }
31}
32
34 const Ort::Session& session,
35 std::vector<std::vector<int64_t> >& dataShape,
36 std::vector<std::string>& nodeNames
37){
38 getNodeInfo(session, dataShape, nodeNames, true);
39}
40
42 const Ort::Session& session,
43 std::vector<std::vector<int64_t> >& dataShape,
44 std::vector<std::string>& nodeNames
45) {
46 getNodeInfo(session, dataShape, nodeNames, false);
47}
48
49void inferenceWithIOBinding(Ort::Session& session,
50 const std::vector<std::string>& inputNames,
51 const std::vector<Ort::Value>& inputData,
52 const std::vector<std::string>& outputNames,
53 const std::vector<Ort::Value>& outputData){
54
55 if (inputNames.empty()) {
56 throw std::runtime_error("Onnxruntime input data maping cannot be empty");
57 }
58 assert(inputNames.size() == inputData.size());
59
60 Ort::IoBinding iobinding(session);
61 for(size_t idx = 0; idx < inputNames.size(); ++idx){
62 iobinding.BindInput(inputNames[idx].data(), inputData[idx]);
63 }
64
65
66 for(size_t idx = 0; idx < outputNames.size(); ++idx){
67 iobinding.BindOutput(outputNames[idx].data(), outputData[idx]);
68 }
69
70 session.Run(Ort::RunOptions{nullptr}, iobinding);
71}
72
73int64_t getTensorSize(const std::vector<int64_t>& dataShape){
74 int64_t size = 1;
75 for (const auto& dim : dataShape) {
76 size *= dim;
77 }
78 return size;
79}
80
81
82} // namespace AthOnnx
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
Definition OnnxUtils.cxx:49
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
Definition OnnxUtils.cxx:73
void getNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames, bool isInput)
Definition OnnxUtils.cxx:9
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition OnnxUtils.cxx:41
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition OnnxUtils.cxx:33