Skip to content

Commit

Permalink
Filtering when looking for the nearest neighbour.
Browse files Browse the repository at this point in the history
  • Loading branch information
pleroy committed Oct 8, 2022
1 parent 471d1e6 commit 188d71e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
15 changes: 12 additions & 3 deletions numerics/nearest_neighbour.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <functional>
#include <utility>
#include <variant>
#include <vector>
Expand Down Expand Up @@ -31,6 +32,8 @@ class PrincipalComponentPartitioningTree {
public:
using Value = Value_;

using Filter = std::function<bool(Value const*)>;

// We stop subdividing a cell when it contains |max_values_per_cell| or fewer
// values. This API takes (non-owning) pointers so that the client can relate
// the values given here to the ones it gets from |FindNearestNeighbour|. The
Expand All @@ -44,8 +47,10 @@ class PrincipalComponentPartitioningTree {
void Add(not_null<Value const*> value);

// Finds the nearest neighbour of the given |value|. Returns nullptr if the
// tree is empty.
Value const* FindNearestNeighbour(Value const& value) const;
// tree is empty. Only the values for which |filter| returns true are
// considered.
Value const* FindNearestNeighbour(Value const& value,
Filter const& filter = nullptr) const;

private:
// A frame used to compute the principal components.
Expand Down Expand Up @@ -138,6 +143,7 @@ class PrincipalComponentPartitioningTree {
// true. That pointer may be null if the client doesn't want to check this
// condition. |parent| should be null for the root of the tree.
void Find(Displacement const& displacement,
Filter const& filter,
Internal const* parent,
Node const& node,
NormĀ²& min_distanceĀ²,
Expand All @@ -146,12 +152,14 @@ class PrincipalComponentPartitioningTree {

// Specializations for internal nodes and leaves, respectively.
void Find(Displacement const& displacement,
Filter const& filter,
Internal const* parent,
Internal const& internal,
NormĀ²& min_distanceĀ²,
std::int32_t& min_index,
bool* must_check_other_side) const;
void Find(Displacement const& displacement,
Filter const& filter,
Internal const* parent,
Leaf const& leaf,
NormĀ²& min_distanceĀ²,
Expand All @@ -166,7 +174,8 @@ class PrincipalComponentPartitioningTree {
// passed to |Add| if no value is passed at construction.
Value centroid_;

// The displacements from the centroid.
// The displacements from the centroid. The indices are the same as for
// |values_|.
std::vector<Displacement> displacements_;

std::unique_ptr<Node> root_;
Expand Down
29 changes: 23 additions & 6 deletions numerics/nearest_neighbour_body.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ using geometry::Wedge;
using quantities::Infinity;
using quantities::Pow;

constexpr std::int32_t no_min_index = -1;

template<typename Value_>
PrincipalComponentPartitioningTree<Value_>::PrincipalComponentPartitioningTree(
std::vector<not_null<Value const*>> const& values,
Expand Down Expand Up @@ -57,21 +59,23 @@ void PrincipalComponentPartitioningTree<Value_>::Add(

template<typename Value_>
Value_ const* PrincipalComponentPartitioningTree<Value_>::FindNearestNeighbour(
Value const& value) const {
Value const& value,
Filter const& filter) const {
if (displacements_.empty()) {
return nullptr;
}
NormĀ² min_distanceĀ²;
std::int32_t min_index;
Find(value - centroid_,
filter,
/*parent=*/nullptr,
*root_,
min_distanceĀ², min_index,
/*must_check_other_side=*/nullptr);

// In the end, this is why we retain the values: we want to return a pointer
// that the client gave us.
return values_[min_index];
return min_index == no_min_index ? nullptr : values_[min_index];
}

template<typename Value_>
Expand Down Expand Up @@ -226,19 +230,22 @@ void PrincipalComponentPartitioningTree<Value_>::Add(std::int32_t const index,
template<typename Value_>
void PrincipalComponentPartitioningTree<Value_>::Find(
Displacement const& displacement,
Filter const& filter,
Internal const* const parent,
Node const& node,
NormĀ²& min_distanceĀ²,
std::int32_t& min_index,
bool* const must_check_other_side) const {
if (std::holds_alternative<Internal>(node)) {
Find(displacement,
filter,
parent,
std::get<Internal>(node),
min_distanceĀ², min_index,
must_check_other_side);
} else if (std::holds_alternative<Leaf>(node)) {
Find(displacement,
filter,
parent,
std::get<Leaf>(node),
min_distanceĀ², min_index,
Expand All @@ -251,6 +258,7 @@ void PrincipalComponentPartitioningTree<Value_>::Find(
template<typename Value_>
void PrincipalComponentPartitioningTree<Value_>::Find(
Displacement const& displacement,
Filter const& filter,
Internal const* parent,
Internal const& internal,
NormĀ²& min_distanceĀ²,
Expand All @@ -273,6 +281,7 @@ void PrincipalComponentPartitioningTree<Value_>::Find(
NormĀ² preferred_min_distanceĀ²;
bool preferred_must_check_other_side;
Find(displacement,
filter,
&internal,
*preferred_side,
preferred_min_distanceĀ², preferred_min_index,
Expand All @@ -293,10 +302,11 @@ void PrincipalComponentPartitioningTree<Value_>::Find(
std::int32_t other_min_index;
NormĀ² other_min_distanceĀ²;
Find(displacement,
parent,
*other_side,
other_min_distanceĀ², other_min_index,
/*must_check_other_side=*/nullptr);
filter,
parent,
*other_side,
other_min_distanceĀ², other_min_index,
/*must_check_other_side=*/nullptr);

if (other_min_distanceĀ² < preferred_min_distanceĀ²) {
min_distanceĀ² = other_min_distanceĀ²;
Expand All @@ -322,6 +332,7 @@ void PrincipalComponentPartitioningTree<Value_>::Find(
template<typename Value_>
void PrincipalComponentPartitioningTree<Value_>::Find(
Displacement const& displacement,
Filter const& filter,
Internal const* const parent,
Leaf const& leaf,
NormĀ²& min_distanceĀ²,
Expand All @@ -331,7 +342,13 @@ void PrincipalComponentPartitioningTree<Value_>::Find(

// Find the point in this leaf which is the closest to |displacement|.
min_distanceĀ² = Infinity<NormĀ²>;
min_index = no_min_index;
for (auto const index : leaf) {
// Skip the values that are filtered out. Note that *all* the values may be
// filtered out.
if (filter != nullptr && !filter(values_[index])) {
continue;
}
auto const distanceĀ² = (displacements_[index] - displacement).NormĀ²();
if (distanceĀ² < min_distanceĀ²) {
min_distanceĀ² = distanceĀ²;
Expand Down

0 comments on commit 188d71e

Please sign in to comment.