ATLAS Offline Software
Loading...
Searching...
No Matches
OnnxRuntimeInferenceTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
7
9 : asg::AsgTool ( name )
10{
11}
12
14{
15 // Get the Onnx Runtime service.
16 ATH_CHECK(m_onnxRuntimeSvc.retrieve());
17
18 // Create the session.
19 ATH_CHECK(m_onnxSessionTool.retrieve());
20
22
23 return StatusCode::SUCCESS;
24}
25
27{
28 auto& session = m_onnxSessionTool->session();
29 // obtain the model information
30 m_numInputs = session.GetInputCount();
31 m_numOutputs = session.GetOutputCount();
32
35
36 return StatusCode::SUCCESS;
37}
38
39
41{
42 if (batchSize <= 0) {
43 ATH_MSG_ERROR("Batch size should be positive");
44 return;
45 }
46
47 for (auto& shape : m_inputShapes) {
48 if (shape[0] == -1) {
49 shape[0] = batchSize;
50 }
51 }
52
53 for (auto& shape : m_outputShapes) {
54 if (shape[0] == -1) {
55 shape[0] = batchSize;
56 }
57 }
58}
59
60int64_t AthOnnx::OnnxRuntimeInferenceTool::getBatchSize(int64_t inputDataSize, int idx) const
61{
62 auto tensorSize = AthOnnxUtils::getTensorSize(m_inputShapes[idx]);
63 if (tensorSize < 0) {
64 return inputDataSize / abs(tensorSize);
65 } else {
66 return -1;
67 }
68}
69
70StatusCode AthOnnx::OnnxRuntimeInferenceTool::inference(std::vector<Ort::Value>& inputTensors, std::vector<Ort::Value>& outputTensors) const
71{
72 assert (inputTensors.size() == m_numInputs);
73 assert (outputTensors.size() == m_numOutputs);
74
75 // Run the model.
77 m_onnxSessionTool->session(),
78 m_inputNodeNames, inputTensors,
79 m_outputNodeNames, outputTensors);
80
81 return StatusCode::SUCCESS;
82}
83
85{
86 ATH_MSG_INFO("Number of inputs: " << m_numInputs);
87 ATH_MSG_INFO("Number of outputs: " << m_numOutputs);
88
89 ATH_MSG_INFO("Input node names: ");
90 for (const auto& name : m_inputNodeNames) {
91 ATH_MSG_INFO("\t" << name);
92 }
93
94 ATH_MSG_INFO("Output node names: ");
95 for (const auto& name : m_outputNodeNames) {
96 ATH_MSG_INFO("\t" << name);
97 }
98
99 ATH_MSG_INFO("Input shapes: ");
100 for (const auto& shape : m_inputShapes) {
101 std::string shapeStr = "\t";
102 for (const auto& dim : shape) {
103 shapeStr += std::to_string(dim) + " ";
104 }
105 ATH_MSG_INFO(shapeStr);
106 }
107
108 ATH_MSG_INFO("Output shapes: ");
109 for (const auto& shape : m_outputShapes) {
110 std::string shapeStr = "\t";
111 for (const auto& dim : shape) {
112 shapeStr += std::to_string(dim) + " ";
113 }
114 ATH_MSG_INFO(shapeStr);
115 }
116}
117
119{
120 // Create input tensors.
121 std::vector<Ort::Value> inputTensors;
122 for (auto& [inputName, inputInfo] : inputData) {
123 const std::vector<int64_t>& shape = inputInfo.first;
124 if (std::holds_alternative<std::vector<float>>(inputInfo.second)) {
125 auto& data = std::get<std::vector<float>>(inputInfo.second);
126 inputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
127 } else if (std::holds_alternative<std::vector<int64_t>>(inputInfo.second)) {
128 auto& data = std::get<std::vector<int64_t>>(inputInfo.second);
129 inputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
130 } else {
131 ATH_MSG_ERROR("Unsupported data type");
132 return StatusCode::FAILURE;
133 }
134 }
135
136 // Create output tensors.
137 std::vector<Ort::Value> outputTensors;
138 outputTensors.reserve(inputData.size());
139 for (auto& outName : m_outputNodeNames) {
140 if (outputData.find(outName) == outputData.end()) {
141 ATH_MSG_ERROR("Output name " << outName << " not found in output data map");
142 return StatusCode::FAILURE;
143 }
144 auto& outputInfo = outputData.at(outName);
145 auto& shape = outputInfo.first;
146 auto tensorSize = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
147
148 if (std::holds_alternative<std::vector<float>>(outputInfo.second)) {
149 auto& data = std::get<std::vector<float>>(outputInfo.second);
150 data.resize(tensorSize);
151 outputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
152 } else if (std::holds_alternative<std::vector<int64_t>>(outputInfo.second)) {
153 auto& data = std::get<std::vector<int64_t>>(outputInfo.second);
154 data.resize(tensorSize);
155 outputTensors.push_back(AthOnnxUtils::createTensor(data, shape));
156 } else {
157 ATH_MSG_ERROR("Unsupported data type");
158 return StatusCode::FAILURE;
159 }
160 }
161
162 ATH_CHECK(inference(inputTensors, outputTensors));
163
164 return StatusCode::SUCCESS;
165}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
std::vector< std::vector< int64_t > > m_outputShapes
std::vector< std::vector< int64_t > > m_inputShapes
ToolHandle< IOnnxRuntimeSessionTool > m_onnxSessionTool
virtual StatusCode inference(std::vector< Ort::Value > &inputTensors, std::vector< Ort::Value > &outputTensors) const override final
perform inference
std::vector< std::string > m_inputNodeNames
virtual int64_t getBatchSize(int64_t inputDataSize, int idx=0) const override final
methods for determining batch size from the data size
virtual StatusCode initialize() override
Initialize the tool.
virtual void printModelInfo() const override final
ServiceHandle< IOnnxRuntimeSvc > m_onnxRuntimeSvc
std::vector< std::string > m_outputNodeNames
virtual void setBatchSize(int64_t batchSize) override final
set batch size.
AsgTool(const std::string &name)
Constructor specifying the tool instance's name.
Definition AsgTool.cxx:58
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap
void inferenceWithIOBinding(Ort::Session &session, const std::vector< std::string > &inputNames, const std::vector< Ort::Value > &inputData, const std::vector< std::string > &outputNames, const std::vector< Ort::Value > &outputData)
Definition OnnxUtils.cxx:49
int64_t getTensorSize(const std::vector< int64_t > &dataShape)
Definition OnnxUtils.cxx:73
void getOutputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition OnnxUtils.cxx:41
void getInputNodeInfo(const Ort::Session &session, std::vector< std::vector< int64_t > > &dataShape, std::vector< std::string > &nodeNames)
Definition OnnxUtils.cxx:33
Ort::Value createTensor(std::vector< T > &data, const std::vector< int64_t > &dataShape)
Definition OnnxUtils.h:78