ATLAS Offline Software
Loading...
Searching...
No Matches
EFTrackingXrtAlgorithm.cxx
Go to the documentation of this file.
1/*
2 * Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4
6
8
10 const std::string& name,
11 ISvcLocator* pSvcLocator
12) : AthReentrantAlgorithm(name, pSvcLocator)
13{}
14
16 ATH_MSG_INFO("Initializing " << name());
17
18 ATH_CHECK(m_DeviceMgmtSvc.retrieve());
19 ATH_CHECK(m_chronoSvc.retrieve());
20 ATH_CHECK(m_inputDataStreamKeys.initialize());
21 ATH_CHECK(m_vSizeDataStreamKeys.initialize());
23
24 for (const auto& [kernelName, storeGateKey, argumentIndex] : m_inputInterfaces) {
26 "Setting up " <<
27 kernelName <<
28 " to read " <<
29 storeGateKey <<
30 " into argument " <<
31 argumentIndex
32 );
33
34 const std::vector<std::shared_ptr<xrt::device>> devices =
35 m_DeviceMgmtSvc->get_xrt_devices_by_kernel_name(kernelName);
36
37 ATH_CHECK(devices.size() != 0);
38
39 if (!m_kernels.contains(kernelName)) {
40 m_kernels[kernelName] = std::make_unique<xrt::kernel>(
41 *(devices[0]),
42 (devices[0])->get_xclbin_uuid(),
43 kernelName,
44 xrt::kernel::cu_access_mode::exclusive
45 );
46 }
47
48 ATH_CHECK(m_kernels[kernelName].get() != nullptr);
49
50 const std::optional<xrt::bo::flags> mem_flags =
51 determine_mem_flags(m_kernels[kernelName], argumentIndex);
52
53 if (!mem_flags.has_value()) {
55 "Unable to determine mem_flags for argument with index " <<
56 argumentIndex <<
57 " in kernel named " <<
58 kernelName <<
59 ". Defaulting to xrt::bo::normal. Good luck!"
60 );
61
62 m_inputBuffers.emplace_back(
63 *(devices[0]),
64 sizeof(unsigned long) * m_bufferSize,
65 xrt::bo::flags::normal,
66 m_kernels[kernelName]->group_id(argumentIndex)
67 );
68 }
69 else {
70 m_inputBuffers.emplace_back(
71 *(devices[0]),
72 sizeof(unsigned long) * m_bufferSize,
73 mem_flags.value(),
74 m_kernels[kernelName]->group_id(argumentIndex)
75 );
76 }
77
78 if (!m_runs.contains(kernelName)) {
79 m_runs[kernelName] = std::make_unique<xrt::run>(*m_kernels[kernelName]);
80 }
81
82 ATH_CHECK(m_runs[kernelName].get() != nullptr);
83 m_runs[kernelName]->set_arg(argumentIndex, m_inputBuffers.back());
84 }
85
86 for (const auto& [kernelName, storeGateKey, argumentIndex] : m_vSizeInterfaces) {
88 "Setting up " <<
89 kernelName <<
90 " to get input size from " <<
91 storeGateKey <<
92 " for argument " <<
93 argumentIndex
94 );
95
96 const std::vector<std::shared_ptr<xrt::device>> devices =
97 m_DeviceMgmtSvc->get_xrt_devices_by_kernel_name(kernelName);
98
99 ATH_CHECK(devices.size() != 0);
100
101 if (!m_kernels.contains(kernelName)) {
102 m_kernels[kernelName] = std::make_unique<xrt::kernel>(
103 *(devices[0]),
104 (devices[0])->get_xclbin_uuid(),
105 kernelName,
106 xrt::kernel::cu_access_mode::exclusive
107 );
108 }
109
110 ATH_CHECK(m_kernels[kernelName].get() != nullptr);
111
112 if (!m_runs.contains(kernelName)) {
113 m_runs[kernelName] = std::make_unique<xrt::run>(*m_kernels[kernelName]);
114 }
115
116 ATH_CHECK(m_runs[kernelName].get() != nullptr);
117 }
118
119 for (const auto& [kernelName, storeGateKey, argumentIndex] : m_outputInterfaces) {
121 "Setting up " <<
122 kernelName <<
123 " to write " <<
124 storeGateKey <<
125 " from argument " <<
126 argumentIndex
127 );
128
129 const std::vector<std::shared_ptr<xrt::device>> devices =
130 m_DeviceMgmtSvc->get_xrt_devices_by_kernel_name(kernelName);
131
132 ATH_CHECK(devices.size() != 0);
133
134 if (!m_kernels.contains(kernelName)) {
135 m_kernels[kernelName] = std::make_unique<xrt::kernel>(
136 *(devices[0]),
137 devices[0]->get_xclbin_uuid(),
138 kernelName,
139 xrt::kernel::cu_access_mode::exclusive
140 );
141 }
142
143 const std::optional<xrt::bo::flags> mem_flags =
144 determine_mem_flags(m_kernels[kernelName], argumentIndex);
145
146 if (!mem_flags.has_value()) {
148 "Unable to determine mem_flags for argument with index " <<
149 argumentIndex <<
150 " in kernel named " <<
151 kernelName <<
152 ". Defaulting to xrt::bo::normal. Good luck!"
153 );
154
155 m_outputBuffers.emplace_back(
156 *(devices[0]),
157 sizeof(unsigned long) * m_bufferSize,
158 xrt::bo::flags::normal,
159 m_kernels[kernelName]->group_id(argumentIndex)
160 );
161 }
162 else {
163 m_outputBuffers.emplace_back(
164 *(devices[0]),
165 sizeof(unsigned long) * m_bufferSize,
166 mem_flags.value(),
167 m_kernels[kernelName]->group_id(argumentIndex)
168 );
169 }
170
171 if (!m_runs.contains(kernelName)) {
172 m_runs[kernelName] = std::make_unique<xrt::run>(*m_kernels[kernelName]);
173 }
174
175 ATH_CHECK(m_runs[kernelName].get() != nullptr);
176 m_runs[kernelName]->set_arg(argumentIndex, m_outputBuffers.back());
177 }
178
179 for (const auto& [kernelName, argumentIndex, sourceKernelName, sourceArgumentIndex] : m_sharedInterfaces) {
181 "Setting up shared buffer between " <<
182 kernelName <<
183 " argument " <<
184 argumentIndex <<
185 " and " <<
186 sourceKernelName <<
187 " argument " <<
188 sourceArgumentIndex
189 );
190
191 const std::vector<std::shared_ptr<xrt::device>> devices =
192 m_DeviceMgmtSvc->get_xrt_devices_by_kernel_name(kernelName);
193
194 ATH_CHECK(devices.size() != 0);
195
196 if (!m_kernels.contains(kernelName)) {
197 m_kernels[kernelName] = std::make_unique<xrt::kernel>(
198 *(devices[0]),
199 devices[0]->get_xclbin_uuid(),
200 kernelName,
201 xrt::kernel::cu_access_mode::exclusive
202 );
203 }
204
205 if (!m_runs.contains(kernelName)) {
206 m_runs[kernelName] = std::make_unique<xrt::run>(*m_kernels[kernelName]);
207 }
208
209 ATH_CHECK(m_runs[kernelName].get() != nullptr);
210
211 std::size_t index = 0;
212 for (const auto& [outputKernelName, outputStoreGateKey, outputArgumentIndex] : m_outputInterfaces) {
213 if (
214 outputKernelName == sourceKernelName &&
215 outputArgumentIndex == sourceArgumentIndex
216 ) {
217 m_runs[kernelName]->set_arg(argumentIndex, m_outputBuffers[index]);
218
219 break;
220 }
221
222 index++;
223 }
224 }
225
226 for (const auto& kernelNames : m_kernelOrder) {
227 for (const auto& kernelName : kernelNames) {
228 const std::vector<std::shared_ptr<xrt::device>> devices =
229 m_DeviceMgmtSvc->get_xrt_devices_by_kernel_name(kernelName);
230
231 ATH_CHECK(devices.size() != 0);
232
233 if (!m_kernels.contains(kernelName)) {
234 ATH_MSG_DEBUG("Creating kernel: " << kernelName);
235
236 m_kernels[kernelName] = std::make_unique<xrt::kernel>(
237 *(devices[0]),
238 devices[0]->get_xclbin_uuid(),
239 kernelName,
240 xrt::kernel::cu_access_mode::exclusive
241 );
242
243 if (!m_runs.contains(kernelName)) {
244 m_runs[kernelName] = std::make_unique<xrt::run>(*m_kernels[kernelName]);
245 }
246
247 ATH_CHECK(m_runs[kernelName].get() != nullptr);
248 }
249 }
250 }
251
252 return StatusCode::SUCCESS;
253}
254
255StatusCode EFTrackingXrtAlgorithm::execute(const EventContext& ctx) const
256{
257 ATH_CHECK(m_inputDataStreamKeys.size() == m_inputBuffers.size());
258 std::size_t inputHandleIndex = 0;
259 for (
260 const SG::ReadHandleKey<std::vector<unsigned long>>& inputDataStreamKey :
262 ) {
263 SG::ReadHandle<std::vector<unsigned long>> inputDataStream(inputDataStreamKey, ctx);
264 ATH_MSG_DEBUG("Writing: " << inputDataStream.name());
265 unsigned long* inputMap = m_inputBuffers.at(inputHandleIndex).map<unsigned long*>();
266
267 ATH_CHECK(inputDataStream->size() <= m_bufferSize);
268
269 ATH_MSG_DEBUG("Copy " + inputDataStream.name() + " from storegate to host side map");
270 {
271 Athena::Chrono chrono(
272 "Copy " + inputDataStream.name() + " from storegate to host side map",
273 m_chronoSvc.get()
274 );
275
276 for (std::size_t index = 0; index < inputDataStream->size(); index++) {
277 inputMap[index] = inputDataStream->at(index);
278 }
279 }
280
281 ATH_MSG_DEBUG("Copy " + inputDataStream.name() + " from host side map to device");
282 {
283 Athena::Chrono chrono(
284 "Copy " + inputDataStream.name() + " from host side map to device",
285 m_chronoSvc.get()
286 );
287
288 m_inputBuffers.at(inputHandleIndex).sync(XCL_BO_SYNC_BO_TO_DEVICE);
289 }
290
291 inputHandleIndex++;
292 }
293
295 std::size_t vSizeHandleIndex = 0;
296 for (
297 const SG::ReadHandleKey<std::vector<unsigned long>>& vSizeDataStreamKey :
299 ) {
300 SG::ReadHandle<std::vector<unsigned long>> vSizeDataStream(vSizeDataStreamKey, ctx);
301 const auto& [kernelName, storeGateKey, argumentIndex] = m_vSizeInterfaces[vSizeHandleIndex];
302 ATH_MSG_DEBUG("Setting VSize: " << kernelName << ", " << vSizeDataStream.name() << ", " << vSizeDataStream->size());
303
304 m_runs.at(kernelName)->set_arg(argumentIndex, vSizeDataStream->size());
305 vSizeHandleIndex++;
306 }
307
308 ATH_MSG_DEBUG("Run kernels");
309 {
310 Athena::Chrono chrono("Run kernels", m_chronoSvc.get());
311
312 for (const auto& kernelNames : m_kernelOrder) {
313 for (const auto& kernelName : kernelNames) {
314 ATH_MSG_DEBUG("Running: " << kernelName);
315 m_runs.at(kernelName)->start();
316 }
317
318 for (const auto& kernelName : kernelNames) {
319 ATH_MSG_DEBUG("Waiting: " << kernelName);
320 m_runs.at(kernelName)->wait();
321 }
322 }
323 }
324
325 std::size_t outputHandleIndex = 0;
326 for (
327 const SG::WriteHandleKey<std::vector<unsigned long>>& outputDataStreamKey :
329 ) {
330 SG::WriteHandle<std::vector<unsigned long>> outputDataStream(outputDataStreamKey, ctx);
331 ATH_CHECK(outputDataStream.record(std::make_unique<std::vector<unsigned long>>(m_bufferSize)));
332
333 ATH_MSG_DEBUG("Copy " + outputDataStream.name() + " from device to host side map");
334 {
335 Athena::Chrono chrono(
336 "Copy " + outputDataStream.name() + " from device to host side map",
337 m_chronoSvc.get()
338 );
339
340 m_outputBuffers.at(outputHandleIndex).sync(XCL_BO_SYNC_BO_FROM_DEVICE);
341 }
342
343 const unsigned long* outputMap = m_outputBuffers.at(outputHandleIndex).map<unsigned long*>();
344 ATH_MSG_DEBUG("Copy " + outputDataStream.name() + " from host side map to storegate");
345 {
346 Athena::Chrono chrono(
347 "Copy " + outputDataStream.name() + " from host side map to storegate",
348 m_chronoSvc.get()
349 );
350
351 for (std::size_t index = 0; index < outputDataStream->size(); index++) {
352 outputDataStream->at(index) = outputMap[index];
353 }
354 }
355
356 outputHandleIndex++;
357 }
358
359 return StatusCode::SUCCESS;
360}
361
363 const std::unique_ptr<xrt::kernel>& kernel,
364 const std::size_t index
365) const {
366 for (
367 const xrt::xclbin::kernel& kernelMetaData :
368 kernel->get_xclbin().get_kernels()
369 ) {
370 if (kernel->get_name() != kernelMetaData.get_name()) {
371 continue;
372 }
373
374 for (const xrt::xclbin::arg& arg : kernelMetaData.get_args()) {
375 if (arg.get_index() != index) {
376 continue;
377 }
378
379 if (arg.get_mems().size() == 0) {
381 "No mems associated with argument " <<
382 index <<
383 " of " <<
384 kernel->get_name() <<
385 ". Expect more warnings."
386 );
387
388 return std::nullopt;
389 }
390
391 if (arg.get_mems()[0].get_tag().find("HOST") != std::string::npos) {
392 return xrt::bo::flags::host_only;
393 }
394
395 return xrt::bo::flags::normal;
396 }
397 }
398
399 return std::nullopt;
400}
401
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_INFO(x)
#define ATH_MSG_WARNING(x)
#define ATH_MSG_DEBUG(x)
Exception-safe IChronoSvc caller.
An algorithm that can be simultaneously executed in multiple threads.
Exception-safe IChronoSvc caller.
Definition Chrono.h:50
Gaudi::Property< std::vector< std::tuple< std::string, std::string, int > > > m_vSizeInterfaces
std::optional< xrt::bo::flags > determine_mem_flags(const std::unique_ptr< xrt::kernel > &kernel, const std::size_t index) const
SG::ReadHandleKeyArray< std::vector< unsigned long > > m_vSizeDataStreamKeys
Gaudi::Property< std::vector< std::tuple< std::string, std::string, int > > > m_inputInterfaces
std::map< std::string, std::unique_ptr< xrt::kernel > > m_kernels
std::map< std::string, std::unique_ptr< xrt::run > > m_runs
StatusCode execute(const EventContext &ctx) const override final
Gaudi::Property< std::vector< std::vector< std::string > > > m_kernelOrder
SG::ReadHandleKeyArray< std::vector< unsigned long > > m_inputDataStreamKeys
Keys to access encoded 64bit words following the EFTracking specification.
Gaudi::Property< std::vector< std::tuple< std::string, int, std::string, int > > > m_sharedInterfaces
Gaudi::Property< std::vector< std::tuple< std::string, std::string, int > > > m_outputInterfaces
ServiceHandle< AthXRT::IDeviceMgmtSvc > m_DeviceMgmtSvc
SG::WriteHandleKeyArray< std::vector< unsigned long > > m_outputDataStreamKeys
Gaudi::Property< std::size_t > m_bufferSize
StatusCode initialize() override final
EFTrackingXrtAlgorithm(const std::string &name, ISvcLocator *pSvcLocator)
ServiceHandle< IChronoSvc > m_chronoSvc
Property holding a SG store/key/clid from which a ReadHandle is made.
const std::string & name() const
Return the StoreGate ID for the referenced object.
Property holding a SG store/key/clid from which a WriteHandle is made.
StatusCode record(std::unique_ptr< T > data)
Record a const object to the store.
T * get(TKey *tobj)
get a TObject* from a TKey* (why can't a TObject be a TKey?)
Definition hcg.cxx:130
Definition index.py:1