ATLAS Offline Software
OnnxRuntimeInferenceTool.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
7 
9  : asg::AsgTool ( name )
10 {
11  declareProperty("OnnxSessionTool", m_onnxSessionTool, "The Onnx session tool");
12  declareProperty("OnnxRuntimeSvc", m_onnxRuntimeSvc, "The Onnx runtime service");
13 }
14 
16 {
17  // Get the Onnx Runtime service.
18  ATH_CHECK(m_onnxRuntimeSvc.retrieve());
19 
20  // Create the session.
21  ATH_CHECK(m_onnxSessionTool.retrieve());
22 
24 
25  return StatusCode::SUCCESS;
26 }
27 
29 {
30  auto& session = m_onnxSessionTool->session();
31  // obtain the model information
32  m_numInputs = session.GetInputCount();
33  m_numOutputs = session.GetOutputCount();
34 
35  AthOnnxUtils::getInputNodeInfo(session, m_inputShapes, m_inputNodeNames);
36  AthOnnxUtils::getOutputNodeInfo(session, m_outputShapes, m_outputNodeNames);
37 
38  return StatusCode::SUCCESS;
39 }
40 
41 
43 {
44  if (batchSize <= 0) {
45  ATH_MSG_ERROR("Batch size should be positive");
46  return;
47  }
48 
49  for (auto& shape : m_inputShapes) {
50  if (shape[0] == -1) {
51  shape[0] = batchSize;
52  }
53  }
54 
55  for (auto& shape : m_outputShapes) {
56  if (shape[0] == -1) {
57  shape[0] = batchSize;
58  }
59  }
60 }
61 
62 int64_t AthOnnx::OnnxRuntimeInferenceTool::getBatchSize(int64_t inputDataSize, int idx) const
63 {
64  auto tensorSize = AthOnnxUtils::getTensorSize(m_inputShapes[idx]);
65  if (tensorSize < 0) {
66  return inputDataSize / abs(tensorSize);
67  } else {
68  return -1;
69  }
70 }
71 
72 StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const
73 {
74  assert (inputTensors.size() == m_numInputs);
75  assert (outputTensors.size() == m_numOutputs);
76 
77  // Run the model.
79  m_onnxSessionTool->session(),
80  m_inputNodeNames, inputTensors,
81  m_outputNodeNames, outputTensors);
82 
83  return StatusCode::SUCCESS;
84 }
85 
87 {
88  ATH_MSG_INFO("Number of inputs: " << m_numInputs);
89  ATH_MSG_INFO("Number of outputs: " << m_numOutputs);
90 
91  ATH_MSG_INFO("Input node names: ");
92  for (const auto& name : m_inputNodeNames) {
93  ATH_MSG_INFO("\t" << name);
94  }
95 
96  ATH_MSG_INFO("Output node names: ");
97  for (const auto& name : m_outputNodeNames) {
98  ATH_MSG_INFO("\t" << name);
99  }
100 
101  ATH_MSG_INFO("Input shapes: ");
102  for (const auto& shape : m_inputShapes) {
103  std::string shapeStr = "\t";
104  for (const auto& dim : shape) {
105  shapeStr += std::to_string(dim) + " ";
106  }
107  ATH_MSG_INFO(shapeStr);
108  }
109 
110  ATH_MSG_INFO("Output shapes: ");
111  for (const auto& shape : m_outputShapes) {
112  std::string shapeStr = "\t";
113  for (const auto& dim : shape) {
114  shapeStr += std::to_string(dim) + " ";
115  }
116  ATH_MSG_INFO(shapeStr);
117  }
118 }
119 
121 {
122  // Create input tensors.
123  std::vector<Ort::Value> inputTensors;
124  for (auto& [inputName, inputInfo] : inputData) {
125  const std::vector<int64_t>& shape = inputInfo.first;
126  if (std::holds_alternative<std::vector<float>>(inputInfo.second)) {
127  auto& data = std::get<std::vector<float>>(inputInfo.second);
128  inputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
129  } else if (std::holds_alternative<std::vector<int64_t>>(inputInfo.second)) {
130  auto& data = std::get<std::vector<int64_t>>(inputInfo.second);
131  inputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
132  } else {
133  ATH_MSG_ERROR("Unsupported data type");
134  return StatusCode::FAILURE;
135  }
136  }
137 
138  // Create output tensors.
139  std::vector<Ort::Value> outputTensors;
140  outputTensors.reserve(inputData.size());
141  for (auto& [outputName, outputInfo] : outputData) {
142  auto& shape = outputInfo.first;
143  auto tensorSize = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
144 
145  if (std::holds_alternative<std::vector<float>>(outputInfo.second)) {
146  auto& data = std::get<std::vector<float>>(outputInfo.second);
147  data.resize(tensorSize);
148  outputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
149  } else if (std::holds_alternative<std::vector<int64_t>>(outputInfo.second)) {
150  auto& data = std::get<std::vector<int64_t>>(outputInfo.second);
151  data.resize(tensorSize);
152  outputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
153  } else {
154  ATH_MSG_ERROR("Unsupported data type");
155  return StatusCode::FAILURE;
156  }
157  }
158 
159  ATH_CHECK(inference(inputTensors, outputTensors));
160 
161  return StatusCode::SUCCESS;
162 }
AthOnnx::OnnxRuntimeInferenceTool::inference
virtual StatusCode inference(std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
perform inference
Definition: OnnxRuntimeInferenceTool.cxx:72
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
AthOnnx::OnnxRuntimeInferenceTool::initialize
virtual StatusCode initialize() override
Initialize the tool.
Definition: OnnxRuntimeInferenceTool.cxx:15
AthOnnxUtils::inferenceWithIOBinding
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
Definition: OnnxUtils.cxx:49
yodamerge_tmp.dim
dim
Definition: yodamerge_tmp.py:239
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
Definition: AthCommonDataStore.h:145
AthOnnx::OnnxRuntimeInferenceTool::m_onnxRuntimeSvc
ServiceHandle< IOnnxRuntimeSvc > m_onnxRuntimeSvc
Definition: OnnxRuntimeInferenceTool.h:50
accumulate
bool accumulate(AccumulateMap &map, std::vector< module_t > const &modules, FPGATrackSimMatrixAccumulator const &acc)
Accumulates an accumulator (e.g.
Definition: FPGATrackSimMatrixAccumulator.cxx:22
asg
Definition: DataHandleTestTool.h:28
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
AthOnnx::OnnxRuntimeInferenceTool::setBatchSize
virtual void setBatchSize(int64_t batchSize) override final
set batch size.
Definition: OnnxRuntimeInferenceTool.cxx:42
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
AthInfer::OutputDataMap
std::map< std::string, InferenceData > OutputDataMap
Definition: IAthInferenceTool.h:17
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
AthOnnxUtils::getTensorSize
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
Definition: OnnxUtils.cxx:73
AthOnnxUtils::getNodeInfo
void getNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames, bool isInput)
Definition: OnnxUtils.cxx:9
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:221
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
lumiFormat.outputName
string outputName
Definition: lumiFormat.py:65
AthOnnx::OnnxRuntimeInferenceTool::printModelInfo
virtual void printModelInfo() const override final
Definition: OnnxRuntimeInferenceTool.cxx:86
AthOnnxUtils::getInputNodeInfo
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:33
AthOnnxUtils::getOutputNodeInfo
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:41
AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool
OnnxRuntimeInferenceTool()=delete
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
OnnxUtils.h
AthOnnx::OnnxRuntimeInferenceTool::m_onnxSessionTool
ToolHandle< IOnnxRuntimeSessionTool > m_onnxSessionTool
Definition: OnnxRuntimeInferenceTool.h:51
OnnxRuntimeInferenceTool.h
AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo
StatusCode getNodeInfo()
Definition: OnnxRuntimeInferenceTool.cxx:28
AthOnnxUtils::createTensor
Ort::Value createTensor(std::vector< T > &data, const std::vector< int64_t > &dataShape)
Definition: OnnxUtils.h:78
AthInfer::InputDataMap
std::map< std::string, InferenceData > InputDataMap
Definition: IAthInferenceTool.h:16
AthOnnx::OnnxRuntimeInferenceTool::getBatchSize
virtual int64_t getBatchSize(int64_t inputDataSize, int idx=0) const override final
methods for determining batch size from the data size
Definition: OnnxRuntimeInferenceTool.cxx:62