ATLAS Offline Software
Loading...
Searching...
No Matches
ActsGnnModuleMapFinderTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6
8
9#include "ActsPlugins/Gnn/CudaTrackBuilding.hpp"
10#include "ActsPlugins/Gnn/GnnPipeline.hpp"
11#include "ActsPlugins/Gnn/ModuleMapCuda.hpp"
12#include "ActsPlugins/Gnn/OnnxEdgeClassifier.hpp"
13#include "ActsPlugins/Gnn/TensorRTEdgeClassifier.hpp"
14#include "ActsPlugins/Gnn/TorchEdgeClassifier.hpp"
15
19#include "ActsGnnHookTool.h"
20
21#include <algorithm>
22#include <numeric>
23
24
27
28 m_logger = makeActsAthenaLogger(this, "ActsGnnModuleMapFinderTool");
29
30 // Build ACTS GNN pipeline components
31
32 // 1. Graph constructor (ModuleMapCuda)
33 ActsPlugins::ModuleMapCuda::Config gcCfg;
34 gcCfg.rScale = kScaleR;
35 gcCfg.zScale = kScaleZ;
36 gcCfg.phiScale = kScalePhi;
37 gcCfg.moduleMapPath = m_moduleMapPath.value();
38 gcCfg.gpuBlocks = 512;
39 auto gc = std::make_shared<ActsPlugins::ModuleMapCuda>(
40 gcCfg, m_logger->cloneWithSuffix("ModuleMap"));
41
42 // 2. Edge classifier (ONNX / Torch / TensorRT)
43 // ONNX and Torch are not thread-safe: a mutex is emplaced to serialise run().
44 // TensorRT manages concurrency internally via execution contexts: no guard needed.
45 std::shared_ptr<ActsPlugins::EdgeClassificationBase> gnn;
46 if (m_gnnPath.value().find(".onnx") != std::string::npos) {
47#ifdef ACTS_GNN_ONNX_BACKEND
48 ActsPlugins::OnnxEdgeClassifier::Config gnnCfg;
49 gnnCfg.modelPath = m_gnnPath.value();
50 gnnCfg.cut = m_edgeCut.value();
51 gnn = std::make_shared<ActsPlugins::OnnxEdgeClassifier>(
52 gnnCfg, m_logger->cloneWithSuffix("GNN"));
53 m_runMutex.emplace();
54#else
55 ATH_MSG_FATAL("Not compiled with ONNX, cannot interpret *.onnx files");
56 return StatusCode::FAILURE;
57#endif
58 } else if (m_gnnPath.value().find(".pt") != std::string::npos) {
59#ifdef ACTS_GNN_TORCH_BACKEND
60 ActsPlugins::TorchEdgeClassifier::Config gnnCfg;
61 gnnCfg.modelPath = m_gnnPath.value();
62 gnnCfg.cut = m_edgeCut.value();
63 gnn = std::make_shared<ActsPlugins::TorchEdgeClassifier>(
64 gnnCfg, m_logger->cloneWithSuffix("GNN"));
65 m_runMutex.emplace();
66#else
67 ATH_MSG_FATAL("Not compiled with Torch, cannot interpret *.pt files");
68 return StatusCode::FAILURE;
69#endif
70 } else if (m_gnnPath.value().find(".engine") != std::string::npos) {
71#ifdef ACTS_GNN_WITH_TENSORRT
72 ActsPlugins::TensorRTEdgeClassifier::Config gnnCfg;
73 gnnCfg.modelPath = m_gnnPath.value();
74 gnnCfg.cut = m_edgeCut.value();
75 gnnCfg.numExecutionContexts = m_numTrtContexts.value();
76 gnn = std::make_shared<ActsPlugins::TensorRTEdgeClassifier>(
77 gnnCfg, m_logger->cloneWithSuffix("GNN"));
78#else
79 ATH_MSG_FATAL("Not compiled with TensorRT, cannot interpret *.engine files");
80 return StatusCode::FAILURE;
81#endif
82 } else {
83 ATH_MSG_FATAL("Unknown extension for GNN model: " << m_gnnPath.value());
84 return StatusCode::FAILURE;
85 }
86
87 // 3. Track builder
88 ATH_MSG_INFO("Configure CC&JunctionRemoval as graph segmentation algorithm");
89 ActsPlugins::CudaTrackBuilding::Config tbCfg;
90 tbCfg.doJunctionRemoval = true;
91 auto tb = std::make_shared<ActsPlugins::CudaTrackBuilding>(
92 tbCfg, m_logger->cloneWithSuffix("CC&JR"));
93
94 // 4. Assemble pipeline
95 m_gnnPipeline = std::make_unique<ActsPlugins::GnnPipeline>(
96 gc, std::vector{gnn}, tb, m_logger->cloneWithSuffix("Pipeline"));
97
98 return StatusCode::SUCCESS;
99}
100
102 const std::vector<const Trk::SpacePoint*>& spacepoints,
103 std::vector<std::vector<uint32_t>>& tracks,
104 std::unordered_map<int, std::unordered_map<int, float>>* edgeMap) const {
105
106 const std::size_t nSP = spacepoints.size();
107
108 ATH_MSG_DEBUG("Processing " << nSP << " spacepoints with " << NUM_FEATURES << " features");
109
110 // Sort spacepoint indices by module ID (required by module map graph construction)
111 std::vector<std::size_t> sortIdx(nSP);
112 std::iota(sortIdx.begin(), sortIdx.end(), 0);
113 std::ranges::sort(sortIdx, std::less{}, [&](std::size_t i) {
114 return spacepoints[i]->clusterList().first->detectorElement()->identify().get_compact();
115 });
116
117 // Build features, module IDs, and IDs directly in sorted order
118 std::vector<float> features(NUM_FEATURES * nSP);
119 std::vector<std::uint64_t> moduleIds(nSP);
120 std::vector<int> ids(nSP);
121
122 for (std::size_t k = 0; k < nSP; ++k) {
123 const std::size_t origIdx = sortIdx[k];
124 auto featureMap = m_spacepointFeatureTool->getFeatures(spacepoints[origIdx]);
125 // Use detector element identifier, not cluster identifier, to get the module ID
126 moduleIds[k] = spacepoints[origIdx]->clusterList().first->detectorElement()->identify().get_compact();
127 ids[k] = static_cast<int>(k);
128 for (std::size_t j = 0; j < NUM_FEATURES; ++j) {
129 features[k * NUM_FEATURES + j] = featureMap[FEATURE_NAMES[j]] / FEATURE_SCALES[j];
130 }
131 }
132
133 // Run GNN pipeline (mutex present for ONNX/Torch, absent for TRT)
134 auto candidates = [&] {
135 std::unique_lock<std::mutex> lock;
136 if (m_runMutex) lock = std::unique_lock<std::mutex>(*m_runMutex);
137
138 if (edgeMap != nullptr) {
139 ScoredGraphHook hook;
140 auto result = m_gnnPipeline->run(features, moduleIds, ids, ActsPlugins::Device::Cuda(0), hook);
141
142 // Retrieve edgeScores and edgeIndex from hook
143 const std::vector<float>& edgeScores = hook.getEdgeScores();
144 const std::vector<std::int64_t>& edgeIndex = hook.getEdgeIndex();
145 const std::size_t nEdges = edgeScores.size();
146
147 // Create a map to acces edge score (sorted indices back to original spacepoint indices)
148 for (std::size_t i = 0; i < nEdges; ++i) {
149 std::int64_t src = edgeIndex[i];
150 std::int64_t dst = edgeIndex[nEdges + i];
151 (*edgeMap)[sortIdx[src]][sortIdx[dst]] = edgeScores[i];
152 }
153 return result;
154 }
155
156 return m_gnnPipeline->run(features, moduleIds, ids, ActsPlugins::Device::Cuda(0));;
157 }();
158
159 ATH_MSG_DEBUG("GNN pipeline returned " << candidates.size() << " candidates");
160
161 // Filter by minimum measurements and convert indices back to original ordering
162 tracks.clear();
163 tracks.reserve(candidates.size());
164
165 for (const auto& candidate : candidates) {
166 if (candidate.size() < m_minCandidateMeasurements.value()) {
167 continue;
168 }
169
170 // Map sorted indices back to original spacepoint indices
171 std::vector<uint32_t> track;
172 track.reserve(candidate.size());
173 for (int sortedIdx : candidate) {
174 track.push_back(static_cast<uint32_t>(sortIdx[sortedIdx]));
175 }
176 tracks.push_back(std::move(track));
177 }
178
179 ATH_MSG_DEBUG("Returning " << tracks.size() << " track candidates after filtering (>= "
180 << m_minCandidateMeasurements.value() << " measurements)");
181
182 return StatusCode::SUCCESS;
183}
184
185MsgStream& InDet::ActsGnnModuleMapFinderTool::dump(MsgStream& out) const {
186 out << std::endl;
187 out << "|---------------------------------------------------------------------|" << std::endl;
188 out << "| ActsGnnModuleMapFinderTool |" << std::endl;
189 out << "|---------------------------------------------------------------------|" << std::endl;
190 return out;
191}
192
193std::ostream& InDet::ActsGnnModuleMapFinderTool::dump(std::ostream& out) const {
194 return out;
195}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_FATAL(x)
#define ATH_MSG_INFO(x)
#define ATH_MSG_DEBUG(x)
virtual void lock()=0
Interface to allow an object to lock itself when made const in SG.
std::unique_ptr< const Acts::Logger > makeActsAthenaLogger(IMessageSvc *svc, const std::string &name, int level, std::optional< std::string > parent_name)
static constexpr std::array< float, NUM_FEATURES > FEATURE_SCALES
static constexpr std::array< const char *, NUM_FEATURES > FEATURE_NAMES
virtual MsgStream & dump(MsgStream &out) const override
std::unique_ptr< const Acts::Logger > m_logger
ToolHandle< ISpacepointFeatureTool > m_spacepointFeatureTool
static constexpr std::size_t NUM_FEATURES
std::unique_ptr< ActsPlugins::GnnPipeline > m_gnnPipeline
virtual StatusCode getTracks(const std::vector< const Trk::SpacePoint * > &spacepoints, std::vector< std::vector< uint32_t > > &tracks, std::unordered_map< int, std::unordered_map< int, float > > *edgeMap=nullptr) const override
const std::vector< float > & getEdgeScores() const
const std::vector< int64_t > & getEdgeIndex() const