ATLAS Offline Software
OnnxRuntimeInferenceTool.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 
7 
9  : asg::AsgTool ( name )
10 {
11  declareProperty("OnnxSessionTool", m_onnxSessionTool, "The Onnx session tool");
12  declareProperty("OnnxRuntimeSvc", m_onnxRuntimeSvc, "The Onnx runtime service");
13 }
14 
16 {
17  // Get the Onnx Runtime service.
18  ATH_CHECK(m_onnxRuntimeSvc.retrieve());
19 
20  // Create the session.
21  ATH_CHECK(m_onnxSessionTool.retrieve());
22 
24 
25  return StatusCode::SUCCESS;
26 }
27 
29 {
30  auto& session = m_onnxSessionTool->session();
31  // obtain the model information
32  m_numInputs = session.GetInputCount();
33  m_numOutputs = session.GetOutputCount();
34 
35  AthOnnx::getInputNodeInfo(session, m_inputShapes, m_inputNodeNames);
36  AthOnnx::getOutputNodeInfo(session, m_outputShapes, m_outputNodeNames);
37 
38  return StatusCode::SUCCESS;
39 }
40 
41 
43 {
44  if (batchSize <= 0) {
45  ATH_MSG_ERROR("Batch size should be positive");
46  return;
47  }
48 
49  for (auto& shape : m_inputShapes) {
50  if (shape[0] == -1) {
51  shape[0] = batchSize;
52  }
53  }
54 
55  for (auto& shape : m_outputShapes) {
56  if (shape[0] == -1) {
57  shape[0] = batchSize;
58  }
59  }
60 }
61 
62 int64_t AthOnnx::OnnxRuntimeInferenceTool::getBatchSize(int64_t inputDataSize, int idx) const
63 {
64  auto tensorSize = AthOnnx::getTensorSize(m_inputShapes[idx]);
65  if (tensorSize < 0) {
66  return inputDataSize / abs(tensorSize);
67  } else {
68  return -1;
69  }
70 }
71 
72 StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const
73 {
74  assert (inputTensors.size() == m_numInputs);
75  assert (outputTensors.size() == m_numOutputs);
76 
77  // Run the model.
79  m_onnxSessionTool->session(),
80  m_inputNodeNames, inputTensors,
81  m_outputNodeNames, outputTensors);
82 
83  return StatusCode::SUCCESS;
84 }
85 
87 {
88  ATH_MSG_INFO("Number of inputs: " << m_numInputs);
89  ATH_MSG_INFO("Number of outputs: " << m_numOutputs);
90 
91  ATH_MSG_INFO("Input node names: ");
92  for (const auto& name : m_inputNodeNames) {
93  ATH_MSG_INFO("\t" << name);
94  }
95 
96  ATH_MSG_INFO("Output node names: ");
97  for (const auto& name : m_outputNodeNames) {
98  ATH_MSG_INFO("\t" << name);
99  }
100 
101  ATH_MSG_INFO("Input shapes: ");
102  for (const auto& shape : m_inputShapes) {
103  std::string shapeStr = "\t";
104  for (const auto& dim : shape) {
105  shapeStr += std::to_string(dim) + " ";
106  }
107  ATH_MSG_INFO(shapeStr);
108  }
109 
110  ATH_MSG_INFO("Output shapes: ");
111  for (const auto& shape : m_outputShapes) {
112  std::string shapeStr = "\t";
113  for (const auto& dim : shape) {
114  shapeStr += std::to_string(dim) + " ";
115  }
116  ATH_MSG_INFO(shapeStr);
117  }
118 }
AthOnnx::OnnxRuntimeInferenceTool::inference
virtual StatusCode inference(std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
perform inference
Definition: OnnxRuntimeInferenceTool.cxx:72
AthOnnx::OnnxRuntimeInferenceTool::initialize
virtual StatusCode initialize() override
Initialize the tool.
Definition: OnnxRuntimeInferenceTool.cxx:15
yodamerge_tmp.dim
dim
Definition: yodamerge_tmp.py:239
ATH_MSG_INFO
#define ATH_MSG_INFO(x)
Definition: AthMsgStreamMacros.h:31
AthCommonDataStore< AthCommonMsg< AlgTool > >::declareProperty
Gaudi::Details::PropertyBase & declareProperty(Gaudi::Property< T > &t)
Definition: AthCommonDataStore.h:145
AthOnnx::OnnxRuntimeInferenceTool::m_onnxRuntimeSvc
ServiceHandle< IOnnxRuntimeSvc > m_onnxRuntimeSvc
Definition: OnnxRuntimeInferenceTool.h:46
asg
Definition: DataHandleTestTool.h:28
AthOnnx::getTensorSize
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
Definition: OnnxUtils.cxx:73
ATH_MSG_ERROR
#define ATH_MSG_ERROR(x)
Definition: AthMsgStreamMacros.h:33
AthOnnx::OnnxRuntimeInferenceTool::setBatchSize
virtual void setBatchSize(int64_t batchSize) override final
set batch size.
Definition: OnnxRuntimeInferenceTool.cxx:42
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
AthOnnx::getNodeInfo
void getNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames, bool isInput)
Definition: OnnxUtils.cxx:9
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
ActsTrk::to_string
std::string to_string(const DetectorType &type)
Definition: GeometryDefs.h:34
AthOnnx::inferenceWithIOBinding
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
Definition: OnnxUtils.cxx:49
AthOnnx::OnnxRuntimeInferenceTool::printModelInfo
virtual void printModelInfo() const override final
Definition: OnnxRuntimeInferenceTool.cxx:86
AthOnnx::getInputNodeInfo
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:33
AthOnnx::OnnxRuntimeInferenceTool::OnnxRuntimeInferenceTool
OnnxRuntimeInferenceTool()=delete
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
OnnxUtils.h
AthOnnx::OnnxRuntimeInferenceTool::m_onnxSessionTool
ToolHandle< IOnnxRuntimeSessionTool > m_onnxSessionTool
Definition: OnnxRuntimeInferenceTool.h:47
AthOnnx::getOutputNodeInfo
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition: OnnxUtils.cxx:41
OnnxRuntimeInferenceTool.h
AthOnnx::OnnxRuntimeInferenceTool::getNodeInfo
StatusCode getNodeInfo()
Definition: OnnxRuntimeInferenceTool.cxx:28
AthOnnx::OnnxRuntimeInferenceTool::getBatchSize
virtual int64_t getBatchSize(int64_t inputDataSize, int idx=0) const override final
methods for determining batch size from the data size
Definition: OnnxRuntimeInferenceTool.cxx:62