ATLAS Offline Software
Functions
ExaTrkXUtils.cxx File Reference
#include "ExaTrkXUtils.hpp"
#include <numeric>
#include <algorithm>
Include dependency graph for ExaTrkXUtils.cxx:

Go to the source code of this file.

Functions

void buildEdges (std::vector< float > &embedFeatures, std::vector< int64_t > &edgeList, int64_t numSpacepoints, int embeddingDim, float rVal, int kVal)
 

Function Documentation

◆ buildEdges()

void buildEdges ( std::vector< float > &  embedFeatures,
std::vector< int64_t > &  edgeList,
int64_t  numSpacepoints,
int  embeddingDim,
float  rVal,
int  kVal 
)

Definition at line 10 of file ExaTrkXUtils.cxx.

17  {
18  // calculate the distances between two spacepoints in the embedding space
19  // and keep the k-nearest neighbours within the radius r
20  // the distance is calculated using the L2 norm
21  // the edge list is with dimension of [2, number-of-edges]
22  // the first row is the source node index
23  // the second row is the target node index
24  // TODO: use the KDTree to speed up the calculation
25 
26  std::vector<float> dists;
27  dists.reserve(numSpacepoints);
28  std::vector<int> idx(numSpacepoints);
29  std::vector<int64_t> senders;
30  std::vector<int64_t> receivers;
31  for (int64_t i = 0; i < numSpacepoints; i++) {
32  dists.clear();
33  for (int64_t j = 0; j < numSpacepoints; j++) {
34  if (i == j) {
35  dists.push_back(0);
36  continue;
37  }
38  float dist = 0;
39  for (int k = 0; k < embeddingDim; k++) {
40  float dist_k = embedFeatures[i * embeddingDim + k] - embedFeatures[j * embeddingDim + k];
41  dist += dist_k * dist_k;
42  }
43  dists.push_back(sqrt(dist));
44  }
45  std::iota(idx.begin(), idx.end(), 0);
46  std::sort(idx.begin(), idx.end(), [&dists](int i1, int i2) {return dists[i1] < dists[i2];});
47  int numFilled = -1;
48  for (int j = 0; j < numSpacepoints; j++) {
49  if (i == j) continue;
50  if (dists[idx[j]] > rVal) break;
51  numFilled++;
52  senders.push_back(i);
53  receivers.push_back(idx[j]);
54  if (numFilled >= kVal) break;
55  }
56  }
57  edgeList.resize(2 * senders.size());
58  std::copy(senders.begin(), senders.end(), edgeList.begin());
59  std::copy(receivers.begin(), receivers.end(), edgeList.begin() + senders.size());
60 }
python.compareTCTs.rVal
rVal
Definition: compareTCTs.py:118
lumiFormat.i
int i
Definition: lumiFormat.py:92
std::sort
void sort(typename std::reverse_iterator< DataModel_detail::iterator< DVL > > beg, typename std::reverse_iterator< DataModel_detail::iterator< DVL > > end, const Compare &comp)
Specialization of sort for DataVector/List.
Definition: DVL_algorithms.h:623
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
calibdata.copy
bool copy
Definition: calibdata.py:27
fitman.k
k
Definition: fitman.py:528