ATLAS Offline Software
Loading...
Searching...
No Matches
InDet::ActsGnnModuleMapFinderTool Class Reference

Tool that produces track candidates using the ACTS GNN pipeline with module-map graph construction, edge classification, and track building. More...

#include <ActsGnnModuleMapFinderTool.h>

Inheritance diagram for InDet::ActsGnnModuleMapFinderTool:
Collaboration diagram for InDet::ActsGnnModuleMapFinderTool:

Public Member Functions

virtual StatusCode initialize () override
virtual StatusCode getTracks (const std::vector< const Trk::SpacePoint * > &spacepoints, std::vector< std::vector< uint32_t > > &tracks, std::unordered_map< int, std::unordered_map< int, float > > *edgeMap=nullptr) const override
virtual MsgStream & dump (MsgStream &out) const override
virtual std::ostream & dump (std::ostream &out) const override

Private Attributes

StringProperty m_moduleMapPath {this, "moduleMapPath", "", "Path to module map ROOT files"}
StringProperty m_gnnPath {this, "gnnPath", "", "Path to GNN model (.onnx, .pt, or .engine)"}
FloatProperty m_edgeCut {this, "edgeCut", 0.5, "Edge classification cut"}
UnsignedIntegerProperty m_numTrtContexts {this, "numTrtContexts", 1, "Number of TensorRT execution contexts (controls concurrency)"}
UnsignedIntegerProperty m_minCandidateMeasurements {this, "minCandidateMeasurements", 7, "Min measurements per candidate"}
ToolHandle< ISpacepointFeatureToolm_spacepointFeatureTool
std::unique_ptr< ActsPlugins::GnnPipeline > m_gnnPipeline
std::unique_ptr< const Acts::Logger > m_logger
std::optional< std::mutex > m_runMutex ATLAS_THREAD_SAFE {}

Static Private Attributes

static constexpr std::size_t NUM_FEATURES = 12
static constexpr std::array< const char *, NUM_FEATURESFEATURE_NAMES
static constexpr float kScaleR = 1000.f
static constexpr float kScalePhi = 3.14159265359f
static constexpr float kScaleZ = 1000.f
static constexpr float kScaleEta = 1.f
static constexpr std::array< float, NUM_FEATURESFEATURE_SCALES

Detailed Description

Tool that produces track candidates using the ACTS GNN pipeline with module-map graph construction, edge classification, and track building.

Implements the IGNNTrackFinder interface for use with SiSPGNNTrackMaker.

Definition at line 32 of file ActsGnnModuleMapFinderTool.h.

Member Function Documentation

◆ dump() [1/2]

MsgStream & InDet::ActsGnnModuleMapFinderTool::dump ( MsgStream & out) const
overridevirtual

Definition at line 185 of file ActsGnnModuleMapFinderTool.cxx.

185 {
186 out << std::endl;
187 out << "|---------------------------------------------------------------------|" << std::endl;
188 out << "| ActsGnnModuleMapFinderTool |" << std::endl;
189 out << "|---------------------------------------------------------------------|" << std::endl;
190 return out;
191}

◆ dump() [2/2]

std::ostream & InDet::ActsGnnModuleMapFinderTool::dump ( std::ostream & out) const
overridevirtual

Definition at line 193 of file ActsGnnModuleMapFinderTool.cxx.

193 {
194 return out;
195}

◆ getTracks()

StatusCode InDet::ActsGnnModuleMapFinderTool::getTracks ( const std::vector< const Trk::SpacePoint * > & spacepoints,
std::vector< std::vector< uint32_t > > & tracks,
std::unordered_map< int, std::unordered_map< int, float > > * edgeMap = nullptr ) const
overridevirtual

Definition at line 101 of file ActsGnnModuleMapFinderTool.cxx.

