Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 218 additions & 0 deletions faiss_vector_batch_executor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// Copyright (c) 2025 Couchbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build vectors
// +build vectors

package zap

import (
"encoding/json"
"slices"
"sync"
"time"

"github.com/RoaringBitmap/roaring/v2/roaring64"
faiss "github.com/blevesearch/go-faiss"
segment "github.com/blevesearch/scorch_segment_api/v2"
)

// batchKey represents a unique combination of k and params for batching
type batchKey struct {
k int64
params string // string representation of params for comparison
}

// batchRequest represents a single vector search request in a batch
type batchRequest struct {
qVector []float32
result chan batchResult
}

// batchGroup represents a group of requests with the same k and params
type batchGroup struct {
requests []*batchRequest
vecIndex *faiss.IndexImpl
vecDocIDMap map[int64]uint32
vectorIDsToExclude []int64
}

// batchExecutor manages batched vector search requests
type batchExecutor struct {
batchDelay time.Duration

m sync.RWMutex
groups map[batchKey]*batchGroup
}

func newBatchExecutor(options *segment.InterpretVectorIndexOptions) *batchExecutor {
batchDelay := segment.DefaultBatchExecutionDelay
if options != nil && options.BatchExecutionDelay > 0 {
batchDelay = options.BatchExecutionDelay
}

return &batchExecutor{
batchDelay: batchDelay,
groups: make(map[batchKey]*batchGroup),
}
}

type batchResult struct {
pl segment.VecPostingsList
err error
}

func (be *batchExecutor) close() {
be.m.Lock()
defer be.m.Unlock()

for key, group := range be.groups {
for _, req := range group.requests {
close(req.result)
}
delete(be.groups, key)
}
}

// queueRequest adds a vector search request to the appropriate batch group
func (be *batchExecutor) queueRequest(qVector []float32, k int64, params json.RawMessage,
vecIndex *faiss.IndexImpl, vecDocIDMap map[int64]uint32,
vectorIDsToExclude []int64) <-chan batchResult {

// Create a channel for the result
resultCh := make(chan batchResult, 1)

// Create batch key
key := batchKey{
k: k,
params: string(params),
}

be.m.Lock()
defer be.m.Unlock()

// Get or create batch group
group, exists := be.groups[key]
if !exists {
group = &batchGroup{
requests: make([]*batchRequest, 0),
vecIndex: vecIndex,
vecDocIDMap: vecDocIDMap,
vectorIDsToExclude: vectorIDsToExclude,
}
be.groups[key] = group
}

// Add request to group
group.requests = append(group.requests, &batchRequest{
qVector: qVector,
result: resultCh,
})

// If this is the first request in the group, start a timer to process the batch
if len(group.requests) == 1 {
be.processBatchAfterDelay(key, be.batchDelay)
}

return resultCh
}

// processBatchAfterDelay waits for the specified delay and then processes the batch
func (be *batchExecutor) processBatchAfterDelay(key batchKey, delay time.Duration) {
time.AfterFunc(delay, func() {
be.m.Lock()
group, exists := be.groups[key]
if !exists {
be.m.Unlock()
return
}

requests := slices.Clone(group.requests)
group.requests = group.requests[:0] // re-use
vecIndex := group.vecIndex
vecDocIDMap := group.vecDocIDMap
vectorIDsToExclude := group.vectorIDsToExclude
be.m.Unlock()

// Process the batch
be.processBatch(key, requests, vecIndex, vecDocIDMap, vectorIDsToExclude)
})
}

