SourceXtractorPlusPlus 0.21
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
KdTree.icpp
Go to the documentation of this file.
1/** Copyright © 2021 Université de Genève, LMU Munich - Faculty of Physics, IAP-CNRS/Sorbonne Université
2 *
3 * This library is free software; you can redistribute it and/or modify it under
4 * the terms of the GNU Lesser General Public License as published by the Free
5 * Software Foundation; either version 3.0 of the License, or (at your option)
6 * any later version.
7 *
8 * This library is distributed in the hope that it will be useful, but WITHOUT
9 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
10 * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
11 * details.
12 *
13 * You should have received a copy of the GNU Lesser General Public License
14 * along with this library; if not, write to the Free Software Foundation, Inc.,
15 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 */
17
18namespace SourceXtractor {
19
20template<typename T, size_t N, size_t S>
21class KdTree<T, N, S>::Node {
22public:
23 virtual std::vector<T> findPointsWithinRadius(Coord coord, double radius) const = 0;
24 virtual ~Node() = default;
25};
26
27template<typename T, size_t N, size_t S>
28class KdTree<T, N, S>::Leaf : public KdTree::Node {
29public:
30 explicit Leaf(const std::vector<T>&& data) : m_data(data) {}
31 virtual ~Leaf() = default;
32
33 virtual std::vector<T> findPointsWithinRadius(Coord coord, double radius) const {
34 std::vector<T> selection;
35 for (auto& entry : m_data) {
36 double square_dist = 0.0;
37 for (size_t i =0; i < N; i++) {
38 double delta = Traits::getCoord(entry, i) - coord.coord[i];
39 square_dist += delta * delta;
40 }
41 if (square_dist < radius*radius) {
42 selection.push_back(entry);
43 }
44 }
45 return selection;
46 }
47
48private:
49 const std::vector<T> m_data;
50};
51
52template<typename T, size_t N, size_t S>
53class KdTree<T, N, S>::Split : public KdTree::Node {
54public:
55 virtual ~Split() = default;
56 explicit Split(std::vector<T> data, size_t axis) : m_axis(axis) {
57 std::sort(data.begin(), data.end(), [axis](const T& a, const T& b) -> bool {
58 return Traits::getCoord(a, axis) < Traits::getCoord(b, axis);
59 });
60
61 double a = Traits::getCoord(data.at(data.size() / 2 - 1), axis);
62 double b = Traits::getCoord(data.at(data.size() / 2), axis);
63
64 if (a == b) {
65 // avoid a possible rounding issue
66 m_split_value = a;
67 } else {
68 m_split_value = (a + b) / 2.0;
69 }
70
71 std::vector<T> left(data.begin(), data.begin() + data.size() / 2);
72 std::vector<T> right(data.begin() + data.size() / 2, data.end());
73
74 if (left.size() > S) {
75 m_left_child = std::make_shared<Split>(std::move(left), (axis+1) % N);
76 } else {
77 m_left_child = std::make_shared<Leaf>(std::move(left));
78 }
79 if (right.size() > S) {
80 m_right_child = std::make_shared<Split>(std::move(right), (axis+1) % N);
81 } else {
82 m_right_child = std::make_shared<Leaf>(std::move(right));
83 }
84 }
85
86 virtual std::vector<T> findPointsWithinRadius(Coord coord, double radius) const {
87 if (coord.coord[m_axis] + radius < m_split_value) {
88 return m_left_child->findPointsWithinRadius(coord, radius);
89 } else if (coord.coord[m_axis] - radius > m_split_value) {
90 return m_right_child->findPointsWithinRadius(coord, radius);
91 } else {
92 auto left = m_left_child->findPointsWithinRadius(coord, radius);
93 auto right = m_right_child->findPointsWithinRadius(coord, radius);
94
95 std::vector<T> merge;
96 merge.reserve(left.size() + right.size());
97 merge.insert(merge.end(), left.begin(), left.end());
98 merge.insert(merge.end(), right.begin(), right.end());
99
100 return merge;
101 }
102 }
103
104private:
105 size_t m_axis;
106 double m_split_value;
107
108 std::shared_ptr<Node> m_left_child;
109 std::shared_ptr<Node> m_right_child;
110};
111
112template<typename T, size_t N, size_t S>
113KdTree<T, N, S>::KdTree(const std::vector<T>& data) {
114 if (data.size() > S) {
115 m_root = std::make_shared<Split>(data, 0);
116 } else {
117 std::vector<T> data_copy(data);
118 m_root = std::make_shared<Leaf>(std::move(data_copy));
119 }
120}
121
122template<typename T, size_t N, size_t S>
123std::vector<T> KdTree<T, N, S>::findPointsWithinRadius(Coord coord, double radius) const {
124 return m_root->findPointsWithinRadius(coord, radius);
125}
126
127}