ATLAS Offline Software
Loading...
Searching...
No Matches
SiGNNTrackFinderTool.cxx
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3*/
4
6#include "ExaTrkXUtils.hpp"
7
8// Framework include(s).
11
12#include <cmath>
13#include <random> // std::random_device, std::mt19937, std::random_shuffle
14
16 const std::string& type, const std::string& name, const IInterface* parent):
17 base_class(type, name, parent)
18 {
19 declareInterface<IGNNTrackFinder>(this);
20 }
21
23 ATH_CHECK( m_embedSessionTool.retrieve() );
24 m_embedSessionTool->printModelInfo();
25
26 ATH_CHECK( m_filterSessionTool.retrieve() );
27 m_filterSessionTool->printModelInfo();
28
29 ATH_CHECK( m_gnnSessionTool.retrieve() );
30 m_gnnSessionTool->printModelInfo();
31
32 // tokenize the feature names by comma and push to the vector
33 auto split_fn = [](const std::string& s, auto convert_fn) {
34 using ReturnType = std::decay_t<decltype(convert_fn(std::declval<std::string>()))>;
35 std::vector<ReturnType> tokens;
36 std::string token;
37 std::istringstream tokenStream(s);
38 while (std::getline(tokenStream, token, ',')) {
39 token = token.substr(token.find_first_not_of(" "), token.find_last_not_of(" ") + 1);
40 tokens.push_back(convert_fn(token));
41 }
42 return tokens;
43 };
44 auto convert_to_float = [](const std::string& s) -> float { return std::stof(s); };
45 auto convert_to_str = [](const std::string& s) -> std::string { return s; };
46
47 m_embeddingFeatureNamesVec = split_fn(m_embeddingFeatureNames, convert_to_str);
48 m_embeddingFeatureScalesVec = split_fn(m_embeddingFeatureScales, convert_to_float);
50
51 m_filterFeatureNamesVec = split_fn(m_filterFeatureNames, convert_to_str);
52 m_filterFeatureScalesVec = split_fn(m_filterFeatureScales, convert_to_float);
54
55 m_gnnFeatureNamesVec = split_fn(m_gnnFeatureNames, convert_to_str);
56 m_gnnFeatureScalesVec = split_fn(m_gnnFeatureScales, convert_to_float);
57 assert(m_gnnFeatureNamesVec.size() == m_gnnFeatureScalesVec.size());
58
59 return StatusCode::SUCCESS;
60}
61
62MsgStream& InDet::SiGNNTrackFinderTool::dump( MsgStream& out ) const
63{
64 out<<std::endl;
65 return dumpevent(out);
66}
67
68std::ostream& InDet::SiGNNTrackFinderTool::dump( std::ostream& out ) const
69{
70 return out;
71}
72
73MsgStream& InDet::SiGNNTrackFinderTool::dumpevent( MsgStream& out ) const
74{
75 out<<"|---------------------------------------------------------------------|"
76 <<std::endl;
77 out<<"| Number output tracks | "<<std::setw(12)
78 <<" |"<<std::endl;
79 out<<"|---------------------------------------------------------------------|"
80 <<std::endl;
81 return out;
82}
83
85 const std::vector<const Trk::SpacePoint*>& spacepoints,
86 std::vector<std::vector<uint32_t> >& tracks) const
87{
88 int64_t numSpacepoints = (int64_t)spacepoints.size();
89 std::vector<float> eNodeFeatures;
90 std::vector<float> fNodeFeatures;
91 std::vector<float> gNodeFeatures;
92 std::vector<uint32_t> spacepointIDs;
93 std::vector<int> regions;
94
95 int sp_idx = 0;
96 for(const auto& sp: spacepoints){
97 auto featureMap = m_spacepointFeatureTool->getFeatures(sp);
98 regions.push_back(featureMap["region"]);
99 // fill embedding node features.
100 for(size_t i = 0; i < m_embeddingFeatureNamesVec.size(); i++){
101 eNodeFeatures.push_back(
103 }
104
105 // fill filtering node features.
106 for(size_t i = 0; i < m_filterFeatureNamesVec.size(); i++){
107 fNodeFeatures.push_back(
109 }
110
111 // fill gnn node features.
112 for(size_t i = 0; i < m_gnnFeatureNamesVec.size(); i++){
113 gNodeFeatures.push_back(
114 featureMap[m_gnnFeatureNamesVec[i]] / m_gnnFeatureScalesVec[i]);
115 }
116
117 spacepointIDs.push_back(sp_idx++);
118 }
119 // ************
120 // Embedding
121 // ************
122 std::vector<int64_t> eInputShape{numSpacepoints, (long int) m_embeddingFeatureNamesVec.size()};
123 std::vector<Ort::Value> eInputTensor;
124 ATH_CHECK( m_embedSessionTool->addInput(eInputTensor, eNodeFeatures, 0, numSpacepoints) );
125
126 std::vector<Ort::Value> eOutputTensor;
127 std::vector<float> eOutputData;
128 ATH_CHECK( m_embedSessionTool->addOutput(eOutputTensor, eOutputData, 0, numSpacepoints) );
129
130 ATH_CHECK( m_embedSessionTool->inference(eInputTensor, eOutputTensor) );
131
132 // ************
133 // Building Edges
134 // ************
135 std::vector<int64_t> senders;
136 std::vector<int64_t> receivers;
137 ExaTrkXUtils::buildEdges(eOutputData, senders, receivers, numSpacepoints, m_embeddingDim, m_rVal, m_knnVal);
138 int64_t numEdges = senders.size();
139
140 // clean up embedding data.
141 eNodeFeatures.clear();
142 eInputTensor.clear();
143 eOutputData.clear();
144 eOutputTensor.clear();
145
146 // sort the edge list and remove duplicate edges.
147 std::vector<std::pair<int64_t, int64_t>> edgePairs;
148 for(int64_t idx = 0; idx < numEdges; idx ++ ) {
149 edgePairs.push_back({senders[idx], receivers[idx]});
150 }
151 std::sort(edgePairs.begin(), edgePairs.end());
152 edgePairs.erase(std::unique(edgePairs.begin(), edgePairs.end()), edgePairs.end());
153
154 // random shuffle the edge list.
155 std::random_device rd;
156 std::mt19937 rdm_gen(rd());
157 std::random_shuffle(edgePairs.begin(), edgePairs.end());
158
159 // sort the edge list by the sender * numSpacepoints + receiver.
160 std::sort(edgePairs.begin(), edgePairs.end(),
161 [numSpacepoints](const std::pair<int64_t, int64_t>& a, const std::pair<int64_t, int64_t>& b){
162 return a.first * numSpacepoints + a.second < b.first * numSpacepoints + b.second;
163 });
164
165 // convert the edge list to senders and receivers.
166 senders.clear();
167 receivers.clear();
168 for(const auto& edge: edgePairs){
169 senders.push_back(edge.first);
170 receivers.push_back(edge.second);
171 }
172
173 edgePairs.clear();
174
175 // ************
176 // Filtering
177 // ************
178 std::vector<Ort::Value> fInputTensor;
179 ATH_CHECK( m_filterSessionTool->addInput(fInputTensor, fNodeFeatures, 0, numSpacepoints) );
180
181 std::vector<int64_t> edgeList(numEdges * 2);
182 std::copy(senders.begin(), senders.end(), edgeList.begin());
183 std::copy(receivers.begin(), receivers.end(), edgeList.begin() + senders.size());
184
185
186 ATH_CHECK( m_filterSessionTool->addInput(fInputTensor, edgeList, 1, numEdges) );
187
188 std::vector<float> fOutputData;
189 std::vector<Ort::Value> fOutputTensor;
190 ATH_CHECK( m_filterSessionTool->addOutput(fOutputTensor, fOutputData, 0, numEdges) );
191
192 ATH_CHECK( m_filterSessionTool->inference(fInputTensor, fOutputTensor) );
193
194 // apply sigmoid to the filtering output data
195 // and remove edges with score < filterCut
196 // and sort the edge list so that sender idx < receiver.
197 std::vector<int64_t> rowIndices;
198 std::vector<int64_t> colIndices;
199 for (int64_t i = 0; i < numEdges; i++){
200 float v = 1.f / (1.f + std::exp(-fOutputData[i])); // sigmoid, float type
201 if (v >= m_filterCut){
202 auto src = edgeList[i];
203 auto dst = edgeList[numEdges + i];
204 if (src > dst) {
205 std::swap(src, dst);
206 }
207 rowIndices.push_back(src);
208 colIndices.push_back(dst);
209 };
210 };
211 int64_t numEdgesAfterF = rowIndices.size();
212
213 // clean up filtering data.
214 fNodeFeatures.clear();
215 fInputTensor.clear();
216 fOutputData.clear();
217 fOutputTensor.clear();
218 // clean up sender and receiver list.
219 senders.clear();
220 receivers.clear();
221
222 std::vector<int64_t> edgesAfterFiltering(numEdgesAfterF * 2);
223 std::copy(rowIndices.begin(), rowIndices.end(), edgesAfterFiltering.begin());
224 std::copy(colIndices.begin(), colIndices.end(), edgesAfterFiltering.begin() + senders.size());
225
226 // ************
227 // GNN
228 // ************
229
230 // use the same features for regions (2, 6)
231 for(size_t idx = 0; idx < static_cast<size_t>(numSpacepoints); idx++){
232 if (regions[idx] == 2 || regions[idx] == 6){
233 for(size_t i = 4; i < m_gnnFeatureNamesVec.size(); i++){
234 gNodeFeatures[idx * m_gnnFeatureNamesVec.size() + i] = gNodeFeatures[idx * m_gnnFeatureNamesVec.size() + i % 4];
235 }
236 }
237 }
238
239 std::vector<Ort::Value> gInputTensor;
240 ATH_CHECK( m_gnnSessionTool->addInput(gInputTensor, gNodeFeatures, 0, numSpacepoints) );
241 ATH_CHECK( m_gnnSessionTool->addInput(gInputTensor, edgesAfterFiltering, 1, numEdgesAfterF) );
242
243 // calculate the edge features.
244 std::vector<float> gnnEdgeFeatures;
245 ExaTrkXUtils::calculateEdgeFeatures(gNodeFeatures, numSpacepoints, rowIndices, colIndices, gnnEdgeFeatures);
246 ATH_CHECK( m_gnnSessionTool->addInput(gInputTensor, gnnEdgeFeatures, 2, numEdgesAfterF) );
247
248 // gnn outputs
249 std::vector<float> gOutputData;
250 std::vector<Ort::Value> gOutputTensor;
251 ATH_CHECK( m_gnnSessionTool->addOutput(gOutputTensor, gOutputData, 0, numEdgesAfterF) );
252
253 ATH_CHECK( m_gnnSessionTool->inference(gInputTensor, gOutputTensor) );
254 // apply sigmoid to the gnn output data
255 for(auto& v : gOutputData){
256 v = 1.f / (1.f + std::exp(-v));
257 };
258
259 // clean up GNN data.
260 gNodeFeatures.clear();
261 gInputTensor.clear();
262 edgesAfterFiltering.clear();
263
264 // ************
265 // Track Labeling with cugraph::connected_components
266 // ************
267 tracks.clear();
269 numSpacepoints,
270 rowIndices, colIndices, gOutputData,
271 tracks, m_ccCut, m_walkMin, m_walkMax
272 );
273
274 return StatusCode::SUCCESS;
275}
#define ATH_CHECK
Evaluate an expression and check for errors.
static Double_t sp
static Double_t a
static const std::vector< std::string > regions
virtual StatusCode getTracks(const std::vector< const Trk::SpacePoint * > &spacepoints, std::vector< std::vector< uint32_t > > &tracks) const override
Get track candidates from a list of space points.
std::vector< float > m_gnnFeatureScalesVec
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_filterSessionTool
ToolHandle< ISpacepointFeatureTool > m_spacepointFeatureTool
virtual StatusCode initialize() override
UnsignedIntegerProperty m_knnVal
UnsignedIntegerProperty m_embeddingDim
std::vector< float > m_filterFeatureScalesVec
std::vector< std::string > m_embeddingFeatureNamesVec
MsgStream & dumpevent(MsgStream &out) const
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_embedSessionTool
std::vector< std::string > m_gnnFeatureNamesVec
std::vector< float > m_embeddingFeatureScalesVec
ToolHandle< AthOnnx::IOnnxRuntimeInferenceTool > m_gnnSessionTool
std::vector< std::string > m_filterFeatureNamesVec
virtual MsgStream & dump(MsgStream &out) const override
void CCandWalk(vertex_t numSpacepoints, const std::vector< int64_t > &rowIndices, const std::vector< int64_t > &colIndices, const std::vector< weight_t > &edgeWeights, std::vector< std::vector< uint32_t > > &tracks, float ccCut, float walkMin, float walkMax)
void calculateEdgeFeatures(const std::vector< float > &gNodeFeatures, int64_t numSpacepoints, const std::vector< int64_t > &rowIndices, const std::vector< int64_t > &colIndices, std::vector< float > &edgeFeatures)
void buildEdges(const std::vector< float > &embedFeatures, std::vector< int64_t > &senders, std::vector< int64_t > &receivers, int64_t numSpacepoints, int embeddingDim, float rVal, int kVal)
DataModel_detail::iterator< DVL > unique(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of unique for DataVector/List.
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
void swap(ElementLinkVector< DOBJ > &lhs, ElementLinkVector< DOBJ > &rhs)