ATLAS Offline Software
Loading...
Searching...
No Matches
GNNTrackFinderTritonTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3*/
4
6
7#include "ExaTrkXUtils.hpp"
9
10// Framework include(s).
11#include <cmath>
12
14
18
19 // tokenize the feature names by comma and push to the vector
21 return StatusCode::SUCCESS;
22}
23
25 const std::vector<const Trk::SpacePoint*>& spacepoints,
26 std::vector<std::vector<uint32_t> >& tracks) const {
27 int64_t numSpacepoints = (int64_t)spacepoints.size();
28 std::vector<float> inputValues;
29 std::vector<uint32_t> spacepointIDs;
30
31 int64_t spacepointFeatures = m_featureNamesVec.size();
32 int sp_idx = 0;
33 for (const auto& sp : spacepoints) {
34 // depending on the trained embedding and GNN models, the input features
35 // may need to be updated.
36 auto featureMap = m_spacepointFeatureTool->getFeatures(sp);
37 for (int i = 0; i < spacepointFeatures; i++){
38 // if the feature is "hit_id", use sp_idx as its value
39 if (m_featureNamesVec[i] == "hit_id"){
40 inputValues.push_back((float)sp_idx);
41 continue;
42 }
43 inputValues.push_back(featureMap[m_featureNamesVec[i]]);
44 }
45 sp_idx++;
46 }
47
48 AthInfer::InputDataMap inputData;
49 inputData["FEATURES"] = std::make_pair(
50 std::vector<int64_t>{numSpacepoints, spacepointFeatures}, std::move(inputValues));
51
52 AthInfer::OutputDataMap outputData;
53 outputData["LABELS"] = std::make_pair(std::vector<int64_t>{numSpacepoints, 1}, std::vector<int64_t>{});
54
55 ATH_CHECK(m_gnnTrackingTritonTool->inference(inputData, outputData));
56
57 auto& trackLabels = std::get<std::vector<int64_t>>(outputData["LABELS"].second);
58 if (trackLabels.size() == 0){
59 ATH_MSG_DEBUG("No tracks found in the event.");
60 return StatusCode::SUCCESS;
61 }
62
63 tracks.clear();
64 std::vector<uint32_t> this_track;
65 for (auto label : trackLabels) {
66 if (label == -1) {
67 if (this_track.size() > 0) {
68 tracks.push_back(this_track);
69 this_track.clear();
70 }
71 } else {
72 this_track.push_back(label);
73 }
74 }
75
76 return StatusCode::SUCCESS;
77}
78
79MsgStream& InDet::GNNTrackFinderTritonTool::dump( MsgStream& out ) const
80{
81 out<<std::endl;
82 return dumpevent(out);
83}
84
85std::ostream& InDet::GNNTrackFinderTritonTool::dump( std::ostream& out ) const
86{
87 return out;
88}
89
90MsgStream& InDet::GNNTrackFinderTritonTool::dumpevent( MsgStream& out ) const
91{
92 out<<"|---------------------------------------------------------------------|"
93 <<std::endl;
94 out<<"| Number output tracks | "<<std::setw(12)
95 <<" |"<<std::endl;
96 out<<"|---------------------------------------------------------------------|"
97 <<std::endl;
98 return out;
99}
#define ATH_CHECK
Evaluate an expression and check for errors.
#define ATH_MSG_DEBUG(x)
static Double_t sp
virtual StatusCode getTracks(const std::vector< const Trk::SpacePoint * > &spacepoints, std::vector< std::vector< uint32_t > > &tracks) const override
Get track candidates from a list of space points.
MsgStream & dumpevent(MsgStream &out) const
ToolHandle< AthInfer::IAthInferenceTool > m_gnnTrackingTritonTool
std::vector< std::string > m_featureNamesVec
virtual StatusCode initialize() override
ToolHandle< ISpacepointFeatureTool > m_spacepointFeatureTool
virtual MsgStream & dump(MsgStream &out) const override
std::string label(const std::string &format, int i)
Definition label.h:19
std::map< std::string, InferenceData > OutputDataMap
std::map< std::string, InferenceData > InputDataMap
std::vector< std::string > tokenize(const std::string &the_str, std::string_view delimiters)
Splits the string into smaller substrings.