// processBatch executes a batch of vector search requests
func (be *batchExecutor) processBatch(key batchKey, requests []*batchRequest,
vecIndex *faiss.IndexImpl, vecDocIDMap map[int64]uint32,
vectorIDsToExclude []int64) {
if len(requests) == 0 {
return
}

// Prepare vectors for batch search
dim := vecIndex.D()
vecs := make([]float32, len(requests)*dim)
for i, req := range requests {
copy(vecs[i*dim:(i+1)*dim], req.qVector)
}

// Execute batch search
scores, ids, err := vecIndex.SearchWithoutIDs(vecs, key.k, vectorIDsToExclude,
json.RawMessage(key.params))
if err != nil {
// Send error to all channels
for _, req := range requests {
req.result <- batchResult{
err: err,
}
close(req.result)
}
return
}

// Calculate number of results per request
resultsPerRequest := int(key.k)
totalResults := len(scores)

// Process results and send to respective channels
for i := range requests {
pl := &VecPostingsList{
postings: roaring64.New(),
}

// Calculate start and end indices for this request's results
startIdx := i * resultsPerRequest
endIdx := startIdx + resultsPerRequest
if endIdx > totalResults {
endIdx = totalResults
}

// Get this request's results
currScores := scores[startIdx:endIdx]
currIDs := ids[startIdx:endIdx]

// Add results to postings list
for j := 0; j < len(currIDs); j++ {
vecID := currIDs[j]
if docID, ok := vecDocIDMap[vecID]; ok {
code := getVectorCode(docID, currScores[j])
pl.postings.Add(code)
}
}

// Send result to channel
requests[i].result <- batchResult{
pl: pl,
}
close(requests[i].result)
}
}
54 changes: 34 additions & 20 deletions faiss_vector_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/RoaringBitmap/roaring/v2"
faiss "github.com/blevesearch/go-faiss"
segment "github.com/blevesearch/scorch_segment_api/v2"
)

