diff --git a/faiss_vector_batch_executor.go b/faiss_vector_batch_executor.go new file mode 100644 index 0000000..93dbaf4 --- /dev/null +++ b/faiss_vector_batch_executor.go @@ -0,0 +1,218 @@ +// Copyright (c) 2025 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package zap + +import ( + "encoding/json" + "slices" + "sync" + "time" + + "github.com/RoaringBitmap/roaring/v2/roaring64" + faiss "github.com/blevesearch/go-faiss" + segment "github.com/blevesearch/scorch_segment_api/v2" +) + +// batchKey represents a unique combination of k and params for batching +type batchKey struct { + k int64 + params string // string representation of params for comparison +} + +// batchRequest represents a single vector search request in a batch +type batchRequest struct { + qVector []float32 + result chan batchResult +} + +// batchGroup represents a group of requests with the same k and params +type batchGroup struct { + requests []*batchRequest + vecIndex *faiss.IndexImpl + vecDocIDMap map[int64]uint32 + vectorIDsToExclude []int64 +} + +// batchExecutor manages batched vector search requests +type batchExecutor struct { + batchDelay time.Duration + + m sync.RWMutex + groups map[batchKey]*batchGroup +} + +func newBatchExecutor(options *segment.InterpretVectorIndexOptions) *batchExecutor { + batchDelay := segment.DefaultBatchExecutionDelay + if options != nil && options.BatchExecutionDelay > 0 { + batchDelay = options.BatchExecutionDelay + } + + return &batchExecutor{ + batchDelay: batchDelay, + groups: make(map[batchKey]*batchGroup), + } +} + +type batchResult struct { + pl segment.VecPostingsList + err error +} + +func (be *batchExecutor) close() { + be.m.Lock() + defer be.m.Unlock() + + for key, group := range be.groups { + for _, req := range group.requests { + close(req.result) + } + delete(be.groups, key) + } +} + +// queueRequest adds a vector search request to the appropriate batch group +func (be *batchExecutor) queueRequest(qVector []float32, k int64, params json.RawMessage, + vecIndex *faiss.IndexImpl, vecDocIDMap map[int64]uint32, + vectorIDsToExclude []int64) <-chan batchResult { + + // Create a channel for the result + resultCh := make(chan batchResult, 1) + + // Create batch key + key := batchKey{ + k: k, + params: string(params), + } + + be.m.Lock() + defer be.m.Unlock() + + // Get or create batch group + group, exists := be.groups[key] + if !exists { + group = &batchGroup{ + requests: make([]*batchRequest, 0), + vecIndex: vecIndex, + vecDocIDMap: vecDocIDMap, + vectorIDsToExclude: vectorIDsToExclude, + } + be.groups[key] = group + } + + // Add request to group + group.requests = append(group.requests, &batchRequest{ + qVector: qVector, + result: resultCh, + }) + + // If this is the first request in the group, start a timer to process the batch + if len(group.requests) == 1 { + be.processBatchAfterDelay(key, be.batchDelay) + } + + return resultCh +} + +// processBatchAfterDelay waits for the specified delay and then processes the batch +func (be *batchExecutor) processBatchAfterDelay(key batchKey, delay time.Duration) { + time.AfterFunc(delay, func() { + be.m.Lock() + group, exists := be.groups[key] + if !exists { + be.m.Unlock() + return + } + + requests := slices.Clone(group.requests) + group.requests = group.requests[:0] // re-use + vecIndex := group.vecIndex + vecDocIDMap := group.vecDocIDMap + vectorIDsToExclude := group.vectorIDsToExclude + be.m.Unlock() + + // Process the batch + be.processBatch(key, requests, vecIndex, vecDocIDMap, vectorIDsToExclude) + }) +} + +// processBatch executes a batch of vector search requests +func (be *batchExecutor) processBatch(key batchKey, requests []*batchRequest, + vecIndex *faiss.IndexImpl, vecDocIDMap map[int64]uint32, + vectorIDsToExclude []int64) { + if len(requests) == 0 { + return + } + + // Prepare vectors for batch search + dim := vecIndex.D() + vecs := make([]float32, len(requests)*dim) + for i, req := range requests { + copy(vecs[i*dim:(i+1)*dim], req.qVector) + } + + // Execute batch search + scores, ids, err := vecIndex.SearchWithoutIDs(vecs, key.k, vectorIDsToExclude, + json.RawMessage(key.params)) + if err != nil { + // Send error to all channels + for _, req := range requests { + req.result <- batchResult{ + err: err, + } + close(req.result) + } + return + } + + // Calculate number of results per request + resultsPerRequest := int(key.k) + totalResults := len(scores) + + // Process results and send to respective channels + for i := range requests { + pl := &VecPostingsList{ + postings: roaring64.New(), + } + + // Calculate start and end indices for this request's results + startIdx := i * resultsPerRequest + endIdx := startIdx + resultsPerRequest + if endIdx > totalResults { + endIdx = totalResults + } + + // Get this request's results + currScores := scores[startIdx:endIdx] + currIDs := ids[startIdx:endIdx] + + // Add results to postings list + for j := 0; j < len(currIDs); j++ { + vecID := currIDs[j] + if docID, ok := vecDocIDMap[vecID]; ok { + code := getVectorCode(docID, currScores[j]) + pl.postings.Add(code) + } + } + + // Send result to channel + requests[i].result <- batchResult{ + pl: pl, + } + close(requests[i].result) + } +} diff --git a/faiss_vector_cache.go b/faiss_vector_cache.go index ce8e1bf..18d7774 100644 --- a/faiss_vector_cache.go +++ b/faiss_vector_cache.go @@ -25,6 +25,7 @@ import ( "github.com/RoaringBitmap/roaring/v2" faiss "github.com/blevesearch/go-faiss" + segment "github.com/blevesearch/scorch_segment_api/v2" ) func newVectorIndexCache() *vectorIndexCache { @@ -56,17 +57,18 @@ func (vc *vectorIndexCache) Clear() { // present. It also returns the batch executor for the field if it's present in the // cache. func (vc *vectorIndexCache) loadOrCreate(fieldID uint16, mem []byte, - loadDocVecIDMap bool, except *roaring.Bitmap) ( + loadDocVecIDMap bool, except *roaring.Bitmap, + options *segment.InterpretVectorIndexOptions) ( index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, docVecIDMap map[uint32][]int64, - vecIDsToExclude []int64, err error) { + vecIDsToExclude []int64, batchExec *batchExecutor, err error) { vc.m.RLock() entry, ok := vc.cache[fieldID] if ok { - index, vecDocIDMap, docVecIDMap = entry.load() + index, vecDocIDMap, docVecIDMap, batchExec = entry.load() vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except) if !loadDocVecIDMap || len(entry.docVecIDMap) > 0 { vc.m.RUnlock() - return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil + return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil } vc.m.RUnlock() @@ -76,14 +78,14 @@ func (vc *vectorIndexCache) loadOrCreate(fieldID uint16, mem []byte, // typically seen for the first filtered query. docVecIDMap = vc.addDocVecIDMapToCacheLOCKED(entry) vc.m.Unlock() - return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil + return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil } vc.m.RUnlock() // acquiring a lock since this is modifying the cache. vc.m.Lock() defer vc.m.Unlock() - return vc.createAndCacheLOCKED(fieldID, mem, loadDocVecIDMap, except) + return vc.createAndCacheLOCKED(fieldID, mem, loadDocVecIDMap, except, options) } func (vc *vectorIndexCache) addDocVecIDMapToCacheLOCKED(ce *cacheEntry) map[uint32][]int64 { @@ -104,21 +106,23 @@ func (vc *vectorIndexCache) addDocVecIDMapToCacheLOCKED(ce *cacheEntry) map[uint // Rebuilding the cache on a miss. func (vc *vectorIndexCache) createAndCacheLOCKED(fieldID uint16, mem []byte, - loadDocVecIDMap bool, except *roaring.Bitmap) ( + loadDocVecIDMap bool, except *roaring.Bitmap, + options *segment.InterpretVectorIndexOptions) ( index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, - docVecIDMap map[uint32][]int64, vecIDsToExclude []int64, err error) { + docVecIDMap map[uint32][]int64, vecIDsToExclude []int64, + batchExec *batchExecutor, err error) { // Handle concurrent accesses (to avoid unnecessary work) by adding a // check within the write lock here. entry := vc.cache[fieldID] if entry != nil { - index, vecDocIDMap, docVecIDMap = entry.load() + index, vecDocIDMap, docVecIDMap, batchExec = entry.load() vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except) if !loadDocVecIDMap || len(entry.docVecIDMap) > 0 { - return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil + return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil } docVecIDMap = vc.addDocVecIDMapToCacheLOCKED(entry) - return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil + return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil } // if the cache doesn't have the entry, construct the vector to doc id map and @@ -154,16 +158,17 @@ func (vc *vectorIndexCache) createAndCacheLOCKED(fieldID uint16, mem []byte, index, err = faiss.ReadIndexFromBuffer(mem[pos:pos+int(indexSize)], faissIOFlags) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, err } - vc.insertLOCKED(fieldID, index, vecDocIDMap, loadDocVecIDMap, docVecIDMap) - return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil + batchExec = newBatchExecutor(options) + vc.insertLOCKED(fieldID, index, vecDocIDMap, loadDocVecIDMap, docVecIDMap, batchExec) + return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, batchExec, nil } func (vc *vectorIndexCache) insertLOCKED(fieldIDPlus1 uint16, index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, loadDocVecIDMap bool, - docVecIDMap map[uint32][]int64) { + docVecIDMap map[uint32][]int64, batchExec *batchExecutor) { // the first time we've hit the cache, try to spawn a monitoring routine // which will reconcile the moving averages for all the fields being hit if len(vc.cache) == 0 { @@ -178,7 +183,7 @@ func (vc *vectorIndexCache) insertLOCKED(fieldIDPlus1 uint16, // longer time and thereby the index to be resident in the cache // for longer time. vc.cache[fieldIDPlus1] = createCacheEntry(index, vecDocIDMap, - loadDocVecIDMap, docVecIDMap, 0.4) + loadDocVecIDMap, docVecIDMap, 0.4, batchExec) } } @@ -272,7 +277,8 @@ func (e *ewma) add(val uint64) { // ----------------------------------------------------------------------------- func createCacheEntry(index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, - loadDocVecIDMap bool, docVecIDMap map[uint32][]int64, alpha float64) *cacheEntry { + loadDocVecIDMap bool, docVecIDMap map[uint32][]int64, alpha float64, + batchExec *batchExecutor) *cacheEntry { ce := &cacheEntry{ index: index, vecDocIDMap: vecDocIDMap, @@ -280,7 +286,8 @@ func createCacheEntry(index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, alpha: alpha, sample: 1, }, - refs: 1, + refs: 1, + batchExec: batchExec, } if loadDocVecIDMap { ce.docVecIDMap = docVecIDMap @@ -299,6 +306,8 @@ type cacheEntry struct { index *faiss.IndexImpl vecDocIDMap map[int64]uint32 docVecIDMap map[uint32][]int64 + + batchExec *batchExecutor } func (ce *cacheEntry) incHit() { @@ -313,10 +322,14 @@ func (ce *cacheEntry) decRef() { atomic.AddInt64(&ce.refs, -1) } -func (ce *cacheEntry) load() (*faiss.IndexImpl, map[int64]uint32, map[uint32][]int64) { +func (ce *cacheEntry) load() ( + *faiss.IndexImpl, + map[int64]uint32, + map[uint32][]int64, + *batchExecutor) { ce.incHit() ce.addRef() - return ce.index, ce.vecDocIDMap, ce.docVecIDMap + return ce.index, ce.vecDocIDMap, ce.docVecIDMap, ce.batchExec } func (ce *cacheEntry) close() { @@ -325,6 +338,7 @@ func (ce *cacheEntry) close() { ce.index = nil ce.vecDocIDMap = nil ce.docVecIDMap = nil + ce.batchExec.close() }() } diff --git a/faiss_vector_posting.go b/faiss_vector_posting.go index 2a77199..7ab46b1 100644 --- a/faiss_vector_posting.go +++ b/faiss_vector_posting.go @@ -308,7 +308,7 @@ func (i *vectorIndexWrapper) Size() uint64 { // (3) close attached vector index // (4) get the size of the attached vector index func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool, - except *roaring.Bitmap) ( + except *roaring.Bitmap, options *segment.InterpretVectorIndexOptions) ( segment.VectorIndex, error) { // Params needed for the closures var vecIndex *faiss.IndexImpl @@ -317,10 +317,10 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool var vectorIDsToExclude []int64 var fieldIDPlus1 uint16 var vecIndexSize uint64 + var batchExec *batchExecutor // Utility function to add the corresponding docID and scores for each vector - // returned after the kNN query to the newly - // created vecPostingsList + // returned after the kNN query to a VecPostingsList addIDsToPostingsList := func(pl *VecPostingsList, ids []int64, scores []float32) { for i := 0; i < len(ids); i++ { vecID := ids[i] @@ -344,24 +344,34 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool // 4. VecPostings would just have the docNum and the score. Every call of Next() // and Advance just returns the next VecPostings. The caller would do a vp.Number() // and the Score() to get the corresponding values - rv := &VecPostingsList{ - except: nil, // todo: handle the except bitmap within postings iterator. - postings: roaring64.New(), - } if vecIndex == nil || vecIndex.D() != len(qVector) { // vector index not found or dimensionality mismatched - return rv, nil + return &VecPostingsList{postings: roaring64.New()}, nil + } + + if options != nil && options.Batch && batchExec != nil { + // Queue request for batch processing + resultCh := batchExec.queueRequest(qVector, k, params, vecIndex, + vecDocIDMap, vectorIDsToExclude) + + // Wait for batch processing result + rv := <-resultCh + return rv.pl, rv.err } + // Fall back to individual search if batching is not enabled or params are present scores, ids, err := vecIndex.SearchWithoutIDs(qVector, k, vectorIDsToExclude, params) if err != nil { return nil, err } + rv := &VecPostingsList{ + except: nil, // todo: handle the except bitmap within postings iterator. + postings: roaring64.New(), + } addIDsToPostingsList(rv, ids, scores) - return rv, nil }, searchWithFilter: func(qVector []float32, k int64, @@ -385,6 +395,9 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool if len(eligibleDocIDs) == 0 { return rv, nil } + + // TODO: Support batching for searchWithFilter + // If every element in the index is eligible (full selectivity), // then this can basically be considered unfiltered kNN. if len(eligibleDocIDs) == int(sb.numDocs) { @@ -547,9 +560,9 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool pos += n } - vecIndex, vecDocIDMap, docVecIDMap, vectorIDsToExclude, err = + vecIndex, vecDocIDMap, docVecIDMap, vectorIDsToExclude, batchExec, err = sb.vecIndexCache.loadOrCreate(fieldIDPlus1, sb.mem[pos:], requiresFiltering, - except) + except, options) if vecIndex != nil { vecIndexSize = vecIndex.Size() diff --git a/faiss_vector_test.go b/faiss_vector_test.go index ac50a47..ceda3dc 100644 --- a/faiss_vector_test.go +++ b/faiss_vector_test.go @@ -8,6 +8,7 @@ import ( "math" "math/rand" "os" + "sync" "testing" "github.com/RoaringBitmap/roaring/v2" @@ -488,7 +489,7 @@ func TestVectorSegment(t *testing.T) { hitDocIDs := []uint64{2, 9, 9} hitVecs := [][]float32{data[0], data[7][0:3], data[7][3:6]} if vecSeg, ok := segOnDisk.(segment.VectorSegment); ok { - vecIndex, err := vecSeg.InterpretVectorIndex("stubVec", false, nil) + vecIndex, err := vecSeg.InterpretVectorIndex("stubVec", false, nil, nil) if err != nil { t.Fatal(err) } @@ -582,7 +583,7 @@ func TestPersistedVectorSegment(t *testing.T) { hitDocIDs := []uint64{2, 9, 9} hitVecs := [][]float32{data[0], data[7][0:3], data[7][3:6]} if vecSeg, ok := segOnDisk.(segment.VectorSegment); ok { - vecIndex, err := vecSeg.InterpretVectorIndex("stubVec", false, nil) + vecIndex, err := vecSeg.InterpretVectorIndex("stubVec", false, nil, nil) if err != nil { t.Fatal(err) } @@ -746,3 +747,109 @@ func TestValidVectorMerge(t *testing.T) { _ = os.RemoveAll(mergedSegPath1) }() } + +func TestBatchingRequestsToVectorIndex(t *testing.T) { + docs := buildMultiDocDataset(stubVecData, stubVec1Data) + + vecSegPlugin := &ZapPlugin{} + seg, _, err := vecSegPlugin.New(docs) + if err != nil { + t.Fatal(err) + } + + path := "./test-seg" + if unPersistedSeg, ok := seg.(segment.UnpersistedSegment); ok { + err = unPersistedSeg.Persist(path) + if err != nil { + t.Fatal(err) + } + } + + segOnDisk, err := vecSegPlugin.Open(path) + if err != nil { + t.Fatal(err) + } + + defer func() { + cerr := segOnDisk.Close() + if cerr != nil { + t.Fatalf("error closing segment: %v", cerr) + } + _ = os.RemoveAll(path) + }() + + expectedHits := [][]struct { + docID uint64 + score float32 + }{ + { // query: [0.0, 0.0, 0.0] + {docID: 2, score: 14}, + {docID: 9, score: 0.84558594}, + {docID: 9, score: 1.504926}, + }, + { // query: [1.0, 2.0, 3.0] + {docID: 2, score: 0}, + {docID: 8, score: 0.27308902}, + {docID: 9, score: 8.041586}, + }, + { // query: [12.0, 42.6, 78.65] + {docID: 3, score: 0}, + {docID: 4, score: 6557.6226}, + {docID: 8, score: 6848.291}, + }, + } + + if vecSeg, ok := segOnDisk.(segment.VectorSegment); ok { + vecIndex, err := vecSeg.InterpretVectorIndex("stubVec", false, nil, + &segment.InterpretVectorIndexOptions{ + Batch: true, + }) + if err != nil { + t.Fatal(err) + } + + wg := sync.WaitGroup{} + for i, queryVec := range [][]float32{ + {0.0, 0.0, 0.0}, + {1.0, 2.0, 3.0}, + {12.0, 42.6, 78.65}, + } { + wg.Add(1) + go func(i int, queryVec []float32) { + defer wg.Done() + pl, err := vecIndex.Search(queryVec, 3, nil) + if err != nil { + vecIndex.Close() + t.Fatal(err) + } + itr := pl.Iterator(nil) + + hitCounter := 0 + for { + next, err := itr.Next() + if err != nil { + vecIndex.Close() + t.Fatal(err) + } + if next == nil { + break + } + + if next.Number() != expectedHits[i][hitCounter].docID || + next.Score() != expectedHits[i][hitCounter].score { + t.Fatalf("[%d] expected %d %f, got %d %f", + i, expectedHits[i][hitCounter].docID, expectedHits[i][hitCounter].score, + next.Number(), next.Score()) + } + hitCounter++ + } + + if hitCounter != 3 { + t.Fatalf("[%d] expected hitCounter: 3, got: %d", i, hitCounter) + } + }(i, queryVec) + } + wg.Wait() + vecIndex.Close() + } +} diff --git a/go.mod b/go.mod index b7c466a..04ef554 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/blevesearch/bleve_index_api v1.2.8 github.com/blevesearch/go-faiss v1.0.25 github.com/blevesearch/mmap-go v1.0.4 - github.com/blevesearch/scorch_segment_api/v2 v2.3.10 + github.com/blevesearch/scorch_segment_api/v2 v2.3.11-0.20250527202424-37f101287093 github.com/blevesearch/vellum v1.1.0 github.com/golang/snappy v0.0.4 github.com/spf13/cobra v1.7.0 diff --git a/go.sum b/go.sum index cf65326..23edcbf 100644 --- a/go.sum +++ b/go.sum @@ -9,8 +9,8 @@ github.com/blevesearch/go-faiss v1.0.25 h1:lel1rkOUGbT1CJ0YgzKwC7k+XH0XVBHnCVWah github.com/blevesearch/go-faiss v1.0.25/go.mod h1:OMGQwOaRRYxrmeNdMrXJPvVx8gBnvE5RYrr0BahNnkk= github.com/blevesearch/mmap-go v1.0.4 h1:OVhDhT5B/M1HNPpYPBKIEJaD0F3Si+CrEKULGCDPWmc= github.com/blevesearch/mmap-go v1.0.4/go.mod h1:EWmEAOmdAS9z/pi/+Toxu99DnsbhG1TIxUoRmJw/pSs= -github.com/blevesearch/scorch_segment_api/v2 v2.3.10 h1:Yqk0XD1mE0fDZAJXTjawJ8If/85JxnLd8v5vG/jWE/s= -github.com/blevesearch/scorch_segment_api/v2 v2.3.10/go.mod h1:Z3e6ChN3qyN35yaQpl00MfI5s8AxUJbpTR/DL8QOQ+8= +github.com/blevesearch/scorch_segment_api/v2 v2.3.11-0.20250527202424-37f101287093 h1:QHWCknx3jQsu4KRjsFLBWNnyN1kBl5yOCgQ+VqEL7Jc= +github.com/blevesearch/scorch_segment_api/v2 v2.3.11-0.20250527202424-37f101287093/go.mod h1:Z3e6ChN3qyN35yaQpl00MfI5s8AxUJbpTR/DL8QOQ+8= github.com/blevesearch/vellum v1.1.0 h1:CinkGyIsgVlYf8Y2LUQHvdelgXr6PYuvoDIajq6yR9w= github.com/blevesearch/vellum v1.1.0/go.mod h1:QgwWryE8ThtNPxtgWJof5ndPfx0/YMBh+W2weHKPw8Y= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=