Skip to content

Commit

Permalink
Merge pull request #3446 from pleroy/Filter
Browse files Browse the repository at this point in the history
Support for applying a filter when looking for the nearest neighbour
  • Loading branch information
pleroy authored Oct 10, 2022
2 parents 07b97b8 + 5a54840 commit 7251a89
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 15 deletions.
15 changes: 12 additions & 3 deletions numerics/nearest_neighbour.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <functional>
#include <utility>
#include <variant>
#include <vector>
Expand Down Expand Up @@ -31,6 +32,8 @@ class PrincipalComponentPartitioningTree {
public:
using Value = Value_;

using Filter = std::function<bool(Value const*)>;

// We stop subdividing a cell when it contains |max_values_per_cell| or fewer
// values. This API takes (non-owning) pointers so that the client can relate
// the values given here to the ones it gets from |FindNearestNeighbour|. The
Expand All @@ -44,8 +47,10 @@ class PrincipalComponentPartitioningTree {
void Add(not_null<Value const*> value);

// Finds the nearest neighbour of the given |value|. Returns nullptr if the
// tree is empty.
Value const* FindNearestNeighbour(Value const& value) const;
// tree is empty. Only the values for which |filter| returns true are
// considered.
Value const* FindNearestNeighbour(Value const& value,
Filter const& filter = nullptr) const;

private:
// A frame used to compute the principal components.
Expand Down Expand Up @@ -138,6 +143,7 @@ class PrincipalComponentPartitioningTree {
// true. That pointer may be null if the client doesn't want to check this
// condition. |parent| should be null for the root of the tree.
void Find(Displacement const& displacement,
Filter const& filter,
Internal const* parent,
Node const& node,
Norm²& min_distance²,
Expand All @@ -146,12 +152,14 @@ class PrincipalComponentPartitioningTree {

// Specializations for internal nodes and leaves, respectively.
void Find(Displacement const& displacement,
Filter const& filter,
Internal const* parent,
Internal const& internal,
Norm²& min_distance²,
std::int32_t& min_index,
bool* must_check_other_side) const;
void Find(Displacement const& displacement,
Filter const& filter,
Internal const* parent,
Leaf const& leaf,
Norm²& min_distance²,
Expand All @@ -166,7 +174,8 @@ class PrincipalComponentPartitioningTree {
// passed to |Add| if no value is passed at construction.
Value centroid_;

// The displacements from the centroid.
// The displacements from the centroid. The indices are the same as for
// |values_|.
std::vector<Displacement> displacements_;

std::unique_ptr<Node> root_;
Expand Down
29 changes: 23 additions & 6 deletions numerics/nearest_neighbour_body.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ using geometry::Wedge;
using quantities::Infinity;
using quantities::Pow;

constexpr std::int32_t no_min_index = -1;

template<typename Value_>
PrincipalComponentPartitioningTree<Value_>::PrincipalComponentPartitioningTree(
std::vector<not_null<Value const*>> const& values,
Expand Down Expand Up @@ -57,21 +59,23 @@ void PrincipalComponentPartitioningTree<Value_>::Add(

template<typename Value_>
Value_ const* PrincipalComponentPartitioningTree<Value_>::FindNearestNeighbour(
Value const& value) const {
Value const& value,
Filter const& filter) const {
if (displacements_.empty()) {
return nullptr;
}
Norm² min_distance²;
std::int32_t min_index;
Find(value - centroid_,
filter,
/*parent=*/nullptr,
*root_,
min_distance², min_index,
/*must_check_other_side=*/nullptr);

// In the end, this is why we retain the values: we want to return a pointer
// that the client gave us.
return values_[min_index];
return min_index == no_min_index ? nullptr : values_[min_index];
}

template<typename Value_>
Expand Down Expand Up @@ -226,19 +230,22 @@ void PrincipalComponentPartitioningTree<Value_>::Add(std::int32_t const index,
template<typename Value_>
void PrincipalComponentPartitioningTree<Value_>::Find(
Displacement const& displacement,
Filter const& filter,
Internal const* const parent,
Node const& node,
Norm²& min_distance²,
std::int32_t& min_index,
bool* const must_check_other_side) const {
if (std::holds_alternative<Internal>(node)) {
Find(displacement,
filter,
parent,
std::get<Internal>(node),
min_distance², min_index,
must_check_other_side);
} else if (std::holds_alternative<Leaf>(node)) {
Find(displacement,
filter,
parent,
std::get<Leaf>(node),
min_distance², min_index,
Expand All @@ -251,6 +258,7 @@ void PrincipalComponentPartitioningTree<Value_>::Find(
template<typename Value_>
void PrincipalComponentPartitioningTree<Value_>::Find(
Displacement const& displacement,
Filter const& filter,
Internal const* parent,
Internal const& internal,
Norm²& min_distance²,
Expand All @@ -273,6 +281,7 @@ void PrincipalComponentPartitioningTree<Value_>::Find(
Norm² preferred_min_distance²;
bool preferred_must_check_other_side;
Find(displacement,
filter,
&internal,
*preferred_side,
preferred_min_distance², preferred_min_index,
Expand All @@ -293,10 +302,11 @@ void PrincipalComponentPartitioningTree<Value_>::Find(
std::int32_t other_min_index;
Norm² other_min_distance²;
Find(displacement,
parent,
*other_side,
other_min_distance², other_min_index,
/*must_check_other_side=*/nullptr);
filter,
parent,
*other_side,
other_min_distance², other_min_index,
/*must_check_other_side=*/nullptr);

if (other_min_distance² < preferred_min_distance²) {
min_distance² = other_min_distance²;
Expand All @@ -322,6 +332,7 @@ void PrincipalComponentPartitioningTree<Value_>::Find(
template<typename Value_>
void PrincipalComponentPartitioningTree<Value_>::Find(
Displacement const& displacement,
Filter const& filter,
Internal const* const parent,
Leaf const& leaf,
Norm²& min_distance²,
Expand All @@ -331,7 +342,13 @@ void PrincipalComponentPartitioningTree<Value_>::Find(

// Find the point in this leaf which is the closest to |displacement|.
min_distance² = Infinity<Norm²>;
min_index = no_min_index;
for (auto const index : leaf) {
// Skip the values that are filtered out. Note that *all* the values may be
// filtered out.
if (filter != nullptr && !filter(values_[index])) {
continue;
}
auto const distance² = (displacements_[index] - displacement).Norm²();
if (distance² < min_distance²) {
min_distance² = distance²;
Expand Down
62 changes: 56 additions & 6 deletions numerics/nearest_neighbour_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@ class PrincipalComponentPartitioningTreeTest : public ::testing::Test {
using V = Vector<double, World>;

// Computes the nearest point using the brute force algorithm.
V const* BruteForceNearestNeighbour(V const& query_value,
std::vector<V> const& values) {
V const* BruteForceNearestNeighbour(
V const& query_value,
std::vector<V> const& values,
PrincipalComponentPartitioningTree<V>::Filter const& filter = nullptr) {
V const* nearest = nullptr;
double nearest_distance = Infinity<double>;
for (auto const& value : values) {
if (filter != nullptr && !filter(&value)) {
continue;
}
double const distance = (value - query_value).Norm();
if (distance < nearest_distance) {
nearest_distance = distance;
Expand Down Expand Up @@ -142,8 +147,8 @@ TEST_F(PrincipalComponentPartitioningTreeTest, RandomConstructor) {
auto* const nearest1 = tree1.FindNearestNeighbour(query_point);
auto* const nearest3 = tree3.FindNearestNeighbour(query_point);

EXPECT_THAT(nearest1, Eq(nearest));
EXPECT_THAT(nearest3, Eq(nearest));
EXPECT_THAT(nearest1, Eq(nearest)) << *nearest1 << " " << *nearest;
EXPECT_THAT(nearest3, Eq(nearest)) << *nearest3 << " " << *nearest;
}
}

Expand Down Expand Up @@ -181,9 +186,54 @@ TEST_F(PrincipalComponentPartitioningTreeTest, RandomAdd) {
auto* const nearest1 = tree1.FindNearestNeighbour(query_point);
auto* const nearest3 = tree3.FindNearestNeighbour(query_point);

EXPECT_THAT(nearest1, Eq(nearest));
EXPECT_THAT(nearest3, Eq(nearest));
EXPECT_THAT(nearest1, Eq(nearest)) << *nearest1 << " " << *nearest;
EXPECT_THAT(nearest3, Eq(nearest)) << *nearest3 << " " << *nearest;
}
}

// Random points with a filter, validated against the brute force algorithm.
TEST_F(PrincipalComponentPartitioningTreeTest, RandomFilter) {
static constexpr int points_in_tree = 100;
static constexpr int points_to_test = 10;
std::mt19937_64 random(42);
std::uniform_real_distribution<double> coordinate_distribution(-10, 10);

// Build two trees with the same points but different leaf sizes.
std::vector<V> tree_points;
std::vector<not_null<V const*>> tree_pointers;
MakeValues(points_in_tree,
tree_points,
tree_pointers,
random,
coordinate_distribution);
PrincipalComponentPartitioningTree<V> tree1(tree_pointers,
/*max_values_per_cell=*/1);
PrincipalComponentPartitioningTree<V> tree3(tree_pointers,
/*max_values_per_cell=*/3);

const PrincipalComponentPartitioningTree<V>::Filter filter =
[](V const* const point) {
return point->Norm²() < 100;
};

bool filtering_was_effective = false;
for (int i = 0; i < points_to_test; ++i) {
auto const query_point = V({coordinate_distribution(random),
coordinate_distribution(random),
coordinate_distribution(random)});

auto* const nearest =
BruteForceNearestNeighbour(query_point, tree_points, filter);
auto* const nearest1 = tree1.FindNearestNeighbour(query_point, filter);
auto* const nearest3 = tree3.FindNearestNeighbour(query_point, filter);

EXPECT_THAT(nearest1, Eq(nearest)) << *nearest1 << " " << *nearest;
EXPECT_THAT(nearest3, Eq(nearest)) << *nearest3 << " " << *nearest;

filtering_was_effective |=
nearest != BruteForceNearestNeighbour(query_point, tree_points);
}
EXPECT_TRUE(filtering_was_effective) << "Filtering did nothing";
}

} // namespace numerics
Expand Down

0 comments on commit 7251a89

Please sign in to comment.