@@ -2,13 +2,16 @@ package faiss
22
33/*
44#include <stdlib.h>
5+ #include <stdint.h>
56#include <faiss/c_api/Index_c.h>
67#include <faiss/c_api/IndexIVF_c.h>
8+ #include <faiss/c_api/IndexBinary_c.h>
79#include <faiss/c_api/IndexIVF_c_ex.h>
810#include <faiss/c_api/Index_c_ex.h>
911#include <faiss/c_api/impl/AuxIndexStructures_c.h>
1012#include <faiss/c_api/index_factory_c.h>
1113#include <faiss/c_api/MetaIndexes_c.h>
14+ #include <faiss/c_api/IndexBinary_c.h>
1215*/
1316import "C"
1417import (
@@ -36,13 +39,13 @@ type Index interface {
3639 MetricType () int
3740
3841 // Train trains the index on a representative set of vectors.
39- Train (x [] float32 ) error
42+ Train (x interface {} ) error
4043
4144 // Add adds vectors to the index.
42- Add (x [] float32 ) error
45+ Add (x interface {} ) error
4346
4447 // AddWithIDs is like Add, but stores xids instead of sequential IDs.
45- AddWithIDs (x [] float32 , xids []int64 ) error
48+ AddWithIDs (x interface {} , xids []int64 ) error
4649
4750 // Returns true if the index is an IVF index.
4851 IsIVFIndex () bool
@@ -75,6 +78,12 @@ type Index interface {
7578 SearchWithIDs (x []float32 , k int64 , include []int64 , params json.RawMessage ) (distances []float32 ,
7679 labels []int64 , err error )
7780
81+ SearchBinaryWithIDs (x []uint8 , k int64 , params json.RawMessage ) (distances []int32 ,
82+ labels []int64 , err error )
83+
84+ SearchBinary (x []uint8 , k int64 ) (distances []int32 ,
85+ labels []int64 , err error )
86+
7887 // Applicable only to IVF indexes: Search clusters whose IDs are in eligibleCentroidIDs
7988 SearchClustersFromIVFIndex (selector Selector , eligibleCentroidIDs []int64 ,
8089 minEligibleCentroids int , k int64 , x , centroidDis []float32 ,
@@ -104,50 +113,101 @@ type Index interface {
104113 Size () uint64
105114
106115 cPtr () * C.FaissIndex
116+
117+ cPtrBinary () * C.FaissIndexBinary
118+
119+ IVFDistCompute (queryData []float32 , ids []int64 , k int , distances []float32 )
107120}
108121
109122type faissIndex struct {
110- idx * C.FaissIndex
123+ idx * C.FaissIndex
124+ idxBinary * C.FaissIndexBinary
111125}
112126
113127func (idx * faissIndex ) cPtr () * C.FaissIndex {
114128 return idx .idx
115129}
116130
131+ func (idx * faissIndex ) IVFDistCompute (queryData []float32 , ids []int64 , k int , distances []float32 ) {
132+ C .faiss_IndexIVF_dist_compute (idx .idx , (* C .float )(& queryData [0 ]),
133+ (* C .idx_t )(& ids [0 ]), (C .size_t )(k ), (* C .float )(& distances [0 ]))
134+ }
135+
136+ func (idx * faissIndex ) cPtrBinary () * C.FaissIndexBinary {
137+ return idx .idxBinary
138+ }
139+
117140func (idx * faissIndex ) Size () uint64 {
118141 size := C .faiss_Index_size (idx .idx )
119142 return uint64 (size )
120143}
121144
122145func (idx * faissIndex ) D () int {
123- return int (C .faiss_Index_d (idx .idx ))
146+ if idx .idx != nil {
147+ return int (C .faiss_Index_d (idx .idx ))
148+ }
149+ return int (C .faiss_IndexBinary_d (idx .idxBinary ))
124150}
125151
126152func (idx * faissIndex ) IsTrained () bool {
127153 return C .faiss_Index_is_trained (idx .idx ) != 0
128154}
129155
130156func (idx * faissIndex ) Ntotal () int64 {
157+ if idx .idxBinary != nil {
158+ return int64 (C .faiss_IndexBinary_ntotal (idx .idxBinary ))
159+ }
131160 return int64 (C .faiss_Index_ntotal (idx .idx ))
132161}
133162
134163func (idx * faissIndex ) MetricType () int {
135164 return int (C .faiss_Index_metric_type (idx .idx ))
136165}
137166
138- func (idx * faissIndex ) Train (x []float32 ) error {
139- n := len (x ) / idx .D ()
140- if c := C .faiss_Index_train (idx .idx , C .idx_t (n ), (* C .float )(& x [0 ])); c != 0 {
141- return getLastError ()
167+ func (idx * faissIndex ) Train (x interface {}) error {
168+ floatVec , ok := x .([]float32 )
169+ if ok {
170+ n := len (floatVec ) / idx .D ()
171+ if c := C .faiss_Index_train (idx .idx , C .idx_t (n ), (* C .float )(& floatVec [0 ])); c != 0 {
172+ return getLastError ()
173+ }
174+ } else {
175+ c , ok := x .([]uint8 )
176+ if ok {
177+ n := (len (c ) * 8 ) / idx .D ()
178+ if c := C .faiss_IndexBinary_train (idx .idxBinary , C .idx_t (n ), (* C .uint8_t )(& c [0 ])); c != 0 {
179+ return getLastError ()
180+ }
181+ }
142182 }
143183 return nil
144184}
145185
146- func (idx * faissIndex ) Add (x []float32 ) error {
147- n := len (x ) / idx .D ()
148- if c := C .faiss_Index_add (idx .idx , C .idx_t (n ), (* C .float )(& x [0 ])); c != 0 {
149- return getLastError ()
186+ func (idx * faissIndex ) Add (x interface {}) error {
187+ floatVec , ok := x .([]float32 )
188+ if ok {
189+ n := len (floatVec ) / idx .D ()
190+ if c := C .faiss_Index_add (
191+ idx .idx ,
192+ C .idx_t (n ),
193+ (* C .float )(& floatVec [0 ]),
194+ ); c != 0 {
195+ return getLastError ()
196+ }
197+ } else {
198+ c , ok := x .([]uint8 )
199+ if ok {
200+ n := (len (c ) * 8 ) / idx .D ()
201+ if c := C .faiss_IndexBinary_add (
202+ idx .idxBinary ,
203+ C .idx_t (n ),
204+ (* C .uint8_t )(& c [0 ]),
205+ ); c != 0 {
206+ return getLastError ()
207+ }
208+ }
150209 }
210+
151211 return nil
152212}
153213
@@ -257,16 +317,50 @@ func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector,
257317 return distances , labels , nil
258318}
259319
260- func (idx * faissIndex ) AddWithIDs (x []float32 , xids []int64 ) error {
261- n := len (x ) / idx .D ()
262- if c := C .faiss_Index_add_with_ids (
263- idx .idx ,
264- C .idx_t (n ),
265- (* C .float )(& x [0 ]),
266- (* C .idx_t )(& xids [0 ]),
267- ); c != 0 {
268- return getLastError ()
320+ func packBits (bits []uint8 ) []uint8 {
321+ n := (len (bits ) + 7 ) / 8
322+ result := make ([]uint8 , n )
323+ for i := 0 ; i < len (bits ); i ++ {
324+ // Determine the index in the result slice
325+ byteIndex := i / 8
326+ // Determine the bit position in the byte
327+ bitPosition := uint (7 - (i % 8 ))
328+ // If the bit is 1, set the corresponding bit in the uint8 value
329+ if bits [i ] == 1 {
330+ result [byteIndex ] |= (1 << bitPosition )
331+ }
332+ }
333+
334+ return result
335+ }
336+
337+ func (idx * faissIndex ) AddWithIDs (x interface {}, xids []int64 ) error {
338+ floatVec , ok := x .([]float32 )
339+ if ok {
340+ n := len (floatVec ) / idx .D ()
341+ if c := C .faiss_Index_add_with_ids (
342+ idx .idx ,
343+ C .idx_t (n ),
344+ (* C .float )(& floatVec [0 ]),
345+ (* C .idx_t )(& xids [0 ]),
346+ ); c != 0 {
347+ return getLastError ()
348+ }
349+ } else {
350+ c , ok := x .([]uint8 )
351+ if ok {
352+ n := (len (c ) * 8 ) / idx .D ()
353+ if c := C .faiss_IndexBinary_add_with_ids (
354+ idx .idxBinary ,
355+ C .idx_t (n ),
356+ (* C .uint8_t )(& c [0 ]),
357+ (* C .idx_t )(& xids [0 ]),
358+ ); c != 0 {
359+ return getLastError ()
360+ }
361+ }
269362 }
363+
270364 return nil
271365}
272366
@@ -318,6 +412,51 @@ func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, p
318412 return
319413}
320414
415+ func (idx * faissIndex ) SearchBinaryWithIDs (x []uint8 , k int64 ,
416+ params json.RawMessage ) (distances []int32 , labels []int64 , err error ,
417+ ) {
418+ d := idx .D ()
419+ nq := (len (x ) * 8 ) / d
420+
421+ distances = make ([]int32 , int64 (nq )* k )
422+ labels = make ([]int64 , int64 (nq )* k )
423+
424+ if c := C .faiss_IndexBinary_search (
425+ idx .idxBinary ,
426+ C .idx_t (nq ),
427+ (* C .uint8_t )(& x [0 ]),
428+ C .idx_t (k ),
429+ (* C .int32_t )(& distances [0 ]),
430+ (* C .idx_t )(& labels [0 ]),
431+ ); c != 0 {
432+ err = getLastError ()
433+ }
434+
435+ return distances , labels , nil
436+ }
437+
438+ func (idx * faissIndex ) SearchBinary (x []uint8 , k int64 ) (distances []int32 , labels []int64 , err error ,
439+ ) {
440+ d := idx .D ()
441+ nq := (len (x ) * 8 ) / d
442+
443+ distances = make ([]int32 , int64 (nq )* k )
444+ labels = make ([]int64 , int64 (nq )* k )
445+
446+ if c := C .faiss_IndexBinary_search (
447+ idx .idxBinary ,
448+ C .idx_t (nq ),
449+ (* C .uint8_t )(& x [0 ]),
450+ C .idx_t (k ),
451+ (* C .int32_t )(& distances [0 ]),
452+ (* C .idx_t )(& labels [0 ]),
453+ ); c != 0 {
454+ err = getLastError ()
455+ }
456+
457+ return distances , labels , nil
458+ }
459+
321460func (idx * faissIndex ) SearchWithIDs (x []float32 , k int64 , include []int64 ,
322461 params json.RawMessage ) (distances []float32 , labels []int64 , err error ,
323462) {
@@ -426,6 +565,7 @@ func (idx *faissIndex) RemoveIDs(sel *IDSelector) (int, error) {
426565
427566func (idx * faissIndex ) Close () {
428567 C .faiss_Index_free (idx .idx )
568+ C .faiss_IndexBinary_free (idx .idxBinary )
429569}
430570
431571func (idx * faissIndex ) searchWithParams (x []float32 , k int64 , searchParams * C.FaissSearchParameters ) (
@@ -507,6 +647,17 @@ func IndexFactory(d int, description string, metric int) (*IndexImpl, error) {
507647 return & IndexImpl {& idx }, nil
508648}
509649
650+ func IndexBinaryFactory (d int , description string , metric int ) (* IndexImpl , error ) {
651+ cdesc := C .CString (description )
652+ defer C .free (unsafe .Pointer (cdesc ))
653+ var idx faissIndex
654+ c := C .faiss_index_binary_factory (& idx .idxBinary , C .int (d ), cdesc )
655+ if c != 0 {
656+ return nil , getLastError ()
657+ }
658+ return & IndexImpl {& idx }, nil
659+ }
660+
510661func SetOMPThreads (n uint ) {
511662 C .faiss_set_omp_threads (C .uint (n ))
512663}
0 commit comments