ATLAS Offline Software
Loading...
Searching...
No Matches
ExampleAsyncMLInferenceWithTriton.cxx
Go to the documentation of this file.
1// Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2
3// Local include(s).
5
6// Framework include(s).
7#include <arpa/inet.h>
8
9#include <fstream>
10#include <utility> //std::pair
11
13
14namespace AthInfer {
15
17 // Fetch tools
18 ATH_CHECK(m_tritonTool.retrieve());
19
20 if (m_batchSize > 10000) {
22 "The total no. of sample crossed the no. of available sample ....");
23 return StatusCode::FAILURE;
24 }
25 // read input file, and the target file for comparison.
26 std::string pixelFilePath = PathResolver::find_calib_file(m_pixelFileName);
27 ATH_MSG_INFO("Using pixel file: " << pixelFilePath);
28
31 "Total no. of samples: " << m_input_tensor_values_notFlat.size());
32
33 return StatusCode::SUCCESS;
34}
35
37 [[maybe_unused]] const EventContext& ctx) const {
38
39 // prepare inputs
40 std::vector<float> inputDataVector;
41 inputDataVector.reserve(m_input_tensor_values_notFlat.size());
42 for (const std::vector<std::vector<float>>& imageData :
44
45 std::vector<float> flatten;
46 int total_size = 0;
47 for (const auto& feature : imageData)
48 total_size += feature.size();
49 flatten.reserve(total_size);
50 for (const auto& feature : imageData)
51 for (const auto& elem : feature)
52 flatten.push_back(elem);
53
54 inputDataVector.insert(inputDataVector.end(), flatten.begin(),
55 flatten.end());
56 }
57 std::vector<int64_t> inputShape = {m_batchSize, 28, 28};
58
59 AthInfer::InputDataMap inputData;
60 inputData["flatten_input:0"] =
61 std::make_pair(inputShape, std::move(inputDataVector));
62
63 AthInfer::OutputDataMap outputData;
64 outputData["dense_1/Softmax:0"] = std::make_pair(
65 std::vector<int64_t>{m_batchSize, 10}, std::vector<float>{});
66
67 ATH_CHECK(m_tritonTool->inference(inputData, outputData));
68
69 auto& outputScores =
70 std::get<std::vector<float>>(outputData["dense_1/Softmax:0"].second);
71 auto inRange = [&outputScores](int idx) -> bool {
72 return (idx >= 0) and (idx < std::ssize(outputScores));
73 };
74 ATH_MSG_DEBUG("Label for the input test data: ");
75 for (int ibatch = 0; ibatch < m_batchSize; ibatch++) {
76 float max = -999;
77 int max_index{-1};
78 for (int i = 0; i < 10; i++) {
79 ATH_MSG_DEBUG("Score for class " << i << " = " << outputScores[i]
80 << " in batch " << ibatch);
81 int index = i + ibatch * 10;
82 if (not inRange(index))
83 continue;
84 if (max < outputScores[index]) {
85 max = outputScores[index];
86 max_index = index;
87 }
88 }
89 if (not inRange(max_index)) {
91 "No maximum found in ExampleAsyncMLInferenceWithTriton::execute");
92 return StatusCode::FAILURE;
93 }
94 ATH_MSG_DEBUG("Class: " << max_index << " has the highest score: "
95 << outputScores[max_index] << " in batch "
96 << ibatch);
97 }
98
99 return StatusCode::SUCCESS;
100}
101
102std::vector<std::vector<std::vector<float>>>
104 const std::string& full_path) const {
105 std::vector<std::vector<std::vector<float>>> input_tensor_values;
106 input_tensor_values.resize(
107 10000, std::vector<std::vector<float>>(28, std::vector<float>(28)));
108 std::ifstream file(full_path.c_str(), std::ios::binary);
109 int magic_number = 0;
110 int number_of_images = 0;
111 int n_rows = 0;
112 int n_cols = 0;
113 file.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
114 magic_number = ntohl(magic_number);
115 file.read(reinterpret_cast<char*>(&number_of_images),
116 sizeof(number_of_images));
117 number_of_images = ntohl(number_of_images);
118 file.read(reinterpret_cast<char*>(&n_rows), sizeof(n_rows));
119 n_rows = ntohl(n_rows);
120 file.read(reinterpret_cast<char*>(&n_cols), sizeof(n_cols));
121 n_cols = ntohl(n_cols);
122 for (int i = 0; i < number_of_images; ++i) {
123 for (int r = 0; r < n_rows; ++r) {
124 for (int c = 0; c < n_cols; ++c) {
125 unsigned char temp = 0;
126 file.read(reinterpret_cast<char*>(&temp), sizeof(temp));
127 input_tensor_values[i][r][c] = float(temp) / 255;
128 }
129 }
130 }
131 return input_tensor_values;
132}
133
134} // namespace AthInfer
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_ERROR(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
bool inRange(const double *boundaries, const double value, const double tolerance=0.02)
#define max(a, b)
Definition cfImp.cxx:41
ToolHandle< AthInfer::IAthInferenceTool > m_tritonTool
Tool handle for the Triton client.
virtual StatusCode initialize() override
Function initialising the algorithm.
Gaudi::Property< std::string > m_pixelFileName
Name of the model file to load.
std::vector< std::vector< std::vector< float > > > m_input_tensor_values_notFlat
std::vector< std::vector< std::vector< float > > > read_mnist_pixel_notFlat(const std::string &full_path) const
virtual StatusCode execute(const EventContext &ctx) const override
Function executing the algorithm for a single event.
Gaudi::Property< int > m_batchSize
Following properties needed to be consdered if the .onnx model is evaluated in batch mode.
static std::string find_calib_file(const std::string &logical_file_name)
int r
Definition globals.cxx:22
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap
Definition index.py:1
TFile * file