ATLAS Offline Software
SiGNNTrackFinderTool.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #include "SiGNNTrackFinderTool.h"
6 #include "ExaTrkXUtils.hpp"
7 
8 // Framework include(s).
10 #include "AthOnnxUtils/OnnxUtils.h"
11 #include <cmath>
12 
14  const std::string& type, const std::string& name, const IInterface* parent):
15  base_class(type, name, parent)
16  {
17  declareInterface<IGNNTrackFinder>(this);
18  }
19 
21  ATH_CHECK( m_embedSessionTool.retrieve() );
22  ATH_CHECK( m_filterSessionTool.retrieve() );
23  ATH_CHECK( m_gnnSessionTool.retrieve() );
24  return StatusCode::SUCCESS;
25 }
26 
27 MsgStream& InDet::SiGNNTrackFinderTool::dump( MsgStream& out ) const
28 {
29  out<<std::endl;
30  return dumpevent(out);
31 }
32 
33 std::ostream& InDet::SiGNNTrackFinderTool::dump( std::ostream& out ) const
34 {
35  return out;
36 }
37 
38 MsgStream& InDet::SiGNNTrackFinderTool::dumpevent( MsgStream& out ) const
39 {
40  out<<"|---------------------------------------------------------------------|"
41  <<std::endl;
42  out<<"| Number output tracks | "<<std::setw(12)
43  <<" |"<<std::endl;
44  out<<"|---------------------------------------------------------------------|"
45  <<std::endl;
46  return out;
47 }
48 
50  const std::vector<const Trk::SpacePoint*>& spacepoints,
51  std::vector<std::vector<uint32_t> >& tracks) const
52 {
53  int64_t numSpacepoints = (int64_t)spacepoints.size();
54  std::vector<float> inputValues;
55  std::vector<uint32_t> spacepointIDs;
56 
57  int64_t spacepointFeatures = 3;
58  int sp_idx = 0;
59  for(const auto& sp: spacepoints){
60  // depending on the trained embedding and GNN models, the input features
61  // may need to be updated.
62 
63  float z = sp->globalPosition().z() / 1000.;
64  float r = sp->r() / 1000.;
65  float phi = sp->phi() / M_PI;
66  inputValues.push_back(r);
67  inputValues.push_back(phi);
68  inputValues.push_back(z);
69 
70 
71  spacepointIDs.push_back(sp_idx++);
72  }
73 
74  // ************
75  // Embedding
76  // ************
77 
78  std::vector<int64_t> eInputShape{numSpacepoints, spacepointFeatures};
79  std::vector<Ort::Value> eInputTensor;
80  ATH_CHECK( m_embedSessionTool->addInput(eInputTensor, inputValues, 0, numSpacepoints) );
81 
82  std::vector<Ort::Value> eOutputTensor;
83  std::vector<float> eOutputData;
84  ATH_CHECK( m_embedSessionTool->addOutput(eOutputTensor, eOutputData, 0, numSpacepoints) );
85 
86  ATH_CHECK( m_embedSessionTool->inference(eInputTensor, eOutputTensor) );
87 
88  // ************
89  // Building Edges
90  // ************
91  std::vector<int64_t> edgeList;
92  buildEdges(eOutputData, edgeList, numSpacepoints, m_embeddingDim, m_rVal, m_knnVal);
93  int64_t numEdges = edgeList.size() / 2;
94 
95  // ************
96  // Filtering
97  // ************
98  std::vector<Ort::Value> fInputTensor;
99  fInputTensor.push_back(
100  std::move(eInputTensor[0])
101  );
102  ATH_CHECK( m_filterSessionTool->addInput(fInputTensor, edgeList, 1, numEdges) );
103  std::vector<int64_t> fEdgeShape{2, numEdges};
104 
105  std::vector<float> fOutputData;
106  std::vector<Ort::Value> fOutputTensor;
107  ATH_CHECK( m_filterSessionTool->addOutput(fOutputTensor, fOutputData, 0, numEdges) );
108 
109  ATH_CHECK( m_filterSessionTool->inference(fInputTensor, fOutputTensor) );
110 
111  // apply sigmoid to the filtering output data
112  // and remove edges with score < filterCut
113  std::vector<int64_t> rowIndices;
114  std::vector<int64_t> colIndices;
115  for (int64_t i = 0; i < numEdges; i++){
116  float v = 1.f / (1.f + std::exp(-fOutputData[i])); // sigmoid, float type
117  if (v > m_filterCut){
118  rowIndices.push_back(edgeList[i]);
119  colIndices.push_back(edgeList[numEdges + i]);
120  };
121  };
122  std::vector<int64_t> edgesAfterFiltering;
123  edgesAfterFiltering.insert(edgesAfterFiltering.end(), rowIndices.begin(), rowIndices.end());
124  edgesAfterFiltering.insert(edgesAfterFiltering.end(), colIndices.begin(), colIndices.end());
125 
126  int64_t numEdgesAfterF = edgesAfterFiltering.size() / 2;
127 
128  // ************
129  // GNN
130  // ************
131  std::vector<Ort::Value> gInputTensor;
132  gInputTensor.push_back(
133  std::move(fInputTensor[0])
134  );
135  ATH_CHECK( m_gnnSessionTool->addInput(gInputTensor, edgesAfterFiltering, 1, numEdgesAfterF) );
136 
137  // gnn outputs
138  std::vector<float> gOutputData;
139  std::vector<Ort::Value> gOutputTensor;
140  ATH_CHECK( m_gnnSessionTool->addOutput(gOutputTensor, gOutputData, 0, numEdgesAfterF) );
141 
142  ATH_CHECK( m_gnnSessionTool->inference(gInputTensor, gOutputTensor) );
143  // apply sigmoid to the gnn output data
144  for(auto& v : gOutputData){
145  v = 1.f / (1.f + std::exp(-v));
146  };
147 
148  // ************
149  // Track Labeling with cugraph::connected_components
150  // ************
151  std::vector<int32_t> trackLabels(numSpacepoints);
152  weaklyConnectedComponents<int64_t,float,int32_t>(numSpacepoints, rowIndices, colIndices, gOutputData, trackLabels);
153 
154  if (trackLabels.size() == 0) return StatusCode::SUCCESS;
155 
156  tracks.clear();
157 
158  int existTrkIdx = 0;
159  // map labeling from MCC to customized track id.
160  std::map<int32_t, int32_t> trackLableToIds;
161 
162  for(int32_t idx=0; idx < numSpacepoints; ++idx) {
163  int32_t trackLabel = trackLabels[idx];
164  uint32_t spacepointID = spacepointIDs[idx];
165 
166  int trkId;
167  if(trackLableToIds.find(trackLabel) != trackLableToIds.end()) {
168  trkId = trackLableToIds[trackLabel];
169  tracks[trkId].push_back(spacepointID);
170  } else {
171  // a new track, assign the track id
172  // and create a vector
173  trkId = existTrkIdx;
174  tracks.push_back(std::vector<uint32_t>{spacepointID});
175  trackLableToIds[trackLabel] = trkId;
176  existTrkIdx++;
177  }
178  }
179  return StatusCode::SUCCESS;
180 }
181 
beamspotman.r
def r
Definition: beamspotman.py:676
phi
Scalar phi() const
phi method
Definition: AmgMatrixBasePlugin.h:64
xAOD::uint32_t
setEventNumber uint32_t
Definition: EventInfo_v1.cxx:127
InDet::SiGNNTrackFinderTool::dumpevent
MsgStream & dumpevent(MsgStream &out) const
Definition: SiGNNTrackFinderTool.cxx:38
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:71
M_PI
#define M_PI
Definition: ActiveFraction.h:11
InDet::SiGNNTrackFinderTool::m_gnnSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_gnnSessionTool
Definition: SiGNNTrackFinderTool.h:81
drawFromPickle.exp
exp
Definition: drawFromPickle.py:36
InDet::SiGNNTrackFinderTool::m_filterCut
FloatProperty m_filterCut
Definition: SiGNNTrackFinderTool.h:67
lumiFormat.i
int i
Definition: lumiFormat.py:92
z
#define z
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
InDet::SiGNNTrackFinderTool::m_embedSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_embedSessionTool
Definition: SiGNNTrackFinderTool.h:75
InDet::SiGNNTrackFinderTool::m_knnVal
UnsignedIntegerProperty m_knnVal
Definition: SiGNNTrackFinderTool.h:66
InDet::SiGNNTrackFinderTool::getTracks
virtual StatusCode getTracks(const std::vector< const Trk::SpacePoint * > &spacepoints, std::vector< std::vector< uint32_t > > &tracks) const override
Get track candidates from a list of space points.
Definition: SiGNNTrackFinderTool.cxx:49
buildEdges
void buildEdges(std::vector< float > &embedFeatures, std::vector< int64_t > &edgeList, int64_t numSpacepoints, int embeddingDim, float rVal, int kVal)
Definition: ExaTrkXUtils.cxx:10
test_pyathena.parent
parent
Definition: test_pyathena.py:15
InDet::SiGNNTrackFinderTool::m_filterSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_filterSessionTool
Definition: SiGNNTrackFinderTool.h:78
InDet::SiGNNTrackFinderTool::dump
virtual MsgStream & dump(MsgStream &out) const override
Definition: SiGNNTrackFinderTool.cxx:27
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
InDet::SiGNNTrackFinderTool::m_rVal
FloatProperty m_rVal
Definition: SiGNNTrackFinderTool.h:65
PathResolver.h
InDet::SiGNNTrackFinderTool::initialize
virtual StatusCode initialize() override
Definition: SiGNNTrackFinderTool.cxx:20
SiGNNTrackFinderTool.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:195
python.PyAthena.v
v
Definition: PyAthena.py:157
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool
SiGNNTrackFinderTool()=delete
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
OnnxUtils.h
InDet::SiGNNTrackFinderTool::m_embeddingDim
UnsignedIntegerProperty m_embeddingDim
Definition: SiGNNTrackFinderTool.h:64