1
+ import json
1
2
import time
2
3
import sys
3
4
4
- import nmslib
5
+ import tqdm
6
+ from datasketch .hnsw import HNSW
5
7
6
- from utils import compute_jaccard
8
+ from utils import (
9
+ compute_jaccard ,
10
+ compute_jaccard_distance ,
11
+ compute_minhash_jaccard_distance ,
12
+ lazy_create_minhashes_from_sets ,
13
+ )
7
14
8
15
9
- def search_hnsw_jaccard_topk (index_data , query_data , index_params , k ):
10
- (index_sets , index_keys ) = index_data
11
- (query_sets , query_keys ) = query_data
12
- print ("Building HNSW Index." )
13
- start = time .perf_counter ()
14
- index = nmslib .init (
15
- method = "hnsw" ,
16
- space = "jaccard_sparse" ,
17
- data_type = nmslib .DataType .OBJECT_AS_STRING ,
18
- )
19
- index .addDataPointBatch (
20
- [" " .join (str (v ) for v in s ) for s in index_sets ], range (len (index_keys ))
21
- )
22
- index .createIndex (index_params )
23
- indexing_time = time .perf_counter () - start
24
- print ("Indexing time: {:.3f}." .format (indexing_time ))
16
+ def search_nswlib_jaccard_topk (index_data , query_data , index_params , k ):
17
+ import nmslib
18
+
19
+ (index_sets , index_keys , _ , index_cache ) = index_data
20
+ (query_sets , query_keys , _ ) = query_data
21
+ cache_key = json .dumps (index_params )
22
+ if cache_key not in index_cache :
23
+ print ("Building HNSW Index." )
24
+ start = time .perf_counter ()
25
+ index = nmslib .init (
26
+ method = "hnsw" ,
27
+ space = "jaccard_sparse" ,
28
+ data_type = nmslib .DataType .OBJECT_AS_STRING ,
29
+ )
30
+ index .addDataPointBatch (
31
+ [" " .join (str (v ) for v in s ) for s in index_sets ], range (len (index_keys ))
32
+ )
33
+ index .createIndex (index_params )
34
+ indexing_time = time .perf_counter () - start
35
+ print ("Indexing time: {:.3f}." .format (indexing_time ))
36
+ index_cache [cache_key ] = (
37
+ index ,
38
+ {
39
+ "indexing_time" : indexing_time ,
40
+ },
41
+ )
42
+ index , indexing = index_cache [cache_key ]
25
43
print ("Querying." )
26
44
times = []
27
45
results = []
28
46
index .setQueryTimeParams ({"efSearch" : index_params ["efConstruction" ]})
29
- for query_set , query_key in zip (query_sets , query_keys ):
47
+ for query_set , query_key in tqdm .tqdm (
48
+ zip (query_sets , query_keys ),
49
+ total = len (query_keys ),
50
+ desc = "Querying" ,
51
+ unit = " query" ,
52
+ ):
30
53
start = time .perf_counter ()
31
54
result , _ = index .knnQuery (" " .join (str (v ) for v in query_set ), k )
32
55
result = [
@@ -36,6 +59,108 @@ def search_hnsw_jaccard_topk(index_data, query_data, index_params, k):
36
59
duration = time .perf_counter () - start
37
60
times .append (duration )
38
61
results .append ((query_key , result ))
39
- sys .stdout .write (f"\r Queried { len (results )} sets" )
40
- sys .stdout .write ("\n " )
41
- return (indexing_time , results , times )
62
+ return (indexing , results , times )
63
+
64
+
65
+ def search_hnsw_jaccard_topk (index_data , query_data , index_params , k ):
66
+ (index_sets , index_keys , _ , index_cache ) = index_data
67
+ (query_sets , query_keys , _ ) = query_data
68
+ cache_key = json .dumps (index_params )
69
+ if cache_key not in index_cache :
70
+ print ("Building HNSW Index." )
71
+ start = time .perf_counter ()
72
+ index = HNSW (distance_func = compute_jaccard_distance , ** index_params )
73
+ for i in tqdm .tqdm (
74
+ range (len (index_keys )),
75
+ desc = "Indexing" ,
76
+ unit = " set" ,
77
+ total = len (index_keys ),
78
+ ):
79
+ index .insert (i , index_sets [i ])
80
+ indexing_time = time .perf_counter () - start
81
+ print ("Indexing time: {:.3f}." .format (indexing_time ))
82
+ index_cache [cache_key ] = (
83
+ index ,
84
+ {
85
+ "indexing_time" : indexing_time ,
86
+ },
87
+ )
88
+ index , indexing = index_cache [cache_key ]
89
+ print ("Querying." )
90
+ times = []
91
+ results = []
92
+ for query_set , query_key in tqdm .tqdm (
93
+ zip (query_sets , query_keys ),
94
+ total = len (query_keys ),
95
+ desc = "Querying" ,
96
+ unit = " query" ,
97
+ ):
98
+ start = time .perf_counter ()
99
+ result = index .query (query_set , k )
100
+ # Convert distances to similarities.
101
+ result = [(index_keys [i ], 1.0 - dist ) for i , dist in result ]
102
+ duration = time .perf_counter () - start
103
+ times .append (duration )
104
+ results .append ((query_key , result ))
105
+ return (indexing , results , times )
106
+
107
+
108
+ def search_hnsw_minhash_jaccard_topk (index_data , query_data , index_params , k ):
109
+ (index_sets , index_keys , index_minhashes , index_cache ) = index_data
110
+ (query_sets , query_keys , query_minhashes ) = query_data
111
+ num_perm = index_params ["num_perm" ]
112
+ cache_key = json .dumps (index_params )
113
+ if cache_key not in index_cache :
114
+ # Create minhashes
115
+ index_minhash_time , query_minhash_time = lazy_create_minhashes_from_sets (
116
+ index_minhashes ,
117
+ index_sets ,
118
+ query_minhashes ,
119
+ query_sets ,
120
+ num_perm ,
121
+ )
122
+ print ("Building HNSW Index for MinHash." )
123
+ start = time .perf_counter ()
124
+ kwargs = index_params .copy ()
125
+ kwargs .pop ("num_perm" )
126
+ index = HNSW (distance_func = compute_minhash_jaccard_distance , ** kwargs )
127
+ for i in tqdm .tqdm (
128
+ range (len (index_keys )),
129
+ desc = "Indexing" ,
130
+ unit = " query" ,
131
+ total = len (index_keys ),
132
+ ):
133
+ index .insert (i , index_minhashes [num_perm ][i ])
134
+ indexing_time = time .perf_counter () - start
135
+ print ("Indexing time: {:.3f}." .format (indexing_time ))
136
+ index_cache [cache_key ] = (
137
+ index ,
138
+ {
139
+ "index_minhash_time" : index_minhash_time ,
140
+ "query_minhash_time" : query_minhash_time ,
141
+ "indexing_time" : indexing_time ,
142
+ },
143
+ )
144
+ index , indexing = index_cache [cache_key ]
145
+ print ("Querying." )
146
+ times = []
147
+ results = []
148
+ for query_minhash , query_key , query_set in tqdm .tqdm (
149
+ zip (query_minhashes [num_perm ], query_keys , query_sets ),
150
+ total = len (query_keys ),
151
+ desc = "Querying" ,
152
+ unit = " query" ,
153
+ ):
154
+ start = time .perf_counter ()
155
+ result = index .query (query_minhash , k )
156
+ # Recover the retrieved indexed sets and
157
+ # compute the exact Jaccard similarities.
158
+ result = [
159
+ [index_keys [i ], compute_jaccard (query_set , index_sets [i ])] for i in result
160
+ ]
161
+ # Sort by similarity.
162
+ result .sort (key = lambda x : x [1 ], reverse = True )
163
+ duration = time .perf_counter () - start
164
+ times .append (duration )
165
+ results .append ((query_key , result ))
166
+ return (indexing , results , times )
0 commit comments