ATLAS Offline Software
KDPoint.h
Go to the documentation of this file.
1 /*
2  Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3 */
4 #ifndef TRIGTOOLS_TRIG_VSI_KDPOINT
5 #define TRIGTOOLS_TRIG_VSI_KDPOINT
6 
13 #include "TMath.h"
14 
15 #include <array>
16 #include <vector>
17 #include <functional>
18 #include <memory>
19 
20 namespace TrigVSI {
28 template<typename T, size_t D>
29 class KDPoint {
30  private :
31  std::array<T,D> m_point;
32  double m_weight;
33 
34 
35  public :
36  KDPoint() : m_point( std::array<T,D>() ), m_weight(1.){};
37  KDPoint( std::array<T,D>& arr ) : m_point(arr), m_weight(1.){};
38  KDPoint( std::array<T,D>&& arr ) : m_point(std::move(arr)), m_weight(1.){};
39 
40  KDPoint(const std::vector<T>& v) : m_weight(1.) {
41  for (size_t i = 0; i < D; i++) { if (i < v.size()) m_point[i] = v[i]; }
42  };
43 
44  KDPoint(std::initializer_list<T> list) : m_weight(1.) {
45  auto bg = list.begin();
46  for (auto itr = bg; itr != bg + D; ++itr) {
47  size_t i = std::distance(bg, itr);
48  m_point[i] = *itr;
49  }
50  };
51 
54  {
55  std::array<T,D> tmp;
56  for ( size_t i = 0; i < D; i++ ) {
57  tmp[i] = this->m_point[i] + other.m_point[i];
58  }
59  return KDPoint<T,D>(tmp);
60  };
61 
64  {
65  std::array<T,D> tmp;
66  for ( size_t i = 0; i < D; i++ ) {
67  tmp[i] = this->m_point[i] - other.m_point[i];
68  }
69  return KDPoint<T,D>(tmp);
70  };
71 
73  {
74  for ( size_t i = 0; i < D; i++ ) {
75  this->m_point[i] += other.m_point.at(i);
76  }
77  return *this;
78  };
79 
81  {
82  for ( size_t i = 0; i < D; i++ ) {
83  this->m_point[i] -= other.m_point.at(i);
84  }
85  return *this;
86  };
87 
89  template<typename I>
90  const KDPoint<T,D> operator * (const I& other) const
91  {
92  std::array<T,D> tmp;
93  for ( size_t i = 0; i < D; i++ ) {
94  tmp[i] = this->m_point[i] * other;
95  }
96  return KDPoint<T,D>(tmp);
97  }
98 
100  template<typename I>
101  const KDPoint<T,D> operator / (const I& other) const
102  {
103  std::array<T,D> tmp;
104  for ( size_t i = 0; i < D; i++ ) {
105  tmp[i] = this->m_point[i] / other;
106  }
107  return KDPoint<T,D>(tmp);
108  }
109 
111  T& operator [] (size_t i)
112  {
113  return m_point[i];
114  };
115 
117  const T& operator [] (size_t i) const
118  {
119  return m_point[i];
120  };
121 
125  inline const std::array<T,D>& getPos() const { return m_point; };
126 
128  inline T at(size_t i) const { return (i < D)? m_point[i] : TMath::QuietNaN(); };
129 
131  inline double getWeight() const { return m_weight; };
133  inline void setWeight(double w){ m_weight = w; };
135 
138  {
140  tmp = ( *this * this->m_weight + p * p.m_weight ) / ( this->m_weight + p.m_weight );
141  return tmp;
142  };
143 
144  static inline KDPoint<T, D> average(const std::vector<KDPoint<T,D>>&);
145 };
146 
147 template<typename I, typename T, size_t D>
148 const KDPoint<T,D> operator * (const I& b, const KDPoint<T,D>& p)
149 {
150  return p * b;
151 }
152 
153 
158 template<typename T, size_t D>
159 class KDTree {
160  public :
164  class Node {
165  public :
167  int dataIdx;
168  int axis;
169  std::unique_ptr<Node> leftPtr;
170  std::unique_ptr<Node> rightPtr;
171 
173  Node(){};
174  };
175 
176  KDTree(std::vector<KDPoint<T,D>>& v_data) : m_datas( v_data ), m_idLength( m_datas.size() ), m_locked( false )
177  {
178  m_indices.clear();
179  for (size_t i = 0; i < m_idLength; i++) { m_indices.emplace_back(i); }
180  };
181 
182  KDTree() : m_idLength(0), m_locked(false){};
183 
184  void genTree();
185  inline void lock() { m_locked = true; };
186  inline void unlock() { m_locked = false; };
187 
188  inline KDPoint<T,D> at(size_t n) { return m_datas.at(n); };
189 
190  private :
191  std::unique_ptr<Node> m_rootNode;
192  std::vector<KDPoint<T,D>> m_datas;
193  std::vector<size_t> m_indices;
194  size_t m_idLength;
195  bool m_locked;
196 
197  std::unique_ptr<Node> buildTree(int, int, int);
198 
199  void nearestNeighborRec( const KDPoint<T,D>&, const Node*, double&, int&, std::function<double(const KDPoint<T,D>&, const KDPoint<T,D>&)>& );
200 };
201 
202 
206 template<typename T, size_t D>
208 {
209  if (m_locked) return;
210  m_rootNode = buildTree(0, m_idLength, 0);
211 }
212 
213 
221 template<typename T, size_t D>
222 std::unique_ptr<typename KDTree<T,D>::Node> KDTree<T,D>::buildTree(int l, int r, int depth)
223 {
224  if ( l >= r ) return std::make_unique<Node>(nullptr);
225 
226  const int axis_ = depth % D;
227  const int mid = ( l + r ) >> 1;
228 
229  std::nth_element( m_indices.begin()+l, m_indices.begin()+l+mid, m_indices.begin()+r,
230  [this, axis_](size_t lcnt, size_t rcnt) {
231  return m_datas[lcnt].getId(axis_) < m_datas[rcnt].getId(axis_);
232  } );
233 
234  std::unique_ptr<Node> node_ptr = std::make_unique<Node> ( m_datas.at(m_indices[mid]), m_indices[mid] );
235  node_ptr->axis = axis_;
236 
237  node_ptr->leftPtr = buildTree( l, mid, depth+1 );
238  node_ptr->rightPtr = buildTree( mid+1, r, depth+1 );
239 
240  return node_ptr;
241 }
242 
243 
252 template<typename T, size_t D>
254  double& r, int& idx, std::function<double(const KDPoint<T,D>&, const KDPoint<T,D>&)>& dist_func )
255 {
256  // end processing when reach leaf
257  if ( node == nullptr ) {
258  return;
259  }
260 
261  const KDPoint<T,D>& point = node->dataRef;
262 
263  // update minimum distance and candidate point
264  const double dist = dist_func(query, point);
265  if (dist < r) {
266  idx = node->dataIdx;
267  r = dist;
268  }
269 
270  const int axis_ = node->axis;
271  const Node* next_node = (query.at(axis_) < point.at(axis_))? node->leftPtr.get() : node->rightPtr.get();
272 
273  // recursively call self until reach leaf node
274  nearestNeighborRec( query, next_node, r, idx, dist_func );
275 
276  // Check if hyper-sphere with radius r overlaps neighbor hyper-plane.
278  proj[axis_] = point.at(axis_); // proj is the projected point of query on hyperplane x[axis_]=point[axis_]
279  double diff = dist_func(proj, query);
280  if ( diff < r ) {
281  // if hyper-sphere overlaps, check the other region separated by the hyper-plane.
282  const Node* next_node_opps = (query.at(axis_) < point.at(axis_))? node->rightPtr.get() : node->leftPtr.get();
283  nearestNeighborRec( query, next_node_opps, r, idx, dist_func );
284  }
285  return;
286 }
287 
288 
289 } // end of namespace TrigVSI
290 
291 
292 #endif
beamspotman.r
def r
Definition: beamspotman.py:676
data
char data[hepevt_bytes_allocation_ATLAS]
Definition: HepEvt.cxx:11
TrigVSI::KDPoint::operator*
const KDPoint< T, D > operator*(const I &other) const
Multiply each elements except weights.
Definition: KDPoint.h:90
egammaParameters::depth
@ depth
pointing depth of the shower as calculated in egammaqgcld
Definition: egammaParamDefs.h:276
TrigVSI::KDPoint::average
KDPoint< T, D > average(const KDPoint< T, D > &p)
Return average point of this point and given point.
Definition: KDPoint.h:137
StandaloneBunchgroupHandler.bg
bg
Definition: StandaloneBunchgroupHandler.py:243
TrigVSI::KDPoint::KDPoint
KDPoint(std::array< T, D > &&arr)
Definition: KDPoint.h:38
TrigVSI::KDPoint::at
T at(size_t i) const
Return i-th element. If given i exceeds the size, return NaN.
Definition: KDPoint.h:128
TrigVSI::KDTree::m_indices
std::vector< size_t > m_indices
A list of indices of points in m_datas.
Definition: KDPoint.h:193
TrigVSI::KDPoint::operator-
const KDPoint< T, D > operator-(const KDPoint< T, D > &other) const
Subtract each elements except weights.
Definition: KDPoint.h:63
TrigVSI::KDPoint::KDPoint
KDPoint(std::initializer_list< T > list)
Definition: KDPoint.h:44
TrigVSI::KDTree::Node::axis
int axis
Definition: KDPoint.h:168
mc.diff
diff
Definition: mc.SFGenPy8_MuMu_DD.py:14
TrigVSI::KDPoint::m_point
std::array< T, D > m_point
Definition: KDPoint.h:31
UploadAMITag.l
list l
Definition: UploadAMITag.larcaf.py:158
TrigVSI::KDTree::KDTree
KDTree(std::vector< KDPoint< T, D >> &v_data)
Definition: KDPoint.h:176
TrigVSI::KDTree::Node::Node
Node(const KDPoint< T, D > &data, int idx)
Definition: KDPoint.h:172
TrigVSI::KDTree::at
KDPoint< T, D > at(size_t n)
Definition: KDPoint.h:188
TrigVSI::KDTree::Node::leftPtr
std::unique_ptr< Node > leftPtr
Definition: KDPoint.h:169
TrigVSI::KDTree::m_rootNode
std::unique_ptr< Node > m_rootNode
The root node of the tree.
Definition: KDPoint.h:188
TrigVSI::KDTree::lock
void lock()
Definition: KDPoint.h:185
query
Definition: query.py:1
TrigVSI::KDTree::m_locked
bool m_locked
Definition: KDPoint.h:195
python.setupRTTAlg.size
int size
Definition: setupRTTAlg.py:39
TrigVSI::KDPoint::KDPoint
KDPoint(const std::vector< T > &v)
Definition: KDPoint.h:40
TrigVSI::KDTree::Node::dataRef
const KDPoint< T, D > & dataRef
Definition: KDPoint.h:166
python.utils.AtlRunQueryDQUtils.p
p
Definition: AtlRunQueryDQUtils.py:210
TrigVSI::operator*
const KDPoint< T, D > operator*(const I &b, const KDPoint< T, D > &p)
Definition: KDPoint.h:148
TrigVSI::KDTree::buildTree
std::unique_ptr< Node > buildTree(int, int, int)
recursive function to create tree structure.
Definition: KDPoint.h:222
lumiFormat.i
int i
Definition: lumiFormat.py:85
TrigVSI::KDTree::m_datas
std::vector< KDPoint< T, D > > m_datas
Container of the points.
Definition: KDPoint.h:192
TrigVSI::KDPoint::KDPoint
KDPoint(std::array< T, D > &arr)
Definition: KDPoint.h:37
TrigVSI::KDTree::nearestNeighborRec
void nearestNeighborRec(const KDPoint< T, D > &, const Node *, double &, int &, std::function< double(const KDPoint< T, D > &, const KDPoint< T, D > &)> &)
recursive function for nearest neighbor searching.
Definition: KDPoint.h:253
beamspotman.n
n
Definition: beamspotman.py:731
TrigVSI::KDTree::Node::Node
Node()
Definition: KDPoint.h:173
histSizes.list
def list(name, path='/')
Definition: histSizes.py:38
TrigVSI::KDPoint::operator+
const KDPoint< T, D > operator+(const KDPoint< T, D > &other) const
Add each elements except weights.
Definition: KDPoint.h:53
DeMoUpdate.tmp
string tmp
Definition: DeMoUpdate.py:1167
make_coralServer_rep.proj
proj
Definition: make_coralServer_rep.py:48
TrigVSI::KDPoint::getWeight
double getWeight() const
Return the weight of the point.
Definition: KDPoint.h:131
query_example.query
query
Definition: query_example.py:15
TrigVSI::KDTree::unlock
void unlock()
Definition: KDPoint.h:186
lumiFormat.array
array
Definition: lumiFormat.py:91
TrigVSI::KDTree::Node::dataIdx
int dataIdx
Definition: KDPoint.h:167
TrigVSI::KDTree
KDTree.
Definition: KDPoint.h:159
private
#define private
Definition: DetDescrConditionsDict_dict_fixes.cxx:13
plotBeamSpotMon.b
b
Definition: plotBeamSpotMon.py:77
TrigVSI::KDPoint::setWeight
void setWeight(double w)
Set the weight to given value.
Definition: KDPoint.h:133
TrigVSI::KDPoint::operator/
const KDPoint< T, D > operator/(const I &other) const
Divide each elements except weights.
Definition: KDPoint.h:101
TrigVSI::KDPoint
Class for k-dimensional point.
Definition: KDPoint.h:29
python.PyAthena.v
v
Definition: PyAthena.py:154
TrigVSI::KDPoint::operator[]
T & operator[](size_t i)
Return i-th element.
Definition: KDPoint.h:111
InDetDD::other
@ other
Definition: InDetDD_Defs.h:16
TrigVSI::KDPoint::m_weight
double m_weight
Definition: KDPoint.h:32
LArNewCalib_DelayDump_OFC_Cali.idx
idx
Definition: LArNewCalib_DelayDump_OFC_Cali.py:69
TrigVSI::KDPoint::average
static KDPoint< T, D > average(const std::vector< KDPoint< T, D >> &)
TrigVSI::KDPoint::operator+=
KDPoint< T, D > & operator+=(const KDPoint< T, D > &other)
Definition: KDPoint.h:72
TrigVSI::KDTree::genTree
void genTree()
Command to generate tree.
Definition: KDPoint.h:207
TrigVSI::KDTree::Node
Node class for KDTree.
Definition: KDPoint.h:164
python.IoTestsLib.w
def w
Definition: IoTestsLib.py:200
TrigVSI::KDTree::KDTree
KDTree()
Definition: KDPoint.h:182
Amg::distance
float distance(const Amg::Vector3D &p1, const Amg::Vector3D &p2)
calculates the distance between two point in 3D space
Definition: GeoPrimitivesHelpers.h:54
TrigVSI::KDPoint::getPos
const std::array< T, D > & getPos() const
Definition: KDPoint.h:125
TrigVSI::KDPoint::operator-=
KDPoint< T, D > & operator-=(const KDPoint< T, D > &other)
Definition: KDPoint.h:80
TrigVSI::KDPoint::KDPoint
KDPoint()
Definition: KDPoint.h:36
TrigVSI
Definition: TrigVrtSecInclusive.cxx:27
TrigVSI::KDTree::m_idLength
size_t m_idLength
Definition: KDPoint.h:194
node
Definition: memory_hooks-stdcmalloc.h:74
TrigVSI::KDTree::Node::rightPtr
std::unique_ptr< Node > rightPtr
Definition: KDPoint.h:170