ATLAS Offline Software
GNNTrackFinderTritonTool.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2025 CERN for the benefit of the ATLAS collaboration
3 */
4 
6 
7 #include "ExaTrkXUtils.hpp"
8 #include "CxxUtils/StringUtils.h"
9 
10 // Framework include(s).
11 #include <cmath>
12 
14 
18 
19  // tokenize the feature names by comma and push to the vector
21  return StatusCode::SUCCESS;
22 }
23 
25  const std::vector<const Trk::SpacePoint*>& spacepoints,
26  std::vector<std::vector<uint32_t> >& tracks) const {
27  int64_t numSpacepoints = (int64_t)spacepoints.size();
28  std::vector<float> inputValues;
29  std::vector<uint32_t> spacepointIDs;
30 
31  int64_t spacepointFeatures = m_featureNamesVec.size();
32  int sp_idx = 0;
33  for (const auto& sp : spacepoints) {
34  // depending on the trained embedding and GNN models, the input features
35  // may need to be updated.
36  auto featureMap = m_spacepointFeatureTool->getFeatures(sp);
37  for (int i = 0; i < spacepointFeatures; i++){
38  inputValues.push_back(featureMap[m_featureNamesVec[i]]);
39  }
40 
41  spacepointIDs.push_back(sp_idx++);
42  }
43 
44  AthInfer::InputDataMap inputData;
45  inputData["FEATURES"] = std::make_pair(
46  std::vector<int64_t>{numSpacepoints, spacepointFeatures}, std::move(inputValues));
47 
48  AthInfer::OutputDataMap outputData;
49  outputData["LABELS"] = std::make_pair(std::vector<int64_t>{numSpacepoints, 1}, std::vector<int64_t>{});
50 
51  ATH_CHECK(m_gnnTrackingTritonTool->inference(inputData, outputData));
52 
53  auto& trackLabels = std::get<std::vector<int64_t>>(outputData["LABELS"].second);
54  if (trackLabels.size() == 0){
55  ATH_MSG_DEBUG("No tracks found in the event.");
56  return StatusCode::SUCCESS;
57  }
58 
59  tracks.clear();
60  std::vector<uint32_t> this_track;
61  for (auto label : trackLabels) {
62  if (label == -1) {
63  if (this_track.size() > 0) {
64  tracks.push_back(this_track);
65  this_track.clear();
66  }
67  } else {
68  this_track.push_back(label);
69  }
70  }
71 
72  return StatusCode::SUCCESS;
73 }
74 
75 MsgStream& InDet::GNNTrackFinderTritonTool::dump( MsgStream& out ) const
76 {
77  out<<std::endl;
78  return dumpevent(out);
79 }
80 
81 std::ostream& InDet::GNNTrackFinderTritonTool::dump( std::ostream& out ) const
82 {
83  return out;
84 }
85 
86 MsgStream& InDet::GNNTrackFinderTritonTool::dumpevent( MsgStream& out ) const
87 {
88  out<<"|---------------------------------------------------------------------|"
89  <<std::endl;
90  out<<"| Number output tracks | "<<std::setw(12)
91  <<" |"<<std::endl;
92  out<<"|---------------------------------------------------------------------|"
93  <<std::endl;
94  return out;
95 }
InDet::GNNTrackFinderTritonTool::initialize
virtual StatusCode initialize() override
Definition: GNNTrackFinderTritonTool.cxx:15
InDet::GNNTrackFinderTritonTool::m_spacepointFeatureTool
ToolHandle< ISpacepointFeatureTool > m_spacepointFeatureTool
Definition: GNNTrackFinderTritonTool.h:56
InDet::GNNTrackFinderTritonTool::m_featureNames
StringProperty m_featureNames
Definition: GNNTrackFinderTritonTool.h:59
CxxUtils::tokenize
std::vector< std::string > tokenize(const std::string &the_str, std::string_view delimiters)
Splits the string into smaller substrings.
Definition: Control/CxxUtils/Root/StringUtils.cxx:15
StringUtils.h
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:70
InDet::GNNTrackFinderTritonTool::m_gnnTrackingTritonTool
ToolHandle< AthInfer::IAthInferenceTool > m_gnnTrackingTritonTool
Definition: GNNTrackFinderTritonTool.h:54
InDet::GNNTrackFinderTritonTool::dump
virtual MsgStream & dump(MsgStream &out) const override
Definition: GNNTrackFinderTritonTool.cxx:75
GNNTrackFinderTritonTool.h
lumiFormat.i
int i
Definition: lumiFormat.py:85
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
ATH_MSG_DEBUG
#define ATH_MSG_DEBUG(x)
Definition: AthMsgStreamMacros.h:29
AthInfer::OutputDataMap
std::map< std::string, InferenceData > OutputDataMap
Definition: IAthInferenceTool.h:17
add-xsec-uncert-quadrature-N.label
label
Definition: add-xsec-uncert-quadrature-N.py:104
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
InDet::GNNTrackFinderTritonTool::dumpevent
MsgStream & dumpevent(MsgStream &out) const
Definition: GNNTrackFinderTritonTool.cxx:86
PathResolver.h
InDet::GNNTrackFinderTritonTool::getTracks
virtual StatusCode getTracks(const std::vector< const Trk::SpacePoint * > &spacepoints, std::vector< std::vector< uint32_t > > &tracks) const override
Get track candidates from a list of space points.
Definition: GNNTrackFinderTritonTool.cxx:24
AthInfer::InputDataMap
std::map< std::string, InferenceData > InputDataMap
Definition: IAthInferenceTool.h:16
InDet::GNNTrackFinderTritonTool::m_featureNamesVec
std::vector< std::string > m_featureNamesVec
Definition: GNNTrackFinderTritonTool.h:63