25 {
27
29
30
31
32
33 ActsPlugins::ModuleMapCuda::Config gcCfg;
38 gcCfg.gpuBlocks = 512;
39 auto gc = std::make_shared<ActsPlugins::ModuleMapCuda>(
40 gcCfg,
m_logger->cloneWithSuffix(
"ModuleMap"));
41
42
43
44
45 std::shared_ptr<ActsPlugins::EdgeClassificationBase> gnn;
46 if (
m_gnnPath.value().find(
".onnx") != std::string::npos) {
47#ifdef ACTS_GNN_ONNX_BACKEND
48 ActsPlugins::OnnxEdgeClassifier::Config gnnCfg;
51 gnn = std::make_shared<ActsPlugins::OnnxEdgeClassifier>(
52 gnnCfg,
m_logger->cloneWithSuffix(
"GNN"));
53 m_runMutex.emplace();
54#else
55 ATH_MSG_FATAL(
"Not compiled with ONNX, cannot interpret *.onnx files");
56 return StatusCode::FAILURE;
57#endif
58 }
else if (
m_gnnPath.value().find(
".pt") != std::string::npos) {
59#ifdef ACTS_GNN_TORCH_BACKEND
60 ActsPlugins::TorchEdgeClassifier::Config gnnCfg;
63 gnn = std::make_shared<ActsPlugins::TorchEdgeClassifier>(
64 gnnCfg,
m_logger->cloneWithSuffix(
"GNN"));
65 m_runMutex.emplace();
66#else
67 ATH_MSG_FATAL(
"Not compiled with Torch, cannot interpret *.pt files");
68 return StatusCode::FAILURE;
69#endif
70 }
else if (
m_gnnPath.value().find(
".engine") != std::string::npos) {
71#ifdef ACTS_GNN_WITH_TENSORRT
72 ActsPlugins::TensorRTEdgeClassifier::Config gnnCfg;
76 gnn = std::make_shared<ActsPlugins::TensorRTEdgeClassifier>(
77 gnnCfg,
m_logger->cloneWithSuffix(
"GNN"));
78#else
79 ATH_MSG_FATAL(
"Not compiled with TensorRT, cannot interpret *.engine files");
80 return StatusCode::FAILURE;
81#endif
82 } else {
84 return StatusCode::FAILURE;
85 }
86
87
88 ATH_MSG_INFO(
"Configure CC&JunctionRemoval as graph segmentation algorithm");
89 ActsPlugins::CudaTrackBuilding::Config tbCfg;
90 tbCfg.doJunctionRemoval = true;
91 auto tb = std::make_shared<ActsPlugins::CudaTrackBuilding>(
92 tbCfg,
m_logger->cloneWithSuffix(
"CC&JR"));
93
94
96 gc, std::vector{gnn},
tb,
m_logger->cloneWithSuffix(
"Pipeline"));
97
98 return StatusCode::SUCCESS;
99}
#define ATH_CHECK
Evaluate an expression and check for errors.
std::unique_ptr< const Acts::Logger > makeActsAthenaLogger(IMessageSvc *svc, const std::string &name, int level, std::optional< std::string > parent_name)