23 {
25
27
28
29
30
31 ActsPlugins::ModuleMapCuda::Config gcCfg;
36 gcCfg.gpuBlocks = 512;
37 auto gc = std::make_shared<ActsPlugins::ModuleMapCuda>(
38 gcCfg,
m_logger->cloneWithSuffix(
"ModuleMap"));
39
40
41
42
43 std::shared_ptr<ActsPlugins::EdgeClassificationBase> gnn;
44 if (
m_gnnPath.value().find(
".onnx") != std::string::npos) {
45#ifdef ACTS_GNN_ONNX_BACKEND
46 ActsPlugins::OnnxEdgeClassifier::Config gnnCfg;
49 gnn = std::make_shared<ActsPlugins::OnnxEdgeClassifier>(
50 gnnCfg,
m_logger->cloneWithSuffix(
"GNN"));
51 m_runMutex.emplace();
52#else
53 ATH_MSG_FATAL(
"Not compiled with ONNX, cannot interpret *.onnx files");
54 return StatusCode::FAILURE;
55#endif
56 }
else if (
m_gnnPath.value().find(
".pt") != std::string::npos) {
57#ifdef ACTS_GNN_TORCH_BACKEND
58 ActsPlugins::TorchEdgeClassifier::Config gnnCfg;
61 gnn = std::make_shared<ActsPlugins::TorchEdgeClassifier>(
62 gnnCfg,
m_logger->cloneWithSuffix(
"GNN"));
63 m_runMutex.emplace();
64#else
65 ATH_MSG_FATAL(
"Not compiled with Torch, cannot interpret *.pt files");
66 return StatusCode::FAILURE;
67#endif
68 }
else if (
m_gnnPath.value().find(
".engine") != std::string::npos) {
69#ifdef ACTS_GNN_WITH_TENSORRT
70 ActsPlugins::TensorRTEdgeClassifier::Config gnnCfg;
74 gnn = std::make_shared<ActsPlugins::TensorRTEdgeClassifier>(
75 gnnCfg,
m_logger->cloneWithSuffix(
"GNN"));
76#else
77 ATH_MSG_FATAL(
"Not compiled with TensorRT, cannot interpret *.engine files");
78 return StatusCode::FAILURE;
79#endif
80 } else {
82 return StatusCode::FAILURE;
83 }
84
85
86 ATH_MSG_INFO(
"Configure CC&JunctionRemoval as graph segmentation algorithm");
87 ActsPlugins::CudaTrackBuilding::Config tbCfg;
88 tbCfg.doJunctionRemoval = true;
89 auto tb = std::make_shared<ActsPlugins::CudaTrackBuilding>(
90 tbCfg,
m_logger->cloneWithSuffix(
"CC&JR"));
91
92
94 gc, std::vector{gnn},
tb,
m_logger->cloneWithSuffix(
"Pipeline"));
95
96 return StatusCode::SUCCESS;
97}
#define ATH_CHECK
Evaluate an expression and check for errors.
std::unique_ptr< const Acts::Logger > makeActsAthenaLogger(IMessageSvc *svc, const std::string &name, int level, std::optional< std::string > parent_name)