ATLAS Offline Software
Loading...
Searching...
No Matches
TritonTool.cxx
Go to the documentation of this file.
1// Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
2
3// Local include(s).
5
6// Project include(s).
9
10// External include(s).
11#include <grpc_client.h>
12#include <grpc_service.pb.h>
13
14// System include(s).
15#include <cassert>
16#include <cstring>
17#include <string>
18#include <vector>
19
21namespace tc = triton::client;
22
24#define TRITON_CHECK(EXP) \
25 do { \
26 const tc::Error err = EXP; \
27 if (!err.IsOk()) { \
28 ATH_MSG_ERROR("Failed to execute: " << #EXP); \
29 return StatusCode::FAILURE; \
30 } \
31 } while (false)
32
33namespace AthInfer {
34
36template <typename T>
38template <>
39struct TritonDType<float> {
40 static constexpr const char* value = "FP32";
41};
42template <>
43struct TritonDType<int64_t> {
44 static constexpr const char* value = "INT64";
45};
46
48
49 // Inherit the constructor(s) from AthMessaging
51
52 StatusCode getClient(tc::InferenceServerGrpcClient*& client,
53 const std::string& url, int port, bool useSSL) const {
54
55 thread_local std::unique_ptr<tc::InferenceServerGrpcClient> threadClient;
56 if (!threadClient) {
57
58 const std::string urlAndPort =
59 url + ":" + std::to_string(port); // always use the gRPC port
60
61 constexpr bool verbose = false;
62 TRITON_CHECK(tc::InferenceServerGrpcClient::Create(
63 &threadClient, urlAndPort, verbose, useSSL));
64
65 ATH_MSG_INFO("Triton client created for url: " << urlAndPort);
66 }
67 client = threadClient.get();
68
69 return StatusCode::SUCCESS;
70 }
71
72 template <typename T>
73 StatusCode prepareInput(
74 const std::string& name, const std::vector<int64_t>& shape,
75 const std::vector<T>& data,
76 std::vector<std::unique_ptr<tc::InferInput>>& inputs) const {
77
78 const char* dtype = TritonDType<T>::value;
79 tc::InferInput* rawInputPtr = nullptr;
80
81 // create the InferInput object with the predefined name, shape, and data
82 // type.
83 TRITON_CHECK(tc::InferInput::Create(&rawInputPtr, name, shape, dtype));
84 assert(rawInputPtr != nullptr);
85
86 // Append tensor values for this input from a byte array.
87 // Note: The vector is not copied and so it must not be modified or
88 // destroyed until this input is no longer needed (that is until the Infer()
89 // call(s) that use the input have completed). Multiple calls can be made to
90 // this API to keep adding tensor data for this input. The data will be
91 // delivered in the order it was added.
92 std::unique_ptr<tc::InferInput> input{rawInputPtr};
93 TRITON_CHECK(input->AppendRaw(reinterpret_cast<const uint8_t*>(data.data()),
94 data.size() * sizeof(T)));
95
96 inputs.push_back(std::move(input));
97 return StatusCode::SUCCESS;
98 }
99
100 template <typename T>
101 StatusCode extractOutput(const std::string& name,
102 const tc::InferResult& result,
103 std::vector<T>& outputVec) const {
104
105 const uint8_t* rawData = nullptr;
106 size_t size = 0;
107
108 // Get access to the buffer holding raw results of specified output returned
109 // by the server. Note: the buffer is owned by InferResult instance. Users
110 // can copy out the data if required to extend the lifetime.
111 TRITON_CHECK(result.RawData(name, &rawData, &size));
112
113 outputVec.resize(size / sizeof(T));
114 std::memcpy(outputVec.data(), rawData, size);
115 return StatusCode::SUCCESS;
116 }
117
119 std::unique_ptr<tc::InferOptions> m_options;
120
121}; // struct TritonTool::Impl
122
123TritonTool::TritonTool(const std::string& type, const std::string& name,
124 const IInterface* parent)
125 : base_class(type, name, parent) {}
126
127TritonTool::~TritonTool() = default;
128
130
131 // Set up the implementation object.
132 m_impl = std::make_unique<Impl>(name() + "::Impl");
133 m_impl->m_options = std::make_unique<tc::InferOptions>(m_modelName.value());
134 m_impl->m_options->model_version_ = m_modelVersion;
135 m_impl->m_options->client_timeout_ = m_clientTimeout;
136
137 // Figure out if parent is an AthAsynchronousAlgorithm, and set pointer if it
138 // is
139 const IAlgTool* p = dynamic_cast<const IAlgTool*>(this);
140 // Follow chain of parents up until we hit one that can't be converted to an
141 // IAlgTool
142 const IInterface* myParent = nullptr;
143 while (p != nullptr) {
144 myParent = p->parent();
145 p = dynamic_cast<const IAlgTool*>(myParent);
146 }
147 // If this ultimate ancestor can be converted to an AthAsynchronousAlgorithm,
148 // set the member variable
149 m_impl->m_parentAsyncAlg =
150 dynamic_cast<const AthAsynchronousAlgorithm*>(myParent);
151 if (m_impl->m_parentAsyncAlg != nullptr) {
153 "Owned by an AthAsynchronousAlgorithm, using asynchronous inference");
154 } else {
156 "Not owned by an AthAsynchronousAlgorithm, not using asynchronous "
157 "inference");
158 }
159
160 // Make sure already during initialization that a client can be created.
161 tc::InferenceServerGrpcClient* dummyClient = nullptr;
162 ATH_CHECK(m_impl->getClient(dummyClient, m_url, m_port, m_useSSL));
163
164 // Return gracefully.
165 return StatusCode::SUCCESS;
166}
167
169 OutputDataMap& outputData) const {
170
171 assert(m_impl);
172
173 // Create the tensor for the input data.
174 // Use shared_ptr to manage the memory of the InferInput objects.
175 std::vector<std::unique_ptr<tc::InferInput>> inputs;
176 inputs.reserve(inputData.size());
177
178 for (auto& [inputName, inputInfo] : inputData) {
179
180 const std::vector<int64_t>& inputShape = inputInfo.first;
181 const DataVariant& variant = inputInfo.second;
182
183 ATH_CHECK(std::visit(
184 [&](const auto& dataVec) {
185 using T = std::decay_t<decltype(dataVec[0])>;
186 return m_impl->prepareInput<T>(inputName, inputShape, dataVec,
187 inputs);
188 },
189 variant));
190 }
191
192 // construct raw points for inference
193 std::vector<tc::InferInput*> rawInputs;
194 for (auto& input : inputs) {
195 rawInputs.push_back(input.get());
196 }
197
198 // Get the triton client object.
199 tc::InferenceServerGrpcClient* client = nullptr;
200 ATH_CHECK(m_impl->getClient(client, m_url, m_port, m_useSSL));
201 assert(client != nullptr);
202
203 // perform the inference.
204 std::shared_ptr<tc::InferResult> results;
205 tc::Headers http_headers;
206 grpc_compression_algorithm compression_algorithm =
207 grpc_compression_algorithm::GRPC_COMPRESS_NONE;
208
209 if (m_impl->m_parentAsyncAlg == nullptr) {
210 tc::InferResult* rawResultPtr = nullptr;
211 TRITON_CHECK(client->Infer(&rawResultPtr, *(m_impl->m_options), rawInputs,
212 {}, http_headers, compression_algorithm));
213 assert(rawResultPtr != nullptr);
214 results.reset(rawResultPtr);
215 } else {
216 // If m_impl->m_parentAsyncAlg is set, use asynchronous inference
217 using Promise_t = boost::fibers::promise<tc::InferResult*>;
218 using Future_t = boost::fibers::future<tc::InferResult*>;
219 Promise_t promise{};
220 Future_t future = promise.get_future();
221 auto callback = [&promise](tc::InferResult* resultPtr) {
222 assert(resultPtr != nullptr);
223 promise.set_value(resultPtr);
224 };
225 TRITON_CHECK(client->AsyncInfer(callback, *(m_impl->m_options), rawInputs,
226 {}, http_headers, compression_algorithm));
227 results.reset(future.get());
228 ATH_CHECK(m_impl->m_parentAsyncAlg->restoreAfterSuspend());
229 }
230
231 // Get the result of the inference.
232 for (auto& [outputName, outputInfo] : outputData) {
233
234 DataVariant& variant = outputInfo.second;
235
236 ATH_CHECK(std::visit(
237 [&](auto& dataVec) {
238 using T = std::decay_t<decltype(dataVec[0])>;
239 return m_impl->extractOutput<T>(outputName, *results, dataVec);
240 },
241 variant));
242 }
243
244 // Return gracefully.
245 return StatusCode::SUCCESS;
246}
247
248void TritonTool::print() const {}
249
250} // namespace AthInfer
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_INFO(x)
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
static Double_t tc
size_t size() const
Number of registered mappings.
#define TRITON_CHECK(EXP)
Shorthand for the Triton client namespace.
An algorithm that can be suspended while work is offloaded to an accelerator.
virtual StatusCode inference(InputDataMap &inputData, OutputDataMap &outputData) const override final
Run inference with multiple inputs and multiple outputs.
StringProperty m_modelVersion
Definition TritonTool.h:49
virtual StatusCode initialize() override
Initialize the tool.
virtual void print() const override
Print the tool's properties and configuration.
FloatProperty m_clientTimeout
Definition TritonTool.h:51
virtual ~TritonTool()
Destructor.
StringProperty m_modelName
Definition TritonTool.h:47
StringProperty m_url
Definition TritonTool.h:54
TritonTool(const std::string &type, const std::string &name, const IInterface *parent)
Constructor.
IntegerProperty m_port
Definition TritonTool.h:48
BooleanProperty m_useSSL
Definition TritonTool.h:55
std::unique_ptr< Impl > m_impl
Pointer to the implementation details.
Definition TritonTool.h:63
AthMessaging(IMessageSvc *msgSvc, const std::string &name)
Constructor.
bool verbose
Definition hcg.cxx:75
std::map< std::string, InferenceData > OutputDataMap
std::variant< std::vector< float >, std::vector< int64_t > > DataVariant
std::map< std::string, InferenceData > InputDataMap
static constexpr const char * value
static constexpr const char * value
DType traits for Triton.
StatusCode prepareInput(const std::string &name, const std::vector< int64_t > &shape, const std::vector< T > &data, std::vector< std::unique_ptr< tc::InferInput > > &inputs) const
const AthAsynchronousAlgorithm * m_parentAsyncAlg
std::unique_ptr< tc::InferOptions > m_options
StatusCode extractOutput(const std::string &name, const tc::InferResult &result, std::vector< T > &outputVec) const
AthMessaging(IMessageSvc *msgSvc, const std::string &name)
Constructor.
StatusCode getClient(tc::InferenceServerGrpcClient *&client, const std::string &url, int port, bool useSSL) const