ATLAS Offline Software
Loading...
Searching...
No Matches
InDet::ActsGnnModuleMapFinderTool Class Reference

Tool that produces track candidates using the ACTS GNN pipeline with module-map graph construction, edge classification, and track building. More...

#include <ActsGnnModuleMapFinderTool.h>

Inheritance diagram for InDet::ActsGnnModuleMapFinderTool:
Collaboration diagram for InDet::ActsGnnModuleMapFinderTool:

Public Member Functions

virtual StatusCode initialize () override
virtual StatusCode getTracks (const std::vector< const Trk::SpacePoint * > &spacepoints, std::vector< std::vector< uint32_t > > &tracks) const override
virtual MsgStream & dump (MsgStream &out) const override
virtual std::ostream & dump (std::ostream &out) const override

Private Attributes

StringProperty m_moduleMapPath {this, "moduleMapPath", "", "Path to module map ROOT files"}
StringProperty m_gnnPath {this, "gnnPath", "", "Path to GNN model (.onnx, .pt, or .engine)"}
FloatProperty m_edgeCut {this, "edgeCut", 0.5, "Edge classification cut"}
UnsignedIntegerProperty m_numTrtContexts {this, "numTrtContexts", 1, "Number of TensorRT execution contexts (controls concurrency)"}
UnsignedIntegerProperty m_minCandidateMeasurements {this, "minCandidateMeasurements", 7, "Min measurements per candidate"}
ToolHandle< ISpacepointFeatureToolm_spacepointFeatureTool
std::unique_ptr< ActsPlugins::GnnPipeline > m_gnnPipeline
std::unique_ptr< const Acts::Logger > m_logger
std::optional< std::mutex > m_runMutex ATLAS_THREAD_SAFE {}

Static Private Attributes

static constexpr std::size_t NUM_FEATURES = 12
static constexpr std::array< const char *, NUM_FEATURESFEATURE_NAMES
static constexpr float kScaleR = 1000.f
static constexpr float kScalePhi = 3.14159265359f
static constexpr float kScaleZ = 1000.f
static constexpr float kScaleEta = 1.f
static constexpr std::array< float, NUM_FEATURESFEATURE_SCALES

Detailed Description

Tool that produces track candidates using the ACTS GNN pipeline with module-map graph construction, edge classification, and track building.

Implements the IGNNTrackFinder interface for use with SiSPGNNTrackMaker.

Definition at line 32 of file ActsGnnModuleMapFinderTool.h.

Member Function Documentation

◆ dump() [1/2]

MsgStream & InDet::ActsGnnModuleMapFinderTool::dump ( MsgStream & out) const
overridevirtual

Definition at line 163 of file ActsGnnModuleMapFinderTool.cxx.

163 {
164 out << std::endl;
165 out << "|---------------------------------------------------------------------|" << std::endl;
166 out << "| ActsGnnModuleMapFinderTool |" << std::endl;
167 out << "|---------------------------------------------------------------------|" << std::endl;
168 return out;
169}

◆ dump() [2/2]

std::ostream & InDet::ActsGnnModuleMapFinderTool::dump ( std::ostream & out) const
overridevirtual

Definition at line 171 of file ActsGnnModuleMapFinderTool.cxx.

171 {
172 return out;
173}

◆ getTracks()

StatusCode InDet::ActsGnnModuleMapFinderTool::getTracks ( const std::vector< const Trk::SpacePoint * > & spacepoints,
std::vector< std::vector< uint32_t > > & tracks ) const
overridevirtual

Definition at line 99 of file ActsGnnModuleMapFinderTool.cxx.

