Loading [MathJax]/extensions/tex2jax.js
ATLAS Offline Software
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Modules Pages
OnnxRuntimeInferenceTool.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
7 
9  : asg::AsgTool ( name )
10 {
11 }
12 
14 {
15  // Get the Onnx Runtime service.
16  ATH_CHECK(m_onnxRuntimeSvc.retrieve());
17 
18  // Create the session.
19  ATH_CHECK(m_onnxSessionTool.retrieve());
20 
22 
23  return StatusCode::SUCCESS;
24 }
25 
27 {
28  auto& session = m_onnxSessionTool->session();
29  // obtain the model information
30  m_numInputs = session.GetInputCount();
31  m_numOutputs = session.GetOutputCount();
32 
33  AthOnnxUtils::getInputNodeInfo(session, m_inputShapes, m_inputNodeNames);
34  AthOnnxUtils::getOutputNodeInfo(session, m_outputShapes, m_outputNodeNames);
35 
36  return StatusCode::SUCCESS;
37 }
38 
39 
41 {
42  if (batchSize <= 0) {
43  ATH_MSG_ERROR("Batch size should be positive");
44  return;
45  }
46 
47  for (auto& shape : m_inputShapes) {
48  if (shape[0] == -1) {
49  shape[0] = batchSize;
50  }
51  }
52 
53  for (auto& shape : m_outputShapes) {
54  if (shape[0] == -1) {
55  shape[0] = batchSize;
56  }
57  }
58 }
59 
60 int64_t AthOnnx::OnnxRuntimeInferenceTool::getBatchSize(int64_t inputDataSize, int idx) const
61 {
62  auto tensorSize = AthOnnxUtils::getTensorSize(m_inputShapes[idx]);
63  if (tensorSize < 0) {
64  return inputDataSize / abs(tensorSize);
65  } else {
66  return -1;
67  }
68 }
69 
70 StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const
71 {
72  assert (inputTensors.size() == m_numInputs);
73  assert (outputTensors.size() == m_numOutputs);
74 
75  // Run the model.
77  m_onnxSessionTool->session(),
78  m_inputNodeNames, inputTensors,
79  m_outputNodeNames, outputTensors);
80 
81  return StatusCode::SUCCESS;
82 }
83 
85 {
86  ATH_MSG_INFO("Number of inputs: " << m_numInputs);
87  ATH_MSG_INFO("Number of outputs: " << m_numOutputs);
88 
89  ATH_MSG_INFO("Input node names: ");
90  for (const auto& name : m_inputNodeNames) {
91  ATH_MSG_INFO("\t" << name);
92  }
93 
94  ATH_MSG_INFO("Output node names: ");
95  for (const auto& name : m_outputNodeNames) {
96  ATH_MSG_INFO("\t" << name);
97  }
98 
99  ATH_MSG_INFO("Input shapes: ");
100  for (const auto& shape : m_inputShapes) {
101  std::string shapeStr = "\t";
102  for (const auto& dim : shape) {
103  shapeStr += std::to_string(dim) + " ";
104  }
105  ATH_MSG_INFO(shapeStr);
106  }
107 
108  ATH_MSG_INFO("Output shapes: ");
109  for (const auto& shape : m_outputShapes) {
110  std::string shapeStr = "\t";
111  for (const auto& dim : shape) {
112  shapeStr += std::to_string(dim) + " ";
113  }
114  ATH_MSG_INFO(shapeStr);
115  }
116 }
117 
119 {
120  // Create input tensors.
121  std::vector<Ort::Value> inputTensors;
122  for (auto& [inputName, inputInfo] : inputData) {
123  const std::vector<int64_t>& shape = inputInfo.first;
124  if (std::holds_alternative<std::vector<float>>(inputInfo.second)) {
125  auto& data = std::get<std::vector<float>>(inputInfo.second);
126  inputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
127  } else if (std::holds_alternative<std::vector<int64_t>>(inputInfo.second)) {
128  auto& data = std::get<std::vector<int64_t>>(inputInfo.second);
129  inputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
130  } else {
131  ATH_MSG_ERROR("Unsupported data type");
132  return StatusCode::FAILURE;
133  }
134  }
135 
136  // Create output tensors.
137  std::vector<Ort::Value> outputTensors;
138  outputTensors.reserve(inputData.size());
139  for (auto& [outputName, outputInfo] : outputData) {
140  auto& shape = outputInfo.first;
141  auto tensorSize = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
142 
143  if (std::holds_alternative<std::vector<float>>(outputInfo.second)) {
144  auto& data = std::get<std::vector<float>>(outputInfo.second);
145  data.resize(tensorSize);
146  outputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
147  } else if (std::holds_alternative<std::vector<int64_t>>(outputInfo.second)) {
148  auto& data = std::get<std::vector<int64_t>>(outputInfo.second);
149  data.resize(tensorSize);
150  outputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
151  } else {
152  ATH_MSG_ERROR("Unsupported data type");
153  return StatusCode::FAILURE;
154  }
155  }
156 
157  ATH_CHECK(inference(inputTensors, outputTensors));
158 
159  return StatusCode::SUCCESS;
160 }
AthOnnx::OnnxRuntimeInferenceTool::inference
virtual StatusCode inference(std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
perform inference
Definition: OnnxRuntimeInferenceTool.cxx:70
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
AthOnnx::OnnxRuntimeInferenceTool::initialize
virtual StatusCode initialize() override
Initialize the tool.
Definition: OnnxRuntimeInferenceTool.cxx:13
AthOnnxUtils::inferenceWithIOBinding
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
Definition: OnnxUtils.cxx:49
yodamerge_tmp.dim
dim
Definition: yodamerge_tmp.py:239
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
asg
Definition: DataHandleTestTool.h:28
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
AthOnnx::OnnxRuntimeInferenceTool::setBatchSize
virtual void setBatchSize(int64_t batchSize) override final
set batch size.
Definition: OnnxRuntimeInferenceTool.cxx:40
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
AthInfer::OutputDataMap
std::map< std::string, InferenceData > OutputDataMap
Definition: IAthInferenceTool.h:17
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
AthOnnxUtils::getTensorSize
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
Definition: OnnxUtils.cxx:73
AthOnnxUtils::getNodeInfo
void getNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames, bool isInput)
Definition: OnnxUtils.cxx:9
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:240
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
lumiFormat.outputName
string outputName
Definition: lumiFormat.py:65
AthOnnx::OnnxRuntimeInferenceTool::printModelInfo
virtual void printModelInfo() const override final
Definition: OnnxRuntimeInferenceTool.cxx:84
runIDAlign.accumulate
accumulate
Update flags based on parser line args.
Definition: runIDAlign.py:63
AthOnnxUtils::getInputNodeInfo
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:33
AthOnnxUtils::getOutputNodeInfo
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:41
AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool
OnnxRuntimeInferenceTool()=delete
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
OnnxUtils.h
OnnxRuntimeInferenceTool.h
AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo
StatusCode getNodeInfo()
Definition: OnnxRuntimeInferenceTool.cxx:26
AthOnnxUtils::createTensor
Ort::Value createTensor(std::vector< T > &data, const std::vector< int64_t > &dataShape)
Definition: OnnxUtils.h:78
AthInfer::InputDataMap
std::map< std::string, InferenceData > InputDataMap
Definition: IAthInferenceTool.h:16
AthOnnx::OnnxRuntimeInferenceTool::getBatchSize
virtual int64_t getBatchSize(int64_t inputDataSize, int idx=0) const override final
methods for determining batch size from the data size
Definition: OnnxRuntimeInferenceTool.cxx:60