Skip to content

Commit

Permalink
apply review comments
Browse files Browse the repository at this point in the history
Signed-off-by: panguixin <[email protected]>
  • Loading branch information
bugmakerrrrrr committed Mar 14, 2024
1 parent e5b3cf8 commit 7635cc6
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 93 deletions.
30 changes: 0 additions & 30 deletions src/main/java/org/opensearch/knn/common/KNNValidationUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
import java.util.Locale;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNVectorUtil.isZeroVector;

@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class KNNValidationUtil {
Expand Down Expand Up @@ -70,34 +68,6 @@ public static void validateByteVectorValue(float value) {
}
}

/**
* Validate if the given byte vector is supported by the given space type
*
* @param vector the given vector
* @param spaceType the given space type
*/
public static void validateByteVector(byte[] vector, SpaceType spaceType) {
if (spaceType == SpaceType.COSINESIMIL && isZeroVector(vector)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", spaceType.getValue())
);
}
}

/**
* Validate if the given float vector is supported by the given space type
*
* @param vector the given vector
* @param spaceType the given space type
*/
public static void validateFloatVector(float[] vector, SpaceType spaceType) {
if (spaceType == SpaceType.COSINESIMIL && isZeroVector(vector)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", spaceType.getValue())
);
}
}

/**
* Validate if the given vector size matches with the dimension provided in mapping.
*
Expand Down
39 changes: 39 additions & 0 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@

package org.opensearch.knn.index;

import java.util.Locale;
import org.apache.lucene.index.VectorSimilarityFunction;

import java.util.HashSet;
import java.util.Set;

import static org.opensearch.knn.common.KNNVectorUtil.isZeroVector;

/**
* Enum contains spaces supported for approximate nearest neighbor search in the k-NN plugin. Each engine's methods are
* expected to support a subset of these spaces. Validation should be done in the jni layer and an exception should be
Expand Down Expand Up @@ -44,6 +47,24 @@ public float scoreTranslation(float rawScore) {
public VectorSimilarityFunction getVectorSimilarityFunction() {
return VectorSimilarityFunction.COSINE;
}

@Override
public void validateVector(byte[] vector) {
if (isZeroVector(vector)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", getValue())
);
}
}

Check warning on line 58 in src/main/java/org/opensearch/knn/index/SpaceType.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/SpaceType.java#L58

Added line #L58 was not covered by tests

@Override
public void validateVector(float[] vector) {
if (isZeroVector(vector)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "zero vector is not supported when space type is [%s]", getValue())
);
}
}
},
L1("l1") {
@Override
Expand Down Expand Up @@ -105,6 +126,24 @@ public VectorSimilarityFunction getVectorSimilarityFunction() {
throw new UnsupportedOperationException(String.format("Space [%s] does not have a vector similarity function", getValue()));
}

/**
* Validate if the given byte vector is supported by this space type
*
* @param vector the given vector
*/
public void validateVector(byte[] vector) {
// do nothing
}

/**
* Validate if the given float vector is supported by this space type
*
* @param vector the given vector
*/
public void validateVector(float[] vector) {
// do nothing
}

