Skip to content

Commit 329d981

Browse files
clean up
1 parent 5ad828d commit 329d981

File tree

1 file changed

+45
-5
lines changed

1 file changed

+45
-5
lines changed

index.go

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ type Index interface {
7878
SearchWithIDs(x []float32, k int64, include []int64, params json.RawMessage) (distances []float32,
7979
labels []int64, err error)
8080

81-
SearchBinaryWithIDs(x []uint8, k int64, params json.RawMessage) (distances []int32,
81+
SearchBinary(x []uint8, k int64) (distances []int32,
82+
labels []int64, err error)
83+
84+
SearchBinaryWithIDs(x []uint8, k int64, include []int64, params json.RawMessage) (distances []int32,
8285
labels []int64, err error)
8386

8487
SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32,
@@ -399,7 +402,30 @@ func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, p
399402
return
400403
}
401404

402-
func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64,
405+
func (idx *faissIndex) SearchBinary(x []uint8, k int64) (distances []int32,
406+
labels []int64, err error,
407+
) {
408+
d := idx.D()
409+
nq := (len(x) * 8) / d
410+
411+
distances = make([]int32, int64(nq)*k)
412+
labels = make([]int64, int64(nq)*k)
413+
414+
if c := C.faiss_IndexBinary_search(
415+
idx.idxBinary,
416+
C.idx_t(nq),
417+
(*C.uint8_t)(&x[0]),
418+
C.idx_t(k),
419+
(*C.int32_t)(&distances[0]),
420+
(*C.idx_t)(&labels[0]),
421+
); c != 0 {
422+
err = getLastError()
423+
}
424+
425+
return distances, labels, nil
426+
}
427+
428+
func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64, include []int64,
403429
params json.RawMessage) (distances []int32, labels []int64, err error,
404430
) {
405431
d := idx.D()
@@ -408,7 +434,17 @@ func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64,
408434
distances = make([]int32, int64(nq)*k)
409435
labels = make([]int64, int64(nq)*k)
410436

411-
searchParams, err := NewSearchParams(idx, params, nil, nil)
437+
var selector *C.FaissIDSelector
438+
if len(include) > 0 {
439+
includeSelector, err := NewIDSelectorBatch(include)
440+
if err != nil {
441+
return nil, nil, err
442+
}
443+
selector = includeSelector.Get()
444+
defer includeSelector.Delete()
445+
}
446+
447+
searchParams, err := NewSearchParams(idx, params, selector, nil)
412448
if err != nil {
413449
return nil, nil, err
414450
}
@@ -429,8 +465,12 @@ func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64,
429465
return distances, labels, nil
430466
}
431467

432-
func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64,params json.RawMessage) (distances []int32,
468+
func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int64, params json.RawMessage) (distances []int32,
433469
labels []int64, err error) {
470+
if len(exclude) == 0 && params == nil {
471+
return idx.SearchBinary(x, k)
472+
}
473+
434474
d := idx.D()
435475
nq := (len(x) * 8) / d
436476

@@ -461,7 +501,7 @@ func (idx *faissIndex) SearchBinaryWithoutIDs(x []uint8, k int64, exclude []int6
461501
searchParams.sp,
462502
(*C.int32_t)(&distances[0]),
463503
(*C.idx_t)(&labels[0]),
464-
); c != 0 {
504+
); c != 0 {
465505
err = getLastError()
466506
}
467507

0 commit comments

Comments
 (0)