From 54a8a26036b6f501c909dbd284bb55b05c222b51 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Wed, 6 Sep 2023 15:50:08 -0700 Subject: [PATCH] HNSW as MutableMap (#223) --- datasketch/hnsw.py | 280 +++++++++++++++++++++++++++++++++++++++++++-- test/test_hnsw.py | 202 ++++++++++++++++++-------------- 2 files changed, 385 insertions(+), 97 deletions(-) diff --git a/datasketch/hnsw.py b/datasketch/hnsw.py index a1c2e298..a38fa32c 100644 --- a/datasketch/hnsw.py +++ b/datasketch/hnsw.py @@ -1,5 +1,18 @@ +from __future__ import annotations import heapq -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + Set, + Tuple, + Union, +) import numpy as np @@ -63,8 +76,23 @@ def __setitem__(self, key: Any, value: Dict[Any, float]) -> None: for neighbor in value: self._reverse_edges.setdefault(neighbor, set()).add(key) + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, _Layer): + return False + return ( + self._graph == __value._graph + and self._reverse_edges == __value._reverse_edges + ) + + def copy(self) -> _Layer: + """Create a copy of the layer.""" + new_layer = _Layer(None) + new_layer._graph = {k: v.copy() for k, v in self._graph.items()} + new_layer._reverse_edges = self._reverse_edges.copy() + return new_layer + -class HNSW(object): +class HNSW(MutableMapping): """Hierarchical Navigable Small World (HNSW) graph index for approximate nearest neighbor search. This implementation is based on the paper "Efficient and robust approximate nearest neighbor search using Hierarchical @@ -82,16 +110,45 @@ class HNSW(object): seed (Optional[int]): The random seed to use for the random number generator. - Example: + Examples: + + Create an HNSW index with Euclidean distance and insert 1000 random + vectors of dimension 10. .. code-block:: python - import hnsw + from datasketch.hnsw import HNSW import numpy as np + data = np.random.random_sample((1000, 10)) - index = hnsw.HNSW(distance_func=lambda x, y: np.linalg.norm(x - y)) + index = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y)) for i, d in enumerate(data): index.insert(i, d) + + # Query the index for the 10 nearest neighbors of the first vector. + index.query(data[0], k=10) + + Create an HNSW index with Jaccard distance and insert 1000 random + sets of 10 elements each. + + .. code-block:: python + + from datasketch.hnsw import HNSW + import numpy as np + + # Each set is represented as a 10-element vector of random integers + # between 0 and 100. + # Deduplication is handled by the distance function. + data = np.random.randint(0, 100, size=(1000, 10)) + jaccard_distance = lambda x, y: ( + 1.0 - float(len(np.intersect1d(x, y, assume_unique=False))) + / float(len(np.union1d(x, y))) + ) + index = HNSW(distance_func=jaccard_distance) + for i, d in enumerate(data): + index[i] = d + + # Query the index for the 10 nearest neighbors of the first set. index.query(data[0], k=10) """ @@ -115,33 +172,180 @@ def __init__( self._random = np.random.RandomState(seed) def __len__(self) -> int: + """Return the number of points in the index.""" return len(self._data) def __contains__(self, key: Any) -> bool: + """Return ``True`` if the index contains the key, else ``False``.""" return key in self._data def __getitem__(self, key: Any) -> np.ndarray: - """Get the point associated with the key.""" + """Get the point associated with the key. Raises KeyError if the key + does not exist in the index.""" return self._data[key] + def __setitem__(self, key: Any, value: np.ndarray) -> None: + """Set the point associated with the key and update the index. + This is equivalent to calling :meth:`insert` with the key and point.""" + self.insert(key, value) + + def __delitem__(self, key: Any) -> None: + """Delete the point associated with the key. Raises a KeyError if the + key does not exist in the index. + + NOTE: This method is not implemented yet. + """ + raise NotImplementedError("del is not implemented yet.") + + def __iter__(self) -> Iterable[Any]: + """Return an iterator over the keys of the index.""" + return iter(self._data.keys()) + + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, HNSW): + return False + # Check if the index parameters are equal. + if ( + self._distance_func != __value._distance_func + or self._m != __value._m + or self._ef_construction != __value._ef_construction + or self._m0 != __value._m0 + or self._level_mult != __value._level_mult + or self._entry_point != __value._entry_point + ): + return False + # Check if the random states are equal. + rand_state_1 = self._random.get_state() + rand_state_2 = __value._random.get_state() + for i in range(len(rand_state_1)): + if isinstance(rand_state_1[i], np.ndarray): + if not np.array_equal(rand_state_1[i], rand_state_2[i]): + return False + else: + if rand_state_1[i] != rand_state_2[i]: + return False + # Check if keys and points are equal. + return ( + all(key in self._data for key in __value._data) + and all(key in __value._data for key in self._data) + and all( + np.array_equal(self._data[key], __value._data[key]) + for key in self._data + ) + and self._graphs == __value._graphs + ) + + def get(self, key: Any, default: Optional[np.ndarray] = None) -> np.ndarray: + """Return the point for key in the index, else default. If default is not + given and key is not in the index, return None.""" + return self._data.get(key) + def items(self) -> Iterable[Tuple[Any, np.ndarray]]: - """Get an iterator over (key, point) pairs in the index.""" + """Return a new view of the indexed points as (key, point) pairs.""" return self._data.items() def keys(self) -> Iterable[Any]: - """Get an iterator over keys in the index.""" + """Return a new view of the keys of the index points.""" return self._data.keys() def values(self) -> Iterable[np.ndarray]: - """Get an iterator over points in the index.""" + """Return a new view of the index points.""" return self._data.values() + def pop(self, key: Any, default: Optional[np.ndarray] = None) -> np.ndarray: + """If key is in the index, remove it and return its associated point, + else return default. If default is not given and key is not in the index, + raise KeyError. + + NOTE: This method is not implemented yet. + """ + raise NotImplementedError("pop is not implemented yet.") + + def popitem(self) -> Tuple[Any, np.ndarray]: + """Remove and return a (key, point) pair from the index. Pairs are + returned in LIFO order. If the index is empty, raise KeyError. + + NOTE: This method is not implemented yet. + """ + raise NotImplementedError("popitem is not implemented yet.") + def clear(self) -> None: - """Clear the index of all data points.""" + """Clear the index of all data points. This will not reset the random + number generator.""" self._data = {} self._graphs = [] self._entry_point = None + def copy(self) -> HNSW: + """Create a copy of the index. The copy will have the same parameters + as the original index and the same keys and points, but will not share + any index data structures (i.e., graphs) with the original index. + The new index's random state will start from a copy of the original + index's.""" + new_index = HNSW( + self._distance_func, + m=self._m, + ef_construction=self._ef_construction, + m0=self._m0, + ) + new_index._data = self._data.copy() + new_index._graphs = [layer.copy() for layer in self._graphs] + new_index._entry_point = self._entry_point + new_index._random.set_state(self._random.get_state()) + return new_index + + def update(self, other: Union[Mapping, HNSW]) -> None: + """Update the index with the points from the other Mapping or HNSW object, + overwriting existing keys. + + Args: + other (Union[Mapping, HNSW]): The other Mapping or HNSW object. + + Examples: + + Create an HNSW index with a dictionary of points. + + .. code-block:: python + + from datasketch.hnsw import HNSW + import numpy as np + + data = np.random.random_sample((1000, 10)) + index = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y)) + + # Batch insert 1000 points. + index.update({i: d for i, d in enumerate(data)}) + + Create an HNSW index with another HNSW index. + + .. code-block:: python + + from datasketch.hnsw import HNSW + import numpy as np + + data = np.random.random_sample((1000, 10)) + index1 = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y)) + index2 = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y)) + + # Batch insert 1000 points. + index1.update({i: d for i, d in enumerate(data)}) + + # Update index2 with the points from index1. + index2.update(index1) + + """ + for key, point in other.items(): + self.insert(key, point) + + def setdefault(self, key: Any, default: np.ndarray) -> np.ndarray: + """If key is in the index, return its associated point. If not, insert + key with a value of default and return default. default cannot be None.""" + if default is None: + raise ValueError("Default value cannot be None.") + if key not in self._data: + self.insert(key, default) + return self._data[key] + def insert( self, key: Any, @@ -156,7 +360,9 @@ def insert( index, the point will be updated and the index will be repaired. new_point (np.ndarray): The new point to add to the index. ef (Optional[int]): The number of neighbors to consider during insertion. + If None, use the construction ef. level (Optional[int]): The level at which to insert the new point. + If None, the level will be chosen automatically. """ if ef is None: @@ -208,7 +414,7 @@ def insert( ) } # For all levels above the current level, we create an empty graph. - for _ in range(len(self._graphs), level): + for _ in range(len(self._graphs), level + 1): self._graphs.append(_Layer(key)) # We set the entry point for each new level to be the new node. self._entry_point = key @@ -303,7 +509,7 @@ def query( """Search for the k nearest neighbors of the query point. Args: - query (np.ndarray): The query point. + query_point (np.ndarray): The query point. k (Optional[int]): The number of neighbors to return. If None, return all neighbors found. ef (Optional[int]): The number of neighbors to consider during search. @@ -474,3 +680,53 @@ def _heuristic_prune( if good: pruned.append((candidate_dist, candidate_key)) return pruned + + def remove(self, key: Any) -> None: + """Remove a point from the index. + + Args: + key (Any): The key of the point to remove. + + Raises: + ValueError: If the key does not exist in the index. + + NOTE: This method is not implemented yet. + """ + raise NotImplementedError("Remove is not implemented yet.") + + def merge(self, other: HNSW) -> HNSW: + """Create a new index by merging the current index with another index. + The new index will contain all points from both indexes. + If a point exists in both, the point from the other index will be used. + The new index will have the same parameters as the current index and + a copy of the current index's random state. + + Args: + other (HNSW): The other index to merge with. + + Returns: + HNSW: A new index containing all points from both indexes. + + Example: + + .. code-block:: python + + from datasketch.hnsw import HNSW + import numpy as np + + data1 = np.random.random_sample((1000, 10)) + data2 = np.random.random_sample((1000, 10)) + index1 = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y)) + index2 = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y)) + + # Batch insert data into the indexes. + index1.update({i: d for i, d in enumerate(data1)}) + index2.update({i + len(data1): d for i, d in enumerate(data2)}) + + # Merge the indexes. + index = index1.merge(index2) + + """ + new_index = self.copy() + new_index.update(other) + return new_index diff --git a/test/test_hnsw.py b/test/test_hnsw.py index b2639994..33887a0d 100644 --- a/test/test_hnsw.py +++ b/test/test_hnsw.py @@ -9,108 +9,140 @@ def l2_distance(x, y): return np.linalg.norm(x - y) +def jaccard_distance(x, y): + return 1.0 - float(len(np.intersect1d(x, y, assume_unique=False))) / float( + len(np.union1d(x, y)) + ) + + class TestHNSW(unittest.TestCase): - def test_search_l2(self): - data = np.random.rand(100, 10) + def _create_random_points(self, n=100, dim=10): + return np.random.rand(n, dim) + + def _create_index(self, vecs, keys=None): hnsw = HNSW( - distance_func=lambda x, y: np.linalg.norm(x - y), + distance_func=l2_distance, m=16, ef_construction=100, ) - for i in range(len(data)): - hnsw.insert(i, data[i]) - self.assertIn(i, hnsw) - self.assertTrue(np.array_equal(hnsw[i], data[i])) - for i in range(len(data)): - results = hnsw.query(data[i], 10) - self.assertEqual(len(results), 10) - for j in range(len(results) - 1): - self.assertLessEqual( - np.linalg.norm(hnsw[results[j][0]] - data[i]), - np.linalg.norm(hnsw[results[j + 1][0]] - data[i]), - ) + self._insert_points(hnsw, vecs, keys) + return hnsw - def test_search_jaccard(self): - data = np.random.randint(0, 100, (100, 10)) - jaccard_func = lambda x, y: ( - 1.0 - - float(len(np.intersect1d(x, y, assume_unique=False))) - / float(len(np.union1d(x, y))) - ) - hnsw = HNSW(distance_func=jaccard_func, m=16, ef_construction=100) - for i in range(len(data)): - hnsw.insert(i, data[i]) - self.assertIn(i, hnsw) - self.assertTrue(np.array_equal(hnsw[i], data[i])) - for i in range(len(data)): - results = hnsw.query(data[i], 10) - self.assertEqual(len(results), 10) - for j in range(len(results) - 1): - self.assertLessEqual( - jaccard_func(hnsw[results[j][0]], data[i]), - jaccard_func(hnsw[results[j + 1][0]], data[i]), - ) + def _search_index(self, index, queries, k=10): + return self._search_index_dist(index, queries, l2_distance, k) - def test_update_point_l2(self): - data = np.random.rand(100, 10) - hnsw = HNSW( - distance_func=lambda x, y: np.linalg.norm(x - y), - m=16, - ef_construction=100, - ) - for i in range(len(data)): - hnsw.insert(i, data[i]) - new_data = np.random.rand(10, 10) - for i in range(len(new_data)): - hnsw.insert(i, new_data[i]) - self.assertTrue(np.array_equal(hnsw[i], new_data[i])) - for i in range(len(data)): - results = hnsw.query(data[i], 10) + def _insert_points(self, index, points, keys=None): + original_length = len(index) + + if keys is None: + keys = list(range(len(points))) + + for key, point in zip(keys, points): + # Test insert. + if np.random.random_sample() < 0.5: + index.insert(key, point) + else: + index[key] = point + # Make sure the entry point is set. + self.assertTrue(index._entry_point is not None) + # Test contains. + self.assertIn(key, index) + if original_length == 0: + self.assertNotIn(key + 1, index) + # Test get. + self.assertTrue(np.array_equal(index.get(key), point)) + self.assertTrue(np.array_equal(index[key], point)) + + if original_length == 0: + # Test length. + self.assertEqual(len(index), len(points)) + + # Test order. + for key_indexed, key in zip(index, keys): + self.assertEqual(key_indexed, key) + for key_indexed, key in zip(index.keys(), keys): + self.assertEqual(key_indexed, key) + for vec_indexed, vec in zip(index.values(), points): + self.assertTrue(np.array_equal(vec_indexed, vec)) + for (key_indexed, vec_indexed), key, vec in zip( + index.items(), keys, points + ): + self.assertEqual(key_indexed, key) + self.assertTrue(np.array_equal(vec_indexed, vec)) + + def _search_index_dist(self, index, queries, distance_func, k=10): + for i in range(len(queries)): + results = index.query(queries[i], 10) self.assertEqual(len(results), 10) for j in range(len(results) - 1): self.assertLessEqual( - np.linalg.norm(hnsw[results[j][0]] - data[i]), - np.linalg.norm(hnsw[results[j + 1][0]] - data[i]), + distance_func(index[results[j][0]], queries[i]), + distance_func(index[results[j + 1][0]], queries[i]), ) - def test_update_jaccard(self): - data = np.random.randint(0, 100, (100, 10)) - jaccard_func = lambda x, y: ( - 1.0 - - float(len(np.intersect1d(x, y, assume_unique=False))) - / float(len(np.union1d(x, y))) + def test_search(self): + data = self._create_random_points() + hnsw = self._create_index(data) + self._search_index(hnsw, data) + + def test_upsert(self): + data = self._create_random_points() + hnsw = self._create_index(data) + new_data = self._create_random_points(n=10, dim=10) + self._insert_points(hnsw, new_data) + self._search_index(hnsw, new_data) + + def test_update(self): + data = self._create_random_points() + hnsw = self._create_index(data) + new_data = self._create_random_points(n=10, dim=10) + hnsw.update({i: new_data[i] for i in range(len(new_data))}) + self._search_index(hnsw, new_data) + + def test_merge(self): + data1 = self._create_random_points() + data2 = self._create_random_points() + hnsw1 = self._create_index(data1, keys=list(range(len(data1)))) + hnsw2 = self._create_index( + data2, keys=list(range(len(data1), len(data1) + len(data2))) ) - hnsw = HNSW(distance_func=jaccard_func, m=16, ef_construction=100) - for i in range(len(data)): - hnsw.insert(i, data[i]) - new_data = np.random.randint(0, 100, (10, 10)) - for i in range(len(new_data)): - hnsw.insert(i, new_data[i]) - self.assertTrue(np.array_equal(hnsw[i], new_data[i])) - for i in range(len(data)): - results = hnsw.query(data[i], 10) - self.assertEqual(len(results), 10) - for j in range(len(results) - 1): - self.assertLessEqual( - jaccard_func(hnsw[results[j][0]], data[i]), - jaccard_func(hnsw[results[j + 1][0]], data[i]), - ) + new_index = hnsw1.merge(hnsw2) + self._search_index(new_index, data1) + self._search_index(new_index, data2) + for i in range(len(data1)): + self.assertIn(i, new_index) + self.assertTrue(np.array_equal(new_index[i], data1[i])) + for i in range(len(data2)): + self.assertIn(i + len(data1), new_index) + self.assertTrue(np.array_equal(new_index[i + len(data1)], data2[i])) def test_pickle(self): - data = np.random.rand(100, 10) + data = self._create_random_points() + hnsw = self._create_index(data) + import pickle + + hnsw2 = pickle.loads(pickle.dumps(hnsw)) + self.assertEqual(hnsw, hnsw2) + + def test_copy(self): + data = self._create_random_points() + hnsw = self._create_index(data) + hnsw2 = hnsw.copy() + self.assertEqual(hnsw, hnsw2) + + +class TestHNSWJaccard(TestHNSW): + def _create_random_points(self, high=50, n=100, dim=10): + return np.random.randint(0, high, (n, dim)) + + def _create_index(self, sets, keys=None): hnsw = HNSW( - distance_func=l2_distance, + distance_func=jaccard_distance, m=16, ef_construction=100, ) - for i in range(len(data)): - hnsw.insert(i, data[i]) - - import pickle - - hnsw2 = pickle.loads(pickle.dumps(hnsw)) + self._insert_points(hnsw, sets, keys) + return hnsw - for i in range(len(data)): - results1 = hnsw.query(data[i], 10) - results2 = hnsw2.query(data[i], 10) - self.assertEqual(results1, results2) + def _search_index(self, index, queries, k=10): + return super()._search_index_dist(index, queries, jaccard_distance, k)