Skip to content

Commit

Permalink
Validate zero vector when using cosine metric (#1501)
Browse files Browse the repository at this point in the history
Ensure zero vector is not used when using functionality with cosine similarity metric.

Signed-off-by: panguixin <[email protected]>
(cherry picked from commit b7bdda4)
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
bugmakerrrrrr authored and jmazanec15 committed Mar 14, 2024
1 parent e3e3aa1 commit 91376ae
Show file tree
Hide file tree
Showing 20 changed files with 472 additions and 154 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
* Optize Faiss Query With Filters: Reduce iteration and memory for id filter [#1402](https://github.com/opensearch-project/k-NN/pull/1402)
* Detect AVX2 Dynamically on the System [#1502](https://github.com/opensearch-project/k-NN/pull/1502)
* Validate zero vector when using cosine metric [#1501](https://github.com/opensearch-project/k-NN/pull/1501)
### Bug Fixes
* Disable sdc table for HNSWPQ read-only indices [#1518](https://github.com/opensearch-project/k-NN/pull/1518)
* Switch SpaceType.INNERPRODUCT's vector similarity function to MAXIMUM_INNER_PRODUCT [#1532](https://github.com/opensearch-project/k-NN/pull/1532)
Expand Down
83 changes: 83 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNValidationUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.common;

import java.util.Locale;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.opensearch.knn.index.VectorDataType;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class KNNValidationUtil {
/**
* Validate the float vector value and throw exception if it is not a number or not in the finite range.
*
* @param value float vector value
*/
public static void validateFloatVectorValue(float value) {
if (Float.isNaN(value)) {
throw new IllegalArgumentException("KNN vector values cannot be NaN");
}

if (Float.isInfinite(value)) {
throw new IllegalArgumentException("KNN vector values cannot be infinity");
}
}

/**
* Validate the float vector value in the byte range if it is a finite number,
* with no decimal values and in the byte range of [-128 to 127]. If not throw IllegalArgumentException.
*
* @param value float value in byte range
*/
public static void validateByteVectorValue(float value) {
validateFloatVectorValue(value);
if (value % 1 != 0) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue()
)

);
}
if ((int) value < Byte.MIN_VALUE || (int) value > Byte.MAX_VALUE) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue(),
Byte.MIN_VALUE,
Byte.MAX_VALUE
)
);
}
}

/**
* Validate if the given vector size matches with the dimension provided in mapping.
*
* @param dimension dimension of vector
* @param vectorSize size of the vector
*/
public static void validateVectorDimension(int dimension, int vectorSize) {
if (dimension != vectorSize) {
String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize);
throw new IllegalArgumentException(errorMessage);
}
}
}
45 changes: 45 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNVectorUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.common;

import java.util.Objects;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class KNNVectorUtil {
/**
* Check if all the elements of a given vector are zero
*
* @param vector the vector
* @return true if yes; otherwise false
*/
public static boolean isZeroVector(byte[] vector) {
Objects.requireNonNull(vector, "vector must not be null");
for (byte e : vector) {
if (e != 0) {
return false;
}
}
return true;
}

/**
* Check if all the elements of a given vector are zero
*
* @param vector the vector
* @return true if yes; otherwise false
*/
public static boolean isZeroVector(float[] vector) {
Objects.requireNonNull(vector, "vector must not be null");
for (float e : vector) {
if (e != 0f) {
return false;
}
}
return true;
}
}
39 changes: 39 additions & 0 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@

package org.opensearch.knn.index;

import java.util.Locale;
import org.apache.lucene.index.VectorSimilarityFunction;

import java.util.HashSet;
import java.util.Set;

import static org.opensearch.knn.common.KNNVectorUtil.isZeroVector;

/**
* Enum contains spaces supported for approximate nearest neighbor search in the k-NN plugin. Each engine's methods are
* expected to support a subset of these spaces. Validation should be done in the jni layer and an exception should be
Expand Down Expand Up @@ -44,6 +47,24 @@ public float scoreTranslation(float rawScore) {
public VectorSimilarityFunction getVectorSimilarityFunction() {
return VectorSimilarityFunction.COSINE;
}

@Override
public void validateVector(byte[] vector) {
if (isZeroVector(vector)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", getValue())
);
}
}

Check warning on line 58 in src/main/java/org/opensearch/knn/index/SpaceType.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/SpaceType.java#L58

Added line #L58 was not covered by tests

@Override
public void validateVector(float[] vector) {
if (isZeroVector(vector)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", getValue())
);
}
}
},
L1("l1") {
@Override
Expand Down Expand Up @@ -105,6 +126,24 @@ public VectorSimilarityFunction getVectorSimilarityFunction() {
throw new UnsupportedOperationException(String.format("Space [%s] does not have a vector similarity function", getValue()));
}

/**
* Validate if the given byte vector is supported by this space type
*
* @param vector the given vector
*/
public void validateVector(byte[] vector) {
// do nothing
}

