diff --git a/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java b/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java index c9a115ce1..ca8e1459a 100644 --- a/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java +++ b/src/main/java/org/opensearch/knn/common/KNNValidationUtil.java @@ -14,11 +14,9 @@ import java.util.Locale; import lombok.AccessLevel; import lombok.NoArgsConstructor; -import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; -import static org.opensearch.knn.common.KNNVectorUtil.isZeroVector; @NoArgsConstructor(access = AccessLevel.PRIVATE) public class KNNValidationUtil { @@ -70,34 +68,6 @@ public static void validateByteVectorValue(float value) { } } - /** - * Validate if the given byte vector is supported by the given space type - * - * @param vector the given vector - * @param spaceType the given space type - */ - public static void validateByteVector(byte[] vector, SpaceType spaceType) { - if (spaceType == SpaceType.COSINESIMIL && isZeroVector(vector)) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", spaceType.getValue()) - ); - } - } - - /** - * Validate if the given float vector is supported by the given space type - * - * @param vector the given vector - * @param spaceType the given space type - */ - public static void validateFloatVector(float[] vector, SpaceType spaceType) { - if (spaceType == SpaceType.COSINESIMIL && isZeroVector(vector)) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", spaceType.getValue()) - ); - } - } - /** * Validate if the given vector size matches with the dimension provided in mapping. * diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index efa0f1be3..d3e85a642 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -11,11 +11,14 @@ package org.opensearch.knn.index; +import java.util.Locale; import org.apache.lucene.index.VectorSimilarityFunction; import java.util.HashSet; import java.util.Set; +import static org.opensearch.knn.common.KNNVectorUtil.isZeroVector; + /** * Enum contains spaces supported for approximate nearest neighbor search in the k-NN plugin. Each engine's methods are * expected to support a subset of these spaces. Validation should be done in the jni layer and an exception should be @@ -44,6 +47,24 @@ public float scoreTranslation(float rawScore) { public VectorSimilarityFunction getVectorSimilarityFunction() { return VectorSimilarityFunction.COSINE; } + + @Override + public void validateVector(byte[] vector) { + if (isZeroVector(vector)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", getValue()) + ); + } + } + + @Override + public void validateVector(float[] vector) { + if (isZeroVector(vector)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", getValue()) + ); + } + } }, L1("l1") { @Override @@ -105,6 +126,24 @@ public VectorSimilarityFunction getVectorSimilarityFunction() { throw new UnsupportedOperationException(String.format("Space [%s] does not have a vector similarity function", getValue())); } + /** + * Validate if the given byte vector is supported by this space type + * + * @param vector the given vector + */ + public void validateVector(byte[] vector) { + // do nothing + } + + /** + * Validate if the given float vector is supported by this space type + * + * @param vector the given vector + */ + public void validateVector(float[] vector) { + // do nothing + } + /** * Get space type name in engine * diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index b7dba6b09..2369a6937 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -5,23 +5,29 @@ package org.opensearch.knn.index.mapper; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; import lombok.Getter; import lombok.extern.log4j.Log4j2; -import org.opensearch.Version; -import org.opensearch.common.Nullable; -import org.opensearch.common.ValidationException; -import org.opensearch.knn.common.KNNConstants; - import org.apache.lucene.document.FieldType; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.Query; +import org.opensearch.Version; import org.opensearch.common.Explicit; +import org.opensearch.common.Nullable; +import org.opensearch.common.ValidationException; +import org.opensearch.common.xcontent.support.XContentMapValues; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.common.xcontent.support.XContentMapValues; import org.opensearch.index.fielddata.IndexFieldData; import org.opensearch.index.mapper.FieldMapper; import org.opensearch.index.mapper.MappedFieldType; @@ -33,6 +39,7 @@ import org.opensearch.index.mapper.ValueFetcher; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.KNNVectorIndexFieldData; @@ -44,27 +51,16 @@ import org.opensearch.search.aggregations.support.CoreValuesSourceType; import org.opensearch.search.lookup.SearchLookup; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Optional; -import java.util.function.Supplier; - import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; -import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; -import static org.opensearch.knn.common.KNNValidationUtil.validateByteVector; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; -import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector; import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension; +import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting; /** * Field Mapper for KNN vector type. @@ -285,7 +281,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { return new ModelFieldMapper( name, - new KNNVectorFieldType(buildFullName(context), metaValue, -1, modelIdAsString), + new KNNVectorFieldType(buildFullName(context), metaValue, -1, knnMethodContext, modelIdAsString), multiFieldsBuilder, copyToBuilder, ignoreMalformed, @@ -410,8 +406,8 @@ public KNNVectorFieldType(String name, Map meta, int dimension, this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD, knnMethodContext.getSpaceType()); } - public KNNVectorFieldType(String name, Map meta, int dimension, String modelId) { - this(name, meta, dimension, null, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD, null); + public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext, String modelId) { + this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD, null); } public KNNVectorFieldType( @@ -530,7 +526,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s return; } final byte[] array = bytesArrayOptional.get(); - validateByteVector(array, spaceType); + spaceType.validateVector(array); VectorField point = new VectorField(name(), array, fieldType); context.doc().add(point); @@ -542,7 +538,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s return; } final float[] array = floatsArrayOptional.get(); - validateFloatVector(array, spaceType); + spaceType.validateVector(array); VectorField point = new VectorField(name(), array, fieldType); context.doc().add(point); addStoredFieldForVectorField(context, fieldType, name(), point.toString()); diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 07101feea..81c7216bf 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -5,6 +5,9 @@ package org.opensearch.knn.index.mapper; +import java.io.IOException; +import java.util.Locale; +import java.util.Optional; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NonNull; @@ -20,15 +23,9 @@ import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.util.KNNEngine; -import java.io.IOException; -import java.util.Locale; -import java.util.Optional; - import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType; -import static org.opensearch.knn.common.KNNValidationUtil.validateByteVector; -import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector; /** * Field mapper for case when Lucene has been set as an engine. @@ -89,7 +86,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s return; } final byte[] array = bytesArrayOptional.get(); - validateByteVector(array, spaceType); + spaceType.validateVector(array); KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType); context.doc().add(point); @@ -105,7 +102,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s return; } final float[] array = floatsArrayOptional.get(); - validateFloatVector(array, spaceType); + spaceType.validateVector(array); KnnVectorField point = new KnnVectorField(name(), array, fieldType); context.doc().add(point); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 2d93803de..2140487c5 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -5,12 +5,25 @@ package org.opensearch.knn.index.query; +import java.io.IOException; import java.util.Arrays; +import java.util.List; +import java.util.Objects; import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -19,25 +32,9 @@ import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.plugin.stats.KNNCounter; -import org.apache.lucene.search.Query; -import org.opensearch.core.ParseField; -import org.opensearch.core.common.ParsingException; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.mapper.MappedFieldType; -import org.opensearch.index.query.AbstractQueryBuilder; -import org.opensearch.index.query.QueryShardContext; -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; -import static org.opensearch.knn.common.KNNValidationUtil.validateByteVector; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; -import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector; +import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; /** * Helper class to build the KNN query @@ -316,9 +313,9 @@ protected Query doToQuery(QueryShardContext context) { validateByteVectorValue(vector[i]); byteVector[i] = (byte) vector[i]; } - validateByteVector(byteVector, spaceType); + spaceType.validateVector(byteVector); } else { - validateFloatVector(vector, spaceType); + spaceType.validateVector(vector); } if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index 852c712b9..5a8cdb036 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -19,7 +19,6 @@ import java.util.Map; import java.util.function.BiFunction; -import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.getVectorMagnitudeSquared; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryFieldType; import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isKNNVectorFieldType; @@ -100,7 +99,7 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) { ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); - validateFloatVector(processedQuery, SpaceType.COSINESIMIL); + SpaceType.COSINESIMIL.validateVector(processedQuery); float qVectorSquaredMagnitude = getVectorMagnitudeSquared(this.processedQuery); this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude); } diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java index 5dc4deed4..114499100 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringUtil.java @@ -5,18 +5,16 @@ package org.opensearch.knn.plugin.script; -import org.opensearch.knn.index.KNNVectorScriptDocValues; +import java.math.BigInteger; +import java.util.List; +import java.util.Objects; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.knn.index.KNNVectorScriptDocValues; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import java.math.BigInteger; -import java.util.List; -import java.util.Objects; - import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; -import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector; public class KNNScoringUtil { private static Logger logger = LogManager.getLogger(KNNScoringUtil.class); @@ -137,7 +135,7 @@ public static float cosinesimilOptimized(float[] queryVector, float[] inputVecto */ public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) { float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); - validateFloatVector(inputVector, SpaceType.COSINESIMIL); + SpaceType.COSINESIMIL.validateVector(inputVector); return cosinesimilOptimized(inputVector, docValues.getValue(), queryVectorMagnitude.floatValue()); } @@ -184,7 +182,7 @@ public static float cosinesimil(float[] queryVector, float[] inputVector) { */ public static float cosineSimilarity(List queryVector, KNNVectorScriptDocValues docValues) { float[] inputVector = toFloat(queryVector, docValues.getVectorDataType()); - validateFloatVector(inputVector, SpaceType.COSINESIMIL); + SpaceType.COSINESIMIL.validateVector(inputVector); return cosinesimil(inputVector, docValues.getValue()); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 59d7a978c..bcd784e23 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -303,6 +303,7 @@ public void testDoToQuery_FromModel() { ModelMetadata modelMetadata = mock(ModelMetadata.class); when(modelMetadata.getDimension()).thenReturn(4); when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); + when(modelMetadata.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao);