@@ -78,7 +78,10 @@ type Index interface {
7878 SearchWithIDs (x []float32 , k int64 , include []int64 , params json.RawMessage ) (distances []float32 ,
7979 labels []int64 , err error )
8080
81- SearchBinaryWithIDs (x []uint8 , k int64 , params json.RawMessage ) (distances []int32 ,
81+ SearchBinary (x []uint8 , k int64 ) (distances []int32 ,
82+ labels []int64 , err error )
83+
84+ SearchBinaryWithIDs (x []uint8 , k int64 , include []int64 , params json.RawMessage ) (distances []int32 ,
8285 labels []int64 , err error )
8386
8487 SearchBinaryWithoutIDs (x []uint8 , k int64 , exclude []int64 , params json.RawMessage ) (distances []int32 ,
@@ -399,8 +402,8 @@ func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, p
399402 return
400403}
401404
402- func (idx * faissIndex ) SearchBinaryWithIDs (x []uint8 , k int64 ,
403- params json. RawMessage ) ( distances [] int32 , labels []int64 , err error ,
405+ func (idx * faissIndex ) SearchBinary (x []uint8 , k int64 ) ( distances [] int32 ,
406+ labels []int64 , err error ,
404407) {
405408 d := idx .D ()
406409 nq := (len (x ) * 8 ) / d
@@ -429,8 +432,52 @@ func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64,
429432 return distances , labels , nil
430433}
431434
432- func (idx * faissIndex ) SearchBinaryWithoutIDs (x []uint8 , k int64 , exclude []int64 ,params json.RawMessage ) (distances []int32 ,
435+ func (idx * faissIndex ) SearchBinaryWithIDs (x []uint8 , k int64 , include []int64 ,
436+ params json.RawMessage ) (distances []int32 , labels []int64 , err error ,
437+ ) {
438+ d := idx .D ()
439+ nq := (len (x ) * 8 ) / d
440+
441+ distances = make ([]int32 , int64 (nq )* k )
442+ labels = make ([]int64 , int64 (nq )* k )
443+
444+ var selector * C.FaissIDSelector
445+ if len (include ) > 0 {
446+ includeSelector , err := NewIDSelectorBatch (include )
447+ if err != nil {
448+ return nil , nil , err
449+ }
450+ selector = includeSelector .Get ()
451+ defer includeSelector .Delete ()
452+ }
453+
454+ searchParams , err := NewSearchParams (idx , params , selector , nil )
455+ if err != nil {
456+ return nil , nil , err
457+ }
458+ defer searchParams .Delete ()
459+
460+ if c := C .faiss_IndexBinary_search_with_params (
461+ idx .idxBinary ,
462+ C .idx_t (nq ),
463+ (* C .uint8_t )(& x [0 ]),
464+ C .idx_t (k ),
465+ searchParams .sp ,
466+ (* C .int32_t )(& distances [0 ]),
467+ (* C .idx_t )(& labels [0 ]),
468+ ); c != 0 {
469+ err = getLastError ()
470+ }
471+
472+ return distances , labels , nil
473+ }
474+
475+ func (idx * faissIndex ) SearchBinaryWithoutIDs (x []uint8 , k int64 , exclude []int64 , params json.RawMessage ) (distances []int32 ,
433476 labels []int64 , err error ) {
477+ if len (exclude ) == 0 && params == nil {
478+ return idx .SearchBinary (x , k , params )
479+ }
480+
434481 d := idx .D ()
435482 nq := (len (x ) * 8 ) / d
436483
@@ -461,7 +508,7 @@ func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int6
461508 searchParams .sp ,
462509 (* C .int32_t )(& distances [0 ]),
463510 (* C .idx_t )(& labels [0 ]),
464- ); c != 0 {
511+ ); c != 0 {
465512 err = getLastError ()
466513 }
467514
0 commit comments