ATLAS Offline Software
Loading...
Searching...
No Matches
ExaTrkXUtils Namespace Reference

Functions

void buildEdges (const std::vector< float > &embedFeatures, std::vector< int64_t > &senders, std::vector< int64_t > &receivers, int64_t numSpacepoints, int embeddingDim, float rVal, int kVal)
void CCandWalk (vertex_t numSpacepoints, const std::vector< int64_t > &rowIndices, const std::vector< int64_t > &colIndices, const std::vector< weight_t > &edgeWeights, std::vector< std::vector< uint32_t > > &tracks, float ccCut, float walkMin, float walkMax)
std::vector< std::vector< vertex_t > > getSimplePath (const UndirectedGraph &G)
std::vector< vertex_t > findNextNode (const Graph &G, vertex_t current_hit, float th_min, float th_add)
std::vector< std::vector< vertex_t > > buildRoads (const Graph &G, vertex_t starting_node, std::function< std::vector< vertex_t >(const Graph &, vertex_t)> next_node_fn, std::map< vertex_t, bool > &used_hits)
Graph cleanupGraph (const Graph &G, float ccCut)
void calculateEdgeFeatures (const std::vector< float > &gNodeFeatures, int64_t numSpacepoints, const std::vector< int64_t > &rowIndices, const std::vector< int64_t > &colIndices, std::vector< float > &edgeFeatures)

Function Documentation

◆ buildEdges()

void ExaTrkXUtils::buildEdges ( const std::vector< float > & embedFeatures,
std::vector< int64_t > & senders,
std::vector< int64_t > & receivers,
int64_t numSpacepoints,
int embeddingDim,
float rVal,
int kVal )

Definition at line 18 of file ExaTrkXUtils.cxx.

26 {
27 // calculate the distances between two spacepoints in the embedding space
28 // and keep the k-nearest neighbours within the radius r
29 // the distance is calculated using the L2 norm
30 // the edge list is with dimension of [2, number-of-edges]
31 // the first row is the source node index
32 // the second row is the target node index
33 // computing complexity is O(N^2 * D) where N is the number of spacepoints and D is the embedding dimension.
34 // space complexity is O(N * (D + k)).
35
36 // Helper lambda to calculate squared distance between two points
37 auto squaredDistance = [&](int64_t i, int64_t j) {
38 float dist = 0.0;
39 for (int d = 0; d < embeddingDim; ++d) {
40 float diff = embedFeatures[i * embeddingDim + d] - embedFeatures[j * embeddingDim + d];
41 dist += diff * diff;
42 }
43 return dist;
44 };
45
46 // Radius squared to avoid taking square root repeatedly
47 float radiusSquared = rVal * rVal;
48
49 for (int64_t i = 0; i < numSpacepoints; i++) {
50 // Min-heap (priority queue) to store nearest neighbors
51 std::priority_queue<std::pair<float, int64_t>> nearestNeighbors;
52
53 for (int64_t j = i + 1; j < numSpacepoints; j++) {
54
55 float distSquared = squaredDistance(i, j);
56 if (distSquared <= radiusSquared) {
57 nearestNeighbors.push({distSquared, j});
58 // Maintain top k neighbors in the heap
59 if (nearestNeighbors.size() > (unsigned long) kVal) {
60 nearestNeighbors.pop();
61 }
62 }
63 }
64 // Add the k-nearest neighbors to the edge list
65 while (!nearestNeighbors.empty()) {
66 senders.push_back(i);
67 receivers.push_back(nearestNeighbors.top().second);
68 nearestNeighbors.pop();
69 }
70 }
71}
void diff(const Jet &rJet1, const Jet &rJet2, std::map< std::string, double > varDiff)
Difference between jets - Non-Class function required by trigger.
Definition Jet.cxx:631

◆ buildRoads()

std::vector< std::vector< vertex_t > > ExaTrkXUtils::buildRoads ( const Graph & G,
vertex_t starting_node,
std::function< std::vector< vertex_t >(const Graph &, vertex_t)> next_node_fn,
std::map< vertex_t, bool > & used_hits )

Definition at line 243 of file ExaTrkXUtils.cxx.

