ATLAS Offline Software
Loading...
Searching...
No Matches
EvaluateModelWithAsyncInfer.cxx
Go to the documentation of this file.
1// Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2
3// Local include(s).
5
6// Framework include(s).
8#include "EvaluateUtils.h"
10
11// Standard include(s)
12#include <algorithm>
13#include <ranges>
14#include <span>
15
16namespace AthOnnx {
17
19 if (m_batchSize.value() < 1) {
20 ATH_MSG_ERROR("Requested an invalid batch size: " << m_batchSize.value());
21 return StatusCode::FAILURE;
22 }
23
24 // Fetch tools
25 ATH_CHECK(m_onnxTool.retrieve());
26
27 // read input file, and the target file for comparison.
28 std::string pixelFilePath =
30 ATH_MSG_INFO("Using pixel file: " << pixelFilePath);
31
32 try {
36 "Total no. of samples: " << m_input_tensor_values_notFlat.size());
37 } catch (const std::exception& e) {
38 ATH_MSG_ERROR(e.what());
39 return StatusCode::FAILURE;
40 }
41
42 if (std::size_t(m_batchSize.value()) > m_input_tensor_values_notFlat.size()) {
43 ATH_MSG_ERROR("The batch size requested ("
44 << m_batchSize.value()
45 << ") is greater than the number of available "
46 "samples ("
47 << m_input_tensor_values_notFlat.size() << ")");
48 return StatusCode::FAILURE;
49 }
50
51 if (m_input_tensor_values_notFlat.size() % m_batchSize.value() != 0) {
52 ATH_MSG_ERROR("The number of samples ("
54 << ") is not a multiple of the requested batch size ("
55 << m_batchSize.value() << ")");
56 return StatusCode::FAILURE;
57 }
58 return StatusCode::SUCCESS;
59}
60
62 [[maybe_unused]] const EventContext& ctx) const {
63 // We know we have at least one image, otherwise we would have errored out
64 // earlier
65 const std::size_t n_batches =
67 const auto n_rows = std::int64_t(m_input_tensor_values_notFlat[0].size());
68 const auto n_cols = std::int64_t(m_input_tensor_values_notFlat[0][0].size());
69
70 for (std::size_t batch_idx = 0; batch_idx < n_batches; ++batch_idx) {
71 // prepare inputs
72 std::vector<float> inputDataVector;
73 inputDataVector.reserve(m_batchSize.value() * n_rows * n_cols);
74 for (const std::vector<std::vector<float>>& imageData :
76 std::views::drop(batch_idx * m_batchSize.value()) |
77 std::views::take(m_batchSize.value())) {
78 std::vector<float> flatten =
80 inputDataVector.insert(inputDataVector.end(), flatten.begin(),
81 flatten.end());
82 }
83
84 std::vector<int64_t> inputShape = {m_batchSize.value(), n_rows, n_cols};
85
86 AthInfer::InputDataMap inputData;
87 inputData["flatten_input:0"] =
88 std::make_pair(inputShape, std::move(inputDataVector));
89
90 const std::int64_t n_scores = 10;
91 AthInfer::OutputDataMap outputData;
92 outputData["dense_1/Softmax:0"] = std::make_pair(
93 std::vector<int64_t>{m_batchSize, n_scores}, std::vector<float>{});
94
95 ATH_CHECK(m_onnxTool->inference(inputData, outputData));
96
97 auto const& outputScores =
98 std::get<std::vector<float>>(outputData["dense_1/Softmax:0"].second);
99
100 if (outputScores.size() != std::size_t(n_scores * m_batchSize.value())) {
101 ATH_MSG_ERROR("Got back " << outputScores.size()
102 << " scores when it should have been "
103 << n_scores << " * " << m_batchSize.value()
104 << " = " << n_scores * m_batchSize.value());
105 return StatusCode::FAILURE;
106 }
107
108 for (int img_idx = 0; img_idx < m_batchSize.value(); img_idx++) {
109 std::span scores(outputScores.begin() + img_idx * n_scores,
110 outputScores.begin() + (img_idx + 1) * n_scores);
111 ATH_MSG_DEBUG("Scores for img " << img_idx << " of batch " << batch_idx
112 << ": "
113 << EvaluateUtils::spanToString(scores));
114 const auto max_elem = std::ranges::max_element(scores);
115 ATH_MSG_DEBUG("Class: " << max_elem - scores.begin()
116 << " has the highest score: " << *max_elem
117 << " in img " << img_idx << " of batch "
118 << batch_idx);
119 }
120 }
121 return StatusCode::SUCCESS;
122}
123
124} // namespace AthOnnx
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
size_t size() const
Number of registered mappings.
virtual StatusCode execute(const EventContext &ctx) const override
Function executing the algorithm for a single event.
Gaudi::Property< std::string > m_pixelFileName
Name of the input file to load.
std::vector< std::vector< std::vector< float > > > m_input_tensor_values_notFlat
Gaudi::Property< int > m_batchSize
Following properties needed to be consdered if the .onnx model is evaluated in batch mode.
virtual StatusCode initialize() override
Function initialising the algorithm.
ToolHandle< AthInfer::IAthInferenceTool > m_onnxTool
Tool handler for onnx inference session.
static std::string find_calib_file(const std::string &logical_file_name)
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap
std::vector< T > flattenNestedVectors(const std::vector< std::vector< T > > &features)
Definition OnnxUtils.h:24
Namespace holding all of the Onnx Runtime example code.
std::vector< std::vector< std::vector< float > > > read_mnist_pixel_notFlat(const std::string &full_path)