Skip to content

Commit

Permalink
HNSW as MutableMap (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu committed Sep 6, 2023
1 parent e8d1cfa commit 54a8a26
Show file tree
Hide file tree
Showing 2 changed files with 385 additions and 97 deletions.
280 changes: 268 additions & 12 deletions datasketch/hnsw.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
from __future__ import annotations
import heapq
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
MutableMapping,
Optional,
Set,
Tuple,
Union,
)

import numpy as np

Expand Down Expand Up @@ -63,8 +76,23 @@ def __setitem__(self, key: Any, value: Dict[Any, float]) -> None:
for neighbor in value:
self._reverse_edges.setdefault(neighbor, set()).add(key)

def __eq__(self, __value: object) -> bool:
if not isinstance(__value, _Layer):
return False
return (
self._graph == __value._graph
and self._reverse_edges == __value._reverse_edges
)

def copy(self) -> _Layer:
"""Create a copy of the layer."""
new_layer = _Layer(None)
new_layer._graph = {k: v.copy() for k, v in self._graph.items()}
new_layer._reverse_edges = self._reverse_edges.copy()
return new_layer


class HNSW(object):
class HNSW(MutableMapping):
"""Hierarchical Navigable Small World (HNSW) graph index for approximate
nearest neighbor search. This implementation is based on the paper
"Efficient and robust approximate nearest neighbor search using Hierarchical
Expand All @@ -82,16 +110,45 @@ class HNSW(object):
seed (Optional[int]): The random seed to use for the random number
generator.
Example:
Examples:
Create an HNSW index with Euclidean distance and insert 1000 random
vectors of dimension 10.
.. code-block:: python
import hnsw
from datasketch.hnsw import HNSW
import numpy as np
data = np.random.random_sample((1000, 10))
index = hnsw.HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
index = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
for i, d in enumerate(data):
index.insert(i, d)
# Query the index for the 10 nearest neighbors of the first vector.
index.query(data[0], k=10)
Create an HNSW index with Jaccard distance and insert 1000 random
sets of 10 elements each.
.. code-block:: python
from datasketch.hnsw import HNSW
import numpy as np
# Each set is represented as a 10-element vector of random integers
# between 0 and 100.
# Deduplication is handled by the distance function.
data = np.random.randint(0, 100, size=(1000, 10))
jaccard_distance = lambda x, y: (
1.0 - float(len(np.intersect1d(x, y, assume_unique=False)))
/ float(len(np.union1d(x, y)))
)
index = HNSW(distance_func=jaccard_distance)
for i, d in enumerate(data):
index[i] = d
# Query the index for the 10 nearest neighbors of the first set.
index.query(data[0], k=10)
"""
Expand All @@ -115,33 +172,180 @@ def __init__(
self._random = np.random.RandomState(seed)

def __len__(self) -> int:
"""Return the number of points in the index."""
return len(self._data)

def __contains__(self, key: Any) -> bool:
"""Return ``True`` if the index contains the key, else ``False``."""
return key in self._data

def __getitem__(self, key: Any) -> np.ndarray:
"""Get the point associated with the key."""
"""Get the point associated with the key. Raises KeyError if the key
does not exist in the index."""
return self._data[key]

def __setitem__(self, key: Any, value: np.ndarray) -> None:
"""Set the point associated with the key and update the index.
This is equivalent to calling :meth:`insert` with the key and point."""
self.insert(key, value)

def __delitem__(self, key: Any) -> None:
"""Delete the point associated with the key. Raises a KeyError if the
key does not exist in the index.
NOTE: This method is not implemented yet.
"""
raise NotImplementedError("del is not implemented yet.")

def __iter__(self) -> Iterable[Any]:
"""Return an iterator over the keys of the index."""
return iter(self._data.keys())

def __eq__(self, __value: object) -> bool:
if not isinstance(__value, HNSW):
return False
# Check if the index parameters are equal.
if (
self._distance_func != __value._distance_func
or self._m != __value._m
or self._ef_construction != __value._ef_construction
or self._m0 != __value._m0
or self._level_mult != __value._level_mult
or self._entry_point != __value._entry_point
):
return False
# Check if the random states are equal.
rand_state_1 = self._random.get_state()
rand_state_2 = __value._random.get_state()
for i in range(len(rand_state_1)):
if isinstance(rand_state_1[i], np.ndarray):
if not np.array_equal(rand_state_1[i], rand_state_2[i]):
return False
else:
if rand_state_1[i] != rand_state_2[i]:
return False
# Check if keys and points are equal.
return (
all(key in self._data for key in __value._data)
and all(key in __value._data for key in self._data)
and all(
np.array_equal(self._data[key], __value._data[key])
for key in self._data
)
and self._graphs == __value._graphs
)

def get(self, key: Any, default: Optional[np.ndarray] = None) -> np.ndarray:
"""Return the point for key in the index, else default. If default is not
given and key is not in the index, return None."""
return self._data.get(key)

def items(self) -> Iterable[Tuple[Any, np.ndarray]]:
"""Get an iterator over (key, point) pairs in the index."""
"""Return a new view of the indexed points as (key, point) pairs."""
return self._data.items()

def keys(self) -> Iterable[Any]:
"""Get an iterator over keys in the index."""
"""Return a new view of the keys of the index points."""
return self._data.keys()

def values(self) -> Iterable[np.ndarray]:
"""Get an iterator over points in the index."""
"""Return a new view of the index points."""
return self._data.values()

def pop(self, key: Any, default: Optional[np.ndarray] = None) -> np.ndarray:
"""If key is in the index, remove it and return its associated point,
else return default. If default is not given and key is not in the index,
raise KeyError.
NOTE: This method is not implemented yet.
"""
raise NotImplementedError("pop is not implemented yet.")

def popitem(self) -> Tuple[Any, np.ndarray]:
"""Remove and return a (key, point) pair from the index. Pairs are
returned in LIFO order. If the index is empty, raise KeyError.
NOTE: This method is not implemented yet.
"""
raise NotImplementedError("popitem is not implemented yet.")

def clear(self) -> None:
"""Clear the index of all data points."""
"""Clear the index of all data points. This will not reset the random
number generator."""
self._data = {}
self._graphs = []
self._entry_point = None

def copy(self) -> HNSW:
"""Create a copy of the index. The copy will have the same parameters
as the original index and the same keys and points, but will not share
any index data structures (i.e., graphs) with the original index.
The new index's random state will start from a copy of the original
index's."""
new_index = HNSW(
self._distance_func,
m=self._m,
ef_construction=self._ef_construction,
m0=self._m0,
)
new_index._data = self._data.copy()
new_index._graphs = [layer.copy() for layer in self._graphs]
new_index._entry_point = self._entry_point
new_index._random.set_state(self._random.get_state())
return new_index

def update(self, other: Union[Mapping, HNSW]) -> None:
"""Update the index with the points from the other Mapping or HNSW object,
overwriting existing keys.
Args:
other (Union[Mapping, HNSW]): The other Mapping or HNSW object.
Examples:
Create an HNSW index with a dictionary of points.
.. code-block:: python
from datasketch.hnsw import HNSW
import numpy as np
data = np.random.random_sample((1000, 10))
index = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
# Batch insert 1000 points.
index.update({i: d for i, d in enumerate(data)})
Create an HNSW index with another HNSW index.
.. code-block:: python
from datasketch.hnsw import HNSW
import numpy as np
data = np.random.random_sample((1000, 10))
index1 = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
index2 = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
# Batch insert 1000 points.
index1.update({i: d for i, d in enumerate(data)})
# Update index2 with the points from index1.
index2.update(index1)
"""
for key, point in other.items():
self.insert(key, point)

def setdefault(self, key: Any, default: np.ndarray) -> np.ndarray:
"""If key is in the index, return its associated point. If not, insert
key with a value of default and return default. default cannot be None."""
if default is None:
raise ValueError("Default value cannot be None.")
if key not in self._data:
self.insert(key, default)
return self._data[key]

def insert(
self,
key: Any,
Expand All @@ -156,7 +360,9 @@ def insert(
index, the point will be updated and the index will be repaired.
new_point (np.ndarray): The new point to add to the index.
ef (Optional[int]): The number of neighbors to consider during insertion.
If None, use the construction ef.
level (Optional[int]): The level at which to insert the new point.
If None, the level will be chosen automatically.
"""
if ef is None:
Expand Down Expand Up @@ -208,7 +414,7 @@ def insert(
)
}
# For all levels above the current level, we create an empty graph.
for _ in range(len(self._graphs), level):
for _ in range(len(self._graphs), level + 1):
self._graphs.append(_Layer(key))
# We set the entry point for each new level to be the new node.
self._entry_point = key
Expand Down Expand Up @@ -303,7 +509,7 @@ def query(
"""Search for the k nearest neighbors of the query point.
Args:
query (np.ndarray): The query point.
query_point (np.ndarray): The query point.
k (Optional[int]): The number of neighbors to return. If None, return
all neighbors found.
ef (Optional[int]): The number of neighbors to consider during search.
Expand Down Expand Up @@ -474,3 +680,53 @@ def _heuristic_prune(
if good:
pruned.append((candidate_dist, candidate_key))
return pruned

def remove(self, key: Any) -> None:
"""Remove a point from the index.
Args:
key (Any): The key of the point to remove.
Raises:
ValueError: If the key does not exist in the index.
NOTE: This method is not implemented yet.
"""
raise NotImplementedError("Remove is not implemented yet.")

def merge(self, other: HNSW) -> HNSW:
"""Create a new index by merging the current index with another index.
The new index will contain all points from both indexes.
If a point exists in both, the point from the other index will be used.
The new index will have the same parameters as the current index and
a copy of the current index's random state.
Args:
other (HNSW): The other index to merge with.
Returns:
HNSW: A new index containing all points from both indexes.
Example:
.. code-block:: python
from datasketch.hnsw import HNSW
import numpy as np
data1 = np.random.random_sample((1000, 10))
data2 = np.random.random_sample((1000, 10))
index1 = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
index2 = HNSW(distance_func=lambda x, y: np.linalg.norm(x - y))
# Batch insert data into the indexes.
index1.update({i: d for i, d in enumerate(data1)})
index2.update({i + len(data1): d for i, d in enumerate(data2)})
# Merge the indexes.
index = index1.merge(index2)
"""
new_index = self.copy()
new_index.update(other)
return new_index
Loading

0 comments on commit 54a8a26

Please sign in to comment.