Skip to content
Open
91 changes: 64 additions & 27 deletions faiss_vector_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package zap

import (
"encoding/binary"
"log"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -57,28 +58,29 @@ func (vc *vectorIndexCache) Clear() {
// map. It's false otherwise.
func (vc *vectorIndexCache) loadOrCreate(fieldID uint16, mem []byte,
loadDocVecIDMap bool, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, docVecIDMap map[uint32][]int64,
indexes []*faiss.IndexImpl, vecDocIDMap map[int64]uint32, docVecIDMap map[uint32][]int64,
vecIDsToExclude []int64, err error) {
index, vecDocIDMap, docVecIDMap, vecIDsToExclude, err = vc.loadFromCache(
indexes, vecDocIDMap, docVecIDMap, vecIDsToExclude, err = vc.loadFromCache(
fieldID, loadDocVecIDMap, mem, except)
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, err
return indexes, vecDocIDMap, docVecIDMap, vecIDsToExclude, err
}

// function to load the vectorDocIDMap and if required, docVecIDMap from cache
// If not, it will create these and add them to the cache.
func (vc *vectorIndexCache) loadFromCache(fieldID uint16, loadDocVecIDMap bool,
mem []byte, except *roaring.Bitmap) (index *faiss.IndexImpl, vecDocIDMap map[int64]uint32,
docVecIDMap map[uint32][]int64, vecIDsToExclude []int64, err error) {
mem []byte, except *roaring.Bitmap) (indexes []*faiss.IndexImpl,
vecDocIDMap map[int64]uint32, docVecIDMap map[uint32][]int64,
vecIDsToExclude []int64, err error) {

vc.m.RLock()

entry, ok := vc.cache[fieldID]
if ok {
index, vecDocIDMap, docVecIDMap = entry.load()
indexes, vecDocIDMap, docVecIDMap = entry.load()
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)
if !loadDocVecIDMap || (loadDocVecIDMap && len(entry.docVecIDMap) > 0) {
vc.m.RUnlock()
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
return indexes, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
}

vc.m.RUnlock()
Expand All @@ -88,7 +90,7 @@ func (vc *vectorIndexCache) loadFromCache(fieldID uint16, loadDocVecIDMap bool,
// typically seen for the first filtered query.
docVecIDMap = vc.addDocVecIDMapToCacheLOCKED(entry)
vc.m.Unlock()
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
return indexes, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
}

vc.m.RUnlock()
Expand Down Expand Up @@ -117,20 +119,20 @@ func (vc *vectorIndexCache) addDocVecIDMapToCacheLOCKED(ce *cacheEntry) map[uint
// Rebuilding the cache on a miss.
func (vc *vectorIndexCache) createAndCacheLOCKED(fieldID uint16, mem []byte,
loadDocVecIDMap bool, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32,
indexes []*faiss.IndexImpl, vecDocIDMap map[int64]uint32,
docVecIDMap map[uint32][]int64, vecIDsToExclude []int64, err error) {

// Handle concurrent accesses (to avoid unnecessary work) by adding a
// check within the write lock here.
entry := vc.cache[fieldID]
if entry != nil {
index, vecDocIDMap, docVecIDMap = entry.load()
indexes, vecDocIDMap, docVecIDMap = entry.load()
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)
if !loadDocVecIDMap || (loadDocVecIDMap && len(entry.docVecIDMap) > 0) {
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
return indexes, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
}
docVecIDMap = vc.addDocVecIDMapToCacheLOCKED(entry)
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
return indexes, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
}

// if the cache doesn't have the entry, construct the vector to doc id map and
Expand Down Expand Up @@ -161,21 +163,39 @@ func (vc *vectorIndexCache) createAndCacheLOCKED(fieldID uint16, mem []byte,
}
}

indexes = make([]*faiss.IndexImpl, 0)
binaryIndexSize, n := binary.Uvarint(mem[pos : pos+binary.MaxVarintLen64])
pos += n

// Read binary index with proper flags
binaryIndex, err := faiss.ReadBinaryIndexFromBuffer(mem[pos:pos+int(binaryIndexSize)], faissIOFlags)
if err != nil {
log.Printf("Error reading binary index: %v", err)
return nil, nil, nil, nil, err
}
indexes = append(indexes, binaryIndex)
pos += int(binaryIndexSize)

indexSize, n := binary.Uvarint(mem[pos : pos+binary.MaxVarintLen64])
pos += n

index, err = faiss.ReadIndexFromBuffer(mem[pos:pos+int(indexSize)], faissIOFlags)
index, err := faiss.ReadIndexFromBuffer(mem[pos:pos+int(indexSize)], faissIOFlags)
if err != nil {
return nil, nil, nil, nil, err
}
indexes = append(indexes, index)

vc.insertLOCKED(fieldID, index, vecDocIDMap, loadDocVecIDMap, docVecIDMap)
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
cacheEntryStub := cacheEntryReqs{
index: indexes[1],
binaryIndex: indexes[0],
vecDocIDMap: vecDocIDMap,
}

vc.insertLOCKED(fieldID, cacheEntryStub)
return indexes, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
}

func (vc *vectorIndexCache) insertLOCKED(fieldIDPlus1 uint16,
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, loadDocVecIDMap bool,
docVecIDMap map[uint32][]int64) {
func (vc *vectorIndexCache) insertLOCKED(fieldIDPlus1 uint16, ce cacheEntryReqs) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take pointer to the cacheEntryReqs struct, don't copy the struct

// the first time we've hit the cache, try to spawn a monitoring routine
// which will reconcile the moving averages for all the fields being hit
if len(vc.cache) == 0 {
Expand All @@ -189,8 +209,7 @@ func (vc *vectorIndexCache) insertLOCKED(fieldIDPlus1 uint16,
// this makes the average to be kept above the threshold value for a
// longer time and thereby the index to be resident in the cache
// for longer time.
vc.cache[fieldIDPlus1] = createCacheEntry(index, vecDocIDMap,
loadDocVecIDMap, docVecIDMap, 0.4)
vc.cache[fieldIDPlus1] = createCacheEntry(&ce, 0.4, ce.loadDocVecIDMap)
}
}

Expand Down Expand Up @@ -283,19 +302,33 @@ func (e *ewma) add(val uint64) {

// -----------------------------------------------------------------------------

func createCacheEntry(index *faiss.IndexImpl, vecDocIDMap map[int64]uint32,
loadDocVecIDMap bool, docVecIDMap map[uint32][]int64, alpha float64) *cacheEntry {
// required info to create a cache entry.
type cacheEntryReqs struct {
alpha float64
index *faiss.IndexImpl
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just have the indexes array itself here

indexes []*faiss.IndexImpl

binaryIndex *faiss.IndexImpl
vecDocIDMap map[int64]uint32
// Used to indicate if the below fields are populated - will only be
// used for pre-filtered queries.
loadDocVecIDMap bool
docVecIDMap map[uint32][]int64
clusterAssignment map[int64]*roaring.Bitmap
}

func createCacheEntry(stub *cacheEntryReqs, alpha float64,
loadDocVecIDMap bool) *cacheEntry {
ce := &cacheEntry{
index: index,
vecDocIDMap: vecDocIDMap,
index: stub.index,
binaryIndex: stub.binaryIndex,
vecDocIDMap: stub.vecDocIDMap,
tracker: &ewma{
alpha: alpha,
sample: 1,
},
refs: 1,
}
if loadDocVecIDMap {
ce.docVecIDMap = docVecIDMap
ce.docVecIDMap = stub.docVecIDMap
}
return ce
}
Expand All @@ -309,6 +342,7 @@ type cacheEntry struct {
refs int64

index *faiss.IndexImpl
binaryIndex *faiss.IndexImpl
vecDocIDMap map[int64]uint32
docVecIDMap map[uint32][]int64
}
Expand All @@ -325,16 +359,19 @@ func (ce *cacheEntry) decRef() {
atomic.AddInt64(&ce.refs, -1)
}

func (ce *cacheEntry) load() (*faiss.IndexImpl, map[int64]uint32, map[uint32][]int64) {
func (ce *cacheEntry) load() ([]*faiss.IndexImpl,
map[int64]uint32, map[uint32][]int64) {
ce.incHit()
ce.addRef()
return ce.index, ce.vecDocIDMap, ce.docVecIDMap
return []*faiss.IndexImpl{ce.binaryIndex, ce.index}, ce.vecDocIDMap, ce.docVecIDMap
}

func (ce *cacheEntry) close() {
go func() {
ce.index.Close()
ce.index = nil
ce.binaryIndex.Close()
ce.binaryIndex = nil
ce.vecDocIDMap = nil
ce.docVecIDMap = nil
}()
Expand Down
63 changes: 60 additions & 3 deletions faiss_vector_posting.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"math"
"reflect"

"container/heap"
"github.com/RoaringBitmap/roaring/v2"
"github.com/RoaringBitmap/roaring/v2/roaring64"
"github.com/bits-and-blooms/bitset"
Expand Down Expand Up @@ -301,6 +302,31 @@ func (i *vectorIndexWrapper) Size() uint64 {
return i.size()
}

// distanceID represents a distance-ID pair for heap operations
type distanceID struct {
distance float32
id int64
}

// maxHeap implements heap.Interface for distanceID
type maxHeap []*distanceID

func (h maxHeap) Len() int { return len(h) }
func (h maxHeap) Less(i, j int) bool { return h[i].distance > h[j].distance }
func (h maxHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }

func (h *maxHeap) Push(x interface{}) {
*h = append(*h, x.(*distanceID))
}

func (h *maxHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}

// InterpretVectorIndex returns a construct of closures (vectorIndexWrapper)
// that will allow the caller to -
// (1) search within an attached vector index
Expand All @@ -312,6 +338,7 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool
segment.VectorIndex, error) {
// Params needed for the closures
var vecIndex *faiss.IndexImpl
var binaryIndex *faiss.IndexImpl
var vecDocIDMap map[int64]uint32
var docVecIDMap map[uint32][]int64
var vectorIDsToExclude []int64
Expand Down Expand Up @@ -354,12 +381,39 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool
return rv, nil
}

scores, ids, err := vecIndex.SearchWithoutIDs(qVector, k,
vectorIDsToExclude, params)
binaryQueryVector := convertToBinary(qVector)
_, binIDs, err := binaryIndex.SearchBinaryWithoutIDs(binaryQueryVector,
k*4, vectorIDsToExclude, params)
if err != nil {
return nil, err
}

distances := make([]float32, k*4)
err = vecIndex.DistCompute(qVector, binIDs, int(k*4), distances)
if err != nil {
return nil, err
}

// Need to map distances to the original IDs to get the top K.
// Use a heap to keep track of the top K.
h := &maxHeap{}
heap.Init(h)
for i := 0; i < len(binIDs); i++ {
heap.Push(h, &distanceID{distance: distances[i], id: binIDs[i]})
if h.Len() > int(k) {
heap.Pop(h)
}
}

// Pop the top K in reverse order to get them in ascending order
ids := make([]int64, k)
scores := make([]float32, k)
for i := int(k) - 1; i >= 0; i-- {
distanceID := heap.Pop(h).(*distanceID)
scores[i] = distanceID.distance
ids[i] = distanceID.id
}

addIDsToPostingsList(rv, ids, scores)

return rv, nil
Expand Down Expand Up @@ -547,9 +601,12 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool
pos += n
}

vecIndex, vecDocIDMap, docVecIDMap, vectorIDsToExclude, err =
vecIndexes := make([]*faiss.IndexImpl, 2)
vecIndexes, vecDocIDMap, docVecIDMap, vectorIDsToExclude, err =
sb.vecIndexCache.loadOrCreate(fieldIDPlus1, sb.mem[pos:], requiresFiltering,
except)
vecIndex = vecIndexes[1]
binaryIndex = vecIndexes[0]

if vecIndex != nil {
vecIndexSize = vecIndex.Size()
Expand Down
6 changes: 5 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module github.com/blevesearch/zapx/v16

go 1.21
go 1.22

toolchain go1.23.0

require (
github.com/RoaringBitmap/roaring/v2 v2.4.5
Expand All @@ -20,3 +22,5 @@ require (
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/sys v0.13.0 // indirect
)

replace github.com/blevesearch/go-faiss => ../go-faiss
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ github.com/bits-and-blooms/bitset v1.22.0 h1:Tquv9S8+SGaS3EhyA+up3FXzmkhxPGjQQCk
github.com/bits-and-blooms/bitset v1.22.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
github.com/blevesearch/bleve_index_api v1.2.8 h1:Y98Pu5/MdlkRyLM0qDHostYo7i+Vv1cDNhqTeR4Sy6Y=
github.com/blevesearch/bleve_index_api v1.2.8/go.mod h1:rKQDl4u51uwafZxFrPD1R7xFOwKnzZW7s/LSeK4lgo0=
github.com/blevesearch/go-faiss v1.0.25 h1:lel1rkOUGbT1CJ0YgzKwC7k+XH0XVBHnCVWahdCXk4U=
github.com/blevesearch/go-faiss v1.0.25/go.mod h1:OMGQwOaRRYxrmeNdMrXJPvVx8gBnvE5RYrr0BahNnkk=
github.com/blevesearch/mmap-go v1.0.4 h1:OVhDhT5B/M1HNPpYPBKIEJaD0F3Si+CrEKULGCDPWmc=
github.com/blevesearch/mmap-go v1.0.4/go.mod h1:EWmEAOmdAS9z/pi/+Toxu99DnsbhG1TIxUoRmJw/pSs=
github.com/blevesearch/scorch_segment_api/v2 v2.3.10 h1:Yqk0XD1mE0fDZAJXTjawJ8If/85JxnLd8v5vG/jWE/s=
Expand Down
Loading