Skip to content

Commit

Permalink
kd tree data structure implementation (#11532)
Browse files Browse the repository at this point in the history
* Implemented KD-Tree Data Structure

* Implemented KD-Tree Data Structure. updated DIRECTORY.md.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Create __init__.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Replaced legacy `np.random.rand` call with `np.random.Generator` in kd_tree/example_usage.py

* Replaced legacy `np.random.rand` call with `np.random.Generator` in kd_tree/hypercube_points.py

* added typehints and docstrings

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* docstring for search()

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added tests. Updated docstrings/typehints

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* updated tests and used | for type annotations

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* E501 for build_kdtree.py, hypercube_points.py, nearest_neighbour_search.py

* I001 for example_usage.py and test_kdtree.py

* I001 for example_usage.py and test_kdtree.py

* Update data_structures/kd_tree/build_kdtree.py

Co-authored-by: Christian Clauss <[email protected]>

* Update data_structures/kd_tree/example/hypercube_points.py

Co-authored-by: Christian Clauss <[email protected]>

* Update data_structures/kd_tree/example/hypercube_points.py

Co-authored-by: Christian Clauss <[email protected]>

* Added new test cases requested in Review. Refactored the test_build_kdtree() to include various checks.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Considered ruff errors

* Considered ruff errors

* Apply suggestions from code review

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update kd_node.py

* imported annotations from __future__

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Christian Clauss <[email protected]>
  • Loading branch information
3 people committed Sep 3, 2024
1 parent bd8085c commit f16d38f
Show file tree
Hide file tree
Showing 10 changed files with 301 additions and 0 deletions.
6 changes: 6 additions & 0 deletions DIRECTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,12 @@
* Trie
* [Radix Tree](data_structures/trie/radix_tree.py)
* [Trie](data_structures/trie/trie.py)
* KD Tree
* [KD Tree Node](data_structures/kd_tree/kd_node.py)
* [Build KD Tree](data_structures/kd_tree/build_kdtree.py)
* [Nearest Neighbour Search](data_structures/kd_tree/nearest_neighbour_search.py)
* [Hypercibe Points](data_structures/kd_tree/example/hypercube_points.py)
* [Example Usage](data_structures/kd_tree/example/example_usage.py)

## Digital Image Processing
* [Change Brightness](digital_image_processing/change_brightness.py)
Expand Down
Empty file.
35 changes: 35 additions & 0 deletions data_structures/kd_tree/build_kdtree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from data_structures.kd_tree.kd_node import KDNode


def build_kdtree(points: list[list[float]], depth: int = 0) -> KDNode | None:
"""
Builds a KD-Tree from a list of points.
Args:
points: The list of points to build the KD-Tree from.
depth: The current depth in the tree
(used to determine axis for splitting).
Returns:
The root node of the KD-Tree,
or None if no points are provided.
"""
if not points:
return None

k = len(points[0]) # Dimensionality of the points
axis = depth % k

# Sort point list and choose median as pivot element
points.sort(key=lambda point: point[axis])
median_idx = len(points) // 2

# Create node and construct subtrees
left_points = points[:median_idx]
right_points = points[median_idx + 1 :]

return KDNode(
point=points[median_idx],
left=build_kdtree(left_points, depth + 1),
right=build_kdtree(right_points, depth + 1),
)
Empty file.
38 changes: 38 additions & 0 deletions data_structures/kd_tree/example/example_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np

from data_structures.kd_tree.build_kdtree import build_kdtree
from data_structures.kd_tree.example.hypercube_points import hypercube_points
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search


def main() -> None:
"""
Demonstrates the use of KD-Tree by building it from random points
in a 10-dimensional hypercube and performing a nearest neighbor search.
"""
num_points: int = 5000
cube_size: float = 10.0 # Size of the hypercube (edge length)
num_dimensions: int = 10

# Generate random points within the hypercube
points: np.ndarray = hypercube_points(num_points, cube_size, num_dimensions)
hypercube_kdtree = build_kdtree(points.tolist())

# Generate a random query point within the same space
rng = np.random.default_rng()
query_point: list[float] = rng.random(num_dimensions).tolist()

# Perform nearest neighbor search
nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
hypercube_kdtree, query_point
)

# Print the results
print(f"Query point: {query_point}")
print(f"Nearest point: {nearest_point}")
print(f"Distance: {nearest_dist:.4f}")
print(f"Nodes visited: {nodes_visited}")


if __name__ == "__main__":
main()
21 changes: 21 additions & 0 deletions data_structures/kd_tree/example/hypercube_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np


def hypercube_points(
num_points: int, hypercube_size: float, num_dimensions: int
) -> np.ndarray:
"""
Generates random points uniformly distributed within an n-dimensional hypercube.
Args:
num_points: Number of points to generate.
hypercube_size: Size of the hypercube.
num_dimensions: Number of dimensions of the hypercube.
Returns:
An array of shape (num_points, num_dimensions)
with generated points.
"""
rng = np.random.default_rng()
shape = (num_points, num_dimensions)
return hypercube_size * rng.random(shape)
30 changes: 30 additions & 0 deletions data_structures/kd_tree/kd_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations


class KDNode:
"""
Represents a node in a KD-Tree.
Attributes:
point: The point stored in this node.
left: The left child node.
right: The right child node.
"""

