Skip to content

Commit

Permalink
Add stored fields for knn_vector type (#1630)
Browse files Browse the repository at this point in the history
Fixes bug where we were not creating stored field type for knn_vector
even when the mapping parameter is passed. Along with this, clean up
the field mapper implementations.

Add relevant uTs and iTs to ensure functionality is working as expected.

Signed-off-by: John Mazanec <[email protected]>
(cherry picked from commit 699510d)
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Apr 22, 2024
1 parent 78a489a commit 1e66633
Show file tree
Hide file tree
Showing 10 changed files with 308 additions and 46 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Implemented the Streaming Feature to stream vectors from Java to JNI layer to enable creation of larger segments for vector indices [#1604](https://github.com/opensearch-project/k-NN/pull/1604)
* Remove unnecessary toString conversion of vector field and added some minor optimization in KNNCodec [1613](https://github.com/opensearch-project/k-NN/pull/1613)
### Bug Fixes
* Add stored fields for knn_vector type [#1630](https://github.com/opensearch-project/k-NN/pull/1630)
### Infrastructure
* Add micro-benchmark module in k-NN plugin for benchmark streaming vectors to JNI layer functionality. [#1583](https://github.com/opensearch-project/k-NN/pull/1583)
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ private static final class KNNNativeVectorScriptDocValues extends KNNVectorScrip

@Override
protected float[] doGetValue() throws IOException {
return getVectorDataType().getVectorFromDocValues(values.binaryValue());
return getVectorDataType().getVectorFromBytesRef(values.binaryValue());
}
}

Expand Down
10 changes: 5 additions & 5 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio
}

@Override
public float[] getVectorFromDocValues(BytesRef binaryValue) {
public float[] getVectorFromBytesRef(BytesRef binaryValue) {
float[] vector = new float[binaryValue.length];
int i = 0;
int j = binaryValue.offset;
Expand All @@ -56,7 +56,7 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio
}

@Override
public float[] getVectorFromDocValues(BytesRef binaryValue) {
public float[] getVectorFromBytesRef(BytesRef binaryValue) {
ByteArrayInputStream byteStream = new ByteArrayInputStream(binaryValue.bytes, binaryValue.offset, binaryValue.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
return vectorSerializer.byteToFloatArray(byteStream);
Expand All @@ -81,12 +81,12 @@ public float[] getVectorFromDocValues(BytesRef binaryValue) {
public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction);

/**
* Deserializes float vector from doc values binary value.
* Deserializes float vector from BytesRef.
*
* @param binaryValue Binary Value of DocValues
* @param binaryValue Binary Value
* @return float vector deserialized from binary value
*/
public abstract float[] getVectorFromDocValues(BytesRef binaryValue);
public abstract float[] getVectorFromBytesRef(BytesRef binaryValue);

/**
* Validates if given VectorDataType is in the list of supported data types.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;
import lombok.Getter;
Expand All @@ -20,6 +21,7 @@
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.opensearch.Version;
import org.opensearch.common.Explicit;
import org.opensearch.common.Nullable;
Expand Down Expand Up @@ -52,16 +54,6 @@
import org.opensearch.search.aggregations.support.CoreValuesSourceType;
import org.opensearch.search.lookup.SearchLookup;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP;
Expand All @@ -74,19 +66,17 @@
import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue;
import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension;
import static org.opensearch.knn.index.KNNSettings.KNN_INDEX;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting;

/**
* Field Mapper for KNN vector type.
*
* Extends ParametrizedFieldMapper in order to easily configure mapping parameters.
*
* Implementations of this class define what needs to be stored in Lucene's fieldType. This allows us to have
* alternative mappings for the same field type.
* Field Mapper for KNN vector type. Implementations of this class define what needs to be stored in Lucene's fieldType.
* This allows us to have alternative mappings for the same field type.
*/
@Log4j2
public abstract class KNNVectorFieldMapper extends ParametrizedFieldMapper {
Expand All @@ -109,8 +99,8 @@ private static KNNVectorFieldMapper toType(FieldMapper in) {
public static class Builder extends ParametrizedFieldMapper.Builder {
protected Boolean ignoreMalformed;

protected final Parameter<Boolean> stored = Parameter.boolParam("store", false, m -> toType(m).stored, false);
protected final Parameter<Boolean> hasDocValues = Parameter.boolParam("doc_values", false, m -> toType(m).hasDocValues, true);
protected final Parameter<Boolean> stored = Parameter.storeParam(m -> toType(m).stored, false);
protected final Parameter<Boolean> hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, true);
protected final Parameter<Integer> dimension = new Parameter<>(KNNConstants.DIMENSION, false, () -> -1, (n, c, o) -> {
if (o == null) {
throw new IllegalArgumentException("Dimension cannot be null");
Expand Down Expand Up @@ -483,6 +473,11 @@ public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, S
failIfNoDocValues();
return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES, this.vectorDataType);
}

@Override
public Object valueForDisplay(Object value) {
return deserializeStoredVector((BytesRef) value, vectorDataType);
}
}

protected Explicit<Boolean> ignoreMalformed;
Expand Down Expand Up @@ -561,7 +556,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point);
if (this.stored) {
context.doc().add(createStoredFieldForByteVector(name(), array));
}
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

Expand All @@ -572,7 +569,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
spaceType.validateVector(array);
VectorField point = new VectorField(name(), array, fieldType);
context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point);
if (this.stored) {
context.doc().add(createStoredFieldForFloatVector(name(), array));
}
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
Expand Down Expand Up @@ -735,11 +734,6 @@ Optional<float[]> getFloatsFromContext(ParseContext context, int dimension, Meth
return Optional.of(array);
}

@Override
protected boolean docValuesByDefault() {
return true;
}

@Override
public ParametrizedFieldMapper.Builder getMergeBuilder() {
return new KNNVectorFieldMapper.Builder(simpleName(), modelDao, indexCreatedVersion).init(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.util.BytesRef;
import org.opensearch.index.mapper.ParametrizedFieldMapper;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Arrays;
import java.util.Locale;

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
Expand All @@ -44,7 +45,6 @@ public class KNNVectorFieldMapperUtil {
*/
public static void validateFP16VectorValue(float value) {
validateFloatVectorValue(value);

if (value < FP16_MIN_VALUE || value > FP16_MAX_VALUE) {
throw new IllegalArgumentException(
String.format(
Expand Down Expand Up @@ -136,9 +136,39 @@ public static FieldType buildDocValuesFieldType(KNNEngine knnEngine) {
return field;
}

public static void addStoredFieldForVectorField(ParseContext context, FieldType fieldType, String mapperName, Field vectorField) {
if (fieldType.stored()) {
context.doc().add(new StoredField(mapperName, vectorField.toString()));
/**
* Creates a stored field for a byte vector
*
* @param name field name
* @param vector vector to be added to stored field
*/
public static StoredField createStoredFieldForByteVector(String name, byte[] vector) {
return new StoredField(name, vector);
}

/**
* Creates a stored field for a float vector
*
* @param name field name
* @param vector vector to be added to stored field
*/
public static StoredField createStoredFieldForFloatVector(String name, float[] vector) {
return new StoredField(name, KNNVectorSerializerFactory.getDefaultSerializer().floatToByteArray(vector));
}

/**
* @param storedVector Vector representation in bytes
* @param vectorDataType type of vector
* @return either int[] or float[] of corresponding vector
*/
public static Object deserializeStoredVector(BytesRef storedVector, VectorDataType vectorDataType) {
if (VectorDataType.BYTE == vectorDataType) {
byte[] bytes = storedVector.bytes;
int[] byteAsIntArray = new int[bytes.length];
Arrays.setAll(byteAsIntArray, i -> bytes[i]);
return byteAsIntArray;
}

return vectorDataType.getVectorFromBytesRef(storedVector);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
import org.opensearch.knn.index.util.KNNEngine;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType;

/**
Expand Down Expand Up @@ -92,7 +93,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point);
if (this.stored) {
context.doc().add(createStoredFieldForByteVector(name(), array));
}

if (hasDocValues && vectorFieldType != null) {
context.doc().add(new VectorField(name(), array, vectorFieldType));
Expand All @@ -108,7 +111,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
KnnVectorField point = new KnnVectorField(name(), array, fieldType);

context.doc().add(point);
addStoredFieldForVectorField(context, fieldType, name(), point);
if (this.stored) {
context.doc().add(createStoredFieldForFloatVector(name(), array));
}

if (hasDocValues && vectorFieldType != null) {
context.doc().add(new VectorField(name(), array, vectorFieldType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ public class AdvancedFilteringUseCasesIT extends KNNRestTestCase {

private static final String FIELD_NAME_VECTOR = "test_vector";

private static final String PROPERTIES_FIELD = "properties";

private static final String FILTER_FIELD = "filter";

private static final String TERM_FIELD = "term";
Expand Down
Loading

0 comments on commit 1e66633

Please sign in to comment.