Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Faiss SQFP16 Range Validation and Clipping #1563

Merged
merged 2 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Validate zero vector when using cosine metric [#1501](https://github.com/opensearch-project/k-NN/pull/1501)
* Persist model definition in model metadata [#1527] (https://github.com/opensearch-project/k-NN/pull/1527)
* Added Inner Product Space type support for Lucene Engine [#1551](https://github.com/opensearch-project/k-NN/pull/1551)
* Add Range Validation for Faiss SQFP16 [#1493](https://github.com/opensearch-project/k-NN/pull/1493)
* SQFP16 Range Validation for Faiss IVF Models [#1557](https://github.com/opensearch-project/k-NN/pull/1557)
### Bug Fixes
* Disable sdc table for HNSWPQ read-only indices [#1518](https://github.com/opensearch-project/k-NN/pull/1518)
* Switch SpaceType.INNERPRODUCT's vector similarity function to MAXIMUM_INNER_PRODUCT [#1532](https://github.com/opensearch-project/k-NN/pull/1532)
Expand Down

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ public class KNNConstants {
public static final String FAISS_SQ_TYPE = "type";
public static final String FAISS_SQ_ENCODER_FP16 = "fp16";
public static final List<String> FAISS_SQ_ENCODER_TYPES = List.of(FAISS_SQ_ENCODER_FP16);
public static final String FAISS_SQ_CLIP = "clip";

// Parameter defaults/limits
public static final Integer ENCODER_PARAMETER_PQ_CODE_COUNT_DEFAULT = 1;
Expand All @@ -111,6 +112,9 @@ public class KNNConstants {
public static final Integer MODEL_CACHE_CAPACITY_ATROPHY_THRESHOLD_IN_MINUTES = 30;
public static final Integer MODEL_CACHE_EXPIRE_AFTER_ACCESS_TIME_MINUTES = 30;

public static final Float FP16_MAX_VALUE = 65504.0f;
public static final Float FP16_MIN_VALUE = -65504.0f;

// Lib names
private static final String JNI_LIBRARY_PREFIX = "opensearchknn_";
public static final String FAISS_JNI_LIBRARY_NAME = JNI_LIBRARY_PREFIX + FAISS_NAME;
Expand Down
25 changes: 25 additions & 0 deletions src/main/java/org/opensearch/knn/index/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,31 @@
*/
public abstract ValidationException validate(Object value);

/**
* Boolean method parameter
*/
public static class BooleanParameter extends Parameter<Boolean> {
public BooleanParameter(String name, Boolean defaultValue, Predicate<Boolean> validator) {
super(name, defaultValue, validator);
}

@Override
public ValidationException validate(Object value) {
ValidationException validationException = null;
if (!(value instanceof Boolean)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("value not of type Boolean for Boolean parameter [%s].", getName()));
return validationException;

Check warning on line 83 in src/main/java/org/opensearch/knn/index/Parameter.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/Parameter.java#L81-L83

Added lines #L81 - L83 were not covered by tests
}

if (!validator.test((Boolean) value)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("parameter validation failed for Boolean parameter [%s].", getName()));

Check warning on line 88 in src/main/java/org/opensearch/knn/index/Parameter.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/Parameter.java#L87-L88

Added lines #L87 - L88 were not covered by tests
}
return validationException;
}
}

/**
* Integer method parameter
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,39 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.KNNVectorIndexFieldData;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorField;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension;
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting;

Expand Down Expand Up @@ -511,10 +529,23 @@

@Override
protected void parseCreateField(ParseContext context) throws IOException {
parseCreateField(context, fieldType().getDimension(), fieldType().getSpaceType());
parseCreateField(
context,
fieldType().getDimension(),
fieldType().getSpaceType(),
getMethodComponentContext(fieldType().getKnnMethodContext())
);
}

protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType) throws IOException {
private MethodComponentContext getMethodComponentContext(KNNMethodContext knnMethodContext) {
if (Objects.isNull(knnMethodContext)) {
return null;
}
return knnMethodContext.getMethodComponentContext();
}

protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext)
throws IOException {

validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();
Expand All @@ -532,7 +563,7 @@
context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

if (floatsArrayOptional.isEmpty()) {
return;
Expand All @@ -551,6 +582,47 @@
context.path().remove();
}

// Verify mapping and return true if it is a "faiss" Index using "sq" encoder of type "fp16"
protected boolean isFaissSQfp16(MethodComponentContext methodComponentContext) {
if (Objects.isNull(methodComponentContext)) {
return false;
}

if (methodComponentContext.getParameters().size() == 0) {
return false;
}

Map<String, Object> methodComponentParams = methodComponentContext.getParameters();

// The method component parameters should have an encoder
if (!methodComponentParams.containsKey(METHOD_ENCODER_PARAMETER)) {
return false;
}

// Validate if the object is of type MethodComponentContext before casting it later
if (!(methodComponentParams.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) {
return false;

Check warning on line 604 in src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java#L604

Added line #L604 was not covered by tests
}

MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER);

// returns true if encoder name is "sq" and type is "fp16"
return ENCODER_SQ.equals(encoderMethodComponentContext.getName())
&& FAISS_SQ_ENCODER_FP16.equals(
encoderMethodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)
);

}

// Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index
// using "sq" encoder of type "fp16".
protected boolean isFaissSQClipToFP16RangeEnabled(MethodComponentContext methodComponentContext) {
if (Objects.nonNull(methodComponentContext)) {
return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false);
}
return false;

Check warning on line 623 in src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java#L623

Added line #L623 was not covered by tests
}

void validateIfCircuitBreakerIsNotTriggered() {
if (KNNSettings.isCircuitBreakerTriggered()) {
throw new IllegalStateException(
Expand Down Expand Up @@ -600,23 +672,53 @@
return Optional.of(array);
}

Optional<float[]> getFloatsFromContext(ParseContext context, int dimension) throws IOException {
Optional<float[]> getFloatsFromContext(ParseContext context, int dimension, MethodComponentContext methodComponentContext)
throws IOException {
context.path().add(simpleName());

// Returns an optional array of float values where each value in the vector is parsed as a float and validated
// if it is a finite number and within the fp16 range of [-65504 to 65504] by default if Faiss encoder is SQ and type is 'fp16'.
// If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be
// clipped to FP16 range.
boolean isFaissSQfp16Flag = isFaissSQfp16(methodComponentContext);
boolean clipVectorValueToFP16RangeFlag = false;
if (isFaissSQfp16Flag) {
clipVectorValueToFP16RangeFlag = isFaissSQClipToFP16RangeEnabled(
(MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER)
);
}

ArrayList<Float> vector = new ArrayList<>();
XContentParser.Token token = context.parser().currentToken();
float value;
if (token == XContentParser.Token.START_ARRAY) {
token = context.parser().nextToken();
while (token != XContentParser.Token.END_ARRAY) {
value = context.parser().floatValue();
validateFloatVectorValue(value);
if (isFaissSQfp16Flag) {
if (clipVectorValueToFP16RangeFlag) {
value = clipVectorValueToFP16Range(value);
} else {
validateFP16VectorValue(value);
}
} else {
validateFloatVectorValue(value);
}

vector.add(value);
token = context.parser().nextToken();
}
} else if (token == XContentParser.Token.VALUE_NUMBER) {
value = context.parser().floatValue();
validateFloatVectorValue(value);
if (isFaissSQfp16Flag) {
if (clipVectorValueToFP16RangeFlag) {
value = clipVectorValueToFP16Range(value);

Check warning on line 715 in src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java#L715

Added line #L715 was not covered by tests
} else {
validateFP16VectorValue(value);

Check warning on line 717 in src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java#L717

Added line #L717 was not covered by tests
}
} else {
validateFloatVectorValue(value);

Check warning on line 720 in src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java#L720

Added line #L720 was not covered by tests
}
vector.add(value);
context.parser().nextToken();
} else if (token == XContentParser.Token.VALUE_NULL) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,55 @@

import java.util.Locale;

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE;
import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class KNNVectorFieldMapperUtil {

/**
* Validate the float vector value and throw exception if it is not a number or not in the finite range
* or is not within the FP16 range of [-65504 to 65504].
*
* @param value float vector value
*/
public static void validateFP16VectorValue(float value) {
validateFloatVectorValue(value);

if (value < FP16_MIN_VALUE || value > FP16_MAX_VALUE) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]",
ENCODER_SQ,
FAISS_SQ_ENCODER_FP16,
FP16_MIN_VALUE,
FP16_MAX_VALUE
)
);
}
}

/**
* Validate the float vector value and if it is outside FP16 range,
* then it will be clipped to FP16 range of [-65504 to 65504].
*
* @param value float vector value
* @return vector value clipped to FP16 range
*/
public static float clipVectorValueToFP16Range(float value) {
validateFloatVectorValue(value);
if (value < FP16_MIN_VALUE) return FP16_MIN_VALUE;
if (value > FP16_MAX_VALUE) return FP16_MAX_VALUE;
return value;
}

/**
* Validates and throws exception if data_type field is set in the index mapping
* using any VectorDataType (other than float, which is default) because other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.common.Explicit;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorField;
Expand Down Expand Up @@ -75,7 +76,8 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper {
}

@Override
protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType) throws IOException {
protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext)
throws IOException {

validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();
Expand All @@ -96,7 +98,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
context.doc().add(new VectorField(name(), array, vectorFieldType));
}
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

if (floatsArrayOptional.isEmpty()) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ protected void parseCreateField(ParseContext context) throws IOException {
);
}

parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType());
parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext());
}
}
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import static org.opensearch.knn.common.KNNConstants.FAISS_HNSW_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_IVF_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_PQ_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_TYPES;
Expand Down Expand Up @@ -90,6 +91,7 @@ class Faiss extends NativeLibrary {
FAISS_SQ_TYPE,
new Parameter.StringParameter(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16, FAISS_SQ_ENCODER_TYPES::contains)
)
.addParameter(FAISS_SQ_CLIP, new Parameter.BooleanParameter(FAISS_SQ_CLIP, false, Objects::nonNull))
.setMapGenerator(
((methodComponent, methodComponentContext) -> MethodAsMapBuilder.builder(
FAISS_SQ_DESCRIPTION,
Expand Down
Loading
Loading