ATLAS Offline Software
Loading...
Searching...
No Matches
GNNTrackFinderTritonTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6
7#include "ExaTrkXUtils.hpp"
9
10// Framework include(s).
11#include <cmath>
12
14
18
19 // tokenize the feature names by comma and push to the vector
21 return StatusCode::SUCCESS;
22}
23
25 const std::vector<const Trk::SpacePoint*>& spacepoints,
26 std::vector<std::vector<uint32_t> >& tracks) const {
27 int64_t numSpacepoints = (int64_t)spacepoints.size();
28 std::vector<float> inputValues;
29 std::vector<uint32_t> spacepointIDs;
30
31 int64_t spacepointFeatures = m_featureNamesVec.size();
32 int sp_idx = 0;
33 for (const auto& sp : spacepoints) {
34 // depending on the trained embedding and GNN models, the input features
35 // may need to be updated.
36 auto featureMap = m_spacepointFeatureTool->getFeatures(sp);
37 for (int i = 0; i < spacepointFeatures; i++){
38 inputValues.push_back(featureMap[m_featureNamesVec[i]]);
39 }
40
41 spacepointIDs.push_back(sp_idx++);
42 }
43
44 AthInfer::InputDataMap inputData;
45 inputData["FEATURES"] = std::make_pair(
46 std::vector<int64_t>{numSpacepoints, spacepointFeatures}, std::move(inputValues));
47
48 AthInfer::OutputDataMap outputData;
49 outputData["LABELS"] = std::make_pair(std::vector<int64_t>{numSpacepoints, 1}, std::vector<int64_t>{});
50
51 ATH_CHECK(m_gnnTrackingTritonTool->inference(inputData, outputData));
52
53 auto& trackLabels = std::get<std::vector<int64_t>>(outputData["LABELS"].second);
54 if (trackLabels.size() == 0){
55 ATH_MSG_DEBUG("No tracks found in the event.");
56 return StatusCode::SUCCESS;
57 }
58
59 tracks.clear();
60 std::vector<uint32_t> this_track;
61 for (auto label : trackLabels) {
62 if (label == -1) {
63 if (this_track.size() > 0) {
64 tracks.push_back(this_track);
65 this_track.clear();
66 }
67 } else {
68 this_track.push_back(label);
69 }
70 }
71
72 return StatusCode::SUCCESS;
73}
74
75MsgStream& InDet::GNNTrackFinderTritonTool::dump( MsgStream& out ) const
76{
77 out<<std::endl;
78 return dumpevent(out);
79}
80
81std::ostream& InDet::GNNTrackFinderTritonTool::dump( std::ostream& out ) const
82{
83 return out;
84}
85
86MsgStream& InDet::GNNTrackFinderTritonTool::dumpevent( MsgStream& out ) const
87{
88 out<<"|---------------------------------------------------------------------|"
89 <<std::endl;
90 out<<"| Number output tracks | "<<std::setw(12)
91 <<" |"<<std::endl;
92 out<<"|---------------------------------------------------------------------|"
93 <<std::endl;
94 return out;
95}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_DEBUG(x)
static Double_t sp
virtual StatusCode getTracks(const std::vector< const Trk::SpacePoint * > &spacepoints, std::vector< std::vector< uint32_t > > &tracks) const override
Get track candidates from a list of space points.
MsgStream & dumpevent(MsgStream &out) const
ToolHandle< AthInfer::IAthInferenceTool > m_gnnTrackingTritonTool
std::vector< std::string > m_featureNamesVec
virtual StatusCode initialize() override
ToolHandle< ISpacepointFeatureTool > m_spacepointFeatureTool
virtual MsgStream & dump(MsgStream &out) const override
std::string label(const std::string &format, int i)
Definition label.h:19
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap
std::vector< std::string > tokenize(const std::string &the_str, std::string_view delimiters)
Splits the string into smaller substrings.