Skip to content

Commit 8baa603

Browse files
authored
Benchmark HNSW for Jaccard (#226)
1 parent e11bb70 commit 8baa603

File tree

9 files changed

+482
-332
lines changed

9 files changed

+482
-332
lines changed

README.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ sub-linear query time:
3737
+---------------------------+-----------------------------+------------------------+
3838
| `MinHash LSH Ensemble`_ | MinHash | Containment Threshold |
3939
+---------------------------+-----------------------------+------------------------+
40-
| `HNSW`_ | Customizable | Metric Distances |
40+
| `HNSW`_ | Any | Custom Metric Top-K |
4141
+---------------------------+-----------------------------+------------------------+
4242

4343
datasketch must be used with Python 3.7 or above, NumPy 1.11 or above, and Scipy.

benchmark/indexes/jaccard/exact.py

+26-15
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import json
12
import time
23
import sys
34
import collections
45

56
from SetSimilaritySearch import SearchIndex
7+
import tqdm
68

79

810
def _query_jaccard_topk(index, query, k):
@@ -22,20 +24,31 @@ def _query_jaccard_topk(index, query, k):
2224
return candidates[:k]
2325

2426

25-
def search_jaccard_topk(index_data, query_data, k):
26-
(index_sets, index_keys) = index_data
27-
(query_sets, query_keys) = query_data
28-
print("Building jaccard search index.")
29-
start = time.perf_counter()
30-
# Build the search index with the 0 threshold to index all tokens.
31-
index = SearchIndex(
32-
index_sets, similarity_func_name="jaccard", similarity_threshold=0.0
33-
)
34-
indexing_time = time.perf_counter() - start
35-
print("Finished building index in {:.3f}.".format(indexing_time))
27+
def search_jaccard_topk(index_data, query_data, index_params, k):
28+
(index_sets, index_keys, _, index_cache) = index_data
29+
(query_sets, query_keys, _) = query_data
30+
cache_key = json.dumps(index_params)
31+
if cache_key not in index_cache:
32+
print("Building jaccard search index.")
33+
start = time.perf_counter()
34+
# Build the search index with the 0 threshold to index all tokens.
35+
index = SearchIndex(
36+
index_sets, similarity_func_name="jaccard", similarity_threshold=0.0
37+
)
38+
indexing_time = time.perf_counter() - start
39+
print("Finished building index in {:.3f}.".format(indexing_time))
40+
index_cache[cache_key] = (
41+
index,
42+
{
43+
"indexing_time": indexing_time,
44+
},
45+
)
46+
index, indexing = index_cache[cache_key]
3647
times = []
3748
results = []
38-
for query_set, query_key in zip(query_sets, query_keys):
49+
for query_set, query_key in tqdm.tqdm(
50+
zip(query_sets, query_keys), total=len(query_keys), desc="Querying", unit=" set"
51+
):
3952
start = time.perf_counter()
4053
result = [
4154
[index_keys[i], similarity]
@@ -44,6 +57,4 @@ def search_jaccard_topk(index_data, query_data, k):
4457
duration = time.perf_counter() - start
4558
times.append(duration)
4659
results.append((query_key, result))
47-
sys.stdout.write("\rQueried {} sets.".format(len(results)))
48-
sys.stdout.write("\n")
49-
return (indexing_time, results, times)
60+
return (indexing, results, times)

benchmark/indexes/jaccard/hnsw.py

+147-22
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,55 @@
1+
import json
12
import time
23
import sys
34

4-
import nmslib
5+
import tqdm
6+
from datasketch.hnsw import HNSW
57

6-
from utils import compute_jaccard
8+
from utils import (
9+
compute_jaccard,
10+
compute_jaccard_distance,
11+
compute_minhash_jaccard_distance,
12+
lazy_create_minhashes_from_sets,
13+
)
714

815

9-
def search_hnsw_jaccard_topk(index_data, query_data, index_params, k):
10-
(index_sets, index_keys) = index_data
11-
(query_sets, query_keys) = query_data
12-
print("Building HNSW Index.")
13-
start = time.perf_counter()
14-
index = nmslib.init(
15-
method="hnsw",
16-
space="jaccard_sparse",
17-
data_type=nmslib.DataType.OBJECT_AS_STRING,
18-
)
19-
index.addDataPointBatch(
20-
[" ".join(str(v) for v in s) for s in index_sets], range(len(index_keys))
21-
)
22-
index.createIndex(index_params)
23-
indexing_time = time.perf_counter() - start
24-
print("Indexing time: {:.3f}.".format(indexing_time))
16+
def search_nswlib_jaccard_topk(index_data, query_data, index_params, k):
17+
import nmslib
18+
19+
(index_sets, index_keys, _, index_cache) = index_data
20+
(query_sets, query_keys, _) = query_data
21+
cache_key = json.dumps(index_params)
22+
if cache_key not in index_cache:
23+
print("Building HNSW Index.")
24+
start = time.perf_counter()
25+
index = nmslib.init(
26+
method="hnsw",
27+
space="jaccard_sparse",
28+
data_type=nmslib.DataType.OBJECT_AS_STRING,
29+
)
30+
index.addDataPointBatch(
31+
[" ".join(str(v) for v in s) for s in index_sets], range(len(index_keys))
32+
)
33+
index.createIndex(index_params)
34+
indexing_time = time.perf_counter() - start
35+
print("Indexing time: {:.3f}.".format(indexing_time))
36+
index_cache[cache_key] = (
37+
index,
38+
{
39+
"indexing_time": indexing_time,
40+
},
41+
)
42+
index, indexing = index_cache[cache_key]
2543
print("Querying.")
2644
times = []
2745
results = []
2846
index.setQueryTimeParams({"efSearch": index_params["efConstruction"]})
29-
for query_set, query_key in zip(query_sets, query_keys):
47+
for query_set, query_key in tqdm.tqdm(
48+
zip(query_sets, query_keys),
49+
total=len(query_keys),
50+
desc="Querying",
51+
unit=" query",
52+
):
3053
start = time.perf_counter()
3154
result, _ = index.knnQuery(" ".join(str(v) for v in query_set), k)
3255
result = [
@@ -36,6 +59,108 @@ def search_hnsw_jaccard_topk(index_data, query_data, index_params, k):
3659
duration = time.perf_counter() - start
3760
times.append(duration)
3861
results.append((query_key, result))
39-
sys.stdout.write(f"\rQueried {len(results)} sets")
40-
sys.stdout.write("\n")
41-
return (indexing_time, results, times)
62+
return (indexing, results, times)
63+
64+
65+
def search_hnsw_jaccard_topk(index_data, query_data, index_params, k):
66+
(index_sets, index_keys, _, index_cache) = index_data
67+
(query_sets, query_keys, _) = query_data
68+
cache_key = json.dumps(index_params)
69+
if cache_key not in index_cache:
70+
print("Building HNSW Index.")
71+
start = time.perf_counter()
72+
index = HNSW(distance_func=compute_jaccard_distance, **index_params)
73+
for i in tqdm.tqdm(
74+
range(len(index_keys)),
75+
desc="Indexing",
76+
unit=" set",
77+
total=len(index_keys),
78+
):
79+
index.insert(i, index_sets[i])
80+
indexing_time = time.perf_counter() - start
81+
print("Indexing time: {:.3f}.".format(indexing_time))
82+
index_cache[cache_key] = (
83+
index,
84+
{
85+
"indexing_time": indexing_time,
86+
},
87+
)
88+
index, indexing = index_cache[cache_key]
89+
print("Querying.")
90+
times = []
91+
results = []
92+
for query_set, query_key in tqdm.tqdm(
93+
zip(query_sets, query_keys),
94+
total=len(query_keys),
95+
desc="Querying",
96+
unit=" query",
97+
):
98+
start = time.perf_counter()
99+
result = index.query(query_set, k)
100+
# Convert distances to similarities.
101+
result = [(index_keys[i], 1.0 - dist) for i, dist in result]
102+
duration = time.perf_counter() - start
103+
times.append(duration)
104+
results.append((query_key, result))
105+
return (indexing, results, times)
106+
107+
108+
def search_hnsw_minhash_jaccard_topk(index_data, query_data, index_params, k):
109+
(index_sets, index_keys, index_minhashes, index_cache) = index_data
110+
(query_sets, query_keys, query_minhashes) = query_data
111+
num_perm = index_params["num_perm"]
112+
cache_key = json.dumps(index_params)
113+
if cache_key not in index_cache:
114+
# Create minhashes
115+
index_minhash_time, query_minhash_time = lazy_create_minhashes_from_sets(
116+
index_minhashes,
117+
index_sets,
118+
query_minhashes,
119+
query_sets,
120+
num_perm,
121+
)
122+
print("Building HNSW Index for MinHash.")
123+
start = time.perf_counter()
124+
kwargs = index_params.copy()
125+
kwargs.pop("num_perm")
126+
index = HNSW(distance_func=compute_minhash_jaccard_distance, **kwargs)
127+
for i in tqdm.tqdm(
128+
range(len(index_keys)),
129+
desc="Indexing",
130+
unit=" query",
131+
total=len(index_keys),
132+
):
133+
index.insert(i, index_minhashes[num_perm][i])
134+
indexing_time = time.perf_counter() - start
135+
print("Indexing time: {:.3f}.".format(indexing_time))
136+
index_cache[cache_key] = (
137+
index,
138+
{
139+
"index_minhash_time": index_minhash_time,
140+
"query_minhash_time": query_minhash_time,
141+
"indexing_time": indexing_time,
142+
},
143+
)
144+
index, indexing = index_cache[cache_key]
145+
print("Querying.")
146+
times = []
147+
results = []
148+
for query_minhash, query_key, query_set in tqdm.tqdm(
149+
zip(query_minhashes[num_perm], query_keys, query_sets),
150+
total=len(query_keys),
151+
desc="Querying",
152+
unit=" query",
153+
):
154+
start = time.perf_counter()
155+
result = index.query(query_minhash, k)
156+
# Recover the retrieved indexed sets and
157+
# compute the exact Jaccard similarities.
158+
result = [
159+
[index_keys[i], compute_jaccard(query_set, index_sets[i])] for i in result
160+
]
161+
# Sort by similarity.
162+
result.sort(key=lambda x: x[1], reverse=True)
163+
duration = time.perf_counter() - start
164+
times.append(duration)
165+
results.append((query_key, result))
166+
return (indexing, results, times)

benchmark/indexes/jaccard/lsh.py

+44-22
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,59 @@
1+
import json
12
import time
23
import sys
34

5+
import tqdm
6+
47
from datasketch import MinHashLSH
58

6-
from utils import compute_jaccard
9+
from utils import compute_jaccard, lazy_create_minhashes_from_sets
710

811

9-
def search_lsh_jaccard_topk(index_data, query_data, b, r, k):
10-
(index_sets, index_keys, index_minhashes) = index_data
12+
def search_lsh_jaccard_topk(index_data, query_data, index_params, k):
13+
(index_sets, index_keys, index_minhashes, index_cache) = index_data
1114
(query_sets, query_keys, query_minhashes) = query_data
15+
b, r = index_params["b"], index_params["r"]
1216
num_perm = b * r
13-
print("Building LSH Index.")
14-
start = time.perf_counter()
15-
index = MinHashLSH(
16-
num_perm=num_perm,
17-
params=(b, r),
18-
)
19-
# Use the indices of the indexed sets as keys in LSH.
20-
for i in range(len(index_keys)):
21-
index.insert(
22-
i,
23-
index_minhashes[num_perm][i],
24-
check_duplication=False,
17+
cache_key = json.dumps(index_params)
18+
if cache_key not in index_cache:
19+
# Create minhashes
20+
index_minhash_time, query_minhash_time = lazy_create_minhashes_from_sets(
21+
index_minhashes,
22+
index_sets,
23+
query_minhashes,
24+
query_sets,
25+
num_perm,
26+
)
27+
print("Building LSH Index.")
28+
start = time.perf_counter()
29+
index = MinHashLSH(num_perm=num_perm, params=(b, r))
30+
# Use the indices of the indexed sets as keys in LSH.
31+
for i in tqdm.tqdm(
32+
range(len(index_keys)),
33+
desc="Indexing",
34+
unit=" minhash",
35+
total=len(index_keys),
36+
):
37+
index.insert(i, index_minhashes[num_perm][i], check_duplication=False)
38+
indexing_time = time.perf_counter() - start
39+
print("Indexing time: {:.3f}.".format(indexing_time))
40+
index_cache[cache_key] = (
41+
index,
42+
{
43+
"index_minhash_time": index_minhash_time,
44+
"query_minhash_time": query_minhash_time,
45+
"indexing_time": indexing_time,
46+
},
2547
)
26-
indexing_time = time.perf_counter() - start
27-
print("Indexing time: {:.3f}.".format(indexing_time))
48+
index, indexing = index_cache[cache_key]
2849
print("Querying.")
2950
times = []
3051
results = []
31-
for query_minhash, query_key, query_set in zip(
32-
query_minhashes[num_perm], query_keys, query_sets
52+
for query_minhash, query_key, query_set in tqdm.tqdm(
53+
zip(query_minhashes[num_perm], query_keys, query_sets),
54+
total=len(query_keys),
55+
desc="Querying",
56+
unit=" query",
3357
):
3458
start = time.perf_counter()
3559
result = index.query(query_minhash)
@@ -45,6 +69,4 @@ def search_lsh_jaccard_topk(index_data, query_data, b, r, k):
4569
duration = time.perf_counter() - start
4670
times.append(duration)
4771
results.append((query_key, result))
48-
sys.stdout.write(f"\rQueried {len(results)} sets")
49-
sys.stdout.write("\n")
50-
return (indexing_time, results, times)
72+
return (indexing, results, times)

0 commit comments

Comments
 (0)