Skip to content

Commit

Permalink
Allow build graph greedily for quantization scenarios
Browse files Browse the repository at this point in the history
Previosuly we only added support to build greedily for
non quantization scenario. In this commit, we can remove
that constraint, however, we cannot skip writing quanitization
state since it is required irrespective of type of search
is executed later.

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Oct 2, 2024
1 parent d61e7d4 commit af96fed
Show file tree
Hide file tree
Showing 3 changed files with 329 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
// Will consider building vector data structure based on threshold only for non quantization indices
if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
// Check only after quantization state writer finish writing its state, since it is required
// even if there are no graph files in segment, which will be later used by exact search
if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
log.info(
"Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during flush",
fieldInfo.name,
Expand Down Expand Up @@ -139,8 +140,9 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState
}

final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);
// Will consider building vector data structure based on threshold only for non quantization indices
if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
// Check only after quantization state writer finish writing its state, since it is required
// even if there are no graph files in segment, which will be later used by exact search
if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
log.info(
"Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during merge",
fieldInfo.name,
Expand Down
118 changes: 118 additions & 0 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,124 @@ public void testHNSWSQFP16_whenIndexedAndQueried_thenSucceed() {
validateGraphEviction();
}

@SneakyThrows
public void testHNSWSQFP16_whenGraphThresholdIsNegative_whenIndexed_thenSkipCreatingGraph() {
final String indexName = "test-index-hnsw-sqfp16";
final String fieldName = "test-field-hnsw-sqfp16";
final SpaceType[] spaceTypes = { SpaceType.L2, SpaceType.INNER_PRODUCT };
final Random random = new Random();
final SpaceType spaceType = spaceTypes[random.nextInt(spaceTypes.length)];

final int dimension = 128;
final int numDocs = 100;

// Create an index
final XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(fieldName)
.field("type", "knn_vector")
.field("dimension", dimension)
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue())
.field(KNN_ENGINE, KNNEngine.FAISS.getName())
.startObject(PARAMETERS)
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, ENCODER_SQ)
.startObject(PARAMETERS)
.field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

final Map<String, Object> mappingMap = xContentBuilderToMap(builder);
final String mapping = builder.toString();
final Settings knnIndexSettings = buildKNNIndexSettings(-1);
createKnnIndex(indexName, knnIndexSettings, mapping);
assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName)));
indexTestData(indexName, fieldName, dimension, numDocs);

final float[] queryVector = new float[dimension];
Arrays.fill(queryVector, (float) numDocs);

// Assert we have the right number of documents in the index
assertEquals(numDocs, getDocCount(indexName));
// KNN Query should return empty result
final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1);
final List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName);
assertEquals(0, results.size());

deleteKNNIndex(indexName);
validateGraphEviction();
}

@SneakyThrows
public void testHNSWSQFP16_whenGraphThresholdIsMetDuringMerge_thenCreateGraph() {
final String indexName = "test-index-hnsw-sqfp16";
final String fieldName = "test-field-hnsw-sqfp16";
final SpaceType[] spaceTypes = { SpaceType.L2, SpaceType.INNER_PRODUCT };
final Random random = new Random();
final SpaceType spaceType = spaceTypes[random.nextInt(spaceTypes.length)];
final int dimension = 128;
final int numDocs = 100;

// Create an index
final XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(fieldName)
.field("type", "knn_vector")
.field("dimension", dimension)
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue())
.field(KNN_ENGINE, KNNEngine.FAISS.getName())
.startObject(PARAMETERS)
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, ENCODER_SQ)
.startObject(PARAMETERS)
.field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

final Map<String, Object> mappingMap = xContentBuilderToMap(builder);
final String mapping = builder.toString();
final Settings knnIndexSettings = buildKNNIndexSettings(numDocs);
createKnnIndex(indexName, knnIndexSettings, mapping);
assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName)));
indexTestData(indexName, fieldName, dimension, numDocs);

final float[] queryVector = new float[dimension];
Arrays.fill(queryVector, (float) numDocs);

// Assert we have the right number of documents in the index
assertEquals(numDocs, getDocCount(indexName));

// KNN Query should return empty result
final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1);
final List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName);
assertEquals(0, results.size());

// update index setting to build graph and do force merge
// update build vector data structure setting
forceMergeKnnIndex(indexName, 1);

queryTestData(indexName, fieldName, dimension, numDocs);

deleteKNNIndex(indexName);
validateGraphEviction();
}

@SneakyThrows
public void testIVFSQFP16_whenIndexedAndQueried_thenSucceed() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,6 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr
verify(flatVectorsWriter).flush(5, null);
if (vectorsPerField.size() > 0) {
assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size());
assertTrue((long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0);
}
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
try {
Expand All @@ -618,6 +617,211 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr
}
}

