Skip to content

Commit f0ae48b

Browse files
123epsilonArham KhanArham KhanArham Khan
authored
Retrieve MinHash from LSHForest (#234)
* add get minhash from lshforest * format * fix format string * return hashvalues instead of MinHash * preallocate hashvalue buffer * add direct hashvalue check to test --------- Co-authored-by: Arham Khan <[email protected]> Co-authored-by: Arham Khan <[email protected]> Co-authored-by: Arham Khan <[email protected]>
1 parent 9973b09 commit f0ae48b

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

datasketch/lshforest.py

+25
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import defaultdict
22
from typing import Hashable, List
3+
import numpy as np
34

45
from datasketch.minhash import MinHash
56

@@ -128,6 +129,30 @@ def query(self, minhash: MinHash, k: int) -> List[Hashable]:
128129
r -= 1
129130
return list(results)
130131

132+
def get_minhash_hashvalues(self, key: Hashable) -> np.ndarray:
133+
"""
134+
Returns the hashvalues from the MinHash object that corresponds to the given key in the LSHForest,
135+
if it exists. This is useful for when we want to reconstruct the original MinHash
136+
object to manually check the Jaccard Similarity for the top-k results from a query.
137+
138+
Args:
139+
key (Hashable): The key whose MinHash hashvalues we want to retrieve.
140+
141+
Returns:
142+
hashvalues: The hashvalues for the MinHash object corresponding to the given key.
143+
"""
144+
byteslist = self.keys.get(key, None)
145+
if byteslist is None:
146+
raise KeyError(f"The provided key does not exist in the LSHForest: {key}")
147+
hashvalue_byte_size = len(byteslist[0])//8
148+
hashvalues = np.empty(len(byteslist)*hashvalue_byte_size, dtype=np.uint64)
149+
for index, item in enumerate(byteslist):
150+
# unswap the bytes, as their representation is flipped during storage
151+
hv_segment = np.frombuffer(item, dtype=np.uint64).byteswap()
152+
curr_index = index*hashvalue_byte_size
153+
hashvalues[curr_index:curr_index+hashvalue_byte_size] = hv_segment
154+
return hashvalues
155+
131156
def _binary_search(self, n, func):
132157
"""
133158
https://golang.org/src/sort/search.go?s=2247:2287#L49

test/test_lshforest.py

+12
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,18 @@ def test_query(self):
6262
results = forest.query(data[key], 10)
6363
self.assertIn(key, results)
6464

65+
def test_get_minhash_hashvalues(self):
66+
forest, data = self._setup()
67+
for key in data:
68+
minhash_ori = data[key]
69+
hashvalues = forest.get_minhash_hashvalues(key)
70+
minhash_retrieved = MinHash(hashvalues=hashvalues)
71+
retrieved_hashvalues = minhash_retrieved.hashvalues
72+
self.assertEqual(len(hashvalues), len(retrieved_hashvalues))
73+
self.assertEqual(minhash_retrieved.jaccard(minhash_ori), 1.0)
74+
for i in range(len(retrieved_hashvalues)):
75+
self.assertEqual(hashvalues[i], retrieved_hashvalues[i])
76+
6577
def test_pickle(self):
6678
forest, _ = self._setup()
6779
forest2 = pickle.loads(pickle.dumps(forest))

0 commit comments

Comments
 (0)