101 {
102
103 const std::size_t nSP = spacepoints.size();
104
105 ATH_MSG_DEBUG("Processing " << nSP << " spacepoints with " << NUM_FEATURES << " features");
106
107 // Sort spacepoint indices by module ID (required by module map graph construction)
108 std::vector<std::size_t> sortIdx(nSP);
109 std::iota(sortIdx.begin(), sortIdx.end(), 0);
110 std::ranges::sort(sortIdx, std::less{}, [&](std::size_t i) {
111 return spacepoints[i]->clusterList().first->detectorElement()->identify().get_compact();
112 });
113
114 // Build features, module IDs, and IDs directly in sorted order
115 std::vector<float> features(NUM_FEATURES * nSP);
116 std::vector<std::uint64_t> moduleIds(nSP);
117 std::vector<int> ids(nSP);
118
119 for (std::size_t k = 0; k < nSP; ++k) {
120 const std::size_t origIdx = sortIdx[k];
121 auto featureMap = m_spacepointFeatureTool->getFeatures(spacepoints[origIdx]);
122 // Use detector element identifier, not cluster identifier, to get the module ID
123 moduleIds[k] = spacepoints[origIdx]->clusterList().first->detectorElement()->identify().get_compact();
124 ids[k] = static_cast<int>(k);
125 for (std::size_t j = 0; j < NUM_FEATURES; ++j) {
126 features[k * NUM_FEATURES + j] = featureMap[FEATURE_NAMES[j]] / FEATURE_SCALES[j];
127 }
128 }
129
130 // Run GNN pipeline (mutex present for ONNX/Torch, absent for TRT)
131 auto candidates = [&] {
132 std::unique_lock<std::mutex> lock;
133 if (m_runMutex) lock = std::unique_lock<std::mutex>(*m_runMutex);
134 return m_gnnPipeline->run(features, moduleIds, ids, ActsPlugins::Device::Cuda(0));
135 }();
136
137 ATH_MSG_DEBUG("GNN pipeline returned " << candidates.size() << " candidates");
138
139 // Filter by minimum measurements and convert indices back to original ordering
140 tracks.clear();
141 tracks.reserve(candidates.size());
142
143 for (const auto& candidate : candidates) {
144 if (candidate.size() < m_minCandidateMeasurements.value()) {
145 continue;
146 }
147
148 // Map sorted indices back to original spacepoint indices
149 std::vector<uint32_t> track;
150 track.reserve(candidate.size());
151 for (int sortedIdx : candidate) {
152 track.push_back(static_cast<uint32_t>(sortIdx[sortedIdx]));
153 }
154 tracks.push_back(std::move(track));
155 }
156
157 ATH_MSG_DEBUG("Returning " << tracks.size() << " track candidates after filtering (>= "
158 << m_minCandidateMeasurements.value() << " measurements)");
159
160 return StatusCode::SUCCESS;
161}
#define ATH_MSG_DEBUG(x)
virtual void lock()=0
Interface to allow an object to lock itself when made const in SG.
static constexpr std::array< float, NUM_FEATURES > FEATURE_SCALES
static constexpr std::array< const char *, NUM_FEATURES > FEATURE_NAMES
ToolHandle< ISpacepointFeatureTool > m_spacepointFeatureTool
static constexpr std::size_t NUM_FEATURES
std::unique_ptr< ActsPlugins::GnnPipeline > m_gnnPipeline
float j(const xAOD::IParticle &, const xAOD::TrackMeasurementValidation &hit, const Eigen::Matrix3d &jab_inv)
setEventNumber uint32_t

◆ initialize()

StatusCode InDet::ActsGnnModuleMapFinderTool::initialize ( )
overridevirtual

Definition at line 23 of file ActsGnnModuleMapFinderTool.cxx.

