ATLAS Offline Software
Loading...
Searching...
No Matches
ExampleMLInferenceWithTriton.cxx
Go to the documentation of this file.
1// Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2
3// Local include(s).
5
6#include "EvaluateUtils.h"
7
8// Framework include(s).
9#include <arpa/inet.h>
10
12
13// Library include(s)
14#include <fmt/format.h>
15#include <fmt/ranges.h>
16
17// Standard include(s)
18#include <ranges>
19#include <utility> //std::pair
20
21namespace AthInfer {
22
24 if (m_batchSize.value() < 1) {
25 ATH_MSG_ERROR("Requested an invalid batch size: " << m_batchSize.value());
26 return StatusCode::FAILURE;
27 }
28
29 // Fetch tools
30 ATH_CHECK(m_tritonTool.retrieve());
31
32 // read input file, and the target file for comparison.
33 std::string pixelFilePath =
35 ATH_MSG_INFO("Using pixel file: " << pixelFilePath);
36
37 try {
41 "Total no. of samples: " << m_input_tensor_values_notFlat.size());
42 } catch (const std::exception& e) {
43 ATH_MSG_ERROR(e.what());
44 return StatusCode::FAILURE;
45 }
46
47 if (std::size_t(m_batchSize.value()) > m_input_tensor_values_notFlat.size()) {
48 ATH_MSG_ERROR("The batch size requested ("
49 << m_batchSize.value()
50 << ") is greater than the number of available "
51 "samples ("
52 << m_input_tensor_values_notFlat.size() << ")");
53 return StatusCode::FAILURE;
54 }
55
56 if (m_input_tensor_values_notFlat.size() % m_batchSize.value() != 0) {
57 ATH_MSG_ERROR("The number of samples ("
59 << ") is not a multiple of the requested batch size ("
60 << m_batchSize.value() << ")");
61 return StatusCode::FAILURE;
62 }
63 ATH_MSG_INFO("Running " << m_input_tensor_values_notFlat.size() /
64 m_batchSize.value()
65 << " batches of " << m_batchSize.value());
66 return StatusCode::SUCCESS;
67}
68
70 [[maybe_unused]] const EventContext& ctx) const {
71 // We know we have at least one image, otherwise we would have errored out
72 // earlier
73 const std::size_t n_batches =
75 const auto n_rows = std::int64_t(m_input_tensor_values_notFlat[0].size());
76 const auto n_cols = std::int64_t(m_input_tensor_values_notFlat[0][0].size());
77
78 for (std::size_t batch_idx = 0; batch_idx < n_batches; ++batch_idx) {
79 // prepare inputs
80 std::vector<float> inputDataVector;
81 inputDataVector.reserve(m_batchSize.value() * n_rows * n_cols);
82 for (const std::vector<std::vector<float>>& imageData :
84 std::views::drop(batch_idx * m_batchSize.value()) |
85 std::views::take(m_batchSize.value())) {
86 std::vector<float> flatten =
88 inputDataVector.insert(inputDataVector.end(), flatten.begin(),
89 flatten.end());
90 }
91
92 std::vector<int64_t> inputShape = {m_batchSize.value(), n_rows, n_cols};
93
94 AthInfer::InputDataMap inputData;
95 inputData["flatten_input:0"] =
96 std::make_pair(inputShape, std::move(inputDataVector));
97
98 const std::int64_t n_scores = 10;
99 AthInfer::OutputDataMap outputData;
100 outputData["dense_1/Softmax:0"] = std::make_pair(
101 std::vector<int64_t>{m_batchSize, n_scores}, std::vector<float>{});
102
103 ATH_CHECK(m_tritonTool->inference(inputData, outputData));
104
105 auto const& outputScores =
106 std::get<std::vector<float>>(outputData["dense_1/Softmax:0"].second);
107
108 if (outputScores.size() != std::size_t(n_scores * m_batchSize.value())) {
109 ATH_MSG_ERROR("Got back " << outputScores.size()
110 << " scores when it should have been "
111 << n_scores << " * " << m_batchSize.value()
112 << " = " << n_scores * m_batchSize.value());
113 return StatusCode::FAILURE;
114 }
115
116 for (int img_idx = 0; img_idx < m_batchSize.value(); img_idx++) {
117 std::span scores(outputScores.begin() + img_idx * n_scores,
118 outputScores.begin() + (img_idx + 1) * n_scores);
119 ATH_MSG_DEBUG("Scores for img " << img_idx << " of batch " << batch_idx
120 << ": "
121 << fmt::format("{::.2e}", scores));
122 const auto max_elem = std::ranges::max_element(scores);
123 ATH_MSG_DEBUG("Class: " << max_elem - scores.begin()
124 << " has the highest score: " << *max_elem
125 << " in img " << img_idx << " of batch "
126 << batch_idx);
127 }
128 }
129 return StatusCode::SUCCESS;
130}
131} // namespace AthInfer
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
size_t size() const
Number of registered mappings.
virtual StatusCode execute(const EventContext &ctx) const override
Function executing the algorithm for a single event.
virtual StatusCode initialize() override
Function initialising the algorithm.
Gaudi::Property< int > m_batchSize
Following properties needed to be consdered if the .onnx model is evaluated in batch mode.
std::vector< std::vector< std::vector< float > > > m_input_tensor_values_notFlat
ToolHandle< AthInfer::IAthInferenceTool > m_tritonTool
Tool handle for the Triton client.
Gaudi::Property< std::string > m_pixelFileName
Name of the model file to load.
static std::string find_calib_file(const std::string &logical_file_name)
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap
std::vector< std::vector< std::vector< float > > > read_mnist_pixel_notFlat(const std::string &full_path)
std::vector< float > flattenNestedVectors(const std::vector< std::vector< float > > &nestedVector)