104 {
105
106 const std::size_t nSP = spacepoints.size();
107
108 ATH_MSG_DEBUG("Processing " << nSP << " spacepoints with " << NUM_FEATURES << " features");
109
110 // Sort spacepoint indices by module ID (required by module map graph construction)
111 std::vector<std::size_t> sortIdx(nSP);
112 std::iota(sortIdx.begin(), sortIdx.end(), 0);
113 std::ranges::sort(sortIdx, std::less{}, [&](std::size_t i) {
114 return spacepoints[i]->clusterList().first->detectorElement()->identify().get_compact();
115 });
116
117 // Build features, module IDs, and IDs directly in sorted order
118 std::vector<float> features(NUM_FEATURES * nSP);
119 std::vector<std::uint64_t> moduleIds(nSP);
120 std::vector<int> ids(nSP);
121
122 for (std::size_t k = 0; k < nSP; ++k) {
123 const std::size_t origIdx = sortIdx[k];
124 auto featureMap = m_spacepointFeatureTool->getFeatures(spacepoints[origIdx]);
125 // Use detector element identifier, not cluster identifier, to get the module ID
126 moduleIds[k] = spacepoints[origIdx]->clusterList().first->detectorElement()->identify().get_compact();
127 ids[k] = static_cast<int>(k);
128 for (std::size_t j = 0; j < NUM_FEATURES; ++j) {
129 features[k * NUM_FEATURES + j] = featureMap[FEATURE_NAMES[j]] / FEATURE_SCALES[j];
130 }
131 }
132
133 // Run GNN pipeline (mutex present for ONNX/Torch, absent for TRT)
134 auto candidates = [&] {
135 std::unique_lock<std::mutex> lock;
136 if (m_runMutex) lock = std::unique_lock<std::mutex>(*m_runMutex);
137
138 if (edgeMap != nullptr) {
139 ScoredGraphHook hook;
140 auto result = m_gnnPipeline->run(features, moduleIds, ids, ActsPlugins::Device::Cuda(0), hook);
141
142 // Retrieve edgeScores and edgeIndex from hook
143 const std::vector<float>& edgeScores = hook.getEdgeScores();
144 const std::vector<std::int64_t>& edgeIndex = hook.getEdgeIndex();
145 const std::size_t nEdges = edgeScores.size();
146
147 // Create a map to acces edge score (sorted indices back to original spacepoint indices)
148 for (std::size_t i = 0; i < nEdges; ++i) {
149 std::int64_t src = edgeIndex[i];
150 std::int64_t dst = edgeIndex[nEdges + i];
151 (*edgeMap)[sortIdx[src]][sortIdx[dst]] = edgeScores[i];
152 }
153 return result;
154 }
155
156 return m_gnnPipeline->run(features, moduleIds, ids, ActsPlugins::Device::Cuda(0));;
157 }();
158
159 ATH_MSG_DEBUG("GNN pipeline returned " << candidates.size() << " candidates");
160
161 // Filter by minimum measurements and convert indices back to original ordering
162 tracks.clear();
163 tracks.reserve(candidates.size());
164
165 for (const auto& candidate : candidates) {
166 if (candidate.size() < m_minCandidateMeasurements.value()) {
167 continue;
168 }
169
170 // Map sorted indices back to original spacepoint indices
171 std::vector<uint32_t> track;
172 track.reserve(candidate.size());
173 for (int sortedIdx : candidate) {
174 track.push_back(static_cast<uint32_t>(sortIdx[sortedIdx]));
175 }
176 tracks.push_back(std::move(track));
177 }
178
179 ATH_MSG_DEBUG("Returning " << tracks.size() << " track candidates after filtering (>= "
180 << m_minCandidateMeasurements.value() << " measurements)");
181
182 return StatusCode::SUCCESS;
183}
#define ATH_MSG_DEBUG(x)
virtual void lock()=0
Interface to allow an object to lock itself when made const in SG.
static constexpr std::array< float, NUM_FEATURES > FEATURE_SCALES
static constexpr std::array< const char *, NUM_FEATURES > FEATURE_NAMES
ToolHandle< ISpacepointFeatureTool > m_spacepointFeatureTool
static constexpr std::size_t NUM_FEATURES
std::unique_ptr< ActsPlugins::GnnPipeline > m_gnnPipeline
float j(const xAOD::IParticle &, const xAOD::TrackMeasurementValidation &hit, const Eigen::Matrix3d &jab_inv)
setEventNumber uint32_t

◆ initialize()

StatusCode InDet::ActsGnnModuleMapFinderTool::initialize ( )
overridevirtual

Definition at line 25 of file ActsGnnModuleMapFinderTool.cxx.

