From 1ce3f6926172027d0d0df810dbb52ec3a5232741 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Mon, 2 Oct 2023 16:00:49 -0700 Subject: [PATCH] HNSW bug fixes (#230) Fix bug with HNSW.copy(). --- .gitignore | 6 ++++++ datasketch/hnsw.py | 43 ++++++++++++++----------------------------- datasketch/version.py | 2 +- test/test_hnsw.py | 5 ++++- 4 files changed, 25 insertions(+), 31 deletions(-) diff --git a/.gitignore b/.gitignore index 803bb448..c3818686 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,9 @@ benchmark/**/*.pdf # Virtual env .venv + +# IDE +.vscode + +# MacOS +.DS_Store \ No newline at end of file diff --git a/datasketch/hnsw.py b/datasketch/hnsw.py index 7229b51b..e57a78c6 100644 --- a/datasketch/hnsw.py +++ b/datasketch/hnsw.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections import OrderedDict import heapq -from itertools import dropwhile from typing import ( Hashable, Callable, @@ -59,7 +58,7 @@ def __iter__(self) -> Iterable[Hashable]: def copy(self) -> _Layer: """Create a copy of the layer.""" new_layer = _Layer(None) - new_layer._graph = {k: v.copy() for k, v in self._graph.items()} + new_layer._graph = {k: dict(v) for k, v in self._graph.items()} return new_layer def get_reverse_edges(self, key: Hashable) -> Set[Hashable]: @@ -91,6 +90,8 @@ def __setitem__(self, key: Hashable, value: Dict[Hashable, float]) -> None: self._reverse_edges[neighbor].discard(key) for neighbor in value: self._reverse_edges.setdefault(neighbor, set()).add(key) + if key not in self._reverse_edges: + self._reverse_edges[key] = set() def __delitem__(self, key: Hashable) -> None: old_neighbors = self._graph.get(key, {}) @@ -115,8 +116,8 @@ def __iter__(self) -> Iterable[Hashable]: def copy(self) -> _LayerWithReversedEdges: """Create a copy of the layer.""" new_layer = _LayerWithReversedEdges(None) - new_layer._graph = {k: v.copy() for k, v in self._graph.items()} - new_layer._reverse_edges = self._reverse_edges.copy() + new_layer._graph = {k: dict(v) for k, v in self._graph.items()} + new_layer._reverse_edges = {k: set(v) for k, v in self._reverse_edges.items()} return new_layer def get_reverse_edges(self, key: Hashable) -> Set[Hashable]: @@ -169,6 +170,9 @@ class HNSW(MutableMapping): the 0th level. If None, defaults to 2 * m. seed (Optional[int]): The random seed to use for the random number generator. + reverse_edges (bool): Whether to maintain reverse edges in the graph. + This speeds up hard remove (:meth:`remove`) but increases memory + usage and slows down :meth:`insert`. Examples: @@ -400,7 +404,9 @@ def copy(self) -> HNSW: ef_construction=self._ef_construction, m0=self._m0, ) - new_index._nodes = self._nodes.copy() + new_index._nodes = OrderedDict( + (key, node.copy()) for key, node in self._nodes.items() + ) new_index._graphs = [layer.copy() for layer in self._graphs] new_index._entry_point = self._entry_point new_index._random.set_state(self._random.get_state()) @@ -608,6 +614,7 @@ def _repair_connections( entry_point, entry_point_dist, layer, + # We allow soft-deleted points to be returned and used as entry point. allow_soft_deleted=True, key_to_hard_delete=key_to_delete, ) @@ -620,6 +627,8 @@ def _repair_connections( entry_points, layer, ef + 1, # We add 1 to ef to account for the point itself. + # We allow soft-deleted points to be returned and used as entry point + # and neighbor candidates. allow_soft_deleted=True, key_to_hard_delete=key_to_delete, ) @@ -1045,27 +1054,3 @@ def merge(self, other: HNSW) -> HNSW: new_index = self.copy() new_index.update(other) return new_index - - def get_non_reachable_keys(self, ef: Optional[int] = None) -> List[Hashable]: - """Return a list of keys of points that are not reachable from the entry - point using the given ``ef`` value. - - Args: - ef (Optional[int]): The number of neighbors to consider during - search. If None, use the construction ef. - - Returns: - List[Hashable]: A list of keys of points that are not reachable. - """ - if ef is None: - ef = self._ef_construction - non_reachable = [] - if self._entry_point is None: - return non_reachable - for key, node in self._nodes.items(): - if node.is_deleted: - continue - neighbors = self.query(node.point, ef=ef) - if key not in [k for k, _ in neighbors]: - non_reachable.append(key) - return non_reachable diff --git a/datasketch/version.py b/datasketch/version.py index 31e744e4..dc79f8f0 100644 --- a/datasketch/version.py +++ b/datasketch/version.py @@ -1 +1 @@ -__version__ = "1.6.3" +__version__ = "1.6.4" diff --git a/test/test_hnsw.py b/test/test_hnsw.py index fc50ddd2..1f4e80f0 100644 --- a/test/test_hnsw.py +++ b/test/test_hnsw.py @@ -133,6 +133,10 @@ def test_copy(self): hnsw2 = hnsw.copy() self.assertEqual(hnsw, hnsw2) + hnsw.remove(0) + self.assertTrue(0 not in hnsw) + self.assertTrue(0 in hnsw2) + def test_soft_remove_and_pop_and_clean(self): data = self._create_random_points() hnsw = self._create_index(data) @@ -162,7 +166,6 @@ def test_soft_remove_and_pop_and_clean(self): "Potential graph connectivity issue." ) # NOTE: we are not getting the expected number of results. - # This may be because the graph is not connected anymore. # Try hard remove all previous soft removed points. hnsw.clean() results = hnsw.query(data[i], 10)