ATLAS Offline Software
ExaTrkXUtils.cxx
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2023 CERN for the benefit of the ATLAS collaboration
3 */
4 
5 #include "ExaTrkXUtils.hpp"
6 #include <numeric> // std::iota
7 #include <algorithm> // std::sort
8 
9 
11  std::vector<float>& embedFeatures,
12  std::vector<int64_t>& edgeList,
13  int64_t numSpacepoints,
14  int embeddingDim, // dimension of embedding space
15  float rVal, // radius of the ball
16  int kVal // number of nearest neighbors
17 ){
18  // calculate the distances between two spacepoints in the embedding space
19  // and keep the k-nearest neighbours within the radius r
20  // the distance is calculated using the L2 norm
21  // the edge list is with dimension of [2, number-of-edges]
22  // the first row is the source node index
23  // the second row is the target node index
24  // TODO: use the KDTree to speed up the calculation
25 
26  std::vector<float> dists;
27  dists.reserve(numSpacepoints);
28  std::vector<int> idx(numSpacepoints);
29  std::vector<int64_t> senders;
30  std::vector<int64_t> receivers;
31  for (int64_t i = 0; i < numSpacepoints; i++) {
32  dists.clear();
33  for (int64_t j = 0; j < numSpacepoints; j++) {
34  if (i == j) {
35  dists.push_back(0);
36  continue;
37  }
38  float dist = 0;
39  for (int k = 0; k < embeddingDim; k++) {
40  float dist_k = embedFeatures[i * embeddingDim + k] - embedFeatures[j * embeddingDim + k];
41  dist += dist_k * dist_k;
42  }
43  dists.push_back(sqrt(dist));
44  }
45  std::iota(idx.begin(), idx.end(), 0);
46  std::sort(idx.begin(), idx.end(), [&dists](int i1, int i2) {return dists[i1] < dists[i2];});
47  int numFilled = -1;
48  for (int j = 0; j < numSpacepoints; j++) {
49  if (i == j) continue;
50  if (dists[idx[j]] > rVal) break;
51  numFilled++;
52  senders.push_back(i);
53  receivers.push_back(idx[j]);
54  if (numFilled >= kVal) break;
55  }
56  }
57  edgeList.resize(2 * senders.size());
58  std::copy(senders.begin(), senders.end(), edgeList.begin());
59  std::copy(receivers.begin(), receivers.end(), edgeList.begin() + senders.size());
60 }
python.compareTCTs.rVal
rVal
Definition: compareTCTs.py:118
lumiFormat.i
int i
Definition: lumiFormat.py:92
buildEdges
void buildEdges(std::vector< float > &embedFeatures, std::vector< int64_t > &edgeList, int64_t numSpacepoints, int embeddingDim, float rVal, int kVal)
Definition: ExaTrkXUtils.cxx:10
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
calibdata.copy
bool copy
Definition: calibdata.py:27
fitman.k
k
Definition: fitman.py:528