Skip to content

Commit

Permalink
Merging (Identically Specified) MinHashLSH objects (#232)
Browse files Browse the repository at this point in the history
* -Issue: #205: Merging (Identically Specified) MinHashLSH objects

* Merging (Identically Specified) MinHashLSH objects
Fixes #205

* Merging (Identically Specified) MinHashLSH objects
Fixes #205

* Merging (Identically Specified) MinHashLSH objects

* Merging (Identically Specified) MinHashLSH objects
Fixes #205
  • Loading branch information
rupeshkumaar authored Mar 12, 2024
1 parent bfd9e7f commit a532f06
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 0 deletions.
55 changes: 55 additions & 0 deletions datasketch/lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,29 @@ def insert(
"""
self._insert(key, minhash, check_duplication=check_duplication, buffer=False)

def merge(
self,
other: MinHashLSH,
check_overlap: bool = False
):
"""Merge the other MinHashLSH with this one, making this one the union
of both.
Note:
Only num_perm, number of bands and sizes of each band is checked for equivalency of two MinHashLSH indexes.
Other initialization parameters threshold, weights, storage_config, prepickle and hash_func are not checked.
Args:
other (MinHashLSH): The other MinHashLSH.
check_overlap (bool): Check if there are any overlapping keys before merging and raise if there are any.
(`default=False`)
Raises:
ValueError: If the two MinHashLSH have different initialization
parameters, or if `check_overlap` is `True` and there are overlapping keys.
"""
self._merge(other, check_overlap=check_overlap, buffer=False)

def insertion_session(self, buffer_size: int = 50000) -> MinHashLSHInsertionSession:
"""
Create a context manager for fast insertion into this index.
Expand Down Expand Up @@ -282,6 +305,38 @@ def _insert(
for H, hashtable in zip(Hs, self.hashtables):
hashtable.insert(H, key, buffer=buffer)

def __equivalent(self, other:MinHashLSH) -> bool:
"""
Returns:
bool: If the two MinHashLSH have equal num_perm, number of bands, size of each band then two are equivalent.
"""
return (
type(self) is type(other) and
self.h == other.h and
self.b == other.b and
self.r == other.r
)

def _merge(
self,
other: MinHashLSH,
check_overlap: bool = False,
buffer: bool = False
) -> MinHashLSH:
if self.__equivalent(other):
if check_overlap and set(self.keys).intersection(set(other.keys)):
raise ValueError("The keys are overlapping, duplicate key exists.")
for key in other.keys:
Hs = other.keys.get(key)
self.keys.insert(key, *Hs, buffer=buffer)
for H, hashtable in zip(Hs, self.hashtables):
hashtable.insert(H, key, buffer=buffer)
else:
if type(self) is not type(other):
raise ValueError(f"Cannot merge type MinHashLSH and type {type(other).__name__}.")
raise ValueError(
"Cannot merge MinHashLSH with different initialization parameters.")

def query(self, minhash) -> List[Hashable]:
"""
Giving the MinHash of the query set, retrieve
Expand Down
8 changes: 8 additions & 0 deletions docs/lsh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ plotting code.
.. figure:: /_static/lsh_benchmark.png
:alt: MinHashLSH Benchmark

You can merge two MinHashLSH indexes to create a union index using the ``merge`` method. This
makes MinHashLSH useful in parallel processing.

.. code:: python
# This merges the lsh1 with lsh2.
lsh1.merge(lsh2)
There are other optional parameters that can be used to tune the index.
See the documentation of :class:`datasketch.MinHashLSH` for details.

Expand Down
13 changes: 13 additions & 0 deletions examples/lsh_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ def eg1():
result = lsh.query(m1)
print("Approximate neighbours with Jaccard similarity > 0.5", result)

# Merge two LSH index
lsh1 = MinHashLSH(threshold=0.5, num_perm=128)
lsh1.insert("m2", m2)
lsh1.insert("m3", m3)

lsh2 = MinHashLSH(threshold=0.5, num_perm=128)
lsh2.insert("m1", m1)

lsh1.merge(lsh2)
print("Does m1 exist in the lsh1...", "m1" in lsh1.keys)
# if check_overlap flag is set to True then it will check the overlapping of the keys in the two MinHashLSH
lsh1.merge(lsh2,check_overlap=True)

def eg2():
mg = WeightedMinHashGenerator(10, 5)
m1 = mg.minhash(v1)
Expand Down
111 changes: 111 additions & 0 deletions test/test_lsh.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,117 @@ def test_get_counts(self):
for table in counts:
self.assertEqual(sum(table.values()), 2)

def test_merge(self):
lsh1 = MinHashLSH(threshold=0.5, num_perm=16)
m1 = MinHash(16)
m1.update("a".encode("utf-8"))
m2 = MinHash(16)
m2.update("b".encode("utf-8"))
lsh1.insert("a",m1)
lsh1.insert("b",m2)

lsh2 = MinHashLSH(threshold=0.5, num_perm=16)
m3 = MinHash(16)
m3.update("c".encode("utf-8"))
m4 = MinHash(16)
m4.update("d".encode("utf-8"))
lsh2.insert("c",m1)
lsh2.insert("d",m2)

lsh1.merge(lsh2)
for t in lsh1.hashtables:
self.assertTrue(len(t) >= 1)
items = []
for H in t:
items.extend(t[H])
self.assertTrue("c" in items)
self.assertTrue("d" in items)
self.assertTrue("a" in lsh1)
self.assertTrue("b" in lsh1)
self.assertTrue("c" in lsh1)
self.assertTrue("d" in lsh1)
for i, H in enumerate(lsh1.keys["c"]):
self.assertTrue("c" in lsh1.hashtables[i][H])

self.assertTrue(lsh1.merge, lsh2)
self.assertRaises(ValueError, lsh1.merge, lsh2, check_overlap=True)

m5 = MinHash(16)
m5.update("e".encode("utf-8"))
lsh3 = MinHashLSH(threshold=0.5, num_perm=16)
lsh3.insert("a",m5)

self.assertRaises(ValueError, lsh1.merge, lsh3, check_overlap=True)

lsh1.merge(lsh3)

m6 = MinHash(16)
m6.update("e".encode("utf-8"))
lsh4 = MinHashLSH(threshold=0.5, num_perm=16)
lsh4.insert("a",m6)

lsh1.merge(lsh4, check_overlap=False)


def test_merge_redis(self):
with patch('redis.Redis', fake_redis) as mock_redis:
lsh1 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={
'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379}
})
lsh2 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={
'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379}
})

m1 = MinHash(16)
m1.update("a".encode("utf8"))
m2 = MinHash(16)
m2.update("b".encode("utf8"))
lsh1.insert("a", m1)
lsh1.insert("b", m2)

m3 = MinHash(16)
m3.update("c".encode("utf8"))
m4 = MinHash(16)
m4.update("d".encode("utf8"))
lsh2.insert("c", m3)
lsh2.insert("d", m4)

lsh1.merge(lsh2)
for t in lsh1.hashtables:
self.assertTrue(len(t) >= 1)
items = []
for H in t:
items.extend(t[H])
self.assertTrue(pickle.dumps("c") in items)
self.assertTrue(pickle.dumps("d") in items)
self.assertTrue("a" in lsh1)
self.assertTrue("b" in lsh1)
self.assertTrue("c" in lsh1)
self.assertTrue("d" in lsh1)
for i, H in enumerate(lsh1.keys[pickle.dumps("c")]):
self.assertTrue(pickle.dumps("c") in lsh1.hashtables[i][H])

self.assertTrue(lsh1.merge, lsh2)
self.assertRaises(ValueError, lsh1.merge, lsh2, check_overlap=True)

m5 = MinHash(16)
m5.update("e".encode("utf-8"))
lsh3 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={
'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379}
})
lsh3.insert("a",m5)

self.assertRaises(ValueError, lsh1.merge, lsh3, check_overlap=True)

m6 = MinHash(16)
m6.update("e".encode("utf-8"))
lsh4 = MinHashLSH(threshold=0.5, num_perm=16, storage_config={
'type': 'redis', 'redis': {'host': 'localhost', 'port': 6379}
})
lsh4.insert("a",m6)

lsh1.merge(lsh4, check_overlap=False)


class TestWeightedMinHashLSH(unittest.TestCase):

Expand Down

0 comments on commit a532f06

Please sign in to comment.