Skip to content

Commit

Permalink
SQFP16 Range Validation for Faiss IVF Models
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Mar 18, 2024
1 parent 774cd8c commit e0cb80d
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 33 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Persist model definition in model metadata [#1527] (https://github.com/opensearch-project/k-NN/pull/1527)
* Added Inner Product Space type support for Lucene Engine [#1551](https://github.com/opensearch-project/k-NN/pull/1551)
* Add Range Validation for Faiss SQFP16 [#1493](https://github.com/opensearch-project/k-NN/pull/1493)
* SQFP16 Range Validation for Faiss IVF Models [#1557](https://github.com/opensearch-project/k-NN/pull/1557)
### Bug Fixes
* Disable sdc table for HNSWPQ read-only indices [#1518](https://github.com/opensearch-project/k-NN/pull/1518)
* Switch SpaceType.INNERPRODUCT's vector similarity function to MAXIMUM_INNER_PRODUCT [#1532](https://github.com/opensearch-project/k-NN/pull/1532)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
Expand Down Expand Up @@ -276,6 +277,76 @@ public void testIVFSQFP16_onUpgradeWhenIndexedAndQueried_thenSucceed() throws Ex
}
}

public void testIVFSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16Range_thenSucceed() throws Exception {
if (!isRunningAgainstOldCluster()) {
int dimension = 2;

// Add training data
createBasicKnnIndex(TRAIN_INDEX, TRAIN_TEST_FIELD, dimension);
int trainingDataCount = 200;
bulkIngestRandomVectors(TRAIN_INDEX, TRAIN_TEST_FIELD, trainingDataCount, dimension);

XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, METHOD_IVF)
.field(KNN_ENGINE, FAISS_NAME)
.field(METHOD_PARAMETER_SPACE_TYPE, "l2")
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_NLIST, 1)
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, ENCODER_SQ)
.startObject(PARAMETERS)
.field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)
.field(FAISS_SQ_CLIP, true)
.endObject()
.endObject()
.endObject()
.endObject();
Map<String, Object> method = xContentBuilderToMap(builder);

trainModel(TEST_MODEL, TRAIN_INDEX, TRAIN_TEST_FIELD, dimension, method, "faiss ivf sqfp16 test description");

// Make sure training succeeds after 30 seconds
assertTrainingSucceeds(TEST_MODEL, 30, 1000);

// Create knn index from model
String indexMapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(TEST_FIELD)
.field("type", "knn_vector")
.field(MODEL_ID, TEST_MODEL)
.endObject()
.endObject()
.endObject()
.toString();

createKnnIndex(testIndex, getKNNDefaultIndexSettings(), indexMapping);
Float[] vector1 = { -65523.76f, 65504.2f };
Float[] vector2 = { -270.85f, 65514.2f };
Float[] vector3 = { -150.9f, 65504.0f };
Float[] vector4 = { -20.89f, 100000000.0f };
addKnnDoc(testIndex, "1", TEST_FIELD, vector1);
addKnnDoc(testIndex, "2", TEST_FIELD, vector2);
addKnnDoc(testIndex, "3", TEST_FIELD, vector3);
addKnnDoc(testIndex, "4", TEST_FIELD, vector4);

float[] queryVector = { -10.5f, 25.48f };
int k = 4;
Response searchResponse = searchKNNIndex(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, k), k);
List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), TEST_FIELD);
assertEquals(k, results.size());
for (int i = 0; i < k; i++) {
assertEquals(k - i, Integer.parseInt(results.get(i).getDocId()));
}

deleteKNNIndex(testIndex);
deleteKNNIndex(TRAIN_INDEX);
deleteModel(TEST_MODEL);
validateGraphEviction();
}
}