def __init__(
self,
point: list[float],
left: KDNode | None = None,
right: KDNode | None = None,
) -> None:
"""
Initializes a KDNode with the given point and child nodes.
Args:
point (list[float]): The point stored in this node.
left (Optional[KDNode]): The left child node.
right (Optional[KDNode]): The right child node.
"""
self.point = point
self.left = left
self.right = right
71 changes: 71 additions & 0 deletions data_structures/kd_tree/nearest_neighbour_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from data_structures.kd_tree.kd_node import KDNode


def nearest_neighbour_search(
root: KDNode | None, query_point: list[float]
) -> tuple[list[float] | None, float, int]:
"""
Performs a nearest neighbor search in a KD-Tree for a given query point.
Args:
root (KDNode | None): The root node of the KD-Tree.
query_point (list[float]): The point for which the nearest neighbor
is being searched.
Returns:
tuple[list[float] | None, float, int]:
- The nearest point found in the KD-Tree to the query point,
or None if no point is found.
- The squared distance to the nearest point.
- The number of nodes visited during the search.
"""
nearest_point: list[float] | None = None
nearest_dist: float = float("inf")
nodes_visited: int = 0

def search(node: KDNode | None, depth: int = 0) -> None:
"""
Recursively searches for the nearest neighbor in the KD-Tree.
Args:
node: The current node in the KD-Tree.
depth: The current depth in the KD-Tree.
"""
nonlocal nearest_point, nearest_dist, nodes_visited
if node is None:
return

nodes_visited += 1

# Calculate the current distance (squared distance)
current_point = node.point
current_dist = sum(
(query_coord - point_coord) ** 2
for query_coord, point_coord in zip(query_point, current_point)
)

# Update nearest point if the current node is closer
if nearest_point is None or current_dist < nearest_dist:
nearest_point = current_point
nearest_dist = current_dist

# Determine which subtree to search first (based on axis and query point)
k = len(query_point) # Dimensionality of points
axis = depth % k

if query_point[axis] <= current_point[axis]:
nearer_subtree = node.left
further_subtree = node.right
else:
nearer_subtree = node.right
further_subtree = node.left

# Search the nearer subtree first
search(nearer_subtree, depth + 1)

# If the further subtree has a closer point
if (query_point[axis] - current_point[axis]) ** 2 < nearest_dist:
search(further_subtree, depth + 1)

search(root, 0)
return nearest_point, nearest_dist, nodes_visited
Empty file.
100 changes: 100 additions & 0 deletions data_structures/kd_tree/tests/test_kdtree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import numpy as np
import pytest

from data_structures.kd_tree.build_kdtree import build_kdtree
from data_structures.kd_tree.example.hypercube_points import hypercube_points
from data_structures.kd_tree.kd_node import KDNode
from data_structures.kd_tree.nearest_neighbour_search import nearest_neighbour_search


@pytest.mark.parametrize(
("num_points", "cube_size", "num_dimensions", "depth", "expected_result"),
[
(0, 10.0, 2, 0, None), # Empty points list
(10, 10.0, 2, 2, KDNode), # Depth = 2, 2D points
(10, 10.0, 3, -2, KDNode), # Depth = -2, 3D points
],
)
def test_build_kdtree(num_points, cube_size, num_dimensions, depth, expected_result):
"""
Test that KD-Tree is built correctly.
Cases:
- Empty points list.
- Positive depth value.
- Negative depth value.
"""
points = (
hypercube_points(num_points, cube_size, num_dimensions).tolist()
if num_points > 0
else []
)

kdtree = build_kdtree(points, depth=depth)

if expected_result is None:
# Empty points list case
assert kdtree is None, f"Expected None for empty points list, got {kdtree}"
else:
# Check if root node is not None
assert kdtree is not None, "Expected a KDNode, got None"

# Check if root has correct dimensions
assert (
len(kdtree.point) == num_dimensions
), f"Expected point dimension {num_dimensions}, got {len(kdtree.point)}"

# Check that the tree is balanced to some extent (simplistic check)
assert isinstance(
kdtree, KDNode
), f"Expected KDNode instance, got {type(kdtree)}"


def test_nearest_neighbour_search():
"""
Test the nearest neighbor search function.
"""
num_points = 10
cube_size = 10.0
num_dimensions = 2
points = hypercube_points(num_points, cube_size, num_dimensions)
kdtree = build_kdtree(points.tolist())

rng = np.random.default_rng()
query_point = rng.random(num_dimensions).tolist()

nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
kdtree, query_point
)

# Check that nearest point is not None
assert nearest_point is not None

# Check that distance is a non-negative number
assert nearest_dist >= 0

# Check that nodes visited is a non-negative integer
assert nodes_visited >= 0


def test_edge_cases():
"""
Test edge cases such as an empty KD-Tree.
"""
empty_kdtree = build_kdtree([])
query_point = [0.0] * 2 # Using a default 2D query point

nearest_point, nearest_dist, nodes_visited = nearest_neighbour_search(
empty_kdtree, query_point
)

# With an empty KD-Tree, nearest_point should be None
assert nearest_point is None
assert nearest_dist == float("inf")
assert nodes_visited == 0


if __name__ == "__main__":
import pytest

pytest.main()

0 comments on commit f16d38f

Please sign in to comment.