@@ -10,10 +10,15 @@ namespace hnswlib {
1010 for (unsigned i = 0 ; i < qty; i++) {
1111 res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
1212 }
13- return ( 1 . 0f - res) ;
13+ return res;
1414
1515 }
1616
17+ static float
18+ InnerProductDistance (const void *pVect1, const void *pVect2, const void *qty_ptr) {
19+ return 1 .0f - InnerProduct (pVect1, pVect2, qty_ptr);
20+ }
21+
1722#if defined(USE_AVX)
1823
1924// Favor using AVX if available.
@@ -61,8 +66,13 @@ namespace hnswlib {
6166
6267 _mm_store_ps (TmpRes, sum_prod);
6368 float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];;
64- return 1 .0f - sum;
65- }
69+ return sum;
70+ }
71+
72+ static float
73+ InnerProductDistanceSIMD4ExtAVX (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
74+ return 1 .0f - InnerProductSIMD4ExtAVX (pVect1v, pVect2v, qty_ptr);
75+ }
6676
6777#endif
6878
@@ -121,7 +131,12 @@ namespace hnswlib {
121131 _mm_store_ps (TmpRes, sum_prod);
122132 float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
123133
124- return 1 .0f - sum;
134+ return sum;
135+ }
136+
137+ static float
138+ InnerProductDistanceSIMD4ExtSSE (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
139+ return 1 .0f - InnerProductSIMD4ExtSSE (pVect1v, pVect2v, qty_ptr);
125140 }
126141
127142#endif
@@ -156,7 +171,12 @@ namespace hnswlib {
156171 _mm512_store_ps (TmpRes, sum512);
157172 float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ] + TmpRes[4 ] + TmpRes[5 ] + TmpRes[6 ] + TmpRes[7 ] + TmpRes[8 ] + TmpRes[9 ] + TmpRes[10 ] + TmpRes[11 ] + TmpRes[12 ] + TmpRes[13 ] + TmpRes[14 ] + TmpRes[15 ];
158173
159- return 1 .0f - sum;
174+ return sum;
175+ }
176+
177+ static float
178+ InnerProductDistanceSIMD16ExtAVX512 (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
179+ return 1 .0f - InnerProductSIMD16ExtAVX512 (pVect1v, pVect2v, qty_ptr);
160180 }
161181
162182#endif
@@ -196,15 +216,20 @@ namespace hnswlib {
196216 _mm256_store_ps (TmpRes, sum256);
197217 float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ] + TmpRes[4 ] + TmpRes[5 ] + TmpRes[6 ] + TmpRes[7 ];
198218
199- return 1 .0f - sum;
219+ return sum;
220+ }
221+
222+ static float
223+ InnerProductDistanceSIMD16ExtAVX (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
224+ return 1 .0f - InnerProductSIMD16ExtAVX (pVect1v, pVect2v, qty_ptr);
200225 }
201226
202227#endif
203228
204229#if defined(USE_SSE)
205230
206- static float
207- InnerProductSIMD16ExtSSE (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
231+ static float
232+ InnerProductSIMD16ExtSSE (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
208233 float PORTABLE_ALIGN32 TmpRes[8 ];
209234 float *pVect1 = (float *) pVect1v;
210235 float *pVect2 = (float *) pVect2v;
@@ -245,17 +270,24 @@ namespace hnswlib {
245270 _mm_store_ps (TmpRes, sum_prod);
246271 float sum = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
247272
248- return 1 .0f - sum;
273+ return sum;
274+ }
275+
276+ static float
277+ InnerProductDistanceSIMD16ExtSSE (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
278+ return 1 .0f - InnerProductSIMD16ExtSSE (pVect1v, pVect2v, qty_ptr);
249279 }
250280
251281#endif
252282
253283#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
254284 DISTFUNC<float > InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE;
255285 DISTFUNC<float > InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE;
286+ DISTFUNC<float > InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE;
287+ DISTFUNC<float > InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE;
256288
257289 static float
258- InnerProductSIMD16ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
290+ InnerProductDistanceSIMD16ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
259291 size_t qty = *((size_t *) qty_ptr);
260292 size_t qty16 = qty >> 4 << 4 ;
261293 float res = InnerProductSIMD16Ext (pVect1v, pVect2v, &qty16);
@@ -264,11 +296,11 @@ namespace hnswlib {
264296
265297 size_t qty_left = qty - qty16;
266298 float res_tail = InnerProduct (pVect1, pVect2, &qty_left);
267- return res + res_tail - 1 . 0f ;
299+ return 1 . 0f - ( res + res_tail) ;
268300 }
269301
270302 static float
271- InnerProductSIMD4ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
303+ InnerProductDistanceSIMD4ExtResiduals (const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
272304 size_t qty = *((size_t *) qty_ptr);
273305 size_t qty4 = qty >> 2 << 2 ;
274306
@@ -279,7 +311,7 @@ namespace hnswlib {
279311 float *pVect2 = (float *) pVect2v + qty4;
280312 float res_tail = InnerProduct (pVect1, pVect2, &qty_left);
281313
282- return res + res_tail - 1 . 0f ;
314+ return 1 . 0f - ( res + res_tail) ;
283315 }
284316#endif
285317
@@ -290,30 +322,37 @@ namespace hnswlib {
290322 size_t dim_;
291323 public:
292324 InnerProductSpace (size_t dim) {
293- fstdistfunc_ = InnerProduct ;
325+ fstdistfunc_ = InnerProductDistance ;
294326 #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
295327 #if defined(USE_AVX512)
296- if (AVX512Capable ())
328+ if (AVX512Capable ()) {
297329 InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
298- else if (AVXCapable ())
330+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
331+ } else if (AVXCapable ()) {
299332 InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
333+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
334+ }
300335 #elif defined(USE_AVX)
301- if (AVXCapable ())
336+ if (AVXCapable ()) {
302337 InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
338+ InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
339+ }
303340 #endif
304341 #if defined(USE_AVX)
305- if (AVXCapable ())
342+ if (AVXCapable ()) {
306343 InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
344+ InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
345+ }
307346 #endif
308347
309348 if (dim % 16 == 0 )
310- fstdistfunc_ = InnerProductSIMD16Ext ;
349+ fstdistfunc_ = InnerProductDistanceSIMD16Ext ;
311350 else if (dim % 4 == 0 )
312- fstdistfunc_ = InnerProductSIMD4Ext ;
351+ fstdistfunc_ = InnerProductDistanceSIMD4Ext ;
313352 else if (dim > 16 )
314- fstdistfunc_ = InnerProductSIMD16ExtResiduals ;
353+ fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals ;
315354 else if (dim > 4 )
316- fstdistfunc_ = InnerProductSIMD4ExtResiduals ;
355+ fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals ;
317356 #endif
318357 dim_ = dim;
319358 data_size_ = dim * sizeof (float );
@@ -334,5 +373,4 @@ namespace hnswlib {
334373 ~InnerProductSpace () {}
335374 };
336375
337-
338376}
0 commit comments