From 5a548408843094a05debe1c03545f8030f12690d Mon Sep 17 00:00:00 2001 From: Pascal Leroy Date: Sun, 9 Oct 2022 18:18:48 +0200 Subject: [PATCH] A test. --- numerics/nearest_neighbour_test.cpp | 62 ++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 6 deletions(-) diff --git a/numerics/nearest_neighbour_test.cpp b/numerics/nearest_neighbour_test.cpp index 118d0d4a6b..369df8c4ec 100644 --- a/numerics/nearest_neighbour_test.cpp +++ b/numerics/nearest_neighbour_test.cpp @@ -26,11 +26,16 @@ class PrincipalComponentPartitioningTreeTest : public ::testing::Test { using V = Vector; // Computes the nearest point using the brute force algorithm. - V const* BruteForceNearestNeighbour(V const& query_value, - std::vector const& values) { + V const* BruteForceNearestNeighbour( + V const& query_value, + std::vector const& values, + PrincipalComponentPartitioningTree::Filter const& filter = nullptr) { V const* nearest = nullptr; double nearest_distance = Infinity; for (auto const& value : values) { + if (filter != nullptr && !filter(&value)) { + continue; + } double const distance = (value - query_value).Norm(); if (distance < nearest_distance) { nearest_distance = distance; @@ -142,8 +147,8 @@ TEST_F(PrincipalComponentPartitioningTreeTest, RandomConstructor) { auto* const nearest1 = tree1.FindNearestNeighbour(query_point); auto* const nearest3 = tree3.FindNearestNeighbour(query_point); - EXPECT_THAT(nearest1, Eq(nearest)); - EXPECT_THAT(nearest3, Eq(nearest)); + EXPECT_THAT(nearest1, Eq(nearest)) << *nearest1 << " " << *nearest; + EXPECT_THAT(nearest3, Eq(nearest)) << *nearest3 << " " << *nearest; } } @@ -181,9 +186,54 @@ TEST_F(PrincipalComponentPartitioningTreeTest, RandomAdd) { auto* const nearest1 = tree1.FindNearestNeighbour(query_point); auto* const nearest3 = tree3.FindNearestNeighbour(query_point); - EXPECT_THAT(nearest1, Eq(nearest)); - EXPECT_THAT(nearest3, Eq(nearest)); + EXPECT_THAT(nearest1, Eq(nearest)) << *nearest1 << " " << *nearest; + EXPECT_THAT(nearest3, Eq(nearest)) << *nearest3 << " " << *nearest; + } +} + +// Random points with a filter, validated against the brute force algorithm. +TEST_F(PrincipalComponentPartitioningTreeTest, RandomFilter) { + static constexpr int points_in_tree = 100; + static constexpr int points_to_test = 10; + std::mt19937_64 random(42); + std::uniform_real_distribution coordinate_distribution(-10, 10); + + // Build two trees with the same points but different leaf sizes. + std::vector tree_points; + std::vector> tree_pointers; + MakeValues(points_in_tree, + tree_points, + tree_pointers, + random, + coordinate_distribution); + PrincipalComponentPartitioningTree tree1(tree_pointers, + /*max_values_per_cell=*/1); + PrincipalComponentPartitioningTree tree3(tree_pointers, + /*max_values_per_cell=*/3); + + const PrincipalComponentPartitioningTree::Filter filter = + [](V const* const point) { + return point->NormĀ²() < 100; + }; + + bool filtering_was_effective = false; + for (int i = 0; i < points_to_test; ++i) { + auto const query_point = V({coordinate_distribution(random), + coordinate_distribution(random), + coordinate_distribution(random)}); + + auto* const nearest = + BruteForceNearestNeighbour(query_point, tree_points, filter); + auto* const nearest1 = tree1.FindNearestNeighbour(query_point, filter); + auto* const nearest3 = tree3.FindNearestNeighbour(query_point, filter); + + EXPECT_THAT(nearest1, Eq(nearest)) << *nearest1 << " " << *nearest; + EXPECT_THAT(nearest3, Eq(nearest)) << *nearest3 << " " << *nearest; + + filtering_was_effective |= + nearest != BruteForceNearestNeighbour(query_point, tree_points); } + EXPECT_TRUE(filtering_was_effective) << "Filtering did nothing"; } } // namespace numerics