public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThresholdIsNotMet_thenSkipBuildingGraph()
throws IOException {
// Given
List<KNNVectorValues<float[]>> expectedVectorValues = new ArrayList<>();
final Map<Integer, Integer> sizeMap = new HashMap<>();
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues(
new ArrayList<>(vectorsPerField.get(i).values())
);
final KNNVectorValues<float[]> knnVectorValues = KNNVectorValuesFactory.getVectorValues(
VectorDataType.FLOAT,
randomVectorValues
);
sizeMap.put(i, randomVectorValues.size());
expectedVectorValues.add(knnVectorValues);

});
final int maxThreshold = sizeMap.values().stream().filter(count -> count != 0).max(Integer::compareTo).orElse(0);
final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter(
segmentWriteState,
flatVectorsWriter,
maxThreshold + 1 // to avoid building graph using max doc threshold, the same can be achieved by -1 too
);

try (
MockedStatic<NativeEngineFieldVectorsWriter> fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class);
MockedStatic<KNNVectorValuesFactory> knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class);
MockedStatic<QuantizationService> quantizationServiceMockedStatic = mockStatic(QuantizationService.class);
MockedStatic<NativeIndexWriter> nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class);
MockedConstruction<KNN990QuantizationStateWriter> knn990QuantWriterMockedConstruction = mockConstruction(
KNN990QuantizationStateWriter.class
);
) {
quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService);

IntStream.range(0, vectorsPerField.size()).forEach(i -> {
final FieldInfo fieldInfo = fieldInfo(
i,
VectorEncoding.FLOAT32,
Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss")
);

NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i));
fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream))
.thenReturn(field);

try {
nativeEngineWriter.addField(fieldInfo);
} catch (Exception e) {
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
try {
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
.thenReturn(quantizationState);
} catch (Exception e) {
throw new RuntimeException(e);
}

nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState))
.thenReturn(nativeIndexWriter);
});
doAnswer(answer -> {
Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion
return null;
}).when(nativeIndexWriter).flushIndex(any(), anyInt());

// When
nativeEngineWriter.flush(5, null);

// Then
verify(flatVectorsWriter).flush(5, null);
if (vectorsPerField.size() > 0) {
verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState);
} else {
assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size());
}
verifyNoInteractions(nativeIndexWriter);
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
try {
if (vectorsPerField.get(i).isEmpty()) {
verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState);
} else {
verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
});
final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count();
knnVectorValuesFactoryMockedStatic.verify(
() -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()),
times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled))
);
}
}

public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThresholdIsNegative_thenSkipBuildingGraph()
throws IOException {
// Given
List<KNNVectorValues<float[]>> expectedVectorValues = new ArrayList<>();
final Map<Integer, Integer> sizeMap = new HashMap<>();
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues(
new ArrayList<>(vectorsPerField.get(i).values())
);
final KNNVectorValues<float[]> knnVectorValues = KNNVectorValuesFactory.getVectorValues(
VectorDataType.FLOAT,
randomVectorValues
);
sizeMap.put(i, randomVectorValues.size());
expectedVectorValues.add(knnVectorValues);

});
final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter(
segmentWriteState,
flatVectorsWriter,
BUILD_GRAPH_NEVER_THRESHOLD
);

try (
MockedStatic<NativeEngineFieldVectorsWriter> fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class);
MockedStatic<KNNVectorValuesFactory> knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class);
MockedStatic<QuantizationService> quantizationServiceMockedStatic = mockStatic(QuantizationService.class);
MockedStatic<NativeIndexWriter> nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class);
MockedConstruction<KNN990QuantizationStateWriter> knn990QuantWriterMockedConstruction = mockConstruction(
KNN990QuantizationStateWriter.class
);
) {
quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService);

IntStream.range(0, vectorsPerField.size()).forEach(i -> {
final FieldInfo fieldInfo = fieldInfo(
i,
VectorEncoding.FLOAT32,
Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss")
);

NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i));
fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream))
.thenReturn(field);

try {
nativeEngineWriter.addField(fieldInfo);
} catch (Exception e) {
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
try {
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
.thenReturn(quantizationState);
} catch (Exception e) {
throw new RuntimeException(e);
}

nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState))
.thenReturn(nativeIndexWriter);
});
doAnswer(answer -> {
Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion
return null;
}).when(nativeIndexWriter).flushIndex(any(), anyInt());

// When
nativeEngineWriter.flush(5, null);

// Then
verify(flatVectorsWriter).flush(5, null);
if (vectorsPerField.size() > 0) {
verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState);
} else {
assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size());
}
verifyNoInteractions(nativeIndexWriter);
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
try {
if (vectorsPerField.get(i).isEmpty()) {
verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState);
} else {
verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
});
final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count();
knnVectorValuesFactoryMockedStatic.verify(
() -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()),
times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled))
);
}
}

private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map<String, String> attributes) {
FieldInfo fieldInfo = mock(FieldInfo.class);
when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber);
Expand Down

0 comments on commit af96fed

Please sign in to comment.