25 {
27
28 m_logger = makeActsAthenaLogger(this, "ActsGnnModuleMapFinderTool");
29
30 // Build ACTS GNN pipeline components
31
32 // 1. Graph constructor (ModuleMapCuda)
33 ActsPlugins::ModuleMapCuda::Config gcCfg;
34 gcCfg.rScale = kScaleR;
35 gcCfg.zScale = kScaleZ;
36 gcCfg.phiScale = kScalePhi;
37 gcCfg.moduleMapPath = m_moduleMapPath.value();
38 gcCfg.gpuBlocks = 512;
39 auto gc = std::make_shared<ActsPlugins::ModuleMapCuda>(
40 gcCfg, m_logger->cloneWithSuffix("ModuleMap"));
41
42 // 2. Edge classifier (ONNX / Torch / TensorRT)
43 // ONNX and Torch are not thread-safe: a mutex is emplaced to serialise run().
44 // TensorRT manages concurrency internally via execution contexts: no guard needed.
45 std::shared_ptr<ActsPlugins::EdgeClassificationBase> gnn;
46 if (m_gnnPath.value().find(".onnx") != std::string::npos) {
47#ifdef ACTS_GNN_ONNX_BACKEND
48 ActsPlugins::OnnxEdgeClassifier::Config gnnCfg;
49 gnnCfg.modelPath = m_gnnPath.value();
50 gnnCfg.cut = m_edgeCut.value();
51 gnn = std::make_shared<ActsPlugins::OnnxEdgeClassifier>(
52 gnnCfg, m_logger->cloneWithSuffix("GNN"));
53 m_runMutex.emplace();
54#else
55 ATH_MSG_FATAL("Not compiled with ONNX, cannot interpret *.onnx files");
56 return StatusCode::FAILURE;
57#endif
58 } else if (m_gnnPath.value().find(".pt") != std::string::npos) {
59#ifdef ACTS_GNN_TORCH_BACKEND
60 ActsPlugins::TorchEdgeClassifier::Config gnnCfg;
61 gnnCfg.modelPath = m_gnnPath.value();
62 gnnCfg.cut = m_edgeCut.value();
63 gnn = std::make_shared<ActsPlugins::TorchEdgeClassifier>(
64 gnnCfg, m_logger->cloneWithSuffix("GNN"));
65 m_runMutex.emplace();
66#else
67 ATH_MSG_FATAL("Not compiled with Torch, cannot interpret *.pt files");
68 return StatusCode::FAILURE;
69#endif
70 } else if (m_gnnPath.value().find(".engine") != std::string::npos) {
71#ifdef ACTS_GNN_WITH_TENSORRT
72 ActsPlugins::TensorRTEdgeClassifier::Config gnnCfg;
73 gnnCfg.modelPath = m_gnnPath.value();
74 gnnCfg.cut = m_edgeCut.value();
75 gnnCfg.numExecutionContexts = m_numTrtContexts.value();
76 gnn = std::make_shared<ActsPlugins::TensorRTEdgeClassifier>(
77 gnnCfg, m_logger->cloneWithSuffix("GNN"));
78#else
79 ATH_MSG_FATAL("Not compiled with TensorRT, cannot interpret *.engine files");
80 return StatusCode::FAILURE;
81#endif
82 } else {
83 ATH_MSG_FATAL("Unknown extension for GNN model: " << m_gnnPath.value());
84 return StatusCode::FAILURE;
85 }
86
87 // 3. Track builder
88 ATH_MSG_INFO("Configure CC&JunctionRemoval as graph segmentation algorithm");
89 ActsPlugins::CudaTrackBuilding::Config tbCfg;
90 tbCfg.doJunctionRemoval = true;
91 auto tb = std::make_shared<ActsPlugins::CudaTrackBuilding>(
92 tbCfg, m_logger->cloneWithSuffix("CC&JR"));
93
94 // 4. Assemble pipeline
95 m_gnnPipeline = std::make_unique<ActsPlugins::GnnPipeline>(
96 gc, std::vector{gnn}, tb, m_logger->cloneWithSuffix("Pipeline"));
97
98 return StatusCode::SUCCESS;
99}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_FATAL(x)
#define ATH_MSG_INFO(x)
std::unique_ptr< const Acts::Logger > makeActsAthenaLogger(IMessageSvc *svc, const std::string &name, int level, std::optional< std::string > parent_name)
std::unique_ptr< const Acts::Logger > m_logger

Member Data Documentation

◆ ATLAS_THREAD_SAFE

std::optional<std::mutex> m_runMutex InDet::ActsGnnModuleMapFinderTool::ATLAS_THREAD_SAFE {}
mutableprivate

Definition at line 84 of file ActsGnnModuleMapFinderTool.h.

84{};

◆ FEATURE_NAMES

std::array<const char*, NUM_FEATURES> InDet::ActsGnnModuleMapFinderTool::FEATURE_NAMES
staticconstexprprivate
Initial value:
= {{
"r", "phi", "z", "eta",
"cluster_r_1", "cluster_phi_1", "cluster_z_1", "cluster_eta_1",
"cluster_r_2", "cluster_phi_2", "cluster_z_2", "cluster_eta_2"
}}

