diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index dd8c145db..7e697fed7 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -16,6 +16,7 @@ import java.util.function.Supplier; import lombok.Getter; import lombok.extern.log4j.Log4j2; +import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.IndexOptions; @@ -540,6 +541,38 @@ private MethodComponentContext getMethodComponentContext(KNNMethodContext knnMet return knnMethodContext.getMethodComponentContext(); } + /** + * Function returns a list of fields to be indexed when the vector is float type. + * + * @param array array of floats + * @param fieldType {@link FieldType} + * @return {@link List} of {@link Field} + */ + protected List getFieldsForFloatVector(final float[] array, final FieldType fieldType) { + final List fields = new ArrayList<>(); + fields.add(new VectorField(name(), array, fieldType)); + if (this.stored) { + fields.add(createStoredFieldForFloatVector(name(), array)); + } + return fields; + } + + /** + * Function returns a list of fields to be indexed when the vector is byte type. + * + * @param array array of bytes + * @param fieldType {@link FieldType} + * @return {@link List} of {@link Field} + */ + protected List getFieldsForByteVector(final byte[] array, final FieldType fieldType) { + final List fields = new ArrayList<>(); + fields.add(new VectorField(name(), array, fieldType)); + if (this.stored) { + fields.add(createStoredFieldForByteVector(name(), array)); + } + return fields; + } + protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext) throws IOException { @@ -554,12 +587,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s } final byte[] array = bytesArrayOptional.get(); spaceType.validateVector(array); - VectorField point = new VectorField(name(), array, fieldType); - - context.doc().add(point); - if (this.stored) { - context.doc().add(createStoredFieldForByteVector(name(), array)); - } + context.doc().addAll(getFieldsForByteVector(array, fieldType)); } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext); @@ -568,11 +596,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s } final float[] array = floatsArrayOptional.get(); spaceType.validateVector(array); - VectorField point = new VectorField(name(), array, fieldType); - context.doc().add(point); - if (this.stored) { - context.doc().add(createStoredFieldForFloatVector(name(), array)); - } + context.doc().addAll(getFieldsForFloatVector(array, fieldType)); } else { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 618d77a32..59f4867dd 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -5,26 +5,23 @@ package org.opensearch.knn.index.mapper; -import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Locale; -import java.util.Optional; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NonNull; +import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.VectorSimilarityFunction; import org.opensearch.common.Explicit; -import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.index.KNNMethodContext; -import org.opensearch.knn.index.MethodComponentContext; -import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.util.KNNEngine; -import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType; @@ -77,54 +74,33 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { } @Override - protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext) - throws IOException { - - validateIfKNNPluginEnabled(); - validateIfCircuitBreakerIsNotTriggered(); - - if (VectorDataType.BYTE == vectorDataType) { - Optional bytesArrayOptional = getBytesFromContext(context, dimension); - if (bytesArrayOptional.isEmpty()) { - return; - } - final byte[] array = bytesArrayOptional.get(); - spaceType.validateVector(array); - KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType); - - context.doc().add(point); - if (this.stored) { - context.doc().add(createStoredFieldForByteVector(name(), array)); - } - - if (hasDocValues && vectorFieldType != null) { - context.doc().add(new VectorField(name(), array, vectorFieldType)); - } - } else if (VectorDataType.FLOAT == vectorDataType) { - Optional floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext); - - if (floatsArrayOptional.isEmpty()) { - return; - } - final float[] array = floatsArrayOptional.get(); - spaceType.validateVector(array); - KnnVectorField point = new KnnVectorField(name(), array, fieldType); - - context.doc().add(point); - if (this.stored) { - context.doc().add(createStoredFieldForFloatVector(name(), array)); - } - - if (hasDocValues && vectorFieldType != null) { - context.doc().add(new VectorField(name(), array, vectorFieldType)); - } - } else { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) - ); + protected List getFieldsForFloatVector(final float[] array, final FieldType fieldType) { + final List fieldsToBeAdded = new ArrayList<>(); + fieldsToBeAdded.add(new KnnVectorField(name(), array, fieldType)); + + if (hasDocValues && vectorFieldType != null) { + fieldsToBeAdded.add(new VectorField(name(), array, vectorFieldType)); + } + + if (this.stored) { + fieldsToBeAdded.add(createStoredFieldForFloatVector(name(), array)); } + return fieldsToBeAdded; + } + + @Override + protected List getFieldsForByteVector(final byte[] array, final FieldType fieldType) { + final List fieldsToBeAdded = new ArrayList<>(); + fieldsToBeAdded.add(new KnnByteVectorField(name(), array, fieldType)); - context.path().remove(); + if (hasDocValues && vectorFieldType != null) { + fieldsToBeAdded.add(new VectorField(name(), array, vectorFieldType)); + } + + if (this.stored) { + fieldsToBeAdded.add(createStoredFieldForByteVector(name(), array)); + } + return fieldsToBeAdded; } @Override