@@ -78,7 +78,10 @@ type Index interface {
7878 SearchWithIDs (x []float32 , k int64 , include []int64 , params json.RawMessage ) (distances []float32 ,
7979 labels []int64 , err error )
8080
81- SearchBinaryWithIDs (x []uint8 , k int64 , params json.RawMessage ) (distances []int32 ,
81+ SearchBinary (x []uint8 , k int64 ) (distances []int32 ,
82+ labels []int64 , err error )
83+
84+ SearchBinaryWithIDs (x []uint8 , k int64 , include []int64 , params json.RawMessage ) (distances []int32 ,
8285 labels []int64 , err error )
8386
8487 SearchBinaryWithoutIDs (x []uint8 , k int64 , exclude []int64 , params json.RawMessage ) (distances []int32 ,
@@ -399,7 +402,30 @@ func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, p
399402 return
400403}
401404
402- func (idx * faissIndex ) SearchBinaryWithIDs (x []uint8 , k int64 ,
405+ func (idx * faissIndex ) SearchBinary (x []uint8 , k int64 ) (distances []int32 ,
406+ labels []int64 , err error ,
407+ ) {
408+ d := idx .D ()
409+ nq := (len (x ) * 8 ) / d
410+
411+ distances = make ([]int32 , int64 (nq )* k )
412+ labels = make ([]int64 , int64 (nq )* k )
413+
414+ if c := C .faiss_IndexBinary_search (
415+ idx .idxBinary ,
416+ C .idx_t (nq ),
417+ (* C .uint8_t )(& x [0 ]),
418+ C .idx_t (k ),
419+ (* C .int32_t )(& distances [0 ]),
420+ (* C .idx_t )(& labels [0 ]),
421+ ); c != 0 {
422+ err = getLastError ()
423+ }
424+
425+ return distances , labels , nil
426+ }
427+
428+ func (idx * faissIndex ) SearchBinaryWithIDs (x []uint8 , k int64 , include []int64 ,
403429 params json.RawMessage ) (distances []int32 , labels []int64 , err error ,
404430) {
405431 d := idx .D ()
@@ -408,7 +434,17 @@ func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64,
408434 distances = make ([]int32 , int64 (nq )* k )
409435 labels = make ([]int64 , int64 (nq )* k )
410436
411- searchParams , err := NewSearchParams (idx , params , nil , nil )
437+ var selector * C.FaissIDSelector
438+ if len (include ) > 0 {
439+ includeSelector , err := NewIDSelectorBatch (include )
440+ if err != nil {
441+ return nil , nil , err
442+ }
443+ selector = includeSelector .Get ()
444+ defer includeSelector .Delete ()
445+ }
446+
447+ searchParams , err := NewSearchParams (idx , params , selector , nil )
412448 if err != nil {
413449 return nil , nil , err
414450 }
@@ -429,8 +465,12 @@ func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64,
429465 return distances , labels , nil
430466}
431467
432- func (idx * faissIndex ) SearchBinaryWithoutIDs (x []uint8 , k int64 , exclude []int64 ,params json.RawMessage ) (distances []int32 ,
468+ func (idx * faissIndex ) SearchBinaryWithoutIDs (x []uint8 , k int64 , exclude []int64 , params json.RawMessage ) (distances []int32 ,
433469 labels []int64 , err error ) {
470+ if len (exclude ) == 0 && params == nil {
471+ return idx .SearchBinary (x , k )
472+ }
473+
434474 d := idx .D ()
435475 nq := (len (x ) * 8 ) / d
436476
@@ -461,7 +501,7 @@ func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int6
461501 searchParams .sp ,
462502 (* C .int32_t )(& distances [0 ]),
463503 (* C .idx_t )(& labels [0 ]),
464- ); c != 0 {
504+ ); c != 0 {
465505 err = getLastError ()
466506 }
467507
0 commit comments