23 {
25
26 m_logger = makeActsAthenaLogger(this, "ActsGnnModuleMapFinderTool");
27
28 // Build ACTS GNN pipeline components
29
30 // 1. Graph constructor (ModuleMapCuda)
31 ActsPlugins::ModuleMapCuda::Config gcCfg;
32 gcCfg.rScale = kScaleR;
33 gcCfg.zScale = kScaleZ;
34 gcCfg.phiScale = kScalePhi;
35 gcCfg.moduleMapPath = m_moduleMapPath.value();
36 gcCfg.gpuBlocks = 512;
37 auto gc = std::make_shared<ActsPlugins::ModuleMapCuda>(
38 gcCfg, m_logger->cloneWithSuffix("ModuleMap"));
39
40 // 2. Edge classifier (ONNX / Torch / TensorRT)
41 // ONNX and Torch are not thread-safe: a mutex is emplaced to serialise run().
42 // TensorRT manages concurrency internally via execution contexts: no guard needed.
43 std::shared_ptr<ActsPlugins::EdgeClassificationBase> gnn;
44 if (m_gnnPath.value().find(".onnx") != std::string::npos) {
45#ifdef ACTS_GNN_ONNX_BACKEND
46 ActsPlugins::OnnxEdgeClassifier::Config gnnCfg;
47 gnnCfg.modelPath = m_gnnPath.value();
48 gnnCfg.cut = m_edgeCut.value();
49 gnn = std::make_shared<ActsPlugins::OnnxEdgeClassifier>(
50 gnnCfg, m_logger->cloneWithSuffix("GNN"));
51 m_runMutex.emplace();
52#else
53 ATH_MSG_FATAL("Not compiled with ONNX, cannot interpret *.onnx files");
54 return StatusCode::FAILURE;
55#endif
56 } else if (m_gnnPath.value().find(".pt") != std::string::npos) {
57#ifdef ACTS_GNN_TORCH_BACKEND
58 ActsPlugins::TorchEdgeClassifier::Config gnnCfg;
59 gnnCfg.modelPath = m_gnnPath.value();
60 gnnCfg.cut = m_edgeCut.value();
61 gnn = std::make_shared<ActsPlugins::TorchEdgeClassifier>(
62 gnnCfg, m_logger->cloneWithSuffix("GNN"));
63 m_runMutex.emplace();
64#else
65 ATH_MSG_FATAL("Not compiled with Torch, cannot interpret *.pt files");
66 return StatusCode::FAILURE;
67#endif
68 } else if (m_gnnPath.value().find(".engine") != std::string::npos) {
69#ifdef ACTS_GNN_WITH_TENSORRT
70 ActsPlugins::TensorRTEdgeClassifier::Config gnnCfg;
71 gnnCfg.modelPath = m_gnnPath.value();
72 gnnCfg.cut = m_edgeCut.value();
73 gnnCfg.numExecutionContexts = m_numTrtContexts.value();
74 gnn = std::make_shared<ActsPlugins::TensorRTEdgeClassifier>(
75 gnnCfg, m_logger->cloneWithSuffix("GNN"));
76#else
77 ATH_MSG_FATAL("Not compiled with TensorRT, cannot interpret *.engine files");
78 return StatusCode::FAILURE;
79#endif
80 } else {
81 ATH_MSG_FATAL("Unknown extension for GNN model: " << m_gnnPath.value());
82 return StatusCode::FAILURE;
83 }
84
85 // 3. Track builder
86 ATH_MSG_INFO("Configure CC&JunctionRemoval as graph segmentation algorithm");
87 ActsPlugins::CudaTrackBuilding::Config tbCfg;
88 tbCfg.doJunctionRemoval = true;
89 auto tb = std::make_shared<ActsPlugins::CudaTrackBuilding>(
90 tbCfg, m_logger->cloneWithSuffix("CC&JR"));
91
92 // 4. Assemble pipeline
93 m_gnnPipeline = std::make_unique<ActsPlugins::GnnPipeline>(
94 gc, std::vector{gnn}, tb, m_logger->cloneWithSuffix("Pipeline"));
95
96 return StatusCode::SUCCESS;
97}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_FATAL(x)
#define ATH_MSG_INFO(x)
std::unique_ptr< const Acts::Logger > makeActsAthenaLogger(IMessageSvc *svc, const std::string &name, int level, std::optional< std::string > parent_name)
std::unique_ptr< const Acts::Logger > m_logger

Member Data Documentation

◆ ATLAS_THREAD_SAFE

std::optional<std::mutex> m_runMutex InDet::ActsGnnModuleMapFinderTool::ATLAS_THREAD_SAFE {}
mutableprivate

Definition at line 83 of file ActsGnnModuleMapFinderTool.h.

83{};

◆ FEATURE_NAMES

std::array<const char*, NUM_FEATURES> InDet::ActsGnnModuleMapFinderTool::FEATURE_NAMES
staticconstexprprivate
Initial value:
= {{
"r", "phi", "z", "eta",
"cluster_r_1", "cluster_phi_1", "cluster_z_1", "cluster_eta_1",
"cluster_r_2", "cluster_phi_2", "cluster_z_2", "cluster_eta_2"
}}

Definition at line 50 of file ActsGnnModuleMapFinderTool.h.

