ATLAS Offline Software
Loading...
Searching...
No Matches
ActsGnnHookTool.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4
5#ifndef ACTS_GNNHOOK_TOOL_H
6#define ACTS_GNNHOOK_TOOL_H
7
8#include "ActsPlugins/Gnn/GnnPipeline.hpp"
9#include "ActsPlugins/Gnn/Stages.hpp"
10#include <vector>
11#include <cstdint>
12
13
14namespace InDet {
15
16 class ScoredGraphHook : public ActsPlugins::GnnHook {
17 public:
18 void operator()(const ActsPlugins::PipelineTensors& tensors,
19 const ActsPlugins::ExecutionContext& ctx) const override {
20
21 ActsPlugins::ExecutionContext cpuCtx;
22 cpuCtx.device = ActsPlugins::Device::Cpu();
23 cpuCtx.stream = ctx.stream;
24
25 if (tensors.edgeScores.has_value()) {
26 const auto& scores = tensors.edgeScores.value();
27 auto cpuScores = scores.clone(cpuCtx);
28 m_edgeScores.assign(cpuScores.data(), cpuScores.data() + cpuScores.size());
29 }
30
31 if (tensors.edgeIndex.size() > 0) {
32 const auto& edgeIndex = tensors.edgeIndex;
33 auto cpuIndex = edgeIndex.clone(cpuCtx);
34 m_edgeIndex.assign(cpuIndex.data(), cpuIndex.data() + cpuIndex.size());
35 }
36 }
37
38 const std::vector<float>& getEdgeScores() const {
39 return m_edgeScores;
40 }
41 const std::vector<int64_t>& getEdgeIndex() const {
42 return m_edgeIndex;
43 }
44
45 private:
46 mutable std::vector<float> m_edgeScores ATLAS_THREAD_SAFE{};
47 mutable std::vector<int64_t> m_edgeIndex ATLAS_THREAD_SAFE{};
48 };
49}
50
51#endif
const std::vector< float > & getEdgeScores() const
void operator()(const ActsPlugins::PipelineTensors &tensors, const ActsPlugins::ExecutionContext &ctx) const override
std::vector< float > m_edgeScores ATLAS_THREAD_SAFE
const std::vector< int64_t > & getEdgeIndex() const
Primary Vertex Finder.