248 {
249 std::vector<std::vector<int>> path = {{starting_node}};
250
251 while (true) {
252 std::vector<std::vector<int>> new_path;
253 bool is_all_done = true;
254 // loop over each path and extend it.
255 for (const auto &pp : path) {
256 vertex_t start = pp.back();
257
258 if (start == -1) {
259 new_path.push_back(pp);
260 continue;
261 }
262
263 auto next_hits = next_node_fn(G, start);
264 // remove used hits.
265 next_hits.erase(std::remove_if(next_hits.begin(), next_hits.end(),
266 [&](int node_id) {
267 auto hit_id = boost::get(boost::vertex_name, G, node_id);
268 return used_hits[hit_id];
269 }), next_hits.end());
270
271 if (next_hits.empty()) {
272 new_path.push_back(pp);
273 } else {
274 is_all_done = false;
275 for (int nh : next_hits) {
276 std::vector<int> pp_extended = pp;
277 pp_extended.push_back(nh);
278 new_path.push_back(std::move(pp_extended));
279 }
280 }
281 }
282
283 path = std::move(new_path);
284 if (is_all_done) break;
285 }
286 return path;
287}
#define G(x, y, z)
Definition MD5.cxx:113
path
python interpreter configuration --------------------------------------—
Definition athena.py:128
DataModel_detail::iterator< DVL > remove_if(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end, Predicate pred)
Specialization of remove_if for DataVector/List.

◆ calculateEdgeFeatures()

void ExaTrkXUtils::calculateEdgeFeatures ( const std::vector< float > & gNodeFeatures,
int64_t numSpacepoints,
const std::vector< int64_t > & rowIndices,
const std::vector< int64_t > & colIndices,
std::vector< float > & edgeFeatures )

Definition at line 322 of file ExaTrkXUtils.cxx.

327 {
328 // calculate the edge features from the node features.
329 // the edge features are: [dr, dphi, dz, deta, phislope, rphislope]
330 // dr: difference in r.
331 // dphi: difference in phi.
332 // dz: difference in z.
333 // deta: difference in eta.
334 // phislope: dphi / dr.
335 // rphislope: (r[dst] + r[src]) / 2 * phislope.
336 // So edge features are with dimension of [6, number-of-edges].
337 // We assume the first 4 node features are r, phi, z, eta.
338
339 int64_t numFeatures = gNodeFeatures.size() / numSpacepoints;
340 edgeFeatures.clear();
341 edgeFeatures.reserve(rowIndices.size() * 6);
342 auto diffPhi = [](float phi1, float phi2) {
343 float dphi = (phi1 - phi2) * M_PI;
344 if (dphi > M_PI) {
345 dphi -= 2 * M_PI;
346 } else if (dphi <= -M_PI) {
347 dphi += 2 * M_PI;
348 }
349 return dphi / M_PI;
350 };
351 for (size_t idx = 0; idx < rowIndices.size(); ++idx) {
352 int64_t src = rowIndices[idx];
353 int64_t dst = colIndices[idx];
354 std::vector<float> src_features(gNodeFeatures.begin() + src * numFeatures,
355 gNodeFeatures.begin() + src * numFeatures + 4);
356 std::vector<float> dst_features(gNodeFeatures.begin() + dst * numFeatures,
357 gNodeFeatures.begin() + dst * numFeatures + 4);
358
359 float dr = dst_features[0] - src_features[0];
360 float dphi = diffPhi(dst_features[1], src_features[1]);
361 float dz = dst_features[2] - src_features[2];
362 float deta = dst_features[3] - src_features[3];
363 float phislope = (fabs(dr) > 1e-7) ? dphi / dr : 0.;
364 float rphislope = (fabs(phislope) > 1e-7) ? (dst_features[0] + src_features[0]) / 2 * phislope : 0.0;
365
366 edgeFeatures.push_back(dr);
367 edgeFeatures.push_back(dphi);
368 edgeFeatures.push_back(dz);
369 edgeFeatures.push_back(deta);
370 edgeFeatures.push_back(phislope);
371 edgeFeatures.push_back(rphislope);
372 }
373}
#define M_PI

◆ CCandWalk()