/**
* Validate if the given float vector is supported by this space type
*
* @param vector the given vector
*/
public void validateVector(float[] vector) {
// do nothing
}

/**
* Get space type name in engine
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,29 @@

package org.opensearch.knn.index.mapper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.Version;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;

import org.apache.lucene.document.FieldType;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.common.Explicit;
import org.opensearch.common.Nullable;
import org.opensearch.common.ValidationException;
import org.opensearch.common.xcontent.support.XContentMapValues;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.common.xcontent.support.XContentMapValues;
import org.opensearch.index.fielddata.IndexFieldData;
import org.opensearch.index.mapper.FieldMapper;
import org.opensearch.index.mapper.MappedFieldType;
Expand All @@ -32,35 +39,28 @@
import org.opensearch.index.mapper.ValueFetcher;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.KNNVectorIndexFieldData;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorField;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension;
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDimension;

/**
* Field Mapper for KNN vector type.
Expand Down Expand Up @@ -313,7 +313,13 @@ public KNNVectorFieldMapper build(BuilderContext context) {

return new LegacyFieldMapper(
name,
new KNNVectorFieldType(buildFullName(context), metaValue, dimension.getValue(), vectorDataType.getValue()),
new KNNVectorFieldType(
buildFullName(context),
metaValue,
dimension.getValue(),
vectorDataType.getValue(),
SpaceType.getSpace(spaceType)
),
multiFieldsBuilder,
copyToBuilder,
ignoreMalformed,
Expand Down Expand Up @@ -384,17 +390,24 @@ public static class KNNVectorFieldType extends MappedFieldType {
String modelId;
KNNMethodContext knnMethodContext;
VectorDataType vectorDataType;
SpaceType spaceType;

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, VectorDataType vectorDataType) {
this(name, meta, dimension, null, null, vectorDataType);
public KNNVectorFieldType(
String name,
Map<String, String> meta,
int dimension,
VectorDataType vectorDataType,
SpaceType spaceType
) {
this(name, meta, dimension, null, null, vectorDataType, spaceType);
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext) {
this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD);
this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD, knnMethodContext.getSpaceType());
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext, String modelId) {
this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD);
this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD, null);
}

public KNNVectorFieldType(
Expand All @@ -404,22 +417,24 @@ public KNNVectorFieldType(
KNNMethodContext knnMethodContext,
VectorDataType vectorDataType
) {
this(name, meta, dimension, knnMethodContext, null, vectorDataType);
this(name, meta, dimension, knnMethodContext, null, vectorDataType, knnMethodContext.getSpaceType());
}

public KNNVectorFieldType(
String name,
Map<String, String> meta,
int dimension,
KNNMethodContext knnMethodContext,
String modelId,
VectorDataType vectorDataType
@Nullable KNNMethodContext knnMethodContext,
@Nullable String modelId,
VectorDataType vectorDataType,
@Nullable SpaceType spaceType
) {
super(name, false, false, true, TextSearchInfo.NONE, meta);
this.dimension = dimension;
this.modelId = modelId;
this.knnMethodContext = knnMethodContext;
this.vectorDataType = vectorDataType;
this.spaceType = spaceType;
}

@Override
Expand Down Expand Up @@ -496,34 +511,35 @@ protected String contentType() {

@Override
protected void parseCreateField(ParseContext context) throws IOException {
parseCreateField(context, fieldType().getDimension());
parseCreateField(context, fieldType().getDimension(), fieldType().getSpaceType());
}

protected void parseCreateField(ParseContext context, int dimension) throws IOException {
protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType) throws IOException {

validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();

if (VectorDataType.BYTE == vectorDataType) {
Optional<byte[]> bytesArrayOptional = getBytesFromContext(context, dimension);

if (!bytesArrayOptional.isPresent()) {
if (bytesArrayOptional.isEmpty()) {
return;
}
final byte[] array = bytesArrayOptional.get();
spaceType.validateVector(array);
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension);

if (!floatsArrayOptional.isPresent()) {
if (floatsArrayOptional.isEmpty()) {
return;
}
final float[] array = floatsArrayOptional.get();
spaceType.validateVector(array);
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
} else {
Expand Down
Loading

0 comments on commit 91376ae

Please sign in to comment.