func newVectorIndexCache() *vectorIndexCache {
Expand Down Expand Up @@ -56,17 +57,18 @@ func (vc *vectorIndexCache) Clear() {
// present. It also returns the batch executor for the field if it's present in the
// cache.
func (vc *vectorIndexCache) loadOrCreate(fieldID uint16, mem []byte,
loadDocVecIDMap bool, except *roaring.Bitmap) (
loadDocVecIDMap bool, except *roaring.Bitmap,
options *segment.InterpretVectorIndexOptions) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, docVecIDMap map[uint32][]int64,
vecIDsToExclude []int64, err error) {
vecIDsToExclude []int64, batchExec *batchExecutor, err error) {
vc.m.RLock()
entry, ok := vc.cache[fieldID]
if ok {
index, vecDocIDMap, docVecIDMap = entry.load()
index, vecDocIDMap, docVecIDMap, batchExec = entry.load()
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)
if !loadDocVecIDMap || len(entry.docVecIDMap) > 0 {
vc.m.RUnlock()
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil
}

vc.m.RUnlock()
Expand All @@ -76,14 +78,14 @@ func (vc *vectorIndexCache) loadOrCreate(fieldID uint16, mem []byte,
// typically seen for the first filtered query.
docVecIDMap = vc.addDocVecIDMapToCacheLOCKED(entry)
vc.m.Unlock()
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil
}

vc.m.RUnlock()
// acquiring a lock since this is modifying the cache.
vc.m.Lock()
defer vc.m.Unlock()
return vc.createAndCacheLOCKED(fieldID, mem, loadDocVecIDMap, except)
return vc.createAndCacheLOCKED(fieldID, mem, loadDocVecIDMap, except, options)
}

func (vc *vectorIndexCache) addDocVecIDMapToCacheLOCKED(ce *cacheEntry) map[uint32][]int64 {
Expand All @@ -104,21 +106,23 @@ func (vc *vectorIndexCache) addDocVecIDMapToCacheLOCKED(ce *cacheEntry) map[uint

// Rebuilding the cache on a miss.
func (vc *vectorIndexCache) createAndCacheLOCKED(fieldID uint16, mem []byte,
loadDocVecIDMap bool, except *roaring.Bitmap) (
loadDocVecIDMap bool, except *roaring.Bitmap,
options *segment.InterpretVectorIndexOptions) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32,
docVecIDMap map[uint32][]int64, vecIDsToExclude []int64, err error) {
docVecIDMap map[uint32][]int64, vecIDsToExclude []int64,
batchExec *batchExecutor, err error) {

// Handle concurrent accesses (to avoid unnecessary work) by adding a
// check within the write lock here.
entry := vc.cache[fieldID]
if entry != nil {
index, vecDocIDMap, docVecIDMap = entry.load()
index, vecDocIDMap, docVecIDMap, batchExec = entry.load()
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)
if !loadDocVecIDMap || len(entry.docVecIDMap) > 0 {
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil
}
docVecIDMap = vc.addDocVecIDMapToCacheLOCKED(entry)
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil
}

// if the cache doesn't have the entry, construct the vector to doc id map and
Expand Down Expand Up @@ -154,16 +158,17 @@ func (vc *vectorIndexCache) createAndCacheLOCKED(fieldID uint16, mem []byte,

index, err = faiss.ReadIndexFromBuffer(mem[pos:pos+int(indexSize)], faissIOFlags)
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, err
}

vc.insertLOCKED(fieldID, index, vecDocIDMap, loadDocVecIDMap, docVecIDMap)
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
batchExec = newBatchExecutor(options)
vc.insertLOCKED(fieldID, index, vecDocIDMap, loadDocVecIDMap, docVecIDMap, batchExec)
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil
}

func (vc *vectorIndexCache) insertLOCKED(fieldIDPlus1 uint16,
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, loadDocVecIDMap bool,
docVecIDMap map[uint32][]int64) {
docVecIDMap map[uint32][]int64, batchExec *batchExecutor) {
// the first time we've hit the cache, try to spawn a monitoring routine
// which will reconcile the moving averages for all the fields being hit
if len(vc.cache) == 0 {
Expand All @@ -178,7 +183,7 @@ func (vc *vectorIndexCache) insertLOCKED(fieldIDPlus1 uint16,
// longer time and thereby the index to be resident in the cache
// for longer time.
vc.cache[fieldIDPlus1] = createCacheEntry(index, vecDocIDMap,
loadDocVecIDMap, docVecIDMap, 0.4)
loadDocVecIDMap, docVecIDMap, 0.4, batchExec)
}
}

Expand Down Expand Up @@ -272,15 +277,17 @@ func (e *ewma) add(val uint64) {
// -----------------------------------------------------------------------------

func createCacheEntry(index *faiss.IndexImpl, vecDocIDMap map[int64]uint32,
loadDocVecIDMap bool, docVecIDMap map[uint32][]int64, alpha float64) *cacheEntry {
loadDocVecIDMap bool, docVecIDMap map[uint32][]int64, alpha float64,
batchExec *batchExecutor) *cacheEntry {
ce := &cacheEntry{
index: index,
vecDocIDMap: vecDocIDMap,
tracker: &ewma{
alpha: alpha,
sample: 1,
},
refs: 1,
refs: 1,
batchExec: batchExec,
}
if loadDocVecIDMap {
ce.docVecIDMap = docVecIDMap
Expand All @@ -299,6 +306,8 @@ type cacheEntry struct {
index *faiss.IndexImpl
vecDocIDMap map[int64]uint32
docVecIDMap map[uint32][]int64

batchExec *batchExecutor
}

func (ce *cacheEntry) incHit() {
Expand All @@ -313,10 +322,14 @@ func (ce *cacheEntry) decRef() {
atomic.AddInt64(&ce.refs, -1)
}

func (ce *cacheEntry) load() (*faiss.IndexImpl, map[int64]uint32, map[uint32][]int64) {
func (ce *cacheEntry) load() (
*faiss.IndexImpl,
map[int64]uint32,
map[uint32][]int64,
*batchExecutor) {
ce.incHit()
ce.addRef()
return ce.index, ce.vecDocIDMap, ce.docVecIDMap
return ce.index, ce.vecDocIDMap, ce.docVecIDMap, ce.batchExec
}

func (ce *cacheEntry) close() {
Expand All @@ -325,6 +338,7 @@ func (ce *cacheEntry) close() {
ce.index = nil
ce.vecDocIDMap = nil
ce.docVecIDMap = nil
ce.batchExec.close()
}()
}

Expand Down
Loading