ATLAS Offline Software
SiGNNTrackFinderTool.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #include "SiGNNTrackFinderTool.h"
6 #include "ExaTrkXUtils.hpp"
7 
8 // Framework include(s).
10 #include "AthOnnxUtils/OnnxUtils.h"
11 
12 #include <cmath>
13 #include <random> // std::random_device, std::mt19937, std::random_shuffle
14 
16  const std::string& type, const std::string& name, const IInterface* parent):
17  base_class(type, name, parent)
18  {
19  declareInterface<IGNNTrackFinder>(this);
20  }
21 
23  ATH_CHECK( m_embedSessionTool.retrieve() );
24  m_embedSessionTool->printModelInfo();
25 
26  ATH_CHECK( m_filterSessionTool.retrieve() );
27  m_filterSessionTool->printModelInfo();
28 
29  ATH_CHECK( m_gnnSessionTool.retrieve() );
30  m_gnnSessionTool->printModelInfo();
31 
32  // tokenize the feature names by comma and push to the vector
33  auto split_fn = [](const std::string& s, auto convert_fn) {
34  using ReturnType = std::decay_t<decltype(convert_fn(std::declval<std::string>()))>;
35  std::vector<ReturnType> tokens;
36  std::string token;
37  std::istringstream tokenStream(s);
38  while (std::getline(tokenStream, token, ',')) {
39  token = token.substr(token.find_first_not_of(" "), token.find_last_not_of(" ") + 1);
40  tokens.push_back(convert_fn(token));
41  }
42  return tokens;
43  };
44  auto convert_to_float = [](const std::string& s) -> float { return std::stof(s); };
45  auto convert_to_str = [](const std::string& s) -> std::string { return s; };
46 
47  m_embeddingFeatureNamesVec = split_fn(m_embeddingFeatureNames, convert_to_str);
48  m_embeddingFeatureScalesVec = split_fn(m_embeddingFeatureScales, convert_to_float);
50 
51  m_filterFeatureNamesVec = split_fn(m_filterFeatureNames, convert_to_str);
52  m_filterFeatureScalesVec = split_fn(m_filterFeatureScales, convert_to_float);
53  assert(m_filterFeatureNamesVec.size() == m_filterFeatureScalesVec.size());
54 
55  m_gnnFeatureNamesVec = split_fn(m_gnnFeatureNames, convert_to_str);
56  m_gnnFeatureScalesVec = split_fn(m_gnnFeatureScales, convert_to_float);
57  assert(m_gnnFeatureNamesVec.size() == m_gnnFeatureScalesVec.size());
58 
59  return StatusCode::SUCCESS;
60 }
61 
62 MsgStream& InDet::SiGNNTrackFinderTool::dump( MsgStream& out ) const
63 {
64  out<<std::endl;
65  return dumpevent(out);
66 }
67 
68 std::ostream& InDet::SiGNNTrackFinderTool::dump( std::ostream& out ) const
69 {
70  return out;
71 }
72 
73 MsgStream& InDet::SiGNNTrackFinderTool::dumpevent( MsgStream& out ) const
74 {
75  out<<"|---------------------------------------------------------------------|"
76  <<std::endl;
77  out<<"| Number output tracks | "<<std::setw(12)
78  <<" |"<<std::endl;
79  out<<"|---------------------------------------------------------------------|"
80  <<std::endl;
81  return out;
82 }
83 
85  const std::vector<const Trk::SpacePoint*>& spacepoints,
86  std::vector<std::vector<uint32_t> >& tracks) const
87 {
88  int64_t numSpacepoints = (int64_t)spacepoints.size();
89  std::vector<float> eNodeFeatures;
90  std::vector<float> fNodeFeatures;
91  std::vector<float> gNodeFeatures;
92  std::vector<uint32_t> spacepointIDs;
93  std::vector<int> regions;
94 
95  int sp_idx = 0;
96  for(const auto& sp: spacepoints){
97  auto featureMap = m_spacepointFeatureTool->getFeatures(sp);
98  regions.push_back(featureMap["region"]);
99  // fill embedding node features.
100  for(size_t i = 0; i < m_embeddingFeatureNamesVec.size(); i++){
101  eNodeFeatures.push_back(
103  }
104 
105  // fill filtering node features.
106  for(size_t i = 0; i < m_filterFeatureNamesVec.size(); i++){
107  fNodeFeatures.push_back(
109  }
110 
111  // fill gnn node features.
112  for(size_t i = 0; i < m_gnnFeatureNamesVec.size(); i++){
113  gNodeFeatures.push_back(
115  }
116 
117  spacepointIDs.push_back(sp_idx++);
118  }
119  // ************
120  // Embedding
121  // ************
122  std::vector<int64_t> eInputShape{numSpacepoints, (long int) m_embeddingFeatureNamesVec.size()};
123  std::vector<Ort::Value> eInputTensor;
124  ATH_CHECK( m_embedSessionTool->addInput(eInputTensor, eNodeFeatures, 0, numSpacepoints) );
125 
126  std::vector<Ort::Value> eOutputTensor;
127  std::vector<float> eOutputData;
128  ATH_CHECK( m_embedSessionTool->addOutput(eOutputTensor, eOutputData, 0, numSpacepoints) );
129 
130  ATH_CHECK( m_embedSessionTool->inference(eInputTensor, eOutputTensor) );
131 
132  // ************
133  // Building Edges
134  // ************
135  std::vector<int64_t> senders;
136  std::vector<int64_t> receivers;
137  ExaTrkXUtils::buildEdges(eOutputData, senders, receivers, numSpacepoints, m_embeddingDim, m_rVal, m_knnVal);
138  int64_t numEdges = senders.size();
139 
140  // clean up embedding data.
141  eNodeFeatures.clear();
142  eInputTensor.clear();
143  eOutputData.clear();
144  eOutputTensor.clear();
145 
146  // sort the edge list and remove duplicate edges.
147  std::vector<std::pair<int64_t, int64_t>> edgePairs;
148  for(int64_t idx = 0; idx < numEdges; idx ++ ) {
149  edgePairs.push_back({senders[idx], receivers[idx]});
150  }
151  std::sort(edgePairs.begin(), edgePairs.end());
152  edgePairs.erase(std::unique(edgePairs.begin(), edgePairs.end()), edgePairs.end());
153 
154  // random shuffle the edge list.
155  std::random_device rd;
156  std::mt19937 rdm_gen(rd());
157  std::random_shuffle(edgePairs.begin(), edgePairs.end());
158 
159  // sort the edge list by the sender * numSpacepoints + receiver.
160  std::sort(edgePairs.begin(), edgePairs.end(),
161  [numSpacepoints](const std::pair<int64_t, int64_t>& a, const std::pair<int64_t, int64_t>& b){
162  return a.first * numSpacepoints + a.second < b.first * numSpacepoints + b.second;
163  });
164 
165  // convert the edge list to senders and receivers.
166  senders.clear();
167  receivers.clear();
168  for(const auto& edge: edgePairs){
169  senders.push_back(edge.first);
170  receivers.push_back(edge.second);
171  }
172 
173  edgePairs.clear();
174 
175  // ************
176  // Filtering
177  // ************
178  std::vector<Ort::Value> fInputTensor;
179  ATH_CHECK( m_filterSessionTool->addInput(fInputTensor, fNodeFeatures, 0, numSpacepoints) );
180 
181  std::vector<int64_t> edgeList(numEdges * 2);
182  std::copy(senders.begin(), senders.end(), edgeList.begin());
183  std::copy(receivers.begin(), receivers.end(), edgeList.begin() + senders.size());
184 
185 
186  ATH_CHECK( m_filterSessionTool->addInput(fInputTensor, edgeList, 1, numEdges) );
187 
188  std::vector<float> fOutputData;
189  std::vector<Ort::Value> fOutputTensor;
190  ATH_CHECK( m_filterSessionTool->addOutput(fOutputTensor, fOutputData, 0, numEdges) );
191 
192  ATH_CHECK( m_filterSessionTool->inference(fInputTensor, fOutputTensor) );
193 
194  // apply sigmoid to the filtering output data
195  // and remove edges with score < filterCut
196  // and sort the edge list so that sender idx < receiver.
197  std::vector<int64_t> rowIndices;
198  std::vector<int64_t> colIndices;
199  for (int64_t i = 0; i < numEdges; i++){
200  float v = 1.f / (1.f + std::exp(-fOutputData[i])); // sigmoid, float type
201  if (v >= m_filterCut){
202  auto src = edgeList[i];
203  auto dst = edgeList[numEdges + i];
204  if (src > dst) {
205  std::swap(src, dst);
206  }
207  rowIndices.push_back(src);
208  colIndices.push_back(dst);
209  };
210  };
211  int64_t numEdgesAfterF = rowIndices.size();
212 
213  // clean up filtering data.
214  fNodeFeatures.clear();
215  fInputTensor.clear();
216  fOutputData.clear();
217  fOutputTensor.clear();
218  // clean up sender and receiver list.
219  senders.clear();
220  receivers.clear();
221 
222  std::vector<int64_t> edgesAfterFiltering(numEdgesAfterF * 2);
223  std::copy(rowIndices.begin(), rowIndices.end(), edgesAfterFiltering.begin());
224  std::copy(colIndices.begin(), colIndices.end(), edgesAfterFiltering.begin() + senders.size());
225 
226  // ************
227  // GNN
228  // ************
229 
230  // use the same features for regions (2, 6)
231  for(size_t idx = 0; idx < static_cast<size_t>(numSpacepoints); idx++){
232  if (regions[idx] == 2 || regions[idx] == 6){
233  for(size_t i = 4; i < m_gnnFeatureNamesVec.size(); i++){
234  gNodeFeatures[idx * m_gnnFeatureNamesVec.size() + i] = gNodeFeatures[idx * m_gnnFeatureNamesVec.size() + i % 4];
235  }
236  }
237  }
238 
239  std::vector<Ort::Value> gInputTensor;
240  ATH_CHECK( m_gnnSessionTool->addInput(gInputTensor, gNodeFeatures, 0, numSpacepoints) );
241  ATH_CHECK( m_gnnSessionTool->addInput(gInputTensor, edgesAfterFiltering, 1, numEdgesAfterF) );
242 
243  // calculate the edge features.
244  std::vector<float> gnnEdgeFeatures;
245  ExaTrkXUtils::calculateEdgeFeatures(gNodeFeatures, numSpacepoints, rowIndices, colIndices, gnnEdgeFeatures);
246  ATH_CHECK( m_gnnSessionTool->addInput(gInputTensor, gnnEdgeFeatures, 2, numEdgesAfterF) );
247 
248  // gnn outputs
249  std::vector<float> gOutputData;
250  std::vector<Ort::Value> gOutputTensor;
251  ATH_CHECK( m_gnnSessionTool->addOutput(gOutputTensor, gOutputData, 0, numEdgesAfterF) );
252 
253  ATH_CHECK( m_gnnSessionTool->inference(gInputTensor, gOutputTensor) );
254  // apply sigmoid to the gnn output data
255  for(auto& v : gOutputData){
256  v = 1.f / (1.f + std::exp(-v));
257  };
258 
259  // clean up GNN data.
260  gNodeFeatures.clear();
261  gInputTensor.clear();
262  edgesAfterFiltering.clear();
263 
264  // ************
265  // Track Labeling with cugraph::connected_components
266  // ************
267  tracks.clear();
269  numSpacepoints,
270  rowIndices, colIndices, gOutputData,
271  tracks, m_ccCut, m_walkMin, m_walkMax
272  );
273 
274  return StatusCode::SUCCESS;
275 }
InDet::SiGNNTrackFinderTool::m_embeddingFeatureNamesVec
std::vector< std::string > m_embeddingFeatureNamesVec
Definition: SiGNNTrackFinderTool.h:117
python.SystemOfUnits.s
int s
Definition: SystemOfUnits.py:131
InDet::SiGNNTrackFinderTool::m_gnnFeatureScalesVec
std::vector< float > m_gnnFeatureScalesVec
Definition: SiGNNTrackFinderTool.h:122
ExaTrkXUtils::CCandWalk
void CCandWalk(vertex_t numSpacepoints, const std::vector< int64_t > &rowIndices, const std::vector< int64_t > &colIndices, const std::vector< weight_t > &edgeWeights, std::vector< std::vector< uint32_t > > &tracks, float ccCut, float walkMin, float walkMax)
Definition: ExaTrkXUtils.cxx:73
CaloCellPos2Ntuple.int
int
Definition: CaloCellPos2Ntuple.py:24
ExaTrkXUtils::calculateEdgeFeatures
void calculateEdgeFeatures(const std::vector< float > &gNodeFeatures, int64_t numSpacepoints, const std::vector< int64_t > &rowIndices, const std::vector< int64_t > &colIndices, std::vector< float > &edgeFeatures)
Definition: ExaTrkXUtils.cxx:322
WriteCellNoiseToCool.src
src
Definition: WriteCellNoiseToCool.py:513
InDet::SiGNNTrackFinderTool::dumpevent
MsgStream & dumpevent(MsgStream &out) const
Definition: SiGNNTrackFinderTool.cxx:73
python.AthDsoLogger.out
out
Definition: AthDsoLogger.py:71
InDet::SiGNNTrackFinderTool::m_filterFeatureScalesVec
std::vector< float > m_filterFeatureScalesVec
Definition: SiGNNTrackFinderTool.h:120
beamspotman.tokens
tokens
Definition: beamspotman.py:1284
InDet::SiGNNTrackFinderTool::m_gnnSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_gnnSessionTool
Definition: SiGNNTrackFinderTool.h:111
drawFromPickle.exp
exp
Definition: drawFromPickle.py:36
InDet::SiGNNTrackFinderTool::m_gnnFeatureNames
StringProperty m_gnnFeatureNames
Definition: SiGNNTrackFinderTool.h:92
InDet::SiGNNTrackFinderTool::m_walkMin
FloatProperty m_walkMin
Definition: SiGNNTrackFinderTool.h:71
InDet::SiGNNTrackFinderTool::m_filterCut
FloatProperty m_filterCut
Definition: SiGNNTrackFinderTool.h:69
InDet::SiGNNTrackFinderTool::m_embeddingFeatureNames
StringProperty m_embeddingFeatureNames
Definition: SiGNNTrackFinderTool.h:74
lumiFormat.i
int i
Definition: lumiFormat.py:85
InDet::SiGNNTrackFinderTool::m_gnnFeatureScales
StringProperty m_gnnFeatureScales
Definition: SiGNNTrackFinderTool.h:96
EL::StatusCode
::StatusCode StatusCode
StatusCode definition for legacy code.
Definition: PhysicsAnalysis/D3PDTools/EventLoop/EventLoop/StatusCode.h:22
InDet::SiGNNTrackFinderTool::m_embedSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_embedSessionTool
Definition: SiGNNTrackFinderTool.h:105
InDet::SiGNNTrackFinderTool::m_knnVal
UnsignedIntegerProperty m_knnVal
Definition: SiGNNTrackFinderTool.h:68
InDet::SiGNNTrackFinderTool::getTracks
virtual StatusCode getTracks(const std::vector< const Trk::SpacePoint * > &spacepoints, std::vector< std::vector< uint32_t > > &tracks) const override
Get track candidates from a list of space points.
Definition: SiGNNTrackFinderTool.cxx:84
test_pyathena.parent
parent
Definition: test_pyathena.py:15
InDet::SiGNNTrackFinderTool::m_filterSessionTool
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_filterSessionTool
Definition: SiGNNTrackFinderTool.h:108
InDet::SiGNNTrackFinderTool::dump
virtual MsgStream & dump(MsgStream &out) const override
Definition: SiGNNTrackFinderTool.cxx:62
ATH_CHECK
#define ATH_CHECK
Definition: AthCheckMacros.h:40
InDet::SiGNNTrackFinderTool::m_walkMax
FloatProperty m_walkMax
Definition: SiGNNTrackFinderTool.h:72
WriteCalibToCool.swap
swap
Definition: WriteCalibToCool.py:94
InDet::SiGNNTrackFinderTool::m_rVal
FloatProperty m_rVal
Definition: SiGNNTrackFinderTool.h:67
InDet::SiGNNTrackFinderTool::m_embeddingFeatureScales
StringProperty m_embeddingFeatureScales
Definition: SiGNNTrackFinderTool.h:78
InDet::SiGNNTrackFinderTool::m_filterFeatureNames
StringProperty m_filterFeatureNames
Definition: SiGNNTrackFinderTool.h:83
InDet::SiGNNTrackFinderTool::m_ccCut
FloatProperty m_ccCut
Definition: SiGNNTrackFinderTool.h:70
PathResolver.h
InDet::SiGNNTrackFinderTool::initialize
virtual StatusCode initialize() override
Definition: SiGNNTrackFinderTool.cxx:22
SiGNNTrackFinderTool.h
name
std::string name
Definition: Control/AthContainers/Root/debug.cxx:228
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:77
InDet::SiGNNTrackFinderTool::m_filterFeatureNamesVec
std::vector< std::string > m_filterFeatureNamesVec
Definition: SiGNNTrackFinderTool.h:119
InDet::SiGNNTrackFinderTool::m_filterFeatureScales
StringProperty m_filterFeatureScales
Definition: SiGNNTrackFinderTool.h:87
python.PyAthena.v
v
Definition: PyAthena.py:154
a
TList * a
Definition: liststreamerinfos.cxx:10
python.CaloScaleNoiseConfig.type
type
Definition: CaloScaleNoiseConfig.py:78
InDet::SiGNNTrackFinderTool::SiGNNTrackFinderTool
SiGNNTrackFinderTool()=delete
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
OnnxUtils.h
calibdata.copy
bool copy
Definition: calibdata.py:27
InDet::SiGNNTrackFinderTool::m_embeddingDim
UnsignedIntegerProperty m_embeddingDim
Definition: SiGNNTrackFinderTool.h:66
InDet::SiGNNTrackFinderTool::m_embeddingFeatureScalesVec
std::vector< float > m_embeddingFeatureScalesVec
Definition: SiGNNTrackFinderTool.h:118
ExaTrkXUtils::buildEdges
void buildEdges(const std::vector< float > &embedFeatures, std::vector< int64_t > &senders, std::vector< int64_t > &receivers, int64_t numSpacepoints, int embeddingDim, float rVal, int kVal)
Definition: ExaTrkXUtils.cxx:18
InDet::SiGNNTrackFinderTool::m_spacepointFeatureTool
ToolHandle< ISpacepointFeatureTool > m_spacepointFeatureTool
Definition: SiGNNTrackFinderTool.h:114
InDet::SiGNNTrackFinderTool::m_gnnFeatureNamesVec
std::vector< std::string > m_gnnFeatureNamesVec
Definition: SiGNNTrackFinderTool.h:121