ATLAS Offline Software
Loading...
Searching...
No Matches
KDPoint.h
Go to the documentation of this file.
1/*
2 Copyright (C) 2002-2024 CERN for the benefit of the ATLAS collaboration
3*/
4#ifndef TRIGTOOLS_TRIG_VSI_KDPOINT
5#define TRIGTOOLS_TRIG_VSI_KDPOINT
6
13#include "TMath.h"
14
15#include <array>
16#include <vector>
17#include <functional>
18#include <memory>
19
20namespace TrigVSI {
28template<typename T, size_t D>
29class KDPoint {
30 private :
31 std::array<T,D> m_point;
32 double m_weight;
33
34
35 public :
36 KDPoint() : m_point( std::array<T,D>() ), m_weight(1.){};
37 KDPoint( std::array<T,D>& arr ) : m_point(arr), m_weight(1.){};
38 KDPoint( std::array<T,D>&& arr ) : m_point(std::move(arr)), m_weight(1.){};
39
40 KDPoint(const std::vector<T>& v) : m_weight(1.) {
41 for (size_t i = 0; i < D; i++) { if (i < v.size()) m_point[i] = v[i]; }
42 };
43
44 KDPoint(std::initializer_list<T> list) : m_weight(1.) {
45 auto bg = list.begin();
46 for (auto itr = bg; itr != bg + D; ++itr) {
47 size_t i = std::distance(bg, itr);
48 m_point[i] = *itr;
49 }
50 };
51
53 const KDPoint<T,D> operator + (const KDPoint<T,D>& other) const
54 {
55 std::array<T,D> tmp;
56 for ( size_t i = 0; i < D; i++ ) {
57 tmp[i] = this->m_point[i] + other.m_point[i];
58 }
59 return KDPoint<T,D>(tmp);
60 };
61
63 const KDPoint<T,D> operator - (const KDPoint<T,D>& other) const
64 {
65 std::array<T,D> tmp;
66 for ( size_t i = 0; i < D; i++ ) {
67 tmp[i] = this->m_point[i] - other.m_point[i];
68 }
69 return KDPoint<T,D>(tmp);
70 };
71
73 {
74 for ( size_t i = 0; i < D; i++ ) {
75 this->m_point[i] += other.m_point.at(i);
76 }
77 return *this;
78 };
79
81 {
82 for ( size_t i = 0; i < D; i++ ) {
83 this->m_point[i] -= other.m_point.at(i);
84 }
85 return *this;
86 };
87
89 template<typename I>
90 const KDPoint<T,D> operator * (const I& other) const
91 {
92 std::array<T,D> tmp;
93 for ( size_t i = 0; i < D; i++ ) {
94 tmp[i] = this->m_point[i] * other;
95 }
96 return KDPoint<T,D>(tmp);
97 }
98
100 template<typename I>
101 const KDPoint<T,D> operator / (const I& other) const
102 {
103 std::array<T,D> tmp;
104 for ( size_t i = 0; i < D; i++ ) {
105 tmp[i] = this->m_point[i] / other;
106 }
107 return KDPoint<T,D>(tmp);
108 }
109
111 T& operator [] (size_t i)
112 {
113 return m_point[i];
114 };
115
117 const T& operator [] (size_t i) const
118 {
119 return m_point[i];
120 };
121
125 inline const std::array<T,D>& getPos() const { return m_point; };
126
128 inline T at(size_t i) const { return (i < D)? m_point[i] : TMath::QuietNaN(); };
129
131 inline double getWeight() const { return m_weight; };
133 inline void setWeight(double w){ m_weight = w; };
135
138 {
139 KDPoint<T,D> tmp;
140 tmp = ( *this * this->m_weight + p * p.m_weight ) / ( this->m_weight + p.m_weight );
141 return tmp;
142 };
143
144 static inline KDPoint<T, D> average(const std::vector<KDPoint<T,D>>&);
145};
146
147template<typename I, typename T, size_t D>
148const KDPoint<T,D> operator * (const I& b, const KDPoint<T,D>& p)
149{
150 return p * b;
151}
152
153
158template<typename T, size_t D>
159class KDTree {
160 public :
164 struct Node {
167 std::unique_ptr<Node> leftPtr;
168 std::unique_ptr<Node> rightPtr;
169
170 Node(const KDPoint<T,D>& data, int idx) : dataRef(data), dataIdx(idx){};
171 Node(){};
172 };
173
174 KDTree(std::vector<KDPoint<T,D>>& v_data) : m_datas( v_data ), m_idLength( m_datas.size() ), m_locked( false )
175 {
176 m_indices.clear();
177 for (size_t i = 0; i < m_idLength; i++) { m_indices.emplace_back(i); }
178 };
179
180 KDTree() : m_idLength(0), m_locked(false){};
181
182 void genTree();
183 inline void lock() { m_locked = true; };
184 inline void unlock() { m_locked = false; };
185
186 inline KDPoint<T,D> at(size_t n) { return m_datas.at(n); };
187
188 private :
189 std::unique_ptr<Node> m_rootNode;
190 std::vector<KDPoint<T,D>> m_datas;
191 std::vector<size_t> m_indices;
194
195 std::unique_ptr<Node> buildTree(int, int, int);
196
197 void nearestNeighborRec( const KDPoint<T,D>&, const Node*, double&, int&, std::function<double(const KDPoint<T,D>&, const KDPoint<T,D>&)>& );
198};
199
200
204template<typename T, size_t D>
206{
207 if (m_locked) return;
209}
210
211
219template<typename T, size_t D>
220std::unique_ptr<typename KDTree<T,D>::Node> KDTree<T,D>::buildTree(int l, int r, int depth)
221{
222 if ( l >= r ) return std::make_unique<Node>(nullptr);
223
224 const int axis_ = depth % D;
225 const int mid = ( l + r ) >> 1;
226
227 std::nth_element( m_indices.begin()+l, m_indices.begin()+l+mid, m_indices.begin()+r,
228 [this, axis_](size_t lcnt, size_t rcnt) {
229 return m_datas[lcnt].getId(axis_) < m_datas[rcnt].getId(axis_);
230 } );
231
232 std::unique_ptr<Node> node_ptr = std::make_unique<Node> ( m_datas.at(m_indices[mid]), m_indices[mid] );
233 node_ptr->axis = axis_;
234
235 node_ptr->leftPtr = buildTree( l, mid, depth+1 );
236 node_ptr->rightPtr = buildTree( mid+1, r, depth+1 );
237
238 return node_ptr;
239}
240
241
250template<typename T, size_t D>
252 double& r, int& idx, std::function<double(const KDPoint<T,D>&, const KDPoint<T,D>&)>& dist_func )
253{
254 // end processing when reach leaf
255 if ( node == nullptr ) {
256 return;
257 }
258
259 const KDPoint<T,D>& point = node->dataRef;
260
261 // update minimum distance and candidate point
262 const double dist = dist_func(query, point);
263 if (dist < r) {
264 idx = node->dataIdx;
265 r = dist;
266 }
267
268 const int axis_ = node->axis;
269 const Node* next_node = (query.at(axis_) < point.at(axis_))? node->leftPtr.get() : node->rightPtr.get();
270
271 // recursively call self until reach leaf node
272 nearestNeighborRec( query, next_node, r, idx, dist_func );
273
274 // Check if hyper-sphere with radius r overlaps neighbor hyper-plane.
275 KDPoint<T,D> proj = query;
276 proj[axis_] = point.at(axis_); // proj is the projected point of query on hyperplane x[axis_]=point[axis_]
277 double diff = dist_func(proj, query);
278 if ( diff < r ) {
279 // if hyper-sphere overlaps, check the other region separated by the hyper-plane.
280 const Node* next_node_opps = (query.at(axis_) < point.at(axis_))? node->rightPtr.get() : node->leftPtr.get();
281 nearestNeighborRec( query, next_node_opps, r, idx, dist_func );
282 }
283 return;
284}
285
286
287} // end of namespace TrigVSI
288
289
290#endif
char data[hepevt_bytes_allocation_ATLAS]
Definition HepEvt.cxx:11
#define I(x, y, z)
Definition MD5.cxx:116
void diff(const Jet &rJet1, const Jet &rJet2, std::map< std::string, double > varDiff)
Difference between jets - Non-Class function required by trigger.
Definition Jet.cxx:631
Class for k-dimensional point.
Definition KDPoint.h:29
const KDPoint< T, D > operator+(const KDPoint< T, D > &other) const
Add each elements except weights.
Definition KDPoint.h:53
void setWeight(double w)
Set the weight to given value.
Definition KDPoint.h:133
const KDPoint< T, D > operator*(const I &other) const
Multiply each elements except weights.
Definition KDPoint.h:90
KDPoint(std::initializer_list< T > list)
Definition KDPoint.h:44
const std::array< T, D > & getPos() const
Definition KDPoint.h:125
KDPoint(const std::vector< T > &v)
Definition KDPoint.h:40
KDPoint< T, D > & operator-=(const KDPoint< T, D > &other)
Definition KDPoint.h:80
T & operator[](size_t i)
Return i-th element.
Definition KDPoint.h:111
double m_weight
Definition KDPoint.h:32
T at(size_t i) const
Return i-th element. If given i exceeds the size, return NaN.
Definition KDPoint.h:128
const KDPoint< T, D > operator/(const I &other) const
Divide each elements except weights.
Definition KDPoint.h:101
double getWeight() const
Return the weight of the point.
Definition KDPoint.h:131
const KDPoint< T, D > operator-(const KDPoint< T, D > &other) const
Subtract each elements except weights.
Definition KDPoint.h:63
KDPoint(std::array< T, D > &arr)
Definition KDPoint.h:37
KDPoint< T, D > & operator+=(const KDPoint< T, D > &other)
Definition KDPoint.h:72
static KDPoint< T, D > average(const std::vector< KDPoint< T, D > > &)
KDPoint(std::array< T, D > &&arr)
Definition KDPoint.h:38
std::array< T, D > m_point
Definition KDPoint.h:31
KDPoint< T, D > average(const KDPoint< T, D > &p)
Return average point of this point and given point.
Definition KDPoint.h:137
KDTree(std::vector< KDPoint< T, D > > &v_data)
Definition KDPoint.h:174
std::vector< KDPoint< T, D > > m_datas
Container of the points.
Definition KDPoint.h:190
size_t m_idLength
Definition KDPoint.h:192
std::unique_ptr< Node > buildTree(int, int, int)
recursive function to create tree structure.
Definition KDPoint.h:220
void unlock()
Definition KDPoint.h:184
std::vector< size_t > m_indices
A list of indices of points in m_datas.
Definition KDPoint.h:191
std::unique_ptr< Node > m_rootNode
The root node of the tree.
Definition KDPoint.h:189
void genTree()
Command to generate tree.
Definition KDPoint.h:205
void nearestNeighborRec(const KDPoint< T, D > &, const Node *, double &, int &, std::function< double(const KDPoint< T, D > &, const KDPoint< T, D > &)> &)
recursive function for nearest neighbor searching.
Definition KDPoint.h:251
KDPoint< T, D > at(size_t n)
Definition KDPoint.h:186
Definition node.h:24
std::string depth
tag string for intendation
Definition fastadd.cxx:46
int r
Definition globals.cxx:22
const KDPoint< T, D > operator*(const I &b, const KDPoint< T, D > &p)
Definition KDPoint.h:148
Definition query.py:1
STL namespace.
Node class for KDTree.
Definition KDPoint.h:164
std::unique_ptr< Node > rightPtr
Definition KDPoint.h:168
Node(const KDPoint< T, D > &data, int idx)
Definition KDPoint.h:170
const KDPoint< T, D > & dataRef
Definition KDPoint.h:165
std::unique_ptr< Node > leftPtr
Definition KDPoint.h:167
#define private