/**
* Get space type name in engine
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,29 @@

package org.opensearch.knn.index.mapper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.Version;
import org.opensearch.common.Nullable;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;

import org.apache.lucene.document.FieldType;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.common.Explicit;
import org.opensearch.common.Nullable;
import org.opensearch.common.ValidationException;
import org.opensearch.common.xcontent.support.XContentMapValues;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.common.xcontent.support.XContentMapValues;
import org.opensearch.index.fielddata.IndexFieldData;
import org.opensearch.index.mapper.FieldMapper;
import org.opensearch.index.mapper.MappedFieldType;
Expand All @@ -33,6 +39,7 @@
import org.opensearch.index.mapper.ValueFetcher;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.KNNVectorIndexFieldData;
Expand All @@ -44,27 +51,16 @@
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVector;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension;
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting;

/**
* Field Mapper for KNN vector type.
Expand Down Expand Up @@ -285,7 +281,7 @@ public KNNVectorFieldMapper build(BuilderContext context) {

return new ModelFieldMapper(
name,
new KNNVectorFieldType(buildFullName(context), metaValue, -1, modelIdAsString),
new KNNVectorFieldType(buildFullName(context), metaValue, -1, knnMethodContext, modelIdAsString),
multiFieldsBuilder,
copyToBuilder,
ignoreMalformed,
Expand Down Expand Up @@ -410,8 +406,8 @@ public KNNVectorFieldType(String name, Map<String, String> meta, int dimension,
this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD, knnMethodContext.getSpaceType());
}

public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, String modelId) {
this(name, meta, dimension, null, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD, null);
public KNNVectorFieldType(String name, Map<String, String> meta, int dimension, KNNMethodContext knnMethodContext, String modelId) {
this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD, null);
}

public KNNVectorFieldType(
Expand Down Expand Up @@ -530,7 +526,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
return;
}
final byte[] array = bytesArrayOptional.get();
validateByteVector(array, spaceType);
spaceType.validateVector(array);
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
Expand All @@ -542,7 +538,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
return;
}
final float[] array = floatsArrayOptional.get();
validateFloatVector(array, spaceType);
spaceType.validateVector(array);
VectorField point = new VectorField(name(), array, fieldType);
context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.knn.index.mapper;

import java.io.IOException;
import java.util.Locale;
import java.util.Optional;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NonNull;
Expand All @@ -20,15 +23,9 @@
import org.opensearch.knn.index.VectorField;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.Locale;
import java.util.Optional;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVector;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector;

/**
* Field mapper for case when Lucene has been set as an engine.
Expand Down Expand Up @@ -89,7 +86,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
return;
}
final byte[] array = bytesArrayOptional.get();
validateByteVector(array, spaceType);
spaceType.validateVector(array);
KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType);

context.doc().add(point);
Expand All @@ -105,7 +102,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
return;
}
final float[] array = floatsArrayOptional.get();
validateFloatVector(array, spaceType);
spaceType.validateVector(array);
KnnVectorField point = new KnnVectorField(name(), array, fieldType);

context.doc().add(point);
Expand Down
35 changes: 16 additions & 19 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,25 @@

package org.opensearch.knn.index.query;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
Expand All @@ -19,25 +32,9 @@
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.apache.lucene.search.Query;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryShardContext;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVector;
import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector;
import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;

/**
* Helper class to build the KNN query
Expand Down Expand Up @@ -316,9 +313,9 @@ protected Query doToQuery(QueryShardContext context) {
validateByteVectorValue(vector[i]);
byteVector[i] = (byte) vector[i];
}
validateByteVector(byteVector, spaceType);
spaceType.validateVector(byteVector);
} else {
validateFloatVector(vector, spaceType);
spaceType.validateVector(vector);
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import java.util.Map;
import java.util.function.BiFunction;

import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.getVectorMagnitudeSquared;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isBinaryFieldType;
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.isKNNVectorFieldType;
Expand Down Expand Up @@ -100,7 +99,7 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) {
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
validateFloatVector(processedQuery, SpaceType.COSINESIMIL);
SpaceType.COSINESIMIL.validateVector(processedQuery);
float qVectorSquaredMagnitude = getVectorMagnitudeSquared(this.processedQuery);
this.scoringMethod = (float[] q, float[] v) -> 1 + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@

package org.opensearch.knn.plugin.script;

import org.opensearch.knn.index.KNNVectorScriptDocValues;
import java.math.BigInteger;
import java.util.List;
import java.util.Objects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.knn.index.KNNVectorScriptDocValues;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;

import java.math.BigInteger;
import java.util.List;
import java.util.Objects;

import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVector;

public class KNNScoringUtil {
private static Logger logger = LogManager.getLogger(KNNScoringUtil.class);
Expand Down Expand Up @@ -137,7 +135,7 @@ public static float cosinesimilOptimized(float[] queryVector, float[] inputVecto
*/
public static float cosineSimilarity(List<Number> queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) {
float[] inputVector = toFloat(queryVector, docValues.getVectorDataType());
validateFloatVector(inputVector, SpaceType.COSINESIMIL);
SpaceType.COSINESIMIL.validateVector(inputVector);
return cosinesimilOptimized(inputVector, docValues.getValue(), queryVectorMagnitude.floatValue());
}

Expand Down Expand Up @@ -184,7 +182,7 @@ public static float cosinesimil(float[] queryVector, float[] inputVector) {
*/
public static float cosineSimilarity(List<Number> queryVector, KNNVectorScriptDocValues docValues) {
float[] inputVector = toFloat(queryVector, docValues.getVectorDataType());
validateFloatVector(inputVector, SpaceType.COSINESIMIL);
SpaceType.COSINESIMIL.validateVector(inputVector);
return cosinesimil(inputVector, docValues.getValue());
}

Expand Down
Loading

0 comments on commit 7635cc6

Please sign in to comment.