void ExaTrkXUtils::CCandWalk ( vertex_t numSpacepoints,
const std::vector< int64_t > & rowIndices,
const std::vector< int64_t > & colIndices,
const std::vector< weight_t > & edgeWeights,
std::vector< std::vector< uint32_t > > & tracks,
float ccCut,
float walkMin,
float walkMax )

Definition at line 73 of file ExaTrkXUtils.cxx.

80 {
81 Graph G;
82 std::map<vertex_t, bool> used_hits;
83 // use the space point ID as the vertex name.
84 for (vertex_t i = 0; i < numSpacepoints; i++) {
85 add_vertex(i, G);
86 used_hits[i] = false;
87 }
88 for(size_t idx=0; idx < rowIndices.size(); ++idx) {
89 add_edge(rowIndices[idx], colIndices[idx], edgeWeights[idx], G);
90 }
91
92 // remove isolated vertices and edges with weight <= ccCut
93 Graph newG = cleanupGraph(G, ccCut);
94
95 UndirectedGraph ugraph;
96 // add vertices from newG to ugraph.
97 for (auto v : boost::make_iterator_range(vertices(newG))) {
98 auto name = boost::get(boost::vertex_name, newG, v);
99 add_vertex(name, ugraph);
100 }
101 // add edges from newG to ugraph.
102 auto [edge_b, edge_e] = boost::edges(newG);
103 for (auto it = edge_b; it != edge_e; ++it) {
104 int source = boost::source(*it, newG);
105 int target = boost::target(*it, newG);
106 add_edge(source, target, ugraph);
107 }
108
109 std::vector<std::vector<vertex_t>> sub_graphs = getSimplePath(ugraph);
110 // mark the used hits.
111 for (const auto& track : sub_graphs) {
112 for (auto hit_id : track) {
113 used_hits[hit_id] = true;
114 }
115 }
116
117 std::vector<Vertex> topo_order;
118 boost::topological_sort(newG, std::back_inserter(topo_order));
119
120 // Define the next_hit function
121 auto next_node_fn = [&](const Graph &G, vertex_t current_hit) {
122 return findNextNode(G, current_hit, walkMin, walkMax);
123 };
124
125 // Traverse the nodes in topological order
126 for(auto it = topo_order.rbegin(); it != topo_order.rend(); ++it) {
127 auto node_id = *it;
128 int hit_id = boost::get(boost::vertex_name, newG, node_id);
129 if (used_hits[hit_id]) continue;
130
131 // Build roads (tracks) starting from the current node
132 auto roads = buildRoads(newG, node_id, next_node_fn, used_hits);
133 used_hits[node_id] = true;
134 if (roads.empty()) {
135 continue;
136 }
137
138 // Find the longest road and remove the last element (-1)
139 std::vector<int>& longest_road = *std::max_element(roads.begin(), roads.end(),
140 [](const std::vector<int> &a, const std::vector<int> &b) {
141 return a.size() < b.size();
142 });
143
144 if (longest_road.size() >= 3) {
145 std::vector<vertex_t> track;
146 for (int node_id : longest_road) {
147 auto hit_id = boost::get(boost::vertex_name, newG, node_id);
148 used_hits[hit_id] = true;
149 track.push_back(hit_id);
150 }
151 sub_graphs.push_back(std::move(track));
152 }
153 }
154
155 // copy subgraph to tracks.
156 tracks.clear();
157 for (const auto& track : sub_graphs) {
158 std::vector<uint32_t> this_track{track.begin(), track.end()};
159 tracks.push_back(std::move(this_track));
160 }
161}
static Double_t a
std::vector< std::vector< vertex_t > > buildRoads(const Graph &G, vertex_t starting_node, std::function< std::vector< vertex_t >(const Graph &, vertex_t)> next_node_fn, std::map< vertex_t, bool > &used_hits)
std::vector< vertex_t > findNextNode(const Graph &G, vertex_t current_hit, float th_min, float th_add)
Graph cleanupGraph(const Graph &G, float ccCut)
std::vector< std::vector< vertex_t > > getSimplePath(const UndirectedGraph &G)

◆ cleanupGraph()

Graph ExaTrkXUtils::cleanupGraph ( const Graph & G,
float ccCut )

Definition at line 289 of file ExaTrkXUtils.cxx.

