Skip to content

Commit

Permalink
apply review comments
Browse files Browse the repository at this point in the history
Signed-off-by: panguixin <[email protected]>
  • Loading branch information
bugmakerrrrrr committed Mar 11, 2024
1 parent 9b86b35 commit 7bb5450
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 108 deletions.
113 changes: 113 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNValidationUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.common;

import java.util.Locale;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNVectorUtil.isZeroVector;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class KNNValidationUtil {
/**
* Validate the float vector value and throw exception if it is not a number or not in the finite range.
*
* @param value float vector value
*/
public static void validateFloatVectorValue(float value) {
if (Float.isNaN(value)) {
throw new IllegalArgumentException("KNN vector values cannot be NaN");
}

if (Float.isInfinite(value)) {
throw new IllegalArgumentException("KNN vector values cannot be infinity");
}
}

/**
* Validate the float vector value in the byte range if it is a finite number,
* with no decimal values and in the byte range of [-128 to 127]. If not throw IllegalArgumentException.
*
* @param value float value in byte range
*/
public static void validateByteVectorValue(float value) {
validateFloatVectorValue(value);
if (value % 1 != 0) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue()
)

);
}
if ((int) value < Byte.MIN_VALUE || (int) value > Byte.MAX_VALUE) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue(),
Byte.MIN_VALUE,
Byte.MAX_VALUE
)
);
}
}

/**
* Validate if the given byte vector is supported by the given space type
*
* @param vector the given vector
* @param spaceType the given space type
*/
public static void validateByteVector(byte[] vector, SpaceType spaceType) {
if (spaceType == SpaceType.COSINESIMIL && isZeroVector(vector)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", spaceType.getValue())
);
}
}

/**
* Validate if the given float vector is supported by the given space type
*
* @param vector the given vector
* @param spaceType the given space type
*/
public static void validateFloatVector(float[] vector, SpaceType spaceType) {
if (spaceType == SpaceType.COSINESIMIL && isZeroVector(vector)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", spaceType.getValue())
);
}
}

/**
* Validate if the given vector size matches with the dimension provided in mapping.
*
* @param dimension dimension of vector
* @param vectorSize size of the vector
*/
public static void validateVectorDimension(int dimension, int vectorSize) {
if (dimension != vectorSize) {
String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize);
throw new IllegalArgumentException(errorMessage);
}
}
}
9 changes: 7 additions & 2 deletions src/main/java/org/opensearch/knn/common/KNNVectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@

package org.opensearch.knn.common;

public class KNNVectorUtil {
private KNNVectorUtil() {}
import java.util.Objects;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class KNNVectorUtil {
/**
* Check if all the elements of a given vector are zero
*
* @param vector the vector
* @return true if yes; otherwise false
*/
public static boolean isZeroVector(byte[] vector) {
Objects.requireNonNull(vector, "vector must not be null");
for (byte e : vector) {
if (e != 0) {
return false;
Expand All @@ -30,6 +34,7 @@ public static boolean isZeroVector(byte[] vector) {
* @return true if yes; otherwise false
*/
public static boolean isZeroVector(float[] vector) {
Objects.requireNonNull(vector, "vector must not be null");
for (float e : vector) {
if (e != 0f) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDimension;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVector;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension;

/**
* Field Mapper for KNN vector type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@

package org.opensearch.knn.index.mapper;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.DocValuesType;
import org.opensearch.index.mapper.ParametrizedFieldMapper;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

Expand All @@ -25,98 +26,9 @@
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNVectorUtil.isZeroVector;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class KNNVectorFieldMapperUtil {
/**
* Validate the float vector value and throw exception if it is not a number or not in the finite range.
*
* @param value float vector value
*/
public static void validateFloatVectorValue(float value) {
if (Float.isNaN(value)) {
throw new IllegalArgumentException("KNN vector values cannot be NaN");
}

if (Float.isInfinite(value)) {
throw new IllegalArgumentException("KNN vector values cannot be infinity");
}
}

/**
* Validate the float vector value in the byte range if it is a finite number,
* with no decimal values and in the byte range of [-128 to 127]. If not throw IllegalArgumentException.
*
* @param value float value in byte range
*/
public static void validateByteVectorValue(float value) {
validateFloatVectorValue(value);
if (value % 1 != 0) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue()
)

);
}
if ((int) value < Byte.MIN_VALUE || (int) value > Byte.MAX_VALUE) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue(),
Byte.MIN_VALUE,
Byte.MAX_VALUE
)
);
}
}

/**
* Validate if the given byte vector is supported by the given space type
*
* @param vector the given vector
* @param spaceType the given space type
*/
public static void validateByteVector(byte[] vector, SpaceType spaceType) {
if (spaceType == SpaceType.COSINESIMIL && isZeroVector(vector)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", spaceType.getValue())
);
}
}

/**
* Validate if the given float vector is supported by the given space type
*
* @param vector the given vector
* @param spaceType the given space type
*/
public static void validateFloatVector(float[] vector, SpaceType spaceType) {
if (spaceType == SpaceType.COSINESIMIL && isZeroVector(vector)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", spaceType.getValue())
);
}
}

/**
* Validate if the given vector size matches with the dimension provided in mapping.
*
* @param dimension dimension of vector
* @param vectorSize size of the vector
*/
public static void validateVectorDimension(int dimension, int vectorSize) {
if (dimension != vectorSize) {
String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize);
throw new IllegalArgumentException(errorMessage);
}
}

/**
* Validates and throws exception if data_type field is set in the index mapping
* using any VectorDataType (other than float, which is default) because other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVector;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVector;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector;

/**
* Field mapper for case when Lucene has been set as an engine.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
import java.util.Objects;

import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVector;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVector;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector;

/**
* Helper class to build the KNN query
Expand Down Expand Up @@ -290,8 +290,10 @@ protected Query doToQuery(QueryShardContext context) {
SpaceType spaceType = knnVectorFieldType.getSpaceType();

if (fieldDimension == -1) {
if (spaceType != null) {
throw new IllegalStateException("Space type should be null when the field uses a model");

Check warning on line 294 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L294

Added line #L294 was not covered by tests
}
// If dimension is not set, the field uses a model and the information needs to be retrieved from there
assert spaceType == null;
ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType);
fieldDimension = modelMetadata.getDimension();
knnEngine = modelMetadata.getKnnEngine();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import java.util.Map;
import java.util.function.BiFunction;

import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVector;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.getVectorMagnitudeSquared;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryFieldType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isKNNVectorFieldType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import java.util.Base64;

import static org.opensearch.index.mapper.NumberFieldMapper.NumberType.LONG;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;

public class KNNScoringSpaceUtil {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import java.util.List;
import java.util.Objects;

import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVector;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector;

public class KNNScoringUtil {
private static Logger logger = LogManager.getLogger(KNNScoringUtil.class);
Expand Down

0 comments on commit 7bb5450

Please sign in to comment.