private void validateGraphEviction() throws Exception {
// Search every 5 seconds 14 times to confirm graph gets evicted
int intervals = 14;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE;
Expand Down Expand Up @@ -530,10 +529,23 @@ protected String contentType() {

@Override
protected void parseCreateField(ParseContext context) throws IOException {
parseCreateField(context, fieldType().getDimension(), fieldType().getSpaceType());
parseCreateField(
context,
fieldType().getDimension(),
fieldType().getSpaceType(),
getMethodComponentContext(fieldType().getKnnMethodContext())
);
}

private MethodComponentContext getMethodComponentContext(KNNMethodContext knnMethodContext) {
if (Objects.isNull(knnMethodContext)) {
return null;
}
return knnMethodContext.getMethodComponentContext();
}

protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType) throws IOException {
protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext)
throws IOException {

validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();
Expand All @@ -551,7 +563,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

if (floatsArrayOptional.isEmpty()) {
return;
Expand All @@ -571,34 +583,30 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
}

// Verify mapping and return true if it is a "faiss" Index using "sq" encoder of type "fp16"
protected boolean isFaissSQfp16(KNNMethodContext knnMethodContext) {

// KNNMethodContext shouldn't be null
if (Objects.isNull(knnMethodContext)) {
protected boolean isFaissSQfp16(MethodComponentContext methodComponentContext) {
if (Objects.isNull(methodComponentContext)) {
return false;
}

// engine should be faiss
if (!FAISS_NAME.equals(knnMethodContext.getKnnEngine().getName())) {
if (methodComponentContext.getParameters().size() == 0) {
return false;
}

// Should have Method Component Parameters
if (knnMethodContext.getMethodComponentContext().getParameters().size() == 0) {
return false;
}
Map<String, Object> methodComponentParams = knnMethodContext.getMethodComponentContext().getParameters();
Map<String, Object> methodComponentParams = methodComponentContext.getParameters();

// The method component parameters should have an encoder
if (!methodComponentParams.containsKey(METHOD_ENCODER_PARAMETER)) {
return false;
}

MethodComponentContext methodComponentContext = (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER);
MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER);

// returns true if encoder name is "sq" and type is "fp16"
return ENCODER_SQ.equals(methodComponentContext.getName())
&& FAISS_SQ_ENCODER_FP16.equals(methodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16));
return ENCODER_SQ.equals(encoderMethodComponentContext.getName())
&& FAISS_SQ_ENCODER_FP16.equals(
encoderMethodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)
);

}

// Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index
Expand Down Expand Up @@ -659,21 +667,19 @@ Optional<byte[]> getBytesFromContext(ParseContext context, int dimension) throws
return Optional.of(array);
}

Optional<float[]> getFloatsFromContext(ParseContext context, int dimension) throws IOException {
Optional<float[]> getFloatsFromContext(ParseContext context, int dimension, MethodComponentContext methodComponentContext)
throws IOException {
context.path().add(simpleName());

// Returns an optional array of float values where each value in the vector is parsed as a float and validated
// if it is a finite number and within the fp16 range of [-65504 to 65504] by default if Faiss encoder is SQ and type is 'fp16'.
// If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be
// clipped to FP16 range.
boolean isFaissSQfp16Flag = isFaissSQfp16(fieldType().getKnnMethodContext());
boolean isFaissSQfp16Flag = isFaissSQfp16(methodComponentContext);
boolean clipVectorValueToFP16RangeFlag = false;
if (isFaissSQfp16Flag) {
clipVectorValueToFP16RangeFlag = isFaissSQClipToFP16RangeEnabled(
(MethodComponentContext) fieldType().getKnnMethodContext()
.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER)
(MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.common.Explicit;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorField;
Expand Down Expand Up @@ -75,7 +76,8 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper {
}

@Override
protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType) throws IOException {
protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext)
throws IOException {

validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();
Expand All @@ -96,7 +98,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
context.doc().add(new VectorField(name(), array, vectorFieldType));
}
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

if (floatsArrayOptional.isEmpty()) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ protected void parseCreateField(ParseContext context) throws IOException {
);
}

parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType());
parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext());
}
}
Loading

0 comments on commit e0cb80d

Please sign in to comment.