Skip to content

Commit f7c8908

Browse files
support for binary indexes - wip
1 parent 371fb38 commit f7c8908

File tree

5 files changed

+307
-42
lines changed

5 files changed

+307
-42
lines changed

go.mod

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
module github.com/blevesearch/go-faiss
22

3-
go 1.21
3+
go 1.22
4+
5+
toolchain go1.23.0

index.go

Lines changed: 173 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@ package faiss
22

33
/*
44
#include <stdlib.h>
5+
#include <stdint.h>
56
#include <faiss/c_api/Index_c.h>
67
#include <faiss/c_api/IndexIVF_c.h>
8+
#include <faiss/c_api/IndexBinary_c.h>
79
#include <faiss/c_api/IndexIVF_c_ex.h>
810
#include <faiss/c_api/Index_c_ex.h>
911
#include <faiss/c_api/impl/AuxIndexStructures_c.h>
1012
#include <faiss/c_api/index_factory_c.h>
1113
#include <faiss/c_api/MetaIndexes_c.h>
14+
#include <faiss/c_api/IndexBinary_c.h>
1215
*/
1316
import "C"
1417
import (
@@ -36,13 +39,13 @@ type Index interface {
3639
MetricType() int
3740

3841
// Train trains the index on a representative set of vectors.
39-
Train(x []float32) error
42+
Train(x interface{}) error
4043

4144
// Add adds vectors to the index.
42-
Add(x []float32) error
45+
Add(x interface{}) error
4346

4447
// AddWithIDs is like Add, but stores xids instead of sequential IDs.
45-
AddWithIDs(x []float32, xids []int64) error
48+
AddWithIDs(x interface{}, xids []int64) error
4649

4750
// Returns true if the index is an IVF index.
4851
IsIVFIndex() bool
@@ -75,6 +78,12 @@ type Index interface {
7578
SearchWithIDs(x []float32, k int64, include []int64, params json.RawMessage) (distances []float32,
7679
labels []int64, err error)
7780

81+
SearchBinaryWithIDs(x []uint8, k int64, params json.RawMessage) (distances []int32,
82+
labels []int64, err error)
83+
84+
SearchBinary(x []uint8, k int64) (distances []int32,
85+
labels []int64, err error)
86+
7887
// Applicable only to IVF indexes: Search clusters whose IDs are in eligibleCentroidIDs
7988
SearchClustersFromIVFIndex(selector Selector, eligibleCentroidIDs []int64,
8089
minEligibleCentroids int, k int64, x, centroidDis []float32,
@@ -104,50 +113,101 @@ type Index interface {
104113
Size() uint64
105114

106115
cPtr() *C.FaissIndex
116+
117+
cPtrBinary() *C.FaissIndexBinary
118+
119+
IVFDistCompute(queryData []float32, ids []int64, k int, distances []float32)
107120
}
108121

109122
type faissIndex struct {
110-
idx *C.FaissIndex
123+
idx *C.FaissIndex
124+
idxBinary *C.FaissIndexBinary
111125
}
112126

113127
func (idx *faissIndex) cPtr() *C.FaissIndex {
114128
return idx.idx
115129
}
116130

131+
func (idx *faissIndex) IVFDistCompute(queryData []float32, ids []int64, k int, distances []float32) {
132+
C.faiss_IndexIVF_dist_compute(idx.idx, (*C.float)(&queryData[0]),
133+
(*C.idx_t)(&ids[0]), (C.size_t)(k), (*C.float)(&distances[0]))
134+
}
135+
136+
func (idx *faissIndex) cPtrBinary() *C.FaissIndexBinary {
137+
return idx.idxBinary
138+
}
139+
117140
func (idx *faissIndex) Size() uint64 {
118141
size := C.faiss_Index_size(idx.idx)
119142
return uint64(size)
120143
}
121144

122145
func (idx *faissIndex) D() int {
123-
return int(C.faiss_Index_d(idx.idx))
146+
if idx.idx != nil {
147+
return int(C.faiss_Index_d(idx.idx))
148+
}
149+
return int(C.faiss_IndexBinary_d(idx.idxBinary))
124150
}
125151

126152
func (idx *faissIndex) IsTrained() bool {
127153
return C.faiss_Index_is_trained(idx.idx) != 0
128154
}
129155

130156
func (idx *faissIndex) Ntotal() int64 {
157+
if idx.idxBinary != nil {
158+
return int64(C.faiss_IndexBinary_ntotal(idx.idxBinary))
159+
}
131160
return int64(C.faiss_Index_ntotal(idx.idx))
132161
}
133162

134163
func (idx *faissIndex) MetricType() int {
135164
return int(C.faiss_Index_metric_type(idx.idx))
136165
}
137166

138-
func (idx *faissIndex) Train(x []float32) error {
139-
n := len(x) / idx.D()
140-
if c := C.faiss_Index_train(idx.idx, C.idx_t(n), (*C.float)(&x[0])); c != 0 {
141-
return getLastError()
167+
func (idx *faissIndex) Train(x interface{}) error {
168+
floatVec, ok := x.([]float32)
169+
if ok {
170+
n := len(floatVec) / idx.D()
171+
if c := C.faiss_Index_train(idx.idx, C.idx_t(n), (*C.float)(&floatVec[0])); c != 0 {
172+
return getLastError()
173+
}
174+
} else {
175+
c, ok := x.([]uint8)
176+
if ok {
177+
n := (len(c) * 8) / idx.D()
178+
if c := C.faiss_IndexBinary_train(idx.idxBinary, C.idx_t(n), (*C.uint8_t)(&c[0])); c != 0 {
179+
return getLastError()
180+
}
181+
}
142182
}
143183
return nil
144184
}
145185

146-
func (idx *faissIndex) Add(x []float32) error {
147-
n := len(x) / idx.D()
148-
if c := C.faiss_Index_add(idx.idx, C.idx_t(n), (*C.float)(&x[0])); c != 0 {
149-
return getLastError()
186+
func (idx *faissIndex) Add(x interface{}) error {
187+
floatVec, ok := x.([]float32)
188+
if ok {
189+
n := len(floatVec) / idx.D()
190+
if c := C.faiss_Index_add(
191+
idx.idx,
192+
C.idx_t(n),
193+
(*C.float)(&floatVec[0]),
194+
); c != 0 {
195+
return getLastError()
196+
}
197+
} else {
198+
c, ok := x.([]uint8)
199+
if ok {
200+
n := (len(c) * 8) / idx.D()
201+
if c := C.faiss_IndexBinary_add(
202+
idx.idxBinary,
203+
C.idx_t(n),
204+
(*C.uint8_t)(&c[0]),
205+
); c != 0 {
206+
return getLastError()
207+
}
208+
}
150209
}
210+
151211
return nil
152212
}
153213

@@ -257,16 +317,50 @@ func (idx *faissIndex) SearchClustersFromIVFIndex(selector Selector,
257317
return distances, labels, nil
258318
}
259319

260-
func (idx *faissIndex) AddWithIDs(x []float32, xids []int64) error {
261-
n := len(x) / idx.D()
262-
if c := C.faiss_Index_add_with_ids(
263-
idx.idx,
264-
C.idx_t(n),
265-
(*C.float)(&x[0]),
266-
(*C.idx_t)(&xids[0]),
267-
); c != 0 {
268-
return getLastError()
320+
func packBits(bits []uint8) []uint8 {
321+
n := (len(bits) + 7) / 8
322+
result := make([]uint8, n)
323+
for i := 0; i < len(bits); i++ {
324+
// Determine the index in the result slice
325+
byteIndex := i / 8
326+
// Determine the bit position in the byte
327+
bitPosition := uint(7 - (i % 8))
328+
// If the bit is 1, set the corresponding bit in the uint8 value
329+
if bits[i] == 1 {
330+
result[byteIndex] |= (1 << bitPosition)
331+
}
332+
}
333+
334+
return result
335+
}
336+
337+
func (idx *faissIndex) AddWithIDs(x interface{}, xids []int64) error {
338+
floatVec, ok := x.([]float32)
339+
if ok {
340+
n := len(floatVec) / idx.D()
341+
if c := C.faiss_Index_add_with_ids(
342+
idx.idx,
343+
C.idx_t(n),
344+
(*C.float)(&floatVec[0]),
345+
(*C.idx_t)(&xids[0]),
346+
); c != 0 {
347+
return getLastError()
348+
}
349+
} else {
350+
c, ok := x.([]uint8)
351+
if ok {
352+
n := (len(c) * 8) / idx.D()
353+
if c := C.faiss_IndexBinary_add_with_ids(
354+
idx.idxBinary,
355+
C.idx_t(n),
356+
(*C.uint8_t)(&c[0]),
357+
(*C.idx_t)(&xids[0]),
358+
); c != 0 {
359+
return getLastError()
360+
}
361+
}
269362
}
363+
270364
return nil
271365
}
272366

@@ -318,6 +412,51 @@ func (idx *faissIndex) SearchWithoutIDs(x []float32, k int64, exclude []int64, p
318412
return
319413
}
320414

415+
func (idx *faissIndex) SearchBinaryWithIDs(x []uint8, k int64,
416+
params json.RawMessage) (distances []int32, labels []int64, err error,
417+
) {
418+
d := idx.D()
419+
nq := (len(x) * 8) / d
420+
421+
distances = make([]int32, int64(nq)*k)
422+
labels = make([]int64, int64(nq)*k)
423+
424+
if c := C.faiss_IndexBinary_search(
425+
idx.idxBinary,
426+
C.idx_t(nq),
427+
(*C.uint8_t)(&x[0]),
428+
C.idx_t(k),
429+
(*C.int32_t)(&distances[0]),
430+
(*C.idx_t)(&labels[0]),
431+
); c != 0 {
432+
err = getLastError()
433+
}
434+
435+
return distances, labels, nil
436+
}
437+
438+
func (idx *faissIndex) SearchBinary(x []uint8, k int64) (distances []int32, labels []int64, err error,
439+
) {
440+
d := idx.D()
441+
nq := (len(x) * 8) / d
442+
443+
distances = make([]int32, int64(nq)*k)
444+
labels = make([]int64, int64(nq)*k)
445+
446+
if c := C.faiss_IndexBinary_search(
447+
idx.idxBinary,
448+
C.idx_t(nq),
449+
(*C.uint8_t)(&x[0]),
450+
C.idx_t(k),
451+
(*C.int32_t)(&distances[0]),
452+
(*C.idx_t)(&labels[0]),
453+
); c != 0 {
454+
err = getLastError()
455+
}
456+
457+
return distances, labels, nil
458+
}
459+
321460
func (idx *faissIndex) SearchWithIDs(x []float32, k int64, include []int64,
322461
params json.RawMessage) (distances []float32, labels []int64, err error,
323462
) {
@@ -426,6 +565,7 @@ func (idx *faissIndex) RemoveIDs(sel *IDSelector) (int, error) {
426565

427566
func (idx *faissIndex) Close() {
428567
C.faiss_Index_free(idx.idx)
568+
C.faiss_IndexBinary_free(idx.idxBinary)
429569
}
430570

431571
func (idx *faissIndex) searchWithParams(x []float32, k int64, searchParams *C.FaissSearchParameters) (
@@ -507,6 +647,17 @@ func IndexFactory(d int, description string, metric int) (*IndexImpl, error) {
507647
return &IndexImpl{&idx}, nil
508648
}
509649

650+
func IndexBinaryFactory(d int, description string, metric int) (*IndexImpl, error) {
651+
cdesc := C.CString(description)
652+
defer C.free(unsafe.Pointer(cdesc))
653+
var idx faissIndex
654+
c := C.faiss_index_binary_factory(&idx.idxBinary, C.int(d), cdesc)
655+
if c != 0 {
656+
return nil, getLastError()
657+
}
658+
return &IndexImpl{&idx}, nil
659+
}
660+
510661
func SetOMPThreads(n uint) {
511662
C.faiss_set_omp_threads(C.uint(n))
512663
}

index_flat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,5 @@ func (idx *IndexImpl) AsFlat() *IndexFlat {
5252
if ptr == nil {
5353
panic("index is not a flat index")
5454
}
55-
return &IndexFlat{&faissIndex{ptr}}
55+
return &IndexFlat{&faissIndex{idx: ptr}}
5656
}

0 commit comments

Comments
 (0)