Skip to content

Commit a961c59

Browse files
clean up
1 parent 5ad828d commit a961c59

File tree

1 file changed

+52
-5
lines changed

1 file changed

+52
-5
lines changed

index.go

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ type Index interface {
7878
SearchWithIDs(x []float32, k int64, include []int64, params json.RawMessage) (distances []float32,
7979
labels []int64, err error)
8080

81-
SearchBinaryWithIDs(x []uint8, k int64, params json.RawMessage) (distances []int32,
81+
SearchBinary(x []uint8, k int64) (distances []int32,
82+
labels []int64, err error)
83+
84+
SearchBinaryWithIDs(x []uint8, k int64, include []int64, params json.RawMessage) (distances []int32,
8285
labels []int64, err error)
8386

8487
SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32,
@@ -399,8 +402,8 @@ func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, p
399402
return
400403
}
401404

402-
func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64,
403-
params json.RawMessage) (distances []int32, labels []int64, err error,
405+
func (idx *faissIndex) SearchBinary(x []uint8, k int64) (distances []int32,
406+
labels []int64, err error,
404407
) {
405408
d := idx.D()
406409
nq := (len(x) * 8) / d
@@ -429,8 +432,52 @@ func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64,
429432
return distances, labels, nil
430433
}
431434

432-
func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64,params json.RawMessage) (distances []int32,
435+
func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64, include []int64,
436+
params json.RawMessage) (distances []int32, labels []int64, err error,
437+
) {
438+
d := idx.D()
439+
nq := (len(x) * 8) / d
440+
441+
distances = make([]int32, int64(nq)*k)
442+
labels = make([]int64, int64(nq)*k)
443+
444+
var selector *C.FaissIDSelector
445+
if len(include) > 0 {
446+
includeSelector, err := NewIDSelectorBatch(include)
447+
if err != nil {
448+
return nil, nil, err
449+
}
450+
selector = includeSelector.Get()
451+
defer includeSelector.Delete()
452+
}
453+
454+
searchParams, err := NewSearchParams(idx, params, selector, nil)
455+
if err != nil {
456+
return nil, nil, err
457+
}
458+
defer searchParams.Delete()
459+
460+
if c := C.faiss_IndexBinary_search_with_params(
461+
idx.idxBinary,
462+
C.idx_t(nq),
463+
(*C.uint8_t)(&x[0]),
464+
C.idx_t(k),
465+
searchParams.sp,
466+
(*C.int32_t)(&distances[0]),
467+
(*C.idx_t)(&labels[0]),
468+
); c != 0 {
469+
err = getLastError()
470+
}
471+
472+
return distances, labels, nil
473+
}
474+
475+
func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32,
433476
labels []int64, err error) {
477+
if len(exclude) == 0 && params == nil {
478+
return idx.SearchBinary(x, k, params)
479+
}
480+
434481
d := idx.D()
435482
nq := (len(x) * 8) / d
436483

@@ -461,7 +508,7 @@ func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int6
461508
searchParams.sp,
462509
(*C.int32_t)(&distances[0]),
463510
(*C.idx_t)(&labels[0]),
464-
); c != 0 {
511+
); c != 0 {
465512
err = getLastError()
466513
}
467514

0 commit comments

Comments
 (0)