Definition at line 51 of file ActsGnnModuleMapFinderTool.h.

51 {{
52 "r", "phi", "z", "eta",
53 "cluster_r_1", "cluster_phi_1", "cluster_z_1", "cluster_eta_1",
54 "cluster_r_2", "cluster_phi_2", "cluster_z_2", "cluster_eta_2"
55 }};

◆ FEATURE_SCALES

std::array<float, NUM_FEATURES> InDet::ActsGnnModuleMapFinderTool::FEATURE_SCALES
staticconstexprprivate

◆ kScaleEta

float InDet::ActsGnnModuleMapFinderTool::kScaleEta = 1.f
staticconstexprprivate

Definition at line 60 of file ActsGnnModuleMapFinderTool.h.

◆ kScalePhi

float InDet::ActsGnnModuleMapFinderTool::kScalePhi = 3.14159265359f
staticconstexprprivate

Definition at line 58 of file ActsGnnModuleMapFinderTool.h.

◆ kScaleR

float InDet::ActsGnnModuleMapFinderTool::kScaleR = 1000.f
staticconstexprprivate

Definition at line 57 of file ActsGnnModuleMapFinderTool.h.

◆ kScaleZ

float InDet::ActsGnnModuleMapFinderTool::kScaleZ = 1000.f
staticconstexprprivate

Definition at line 59 of file ActsGnnModuleMapFinderTool.h.

◆ m_edgeCut

FloatProperty InDet::ActsGnnModuleMapFinderTool::m_edgeCut {this, "edgeCut", 0.5, "Edge classification cut"}
private

Definition at line 71 of file ActsGnnModuleMapFinderTool.h.

71{this, "edgeCut", 0.5, "Edge classification cut"};

◆ m_gnnPath

StringProperty InDet::ActsGnnModuleMapFinderTool::m_gnnPath {this, "gnnPath", "", "Path to GNN model (.onnx, .pt, or .engine)"}
private

Definition at line 70 of file ActsGnnModuleMapFinderTool.h.

70{this, "gnnPath", "", "Path to GNN model (.onnx, .pt, or .engine)"};

◆ m_gnnPipeline

std::unique_ptr<ActsPlugins::GnnPipeline> InDet::ActsGnnModuleMapFinderTool::m_gnnPipeline
private

Definition at line 80 of file ActsGnnModuleMapFinderTool.h.

◆ m_logger

std::unique_ptr<const Acts::Logger> InDet::ActsGnnModuleMapFinderTool::m_logger
private

Definition at line 81 of file ActsGnnModuleMapFinderTool.h.

◆ m_minCandidateMeasurements

UnsignedIntegerProperty InDet::ActsGnnModuleMapFinderTool::m_minCandidateMeasurements {this, "minCandidateMeasurements", 7, "Min measurements per candidate"}
private

Definition at line 73 of file ActsGnnModuleMapFinderTool.h.

73{this, "minCandidateMeasurements", 7, "Min measurements per candidate"};

◆ m_moduleMapPath

StringProperty InDet::ActsGnnModuleMapFinderTool::m_moduleMapPath {this, "moduleMapPath", "", "Path to module map ROOT files"}
private

Definition at line 69 of file ActsGnnModuleMapFinderTool.h.

69{this, "moduleMapPath", "", "Path to module map ROOT files"};

◆ m_numTrtContexts

UnsignedIntegerProperty InDet::ActsGnnModuleMapFinderTool::m_numTrtContexts {this, "numTrtContexts", 1, "Number of TensorRT execution contexts (controls concurrency)"}
private

Definition at line 72 of file ActsGnnModuleMapFinderTool.h.

72{this, "numTrtContexts", 1, "Number of TensorRT execution contexts (controls concurrency)"};

◆ m_spacepointFeatureTool

ToolHandle<ISpacepointFeatureTool> InDet::ActsGnnModuleMapFinderTool::m_spacepointFeatureTool
private
Initial value:
{
this, "SpacepointFeatureTool", "InDet::SpacepointFeatureTool"}

Definition at line 76 of file ActsGnnModuleMapFinderTool.h.

76 {
77 this, "SpacepointFeatureTool", "InDet::SpacepointFeatureTool"};

◆ NUM_FEATURES

std::size_t InDet::ActsGnnModuleMapFinderTool::NUM_FEATURES = 12
staticconstexprprivate

Definition at line 49 of file ActsGnnModuleMapFinderTool.h.


The documentation for this class was generated from the following files: