From 188d71e43e24b1581f43d4a025daef4001b8347d Mon Sep 17 00:00:00 2001 From: Pascal Leroy Date: Sat, 8 Oct 2022 18:37:28 +0200 Subject: [PATCH 1/2] Filtering when looking for the nearest neighbour. --- numerics/nearest_neighbour.hpp | 15 ++++++++++++--- numerics/nearest_neighbour_body.hpp | 29 +++++++++++++++++++++++------ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/numerics/nearest_neighbour.hpp b/numerics/nearest_neighbour.hpp index 188081e0f4..f6ec04a38f 100644 --- a/numerics/nearest_neighbour.hpp +++ b/numerics/nearest_neighbour.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -31,6 +32,8 @@ class PrincipalComponentPartitioningTree { public: using Value = Value_; + using Filter = std::function; + // We stop subdividing a cell when it contains |max_values_per_cell| or fewer // values. This API takes (non-owning) pointers so that the client can relate // the values given here to the ones it gets from |FindNearestNeighbour|. The @@ -44,8 +47,10 @@ class PrincipalComponentPartitioningTree { void Add(not_null value); // Finds the nearest neighbour of the given |value|. Returns nullptr if the - // tree is empty. - Value const* FindNearestNeighbour(Value const& value) const; + // tree is empty. Only the values for which |filter| returns true are + // considered. + Value const* FindNearestNeighbour(Value const& value, + Filter const& filter = nullptr) const; private: // A frame used to compute the principal components. @@ -138,6 +143,7 @@ class PrincipalComponentPartitioningTree { // true. That pointer may be null if the client doesn't want to check this // condition. |parent| should be null for the root of the tree. void Find(Displacement const& displacement, + Filter const& filter, Internal const* parent, Node const& node, Norm²& min_distance², @@ -146,12 +152,14 @@ class PrincipalComponentPartitioningTree { // Specializations for internal nodes and leaves, respectively. void Find(Displacement const& displacement, + Filter const& filter, Internal const* parent, Internal const& internal, Norm²& min_distance², std::int32_t& min_index, bool* must_check_other_side) const; void Find(Displacement const& displacement, + Filter const& filter, Internal const* parent, Leaf const& leaf, Norm²& min_distance², @@ -166,7 +174,8 @@ class PrincipalComponentPartitioningTree { // passed to |Add| if no value is passed at construction. Value centroid_; - // The displacements from the centroid. + // The displacements from the centroid. The indices are the same as for + // |values_|. std::vector displacements_; std::unique_ptr root_; diff --git a/numerics/nearest_neighbour_body.hpp b/numerics/nearest_neighbour_body.hpp index 89909e2aa9..4bae7e82d6 100644 --- a/numerics/nearest_neighbour_body.hpp +++ b/numerics/nearest_neighbour_body.hpp @@ -25,6 +25,8 @@ using geometry::Wedge; using quantities::Infinity; using quantities::Pow; +constexpr std::int32_t no_min_index = -1; + template PrincipalComponentPartitioningTree::PrincipalComponentPartitioningTree( std::vector> const& values, @@ -57,13 +59,15 @@ void PrincipalComponentPartitioningTree::Add( template Value_ const* PrincipalComponentPartitioningTree::FindNearestNeighbour( - Value const& value) const { + Value const& value, + Filter const& filter) const { if (displacements_.empty()) { return nullptr; } Norm² min_distance²; std::int32_t min_index; Find(value - centroid_, + filter, /*parent=*/nullptr, *root_, min_distance², min_index, @@ -71,7 +75,7 @@ Value_ const* PrincipalComponentPartitioningTree::FindNearestNeighbour( // In the end, this is why we retain the values: we want to return a pointer // that the client gave us. - return values_[min_index]; + return min_index == no_min_index ? nullptr : values_[min_index]; } template @@ -226,6 +230,7 @@ void PrincipalComponentPartitioningTree::Add(std::int32_t const index, template void PrincipalComponentPartitioningTree::Find( Displacement const& displacement, + Filter const& filter, Internal const* const parent, Node const& node, Norm²& min_distance², @@ -233,12 +238,14 @@ void PrincipalComponentPartitioningTree::Find( bool* const must_check_other_side) const { if (std::holds_alternative(node)) { Find(displacement, + filter, parent, std::get(node), min_distance², min_index, must_check_other_side); } else if (std::holds_alternative(node)) { Find(displacement, + filter, parent, std::get(node), min_distance², min_index, @@ -251,6 +258,7 @@ void PrincipalComponentPartitioningTree::Find( template void PrincipalComponentPartitioningTree::Find( Displacement const& displacement, + Filter const& filter, Internal const* parent, Internal const& internal, Norm²& min_distance², @@ -273,6 +281,7 @@ void PrincipalComponentPartitioningTree::Find( Norm² preferred_min_distance²; bool preferred_must_check_other_side; Find(displacement, + filter, &internal, *preferred_side, preferred_min_distance², preferred_min_index, @@ -293,10 +302,11 @@ void PrincipalComponentPartitioningTree::Find( std::int32_t other_min_index; Norm² other_min_distance²; Find(displacement, - parent, - *other_side, - other_min_distance², other_min_index, - /*must_check_other_side=*/nullptr); + filter, + parent, + *other_side, + other_min_distance², other_min_index, + /*must_check_other_side=*/nullptr); if (other_min_distance² < preferred_min_distance²) { min_distance² = other_min_distance²; @@ -322,6 +332,7 @@ void PrincipalComponentPartitioningTree::Find( template void PrincipalComponentPartitioningTree::Find( Displacement const& displacement, + Filter const& filter, Internal const* const parent, Leaf const& leaf, Norm²& min_distance², @@ -331,7 +342,13 @@ void PrincipalComponentPartitioningTree::Find( // Find the point in this leaf which is the closest to |displacement|. min_distance² = Infinity; + min_index = no_min_index; for (auto const index : leaf) { + // Skip the values that are filtered out. Note that *all* the values may be + // filtered out. + if (filter != nullptr && !filter(values_[index])) { + continue; + } auto const distance² = (displacements_[index] - displacement).Norm²(); if (distance² < min_distance²) { min_distance² = distance²; From 5a548408843094a05debe1c03545f8030f12690d Mon Sep 17 00:00:00 2001 From: Pascal Leroy Date: Sun, 9 Oct 2022 18:18:48 +0200 Subject: [PATCH 2/2] A test. --- numerics/nearest_neighbour_test.cpp | 62 ++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 6 deletions(-) diff --git a/numerics/nearest_neighbour_test.cpp b/numerics/nearest_neighbour_test.cpp index 118d0d4a6b..369df8c4ec 100644 --- a/numerics/nearest_neighbour_test.cpp +++ b/numerics/nearest_neighbour_test.cpp @@ -26,11 +26,16 @@ class PrincipalComponentPartitioningTreeTest : public ::testing::Test { using V = Vector; // Computes the nearest point using the brute force algorithm. - V const* BruteForceNearestNeighbour(V const& query_value, - std::vector const& values) { + V const* BruteForceNearestNeighbour( + V const& query_value, + std::vector const& values, + PrincipalComponentPartitioningTree::Filter const& filter = nullptr) { V const* nearest = nullptr; double nearest_distance = Infinity; for (auto const& value : values) { + if (filter != nullptr && !filter(&value)) { + continue; + } double const distance = (value - query_value).Norm(); if (distance < nearest_distance) { nearest_distance = distance; @@ -142,8 +147,8 @@ TEST_F(PrincipalComponentPartitioningTreeTest, RandomConstructor) { auto* const nearest1 = tree1.FindNearestNeighbour(query_point); auto* const nearest3 = tree3.FindNearestNeighbour(query_point); - EXPECT_THAT(nearest1, Eq(nearest)); - EXPECT_THAT(nearest3, Eq(nearest)); + EXPECT_THAT(nearest1, Eq(nearest)) << *nearest1 << " " << *nearest; + EXPECT_THAT(nearest3, Eq(nearest)) << *nearest3 << " " << *nearest; } } @@ -181,9 +186,54 @@ TEST_F(PrincipalComponentPartitioningTreeTest, RandomAdd) { auto* const nearest1 = tree1.FindNearestNeighbour(query_point); auto* const nearest3 = tree3.FindNearestNeighbour(query_point); - EXPECT_THAT(nearest1, Eq(nearest)); - EXPECT_THAT(nearest3, Eq(nearest)); + EXPECT_THAT(nearest1, Eq(nearest)) << *nearest1 << " " << *nearest; + EXPECT_THAT(nearest3, Eq(nearest)) << *nearest3 << " " << *nearest; + } +} + +// Random points with a filter, validated against the brute force algorithm. +TEST_F(PrincipalComponentPartitioningTreeTest, RandomFilter) { + static constexpr int points_in_tree = 100; + static constexpr int points_to_test = 10; + std::mt19937_64 random(42); + std::uniform_real_distribution coordinate_distribution(-10, 10); + + // Build two trees with the same points but different leaf sizes. + std::vector tree_points; + std::vector> tree_pointers; + MakeValues(points_in_tree, + tree_points, + tree_pointers, + random, + coordinate_distribution); + PrincipalComponentPartitioningTree tree1(tree_pointers, + /*max_values_per_cell=*/1); + PrincipalComponentPartitioningTree tree3(tree_pointers, + /*max_values_per_cell=*/3); + + const PrincipalComponentPartitioningTree::Filter filter = + [](V const* const point) { + return point->Norm²() < 100; + }; + + bool filtering_was_effective = false; + for (int i = 0; i < points_to_test; ++i) { + auto const query_point = V({coordinate_distribution(random), + coordinate_distribution(random), + coordinate_distribution(random)}); + + auto* const nearest = + BruteForceNearestNeighbour(query_point, tree_points, filter); + auto* const nearest1 = tree1.FindNearestNeighbour(query_point, filter); + auto* const nearest3 = tree3.FindNearestNeighbour(query_point, filter); + + EXPECT_THAT(nearest1, Eq(nearest)) << *nearest1 << " " << *nearest; + EXPECT_THAT(nearest3, Eq(nearest)) << *nearest3 << " " << *nearest; + + filtering_was_effective |= + nearest != BruteForceNearestNeighbour(query_point, tree_points); } + EXPECT_TRUE(filtering_was_effective) << "Filtering did nothing"; } } // namespace numerics