88 ACTS_DEBUG(
"TrackFindingGNNAlg::initialize() - begin");
107 ActsPlugins::ModuleMapCuda::Config gcCfg;
108 gcCfg.rScale = 1000.f;
109 gcCfg.zScale = 1000.f;
110 gcCfg.phiScale = std::numbers::pi_v<float>;
112 gcCfg.gpuBlocks = 512;
113 std::shared_ptr<ActsPlugins::GraphConstructionBase> gc =
114 std::make_shared<ActsPlugins::ModuleMapCuda>(
115 gcCfg,
m_logger->cloneWithSuffix(
"ModuleMap"));
117 std::shared_ptr<ActsPlugins::EdgeClassificationBase> gnn;
118 if (
m_gnnPath.value().find(
".onnx") != std::string::npos) {
119#ifdef ACTS_GNN_ONNX_BACKEND
120 ActsPlugins::OnnxEdgeClassifier::Config gnnCfg;
123 gnn = std::make_shared<ActsPlugins::OnnxEdgeClassifier>(
124 gnnCfg,
m_logger->cloneWithSuffix(
"GNN"));
126 ATH_MSG_ERROR(
"GNN .onnx selected but build lacks ONNX backend");
127 return StatusCode::FAILURE;
129 }
else if (
m_gnnPath.value().find(
".pt") != std::string::npos) {
130#ifdef ACTS_GNN_TORCH_BACKEND
131 ActsPlugins::TorchEdgeClassifier::Config gnnCfg;
134 gnnCfg.useEdgeFeatures =
true;
135 gnn = std::make_shared<ActsPlugins::TorchEdgeClassifier>(
136 gnnCfg,
m_logger->cloneWithSuffix(
"GNN"));
138 ATH_MSG_ERROR(
"GNN .pt selected but build lacks libtorch backend");
139 return StatusCode::FAILURE;
141 }
else if (
m_gnnPath.value().find(
".engine") != std::string::npos) {
142#ifdef ACTS_GNN_WITH_TENSORRT
143 ActsPlugins::TensorRTEdgeClassifier::Config gnnCfg;
147 gnn = std::make_shared<ActsPlugins::TensorRTEdgeClassifier>(
148 gnnCfg,
m_logger->cloneWithSuffix(
"GNN"));
150 ATH_MSG_ERROR(
"GNN .engine selected but build lacks TensorRT backend");
151 return StatusCode::FAILURE;
155 return StatusCode::FAILURE;
158 ActsPlugins::CudaTrackBuilding::Config tbCfg;
159 tbCfg.doJunctionRemoval =
true;
160 std::shared_ptr<ActsPlugins::TrackBuildingBase> tb =
161 std::make_shared<ActsPlugins::CudaTrackBuilding>(
162 tbCfg,
m_logger->cloneWithSuffix(
"GraphSeg"));
165 gc, std::vector{gnn}, tb,
m_logger->cloneWithSuffix(
"Pipeline"));
173 ACTS_INFO(
"Use phi overlap spacepoints: " << std::boolalpha
176 using TSC = Acts::TrackSelector::Config;
179 auto commonConfig = [&](TSC &config) {
180 config.requireReferenceSurface =
true;
188 commonConfig(config);
191 config.ptMin = 900_MeV;
192 config.loc0Max = 2_mm;
193 config.loc0Min = -2_mm;
197 commonConfig(config);
200 config.ptMin = 400_MeV;
201 config.loc0Max = 2_mm;
202 config.loc0Min = -2_mm;
204 .addCuts([&](TSC &config) {
205 commonConfig(config);
207 config.minMeasurements = 7;
208 config.ptMin = 400_MeV;
209 config.loc0Max = 10_mm;
210 config.loc0Min = -10_mm;
215 ACTS_DEBUG(
"TrackFindingGNNAlg::initialize() - end");
216 return StatusCode::SUCCESS;
222 ACTS_DEBUG(
"TrackFindingGNNAlg::execute() - begin");
224 std::optional<Athena::Chrono> timer;
225 timer.emplace(
"GNN get spacepoint handles",
m_chronoSvc.get());
227 Acts::GeometryContext gctx =
229 Acts::MagneticFieldContext mctx =
238 const auto &pixelSPContainer = *pixelSPHandle.cptr();
242 const auto &stripSPContainer = *stripSPHandle.cptr();
244 auto stripSPOVHandle =
247 const auto &stripSPOVContainer = *stripSPOVHandle.cptr();
249 constexpr std::size_t nFeatures = 12;
250 std::size_t nSP = pixelSPContainer.size() + stripSPContainer.size() +
251 stripSPOVContainer.size();
253 ACTS_DEBUG(
"Number spacepoints: "
254 << nSP <<
" (" <<
"pixel: " << pixelSPContainer.size() <<
", "
255 <<
"strip: " << stripSPContainer.size() <<
", "
256 <<
"strip overlap: " << stripSPOVContainer.size() <<
")");
259 timer.emplace(
"GNN extract data",
m_chronoSvc.get());
261 std::vector<std::uint64_t> moduleIds;
262 moduleIds.reserve(nSP);
263 std::vector<const xAOD::SpacePoint *> allSPPtrs;
264 allSPPtrs.reserve(nSP);
265 std::vector<Acts::GeometryIdentifier> geoIds, sortedGeoIds(nSP);
268 std::size_t skipped = 0;
269 for (
const auto &spc :
270 {pixelSPContainer, stripSPContainer, stripSPOVContainer}) {
271 for (
auto sp : spc) {
272 auto cl1 =
sp->measurements().front();
275 Identifier atlasIdCl1(
static_cast<Identifier::value_type
>(cl1->identifier()));
277 if (
sp->measurements().size() == 2) {
278 auto cl2 =
sp->measurements().at(1);
279 Identifier atlasIdCl2(
static_cast<Identifier::value_type
>(cl2->identifier()));
287 if (overlapFlag == 2 || overlapFlag == 3) {
289 ACTS_VERBOSE(
"Skip phi overlap spacepoint (flag=" << overlapFlag
295 geoIds.push_back(geoIdCl1);
297 allSPPtrs.push_back(
sp);
301 ACTS_DEBUG(
"Skipped " << skipped <<
" SPs because of phi overlap");
302 nSP = allSPPtrs.size();
303 ACTS_DEBUG(
"Keep " << nSP <<
" SPs for feature creation");
305 timer.emplace(
"GNN build input tensor",
m_chronoSvc.get());
307 std::vector<std::size_t> idxs(nSP);
308 std::iota(idxs.begin(), idxs.end(), 0);
311 idxs, [&](
auto a,
auto b) {
return moduleIds.at(
a) < moduleIds.at(b); });
312 std::ranges::sort(moduleIds);
314 std::vector<float> features(nFeatures * nSP);
315 std::vector<boost::container::static_vector<Acts::SourceLink, 2>> sourceLinks(
317 std::vector<int> id(nSP);
319 for (
auto k = 0ul; k < nSP; k++) {
323 std::span<float> f(features.data() + k * nFeatures, nFeatures);
324 const auto &
sp = *allSPPtrs.at(i);
326 using namespace Acts::VectorHelpers;
327 using namespace Acts::AngleHelpers;
329 Acts::Vector3 spp{
sp.x(),
sp.y(),
sp.z()};
331 if (
sp.measurements().size() == 1) {
332 for (
auto j = 0ul; j < nFeatures; j += 4) {
333 f[j + 0] =
perp(spp) / 1000.f;
334 f[j + 1] =
phi(spp) / std::numbers::pi_v<float>;
335 f[j + 2] =
sp.z() / 1000.f;
340 f[j + 0] =
perp(spp) / 1000.f;
341 f[j + 1] =
phi(spp) / std::numbers::pi_v<float>;
342 f[j + 2] =
sp.z() / 1000.f;
345 for (
auto m :
sp.measurements()) {
347 auto gp = cl->globalPosition();
349 f[j + 0] =
perp(gp) / 1000.f;
350 f[j + 1] =
phi(gp) / std::numbers::pi_v<float>;
351 f[j + 2] = gp.z() / 1000.f;
360 sortedGeoIds.at(k) = geoIds.at(i);
366 m_gpuInstanceCount->acquire();
369 m_gpuInstanceCount->release();
371 ACTS_DEBUG(
"Have " << candidates.size() <<
" candidates after GNN");
374 auto candidateSelector = [&](
const std::vector<int> &c) {
375 bool tooFewMeasurements = std::accumulate(c.begin(), c.end(), 0ul, [&](
auto sum,
auto spi) {
376 return sum + allSPPtrs.at(spi)->measurements().size();
378 bool noPixelHits = !std::ranges::any_of(c, [&](
auto spi) {
return allSPPtrs.at(spi)->measurements().size() == 1; });
379 return tooFewMeasurements || noPixelHits;
382 candidates.erase(
std::remove_if(candidates.begin(), candidates.end(), candidateSelector),
385 <<
" measurements: " << candidates.size());
388 timer.emplace(
"GNN parameter estimation + fit",
m_chronoSvc.get());
390 Acts::VectorTrackContainer trackBackend;
391 Acts::VectorMultiTrajectory trackStateBackend;
392 constexpr std::size_t nTracksExpected = 3000;
393 trackBackend.reserve(nTracksExpected);
394 trackStateBackend.reserve(nTracksExpected * 30);
400 auto makeSeedFromCandidate = [&](
const std::vector<int> &cand) -> std::optional<boost::container::small_vector<const xAOD::SpacePoint*, 3>> {
402 boost::container::small_vector<const xAOD::SpacePoint*, 3> picked;
403 if (cand.empty())
return std::nullopt;
405 Acts::Vector3 v{
sp->x(),
sp->y(),
sp->z()};
409 picked.push_back(last);
410 for (std::size_t i = 1; i < cand.size() && picked.size() < 3; ++i) {
413 picked.push_back(
sp);
417 if (picked.size() < 3)
return std::nullopt;
421 auto retrieveSurface = [&](
const ActsTrk::Seed& seed,
bool useTopSp) ->
const Acts::Surface& {
426 throw std::runtime_error(
"retrieveSurface: no Acts surface for GeometryIdentifier " + std::to_string(geoId.value()));
432 return Acts::fastHypot(
sp->x(),
sp->y(),
sp->z());
435 for (
const auto &cand : candidates) {
436 auto pickedOpt = makeSeedFromCandidate(cand);
437 if (!pickedOpt.has_value())
continue;
439 auto picked = *pickedOpt;
442 return R_of(a) < R_of(b);
448 seed,
true, gctx, mctx, retrieveSurface);
449 if (!initialParamsOpt.has_value())
continue;
451 boost::container::small_vector<const xAOD::SpacePoint*, 16> sortedSP;
452 sortedSP.reserve(cand.size());
453 for (
int spi : cand) sortedSP.push_back(allSPPtrs.at(spi));
454 std::sort(sortedSP.begin(), sortedSP.end(),
456 return R_of(a) < R_of(b);
459 std::vector<const xAOD::UncalibratedMeasurement*> measList;
460 measList.reserve(sortedSP.size() * 2);
463 measList.push_back(m);
467 auto fitted =
m_fitterTool->fit(measList, *initialParamsOpt, gctx, mctx, cctx);
469 for (
auto track : *fitted) {
470 auto newTrack = tracks.makeTrack();
471 newTrack.copyFrom(track);
476 ACTS_DEBUG(
"After track fit: " << tracks.size() <<
" / " << candidates.size()
480 if (candidates.size() == 1 && tracks.size() == 1) {
481 const auto &t = *tracks.begin();
482 ACTS_DEBUG(
"Single particle case: " << candidates.front().size() <<
" -> "
488 timer.emplace(
"Track selection & conversion",
m_chronoSvc.get());
490 Acts::VectorTrackContainer selTrackBackend;
491 selTrackBackend.reserve(trackBackend.size());
495 for (
auto track : tracks) {
496 if (selector.isValidTrack(track)) {
497 auto newTrack = selectedTracks.makeTrack();
500 newTrack.copyFrom(track);
504 ACTS_DEBUG(
"GNN cand: " << candidates.size() <<
", fitted: " << tracks.size()
505 <<
", selected: " << selectedTracks.size());
508 Acts::ConstVectorTrackContainer constTrackBackend(std::move(selTrackBackend));
509 Acts::ConstVectorMultiTrajectory constTrackStateBackend(std::move(trackStateBackend));
510 std::unique_ptr<ActsTrk::TrackContainer> constTracksContainer
511 = std::make_unique<ActsTrk::TrackContainer>(std::move(constTrackBackend), std::move(constTrackStateBackend) );
515 ATH_CHECK(trackContainerHandle.
record(std::move(constTracksContainer)));
517 return StatusCode::SUCCESS;