ATLAS Offline Software
IOnnxRuntimeInferenceTool.icc
Go to the documentation of this file.
1 // Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
2 template <typename T>
3 Ort::Value AthOnnx::IOnnxRuntimeInferenceTool::createTensor(std::vector<T>& data, const std::vector<int64_t>& dataShape, int64_t batchSize) const
4 {
5  std::vector<int64_t> dataShapeCopy = dataShape;
6 
7  if (batchSize > 0) {
8  for (auto& shape: dataShapeCopy) {
9  if (shape == -1) {
10  shape = batchSize;
11  break;
12  }
13  }
14  }
15 
16  auto memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
17  return Ort::Value::CreateTensor<T>(
18  memoryInfo, data.data(), data.size(), dataShapeCopy.data(), dataShapeCopy.size()
19  );
20 }
21 
22 template <typename T>
23 StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addInput(std::vector<Ort::Value>& inputTensors, std::vector<T>& data, unsigned idx, int64_t batchSize) const
24 {
25  if (idx >= m_numInputs ) {
26  return StatusCode::FAILURE;
27  }
28 
29  inputTensors.push_back(std::move(createTensor(data, m_inputShapes[idx], batchSize)));
30  return StatusCode::SUCCESS;
31 }
32 
33 template <typename T>
34 StatusCode AthOnnx::IOnnxRuntimeInferenceTool::addOutput(std::vector<Ort::Value>& outputTensors, std::vector<T>& data, unsigned idx, int64_t batchSize) const
35 {
36  if (idx >= m_numOutputs ) {
37  return StatusCode::FAILURE;
38  }
39  auto tensorSize = std::accumulate(m_outputShapes[idx].begin(), m_outputShapes[idx].end(), 1, std::multiplies<int64_t>());
40  if (tensorSize < 0) {
41  tensorSize = abs(tensorSize) * batchSize;
42  }
43  data.resize(tensorSize);
44  outputTensors.push_back(std::move(createTensor(data, m_outputShapes[idx], batchSize)));
45  return StatusCode::SUCCESS;
46 }