ATLAS Offline Software
Loading...
Searching...
No Matches
EvaluateModelWithAsyncInfer.cxx
Go to the documentation of this file.
1// Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2
3// Local include(s).
5
6// Framework include(s).
8#include "EvaluateUtils.h"
10
11// Library include(s)
12#include <fmt/format.h>
13#include <fmt/ranges.h>
14
15// Standard include(s)
16#include <algorithm>
17#include <ranges>
18#include <span>
19
20namespace AthOnnx {
21
23 if (m_batchSize.value() < 1) {
24 ATH_MSG_ERROR("Requested an invalid batch size: " << m_batchSize.value());
25 return StatusCode::FAILURE;
26 }
27
28 // Fetch tools
29 ATH_CHECK(m_onnxTool.retrieve());
30
31 // read input file, and the target file for comparison.
32 std::string pixelFilePath =
34 ATH_MSG_INFO("Using pixel file: " << pixelFilePath);
35
36 try {
40 "Total no. of samples: " << m_input_tensor_values_notFlat.size());
41 } catch (const std::exception& e) {
42 ATH_MSG_ERROR(e.what());
43 return StatusCode::FAILURE;
44 }
45
46 if (std::size_t(m_batchSize.value()) > m_input_tensor_values_notFlat.size()) {
47 ATH_MSG_ERROR("The batch size requested ("
48 << m_batchSize.value()
49 << ") is greater than the number of available "
50 "samples ("
51 << m_input_tensor_values_notFlat.size() << ")");
52 return StatusCode::FAILURE;
53 }
54
55 if (m_input_tensor_values_notFlat.size() % m_batchSize.value() != 0) {
56 ATH_MSG_ERROR("The number of samples ("
58 << ") is not a multiple of the requested batch size ("
59 << m_batchSize.value() << ")");
60 return StatusCode::FAILURE;
61 }
62 return StatusCode::SUCCESS;
63}
64
66 [[maybe_unused]] const EventContext& ctx) const {
67 // We know we have at least one image, otherwise we would have errored out
68 // earlier
69 const std::size_t n_batches =
71 const auto n_rows = std::int64_t(m_input_tensor_values_notFlat[0].size());
72 const auto n_cols = std::int64_t(m_input_tensor_values_notFlat[0][0].size());
73
74 for (std::size_t batch_idx = 0; batch_idx < n_batches; ++batch_idx) {
75 // prepare inputs
76 std::vector<float> inputDataVector;
77 inputDataVector.reserve(m_batchSize.value() * n_rows * n_cols);
78 for (const std::vector<std::vector<float>>& imageData :
80 std::views::drop(batch_idx * m_batchSize.value()) |
81 std::views::take(m_batchSize.value())) {
82 std::vector<float> flatten =
84 inputDataVector.insert(inputDataVector.end(), flatten.begin(),
85 flatten.end());
86 }
87
88 std::vector<int64_t> inputShape = {m_batchSize.value(), n_rows, n_cols};
89
90 AthInfer::InputDataMap inputData;
91 inputData["flatten_input:0"] =
92 std::make_pair(inputShape, std::move(inputDataVector));
93
94 const std::int64_t n_scores = 10;
95 AthInfer::OutputDataMap outputData;
96 outputData["dense_1/Softmax:0"] = std::make_pair(
97 std::vector<int64_t>{m_batchSize, n_scores}, std::vector<float>{});
98
99 ATH_CHECK(m_onnxTool->inference(inputData, outputData));
100
101 auto const& outputScores =
102 std::get<std::vector<float>>(outputData["dense_1/Softmax:0"].second);
103
104 if (outputScores.size() != std::size_t(n_scores * m_batchSize.value())) {
105 ATH_MSG_ERROR("Got back " << outputScores.size()
106 << " scores when it should have been "
107 << n_scores << " * " << m_batchSize.value()
108 << " = " << n_scores * m_batchSize.value());
109 return StatusCode::FAILURE;
110 }
111
112 for (int img_idx = 0; img_idx < m_batchSize.value(); img_idx++) {
113 std::span scores(outputScores.begin() + img_idx * n_scores,
114 outputScores.begin() + (img_idx + 1) * n_scores);
115 ATH_MSG_DEBUG("Scores for img " << img_idx << " of batch " << batch_idx
116 << ": "
117 << fmt::format("{::.2e}", scores));
118 const auto max_elem = std::ranges::max_element(scores);
119 ATH_MSG_DEBUG("Class: " << max_elem - scores.begin()
120 << " has the highest score: " << *max_elem
121 << " in img " << img_idx << " of batch "
122 << batch_idx);
123 }
124 }
125 return StatusCode::SUCCESS;
126}
127
128} // namespace AthOnnx
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
size_t size() const
Number of registered mappings.
virtual StatusCode execute(const EventContext &ctx) const override
Function executing the algorithm for a single event.
Gaudi::Property< std::string > m_pixelFileName
Name of the input file to load.
std::vector< std::vector< std::vector< float > > > m_input_tensor_values_notFlat
Gaudi::Property< int > m_batchSize
Following properties needed to be consdered if the .onnx model is evaluated in batch mode.
virtual StatusCode initialize() override
Function initialising the algorithm.
ToolHandle< AthInfer::IAthInferenceTool > m_onnxTool
Tool handler for onnx inference session.
static std::string find_calib_file(const std::string &logical_file_name)
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap
std::vector< T > flattenNestedVectors(const std::vector< std::vector< T > > &features)
Definition OnnxUtils.h:24
Namespace holding all of the Onnx Runtime example code.
std::vector< std::vector< std::vector< float > > > read_mnist_pixel_notFlat(const std::string &full_path)