ATLAS Offline Software
Loading...
Searching...
No Matches
IOnnxRuntimeInferenceTool.icc
Go to the documentation of this file.
1// Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
2template <typename T>
3Ort::Value AthOnnx::IOnnxRuntimeInferenceTool::createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const
4{
5 std::vector<int64_t> dataShapeCopy = dataShape;
6
7 if (batchSize > 0) {
8 for (auto& shape: dataShapeCopy) {
9 if (shape == -1) {
10 shape = batchSize;
11 break;
12 }
13 }
14 }
15
16 auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
17 return Ort::Value::CreateTensor<T>(
18 memoryInfo, data.data(), data.size(), dataShapeCopy.data(), dataShapeCopy.size()
19 );
20}
21
22template <typename T>
23StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, unsigned idx, int64_t batchSize) const
24{
25 if (idx >= m_numInputs ) {
26 return StatusCode::FAILURE;
27 }
28
29 inputTensors.push_back(std::move(createTensor(data, m_inputShapes[idx], batchSize)));
30 return StatusCode::SUCCESS;
31}
32
33template <typename T>
34StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, unsigned idx, int64_t batchSize) const
35{
36 if (idx >= m_numOutputs ) {
37 return StatusCode::FAILURE;
38 }
39 auto tensorSize = std::accumulate(m_outputShapes[idx].begin(), m_outputShapes[idx].end(), 1, std::multiplies<int64_t>());
40 if (tensorSize < 0) {
41 tensorSize = abs(tensorSize) * batchSize;
42 }
43 data.resize(tensorSize);
44 outputTensors.push_back(std::move(createTensor(data, m_outputShapes[idx], batchSize)));
45 return StatusCode::SUCCESS;
46}