Skip to content

Commit abfd88e

Browse files
committed
Toy: Limited Training Size
1 parent 1c5b688 commit abfd88e

File tree

1 file changed

+136
-62
lines changed

1 file changed

+136
-62
lines changed

section_faiss_vector_index.go

Lines changed: 136 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,10 @@ func calculateNprobe(nlist int, indexOptimizedFor string) int32 {
276276
// perhaps, parallelized merging can help speed things up over here.
277277
func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase,
278278
indexes []*vecIndexMeta, w *CountHashWriter, closeCh chan struct{}) error {
279-
280279
vecIndexes := make([]*faiss.IndexImpl, 0, len(sbs))
281280
var finalVecIDCap, indexDataCap, reconsCap int
281+
nvecs := 0
282+
dims := 0
282283
for segI, segBase := range sbs {
283284
// Considering merge operations on vector indexes are expensive, it is
284285
// worth including an early exit if the merge is aborted, saving us
@@ -296,6 +297,8 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase,
296297
}
297298
if len(indexes[segI].vecIds) > 0 {
298299
indexReconsLen := len(indexes[segI].vecIds) * index.D()
300+
dims = index.D()
301+
nvecs += len(indexes[segI].vecIds)
299302
if indexReconsLen > reconsCap {
300303
reconsCap = indexReconsLen
301304
}
@@ -309,51 +312,6 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase,
309312
if len(vecIndexes) == 0 {
310313
return nil
311314
}
312-
313-
finalVecIDs := make([]int64, 0, finalVecIDCap)
314-
// merging of indexes with reconstruction method.
315-
// the indexes[i].vecIds has only the valid vecs of this vector
316-
// index present in it, so we'd be reconstructing only those.
317-
indexData := make([]float32, 0, indexDataCap)
318-
// reusable buffer for reconstruction
319-
recons := make([]float32, 0, reconsCap)
320-
var err error
321-
for i := 0; i < len(vecIndexes); i++ {
322-
if isClosed(closeCh) {
323-
freeReconstructedIndexes(vecIndexes)
324-
return seg.ErrClosed
325-
}
326-
327-
// reconstruct the vectors only if present, it could be that
328-
// some of the indexes had all of their vectors updated/deleted.
329-
if len(indexes[i].vecIds) > 0 {
330-
neededReconsLen := len(indexes[i].vecIds) * vecIndexes[i].D()
331-
recons = recons[:neededReconsLen]
332-
// todo: parallelize reconstruction
333-
recons, err = vecIndexes[i].ReconstructBatch(indexes[i].vecIds, recons)
334-
if err != nil {
335-
freeReconstructedIndexes(vecIndexes)
336-
return err
337-
}
338-
indexData = append(indexData, recons...)
339-
// Adding vector IDs in the same order as the vectors
340-
finalVecIDs = append(finalVecIDs, indexes[i].vecIds...)
341-
}
342-
}
343-
344-
if len(indexData) == 0 {
345-
// no valid vectors for this index, so we don't even have to
346-
// record it in the section
347-
freeReconstructedIndexes(vecIndexes)
348-
return nil
349-
}
350-
recons = nil
351-
352-
nvecs := len(finalVecIDs)
353-
354-
// safe to assume that all the indexes are of the same config values, given
355-
// that they are extracted from the field mapping info.
356-
dims := vecIndexes[0].D()
357315
metric := vecIndexes[0].MetricType()
358316
indexOptimizedFor := indexes[0].indexOptimizedFor
359317

@@ -366,8 +324,6 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase,
366324
// to do the same is not needed because the following operations don't need
367325
// the reconstructed ones anymore and doing so will hold up memory which can
368326
// be detrimental while creating indexes during introduction.
369-
freeReconstructedIndexes(vecIndexes)
370-
vecIndexes = nil
371327

372328
faissIndex, err := faiss.IndexFactory(dims, indexDescription, metric)
373329
if err != nil {
@@ -376,34 +332,152 @@ func (v *vectorIndexOpaque) mergeAndWriteVectorIndexes(sbs []*SegmentBase,
376332
defer faissIndex.Close()
377333

378334
if indexClass == IndexTypeIVF {
379-
// the direct map maintained in the IVF index is essential for the
380-
// reconstruction of vectors based on vector IDs in the future merges.
381-
// the AddWithIDs API also needs a direct map to be set before using.
382335
err = faissIndex.SetDirectMap(2)
383336
if err != nil {
384337
return err
385338
}
386339

387340
nprobe := calculateNprobe(nlist, indexOptimizedFor)
388341
faissIndex.SetNProbe(nprobe)
342+
}
343+
344+
if nvecs < 100000 {
345+
finalVecIDs := make([]int64, 0, finalVecIDCap)
346+
// merging of indexes with reconstruction method.
347+
// the indexes[i].vecIds has only the valid vecs of this vector
348+
// index present in it, so we'd be reconstructing only those.
349+
indexData := make([]float32, 0, indexDataCap)
350+
// reusable buffer for reconstruction
351+
recons := make([]float32, 0, reconsCap)
352+
var err error
353+
for i := 0; i < len(vecIndexes); i++ {
354+
if isClosed(closeCh) {
355+
freeReconstructedIndexes(vecIndexes)
356+
return seg.ErrClosed
357+
}
358+
359+
// reconstruct the vectors only if present, it could be that
360+
// some of the indexes had all of their vectors updated/deleted.
361+
if len(indexes[i].vecIds) > 0 {
362+
neededReconsLen := len(indexes[i].vecIds) * vecIndexes[i].D()
363+
recons = recons[:neededReconsLen]
364+
// todo: parallelize reconstruction
365+
recons, err = vecIndexes[i].ReconstructBatch(indexes[i].vecIds, recons)
366+
if err != nil {
367+
freeReconstructedIndexes(vecIndexes)
368+
return err
369+
}
370+
indexData = append(indexData, recons...)
371+
// Adding vector IDs in the same order as the vectors
372+
finalVecIDs = append(finalVecIDs, indexes[i].vecIds...)
373+
}
374+
}
375+
376+
if len(indexData) == 0 {
377+
// no valid vectors for this index, so we don't even have to
378+
// record it in the section
379+
freeReconstructedIndexes(vecIndexes)
380+
return nil
381+
}
389382

390-
// train the vector index, essentially performs k-means clustering to partition
391-
// the data space of indexData such that during the search time, we probe
392-
// only a subset of vectors -> non-exhaustive search. could be a time
393-
// consuming step when the indexData is large.
394-
err = faissIndex.Train(indexData)
383+
recons = nil
384+
freeReconstructedIndexes(vecIndexes)
385+
vecIndexes = nil
386+
387+
if indexClass == IndexTypeIVF {
388+
err = faissIndex.Train(indexData)
389+
if err != nil {
390+
return err
391+
}
392+
}
393+
err = faissIndex.AddWithIDs(indexData, finalVecIDs)
395394
if err != nil {
396395
return err
397396
}
398-
}
399397

400-
err = faissIndex.AddWithIDs(indexData, finalVecIDs)
401-
if err != nil {
402-
return err
398+
indexData = nil
399+
finalVecIDs = nil
400+
401+
} else {
402+
recons := make([]float32, 0, reconsCap)
403+
curVecs := 0
404+
vecLimit := 100000
405+
if vecLimit < nlist*40 {
406+
vecLimit = nlist * 40
407+
}
408+
finalVecIDs := make([]int64, 0, vecLimit)
409+
indexData := make([]float32, 0, vecLimit*dims)
410+
trained := false
411+
412+
var err error
413+
414+
for i := 0; i < len(vecIndexes); i++ {
415+
if isClosed(closeCh) {
416+
freeReconstructedIndexes(vecIndexes)
417+
return seg.ErrClosed
418+
}
419+
420+
if len(indexes[i].vecIds) > 0 {
421+
neededReconsLen := len(indexes[i].vecIds) * vecIndexes[i].D()
422+
recons = recons[:neededReconsLen]
423+
// todo: parallelize reconstruction
424+
recons, err = vecIndexes[i].ReconstructBatch(indexes[i].vecIds, recons)
425+
if err != nil {
426+
freeReconstructedIndexes(vecIndexes)
427+
return err
428+
}
429+
vecLen := len(indexes[i].vecIds)
430+
shift := 0
431+
dims := vecIndexes[i].D()
432+
for curVecs+vecLen > vecLimit {
433+
indexData = append(indexData, recons[shift*dims:(vecLimit-curVecs)*dims]...)
434+
finalVecIDs = append(finalVecIDs, indexes[i].vecIds[shift:(vecLimit-curVecs)]...)
435+
if !trained {
436+
err = faissIndex.Train(indexData)
437+
if err != nil {
438+
freeReconstructedIndexes(vecIndexes)
439+
return err
440+
}
441+
trained = true
442+
}
443+
err = faissIndex.AddWithIDs(indexData, finalVecIDs)
444+
if err != nil {
445+
freeReconstructedIndexes(vecIndexes)
446+
return err
447+
}
448+
indexData = indexData[:0]
449+
finalVecIDs = finalVecIDs[:0]
450+
shift += vecLimit - curVecs
451+
vecLen -= vecLimit - curVecs
452+
curVecs = 0
453+
}
454+
if vecLen != 0 {
455+
indexData = append(indexData, recons[shift*dims:(shift+vecLen)*dims]...)
456+
finalVecIDs = append(finalVecIDs, indexes[i].vecIds[shift:shift+vecLen]...)
457+
curVecs = len(finalVecIDs)
458+
}
459+
}
460+
}
461+
462+
recons = nil
463+
freeReconstructedIndexes(vecIndexes)
464+
vecIndexes = nil
465+
if curVecs > 0 {
466+
if !trained {
467+
err = faissIndex.Train(indexData)
468+
if err != nil {
469+
return err
470+
}
471+
}
472+
err = faissIndex.AddWithIDs(indexData, finalVecIDs)
473+
if err != nil {
474+
return err
475+
}
476+
}
477+
indexData = nil
478+
finalVecIDs = nil
403479
}
404480

405-
indexData = nil
406-
finalVecIDs = nil
407481
var mergedIndexBytes []byte
408482
mergedIndexBytes, err = faiss.WriteIndexIntoBuffer(faissIndex)
409483
if err != nil {

0 commit comments

Comments
 (0)