ATLAS Offline Software
Loading...
Searching...
No Matches
InPlaceClusterization.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2026 CERN for the benefit of the ATLAS collaboration
3*/
4#ifndef ACTS_INPLACECLUSTERIZATION_H
5#define ACTS_INPLACECLUSTERIZATION_H
6#include <algorithm>
7#include <array>
8#include <cassert>
9#include <cmath>
10#include <concepts>
11#include <limits>
12#include <ranges>
13#include <span>
14#include <type_traits>
15
17
18// For cells that should be clusterized the following functions
19// need to be implemented:
20namespace traits {
21// @brief Get the number of dimensions of the grid the cell lives on
22template <typename cell_t>
23constexpr auto getCellDimension() {
24 return std::tuple_size<decltype(cell_t::coordinates)>();
25}
26
35template <typename cell_t>
36auto getCellCoordinate(const cell_t &a, unsigned int axis_i) {
37 return a.coordinates[axis_i];
38}
39
43template <typename cell_t, std::unsigned_integral index_t>
44void setLabel(cell_t &a, index_t label) {
45 a.label = label;
46}
47
49template <typename cell_t>
50auto getLabel(const cell_t &a) {
51 return a.label;
52}
53} // namespace traits
54
61template <typename coordinates_t, std::size_t N, std::unsigned_integral index_t>
62struct Cell {
63 Cell(const std::array<coordinates_t, N> &the_coordinates, index_t src_index)
64 : coordinates(the_coordinates), label(index_t{}), srcIndex(src_index) {}
65
66 std::array<coordinates_t, N>
68 index_t label;
69 index_t srcIndex;
70};
71
73template <typename cell_t>
74concept CellWithLabel = requires(const cell_t &a, const cell_t &b,
75 decltype(traits::getLabel(a)) idx,
76 unsigned int axis_i) {
77 { traits::getLabel(a) } -> std::unsigned_integral;
78 traits::setLabel(a, idx);
79 {
81 } -> std::convertible_to<bool>;
82 { traits::getCellDimension<cell_t>() } -> std::convertible_to<std::size_t>;
83};
84
86template <typename container_t>
87concept SequenceContainer = requires(container_t &cont, unsigned int index,
88 typename container_t::value_type) {
89 {
90 cont[index]
91 } -> std::convertible_to<const typename container_t::value_type &>;
92 { cont.empty() } -> std::convertible_to<bool>;
93 { cont.size() } -> std::convertible_to<std::size_t>;
94} && std::permutable<typename container_t::iterator>;
95
97template <typename cell_collection_t>
100
105template <typename coordinates_t>
106bool isConnectedCommonEdgeOrCorner(const coordinates_t &coordinates_diff) {
107 bool connected = true;
108 for (const auto &a_coordinate_diff : coordinates_diff) {
109 connected &= a_coordinate_diff <= 1;
110 }
111 return connected;
112}
113
118template <typename coordinates_t>
119bool isConnectedCommonEdge(const coordinates_t &coordinates_diff) {
120 int connections = 0;
121 for (const auto &a_coordinate_diff : coordinates_diff) {
122 connections += a_coordinate_diff;
123 }
124 return connections <= 1;
125}
126
128
137template <CellWithLabel cell_t,
140 static constexpr std::size_t NDim = traits::getCellDimension<cell_t>();
141 using coordinate_t = std::remove_cvref_t<decltype(traits::getCellCoordinate(
142 std::declval<cell_t>(), 0u))>;
143
145 static bool isConnected(
146 const std::array<coordinate_t, NDim> &coordinates_diff) {
147 if constexpr (connection_type == EConnectionType::CommonEdgeOrCorner) {
148 return isConnectedCommonEdgeOrCorner(coordinates_diff);
149 } else {
150 return isConnectedCommonEdge(coordinates_diff);
151 }
152 }
153
158 static bool canAbortSearch(
159 const std::array<coordinate_t, NDim> &coordinates_diff,
160 unsigned int sort_axis_i) {
161 assert(sort_axis_i < coordinates_diff.size());
162 return coordinates_diff[sort_axis_i] > 1;
163 }
164};
165
166template <
168 CellCollection cell_container_t = std::span<Cell<int, 2, unsigned int> > >
169auto defaultConnectionHelper([[maybe_unused]] const cell_container_t &cells) {
170 return ConnectionHelper<typename cell_container_t::value_type,
171 connection_type>{};
172}
173
175template <typename coordinate_t>
176auto absDifference(coordinate_t a, coordinate_t b) {
177 if constexpr (std::is_signed_v<coordinate_t>) {
178 // @TODO check overflow with C++26, or change return type to
179 // an unsigned integer of same size.
180 return static_cast<coordinate_t>(std::abs(a - b));
181 } else {
182 return static_cast<coordinate_t>((a > b ? a - b : b - a));
183 }
184}
185
189template <unsigned int SORT_AXIS, CellCollection cell_collection_t,
190 std::unsigned_integral index_t = unsigned int,
191 typename connection_helper_t>
192void labelSortedCells(cell_collection_t &cells,
193 connection_helper_t &&connection_helper) {
194 using cell_t = typename cell_collection_t::value_type;
195 static constexpr std::size_t NDim = traits::getCellDimension<cell_t>();
196 static_assert(NDim > 0);
197 using label_t = std::remove_cvref_t<decltype(traits::getLabel(cells[0]))>;
198 using coordinate_t =
199 std::remove_cvref_t<decltype(traits::getCellCoordinate(cells[0], 0u))>;
200 static_assert(std::numeric_limits<index_t>::max() >=
201 std::numeric_limits<label_t>::max());
202 // cells are sorted in the first coordinate
203 // thus can stop searching for adjacent cells
204 // if distance in the sorted coordinate is too large
205 for (index_t idx_a = 0; idx_a < cells.size(); ++idx_a) {
206 traits::setLabel(cells[idx_a], idx_a);
207 for (index_t idx_b = idx_a; idx_b-- > 0;) {
208 // Unnecessary default initialization to satisfy clang-tidy:
209 std::array<coordinate_t, NDim> diff{};
210 for (unsigned int axis_i = 0; axis_i < NDim; ++axis_i) {
211 diff[axis_i] =
212 absDifference(traits::getCellCoordinate(cells[idx_a], axis_i),
213 traits::getCellCoordinate(cells[idx_b], axis_i));
214 }
215
216 if (connection_helper.isConnected(diff)) {
217 if (traits::getLabel(cells[idx_a]) < idx_a) {
218 // Unnecessary default initialization to satisfy clang-tidy:
219 label_t min_label{};
220 label_t max_label{};
221 if (traits::getLabel(cells[idx_a]) < traits::getLabel(cells[idx_b])) {
222 min_label = traits::getLabel(cells[idx_a]);
223 max_label = traits::getLabel(cells[idx_b]);
224 traits::setLabel(cells[idx_b], min_label);
225 } else {
226 max_label = traits::getLabel(cells[idx_a]);
227 min_label = traits::getLabel(cells[idx_b]);
228 traits::setLabel(cells[idx_a], min_label);
229 }
230 // nothing will be done if the min and the max label are identical
231 if (min_label != max_label) {
232 // can only encounter cells with label max_label down to index
233 // max_label
234 for (index_t idx = idx_a; idx-- > max_label;) {
235 if (traits::getLabel(cells[idx]) == max_label) {
236 traits::setLabel(cells[idx], min_label);
237 }
238 }
239 }
240 } else {
241 traits::setLabel(cells[idx_a], traits::getLabel(cells[idx_b]));
242 }
243 } else if (connection_helper.canAbortSearch(diff, SORT_AXIS)) {
244 // difference too large in sorted coordinate, there won't be any more
245 // candidates for merging. Can abort search for cell idx_a
246 break;
247 }
248 }
249 }
250}
251
253template <unsigned int AXIS, CellCollection cell_collection_t,
254 std::size_t NDim = 2, std::unsigned_integral index_t = unsigned int>
255 requires(NDim > 0 && AXIS < NDim)
256void sortCellsByCoordinate(cell_collection_t &cells) {
257 using cell_t = typename cell_collection_t::value_type;
258 std::sort(cells.begin(), cells.end(), [](const cell_t &a, const cell_t &b) {
259 return traits::getCellCoordinate(a, AXIS) <
260 traits::getCellCoordinate(b, AXIS);
261 });
262}
263
265template <CellCollection cell_collection_t>
266void groupCellsByLabel(cell_collection_t &cells) {
267 using cell_t = typename cell_collection_t::value_type;
268 // stable sort to retain coordinate ordering
269 std::stable_sort(cells.begin(), cells.end(),
270 [](const cell_t &a, const cell_t &b) {
271 return traits::getLabel(a) < traits::getLabel(b);
272 });
273}
274
286template <
287 unsigned int SORT_AXIS, std::unsigned_integral index_t = unsigned int,
288 CellCollection cell_collection_t = std::span<Cell<int, 2, unsigned int> >,
289 typename connection_helper_t =
290 ConnectionHelper<Cell<int, 2, unsigned int> > >
291void clusterize(cell_collection_t &cells,
292 connection_helper_t &&connection_helper =
293 ConnectionHelper<typename cell_collection_t::value_type,
295 assert(cells.size() <= std::numeric_limits<index_t>::max());
298 connection_helper);
299 groupCellsByLabel(cells);
300}
301
307template <CellCollection cell_collection_t,
308 std::unsigned_integral index_t = unsigned int>
309std::size_t countLabels(const cell_collection_t &cells) {
310 assert(cells.size() <= std::numeric_limits<index_t>::max());
311 std::size_t nlabels = cells.empty() ? 0 : 1;
312 for (index_t idx_a = 1; idx_a < cells.size(); ++idx_a) {
313 nlabels +=
314 traits::getLabel(cells[idx_a]) != traits::getLabel(cells[idx_a - 1]);
315 }
316 return nlabels;
317}
318
325template <CellCollection cell_collection_t, typename func_t>
326void for_each_cluster(cell_collection_t &cells, func_t func) {
327 if (cells.empty()) {
328 return;
329 }
330 using index_t = decltype(traits::getLabel(cells[0]));
331 index_t idx_begin = 0;
332 index_t idx = 1;
333 for (; idx < cells.size(); ++idx) {
334 if (traits::getLabel(cells[idx]) != traits::getLabel(cells[idx - 1])) {
335 func(cells, idx_begin, idx);
336 idx_begin = idx;
337 }
338 }
339 if (idx_begin < idx) {
340 func(cells, idx_begin, idx);
341 }
342}
343
353template <CellCollection cell_collection_t, typename range_collection_t>
354void addCellRanges(const cell_collection_t &cells, range_collection_t &ranges) {
356 cells, [&ranges]([[maybe_unused]] const cell_collection_t &all_cells,
357 unsigned int idx_begin, unsigned int idx_end) {
358 ranges.emplace_back(idx_begin, idx_end);
359 });
360}
361
372template <unsigned int AXIS, CellCollection cell_collection_t,
373 typename range_collection_t>
374void addCellRangesAndSort(cell_collection_t &cells,
375 range_collection_t &ranges) {
376 for_each_cluster(cells, [&ranges](cell_collection_t &all_cells,
377 unsigned int idx_begin,
378 unsigned int idx_end) {
379 auto cluster_range =
380 std::span(all_cells.begin() + idx_begin, all_cells.begin() + idx_end);
381 using cell_t = typename cell_collection_t::value_type;
382 std::ranges::sort(cluster_range, [](const cell_t &a, const cell_t &b) {
383 return traits::getCellCoordinate(a, AXIS) <
385 });
386 ranges.emplace_back(idx_begin, idx_end);
387 });
388}
389
390} // namespace Acts::InPlaceClusterization
391#endif
static Double_t a
void diff(const Jet &rJet1, const Jet &rJet2, std::map< std::string, double > varDiff)
Difference between jets - Non-Class function required by trigger.
Definition Jet.cxx:631
concept of a cell container that can be clustered
concept of a call object that can be clustered.
base concept for the container that can be used for a cell collection
std::string label(const std::string &format, int i)
Definition label.h:19
auto getLabel(const cell_t &a)
Get the label associated to the given cell.
void setLabel(cell_t &a, index_t label)
Set a label for a given cell the label type must fit numbers as high as the number of cells which are...
auto getCellCoordinate(const cell_t &a, unsigned int axis_i)
Get the coordinates of a cell.
void labelSortedCells(cell_collection_t &cells, connection_helper_t &&connection_helper)
Label cells which are in ascending order of the coordinate of the given axis.
auto absDifference(coordinate_t a, coordinate_t b)
compute the absolute difference of two coordinates
void sortCellsByCoordinate(cell_collection_t &cells)
Sort the cells in ascending order of the coordinate of the specified axis.
bool isConnectedCommonEdge(const coordinates_t &coordinates_diff)
test whether cells are connected considering common edges only.
auto defaultConnectionHelper(const cell_container_t &cells)
void groupCellsByLabel(cell_collection_t &cells)
Sort the cells in ascending order of the asscoiated label.
bool isConnectedCommonEdgeOrCorner(const coordinates_t &coordinates_diff)
test whether cells are connected considering common corners and edges.
void for_each_cluster(cell_collection_t &cells, func_t func)
call the given function for each cluster of a label sorted cell collection.
std::size_t countLabels(const cell_collection_t &cells)
determine the number of clusters.
void addCellRangesAndSort(cell_collection_t &cells, range_collection_t &ranges)
Sort the cells of each cluster by one coordinate and add cell ranges to the given range container.
void addCellRanges(const cell_collection_t &cells, range_collection_t &ranges)
Add element ranges for each cluster in the cell collection to the given range container.
void clusterize(cell_collection_t &cells, connection_helper_t &&connection_helper=ConnectionHelper< typename cell_collection_t::value_type, EConnectionType::CommonEdgeOrCorner >{})
Sort the cell collection in such a way that cells of a cluster are adjacent.
Definition index.py:1
void sort(typename DataModel_detail::iterator< DVL > beg, typename DataModel_detail::iterator< DVL > end)
Specialization of sort for DataVector/List.
void stable_sort(DataModel_detail::iterator< DVL > beg, DataModel_detail::iterator< DVL > end)
Specialization of stable_sort for DataVector/List.
index_t label
a label which will be assigned by the clusterization
std::array< coordinates_t, N > coordinates
the coordinates of the cell on the regular grid
index_t srcIndex
the index to find the source cell
Cell(const std::array< coordinates_t, N > &the_coordinates, index_t src_index)
default connection helper which should work for arbitrary cells which fulfil the CellWithLabelConcept
std::remove_cvref_t< decltype(traits::getCellCoordinate( std::declval< cell_t >(), 0u))> coordinate_t
static bool canAbortSearch(const std::array< coordinate_t, NDim > &coordinates_diff, unsigned int sort_axis_i)
test whether the search for connections can be aborted.
static bool isConnected(const std::array< coordinate_t, NDim > &coordinates_diff)
test whether cells are connected considering common edges or common corners