289 {
290 // remove fake edges and isolated vertices.
291 Graph newG;
292
293 // add vertices of G to newG, including the vertex name.
294 std::map<vertex_t, vertex_t> old_vertex_to_new;
295 vertex_t old_vertex_id = 0;
296 vertex_t new_vertex_id = 0;
297 for (auto v : boost::make_iterator_range(vertices(G))) {
298 auto name = boost::get(boost::vertex_name, G, v);
299 if (in_degree(v, G) == 0 && out_degree(v, G) == 0) {
300 old_vertex_id ++;
301 continue; // remove isolated vertices.
302 }
303 add_vertex(name, newG);
304 old_vertex_to_new[old_vertex_id] = new_vertex_id;
305 new_vertex_id ++;
306 old_vertex_id ++;
307 }
308 // add edges of G to newG.
309 auto [edge_b, edge_e] = boost::edges(G);
310 for (auto it = edge_b; it != edge_e; ++it) {
311 auto source = boost::source(*it, G);
312 auto target = boost::target(*it, G);
313 source = old_vertex_to_new[source];
314 target = old_vertex_to_new[target];
315 auto weight = boost::get(boost::edge_weight, G, *it);
316 if (weight <= ccCut) continue;
317 add_edge(source, target, weight, newG);
318 }
319 return newG;
320}

◆ findNextNode()

std::vector< vertex_t > ExaTrkXUtils::findNextNode ( const Graph & G,
vertex_t current_hit,
float th_min,
float th_add )

Definition at line 202 of file ExaTrkXUtils.cxx.

207 {
208 std::vector<vertex_t> next_hits;
209 auto [begin, end] = boost::out_edges(current_hit, G);
210
211 std::vector<std::pair<vertex_t, double>> neighbors_scores;
212 for (auto it = begin; it != end; ++it) {
213 vertex_t neighbor = target(*it, G);
214 auto score = boost::get(boost::edge_weight, G, *it);
215
216 if (neighbor == current_hit || score <= th_min) continue;
217 neighbors_scores.push_back({neighbor, score});
218 }
219
220 if (neighbors_scores.empty()) return {};
221
222 // Find the best neighbor
223 auto best_neighbor = *std::max_element(neighbors_scores.begin(), neighbors_scores.end(),
224 [](const std::pair<int, double> &a, const std::pair<int, double> &b) {
225 return a.second < b.second;
226 });
227
228 // Add neighbors with score > th_add
229 for (const auto &neighbor : neighbors_scores) {
230 if (neighbor.second > th_add) {
231 next_hits.push_back(neighbor.first);
232 }
233 }
234
235 // If no neighbors were added, add the best neighbor
236 if (next_hits.empty()) {
237 next_hits.push_back(best_neighbor.first);
238 }
239
240 return next_hits;
241}

◆ getSimplePath()

std::vector< std::vector< vertex_t > > ExaTrkXUtils::getSimplePath ( const UndirectedGraph & G)

Definition at line 164 of file ExaTrkXUtils.cxx.

164 {
165 std::vector<std::vector<vertex_t>> final_tracks;
166 // Get weakly connected components
167 std::vector<vertex_t> component(num_vertices(G));
168 size_t num_components = boost::connected_components(G, &component[0]);
169
170 std::vector<std::vector<Vertex> > component_groups(num_components);
171 for(size_t i = 0; i < component.size(); ++i) {
172 component_groups[component[i]].push_back(i);
173 }
174
175 // loop over the sorted groups.
176 for(const auto& sub_graph : component_groups) {
177 if (sub_graph.size() < 3) {
178 continue;
179 }
180 bool is_signal_path = true;
181 // Check if all nodes in the sub_graph are signal paths
182 for (int node : sub_graph) {
183 if (degree(node, G) > 2) {
184 is_signal_path = false;
185 break;
186 }
187 }
188
189 // If it's a signal path, collect the hit_ids
190 if (is_signal_path) {
191 std::vector<vertex_t> track;
192 for (int node : sub_graph) {
193 vertex_t hit_id = boost::get(boost::vertex_name, G, node);
194 track.push_back(hit_id);
195 }
196 final_tracks.push_back(std::move(track));
197 }
198 }
199 return final_tracks;
200}