50 {{
51 "r", "phi", "z", "eta",
52 "cluster_r_1", "cluster_phi_1", "cluster_z_1", "cluster_eta_1",
53 "cluster_r_2", "cluster_phi_2", "cluster_z_2", "cluster_eta_2"
54 }};

◆ FEATURE_SCALES

std::array<float, NUM_FEATURES> InDet::ActsGnnModuleMapFinderTool::FEATURE_SCALES
staticconstexprprivate

◆ kScaleEta

float InDet::ActsGnnModuleMapFinderTool::kScaleEta = 1.f
staticconstexprprivate

Definition at line 59 of file ActsGnnModuleMapFinderTool.h.

◆ kScalePhi

float InDet::ActsGnnModuleMapFinderTool::kScalePhi = 3.14159265359f
staticconstexprprivate

Definition at line 57 of file ActsGnnModuleMapFinderTool.h.

◆ kScaleR

float InDet::ActsGnnModuleMapFinderTool::kScaleR = 1000.f
staticconstexprprivate

Definition at line 56 of file ActsGnnModuleMapFinderTool.h.

◆ kScaleZ

float InDet::ActsGnnModuleMapFinderTool::kScaleZ = 1000.f
staticconstexprprivate

Definition at line 58 of file ActsGnnModuleMapFinderTool.h.

◆ m_edgeCut

FloatProperty InDet::ActsGnnModuleMapFinderTool::m_edgeCut {this, "edgeCut", 0.5, "Edge classification cut"}
private

Definition at line 70 of file ActsGnnModuleMapFinderTool.h.

70{this, "edgeCut", 0.5, "Edge classification cut"};

◆ m_gnnPath

StringProperty InDet::ActsGnnModuleMapFinderTool::m_gnnPath {this, "gnnPath", "", "Path to GNN model (.onnx, .pt, or .engine)"}
private

Definition at line 69 of file ActsGnnModuleMapFinderTool.h.

69{this, "gnnPath", "", "Path to GNN model (.onnx, .pt, or .engine)"};

◆ m_gnnPipeline

std::unique_ptr<ActsPlugins::GnnPipeline> InDet::ActsGnnModuleMapFinderTool::m_gnnPipeline
private

Definition at line 79 of file ActsGnnModuleMapFinderTool.h.

◆ m_logger

std::unique_ptr<const Acts::Logger> InDet::ActsGnnModuleMapFinderTool::m_logger
private

Definition at line 80 of file ActsGnnModuleMapFinderTool.h.

◆ m_minCandidateMeasurements

UnsignedIntegerProperty InDet::ActsGnnModuleMapFinderTool::m_minCandidateMeasurements {this, "minCandidateMeasurements", 7, "Min measurements per candidate"}
private

Definition at line 72 of file ActsGnnModuleMapFinderTool.h.

72{this, "minCandidateMeasurements", 7, "Min measurements per candidate"};

◆ m_moduleMapPath

StringProperty InDet::ActsGnnModuleMapFinderTool::m_moduleMapPath {this, "moduleMapPath", "", "Path to module map ROOT files"}
private

Definition at line 68 of file ActsGnnModuleMapFinderTool.h.

68{this, "moduleMapPath", "", "Path to module map ROOT files"};

◆ m_numTrtContexts

UnsignedIntegerProperty InDet::ActsGnnModuleMapFinderTool::m_numTrtContexts {this, "numTrtContexts", 1, "Number of TensorRT execution contexts (controls concurrency)"}
private

Definition at line 71 of file ActsGnnModuleMapFinderTool.h.

71{this, "numTrtContexts", 1, "Number of TensorRT execution contexts (controls concurrency)"};

◆ m_spacepointFeatureTool

ToolHandle<ISpacepointFeatureTool> InDet::ActsGnnModuleMapFinderTool::m_spacepointFeatureTool
private
Initial value:
{
this, "SpacepointFeatureTool", "InDet::SpacepointFeatureTool"}

Definition at line 75 of file ActsGnnModuleMapFinderTool.h.

75 {
76 this, "SpacepointFeatureTool", "InDet::SpacepointFeatureTool"};

◆ NUM_FEATURES

std::size_t InDet::ActsGnnModuleMapFinderTool::NUM_FEATURES = 12
staticconstexprprivate

Definition at line 48 of file ActsGnnModuleMapFinderTool.h.


The documentation for this class was generated from the following files: