From 188d71e43e24b1581f43d4a025daef4001b8347d Mon Sep 17 00:00:00 2001 From: Pascal Leroy Date: Sat, 8 Oct 2022 18:37:28 +0200 Subject: [PATCH] Filtering when looking for the nearest neighbour. --- numerics/nearest_neighbour.hpp | 15 ++++++++++++--- numerics/nearest_neighbour_body.hpp | 29 +++++++++++++++++++++++------ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/numerics/nearest_neighbour.hpp b/numerics/nearest_neighbour.hpp index 188081e0f4..f6ec04a38f 100644 --- a/numerics/nearest_neighbour.hpp +++ b/numerics/nearest_neighbour.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -31,6 +32,8 @@ class PrincipalComponentPartitioningTree { public: using Value = Value_; + using Filter = std::function; + // We stop subdividing a cell when it contains |max_values_per_cell| or fewer // values. This API takes (non-owning) pointers so that the client can relate // the values given here to the ones it gets from |FindNearestNeighbour|. The @@ -44,8 +47,10 @@ class PrincipalComponentPartitioningTree { void Add(not_null value); // Finds the nearest neighbour of the given |value|. Returns nullptr if the - // tree is empty. - Value const* FindNearestNeighbour(Value const& value) const; + // tree is empty. Only the values for which |filter| returns true are + // considered. + Value const* FindNearestNeighbour(Value const& value, + Filter const& filter = nullptr) const; private: // A frame used to compute the principal components. @@ -138,6 +143,7 @@ class PrincipalComponentPartitioningTree { // true. That pointer may be null if the client doesn't want to check this // condition. |parent| should be null for the root of the tree. void Find(Displacement const& displacement, + Filter const& filter, Internal const* parent, Node const& node, Norm²& min_distance², @@ -146,12 +152,14 @@ class PrincipalComponentPartitioningTree { // Specializations for internal nodes and leaves, respectively. void Find(Displacement const& displacement, + Filter const& filter, Internal const* parent, Internal const& internal, Norm²& min_distance², std::int32_t& min_index, bool* must_check_other_side) const; void Find(Displacement const& displacement, + Filter const& filter, Internal const* parent, Leaf const& leaf, Norm²& min_distance², @@ -166,7 +174,8 @@ class PrincipalComponentPartitioningTree { // passed to |Add| if no value is passed at construction. Value centroid_; - // The displacements from the centroid. + // The displacements from the centroid. The indices are the same as for + // |values_|. std::vector displacements_; std::unique_ptr root_; diff --git a/numerics/nearest_neighbour_body.hpp b/numerics/nearest_neighbour_body.hpp index 89909e2aa9..4bae7e82d6 100644 --- a/numerics/nearest_neighbour_body.hpp +++ b/numerics/nearest_neighbour_body.hpp @@ -25,6 +25,8 @@ using geometry::Wedge; using quantities::Infinity; using quantities::Pow; +constexpr std::int32_t no_min_index = -1; + template PrincipalComponentPartitioningTree::PrincipalComponentPartitioningTree( std::vector> const& values, @@ -57,13 +59,15 @@ void PrincipalComponentPartitioningTree::Add( template Value_ const* PrincipalComponentPartitioningTree::FindNearestNeighbour( - Value const& value) const { + Value const& value, + Filter const& filter) const { if (displacements_.empty()) { return nullptr; } Norm² min_distance²; std::int32_t min_index; Find(value - centroid_, + filter, /*parent=*/nullptr, *root_, min_distance², min_index, @@ -71,7 +75,7 @@ Value_ const* PrincipalComponentPartitioningTree::FindNearestNeighbour( // In the end, this is why we retain the values: we want to return a pointer // that the client gave us. - return values_[min_index]; + return min_index == no_min_index ? nullptr : values_[min_index]; } template @@ -226,6 +230,7 @@ void PrincipalComponentPartitioningTree::Add(std::int32_t const index, template void PrincipalComponentPartitioningTree::Find( Displacement const& displacement, + Filter const& filter, Internal const* const parent, Node const& node, Norm²& min_distance², @@ -233,12 +238,14 @@ void PrincipalComponentPartitioningTree::Find( bool* const must_check_other_side) const { if (std::holds_alternative(node)) { Find(displacement, + filter, parent, std::get(node), min_distance², min_index, must_check_other_side); } else if (std::holds_alternative(node)) { Find(displacement, + filter, parent, std::get(node), min_distance², min_index, @@ -251,6 +258,7 @@ void PrincipalComponentPartitioningTree::Find( template void PrincipalComponentPartitioningTree::Find( Displacement const& displacement, + Filter const& filter, Internal const* parent, Internal const& internal, Norm²& min_distance², @@ -273,6 +281,7 @@ void PrincipalComponentPartitioningTree::Find( Norm² preferred_min_distance²; bool preferred_must_check_other_side; Find(displacement, + filter, &internal, *preferred_side, preferred_min_distance², preferred_min_index, @@ -293,10 +302,11 @@ void PrincipalComponentPartitioningTree::Find( std::int32_t other_min_index; Norm² other_min_distance²; Find(displacement, - parent, - *other_side, - other_min_distance², other_min_index, - /*must_check_other_side=*/nullptr); + filter, + parent, + *other_side, + other_min_distance², other_min_index, + /*must_check_other_side=*/nullptr); if (other_min_distance² < preferred_min_distance²) { min_distance² = other_min_distance²; @@ -322,6 +332,7 @@ void PrincipalComponentPartitioningTree::Find( template void PrincipalComponentPartitioningTree::Find( Displacement const& displacement, + Filter const& filter, Internal const* const parent, Leaf const& leaf, Norm²& min_distance², @@ -331,7 +342,13 @@ void PrincipalComponentPartitioningTree::Find( // Find the point in this leaf which is the closest to |displacement|. min_distance² = Infinity; + min_index = no_min_index; for (auto const index : leaf) { + // Skip the values that are filtered out. Note that *all* the values may be + // filtered out. + if (filter != nullptr && !filter(values_[index])) { + continue; + } auto const distance² = (displacements_[index] - displacement).Norm²(); if (distance² < min_distance²) { min_distance² = distance²;