Skip to content

Commit

Permalink
Refractor createParseField function in mappers for code reusability (o…
Browse files Browse the repository at this point in the history
…pensearch-project#1726)

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v authored May 31, 2024
1 parent 180216f commit 623b610
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.util.function.Supplier;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexOptions;
Expand Down Expand Up @@ -540,6 +541,38 @@ private MethodComponentContext getMethodComponentContext(KNNMethodContext knnMet
return knnMethodContext.getMethodComponentContext();
}

/**
* Function returns a list of fields to be indexed when the vector is float type.
*
* @param array array of floats
* @param fieldType {@link FieldType}
* @return {@link List} of {@link Field}
*/
protected List<Field> getFieldsForFloatVector(final float[] array, final FieldType fieldType) {
final List<Field> fields = new ArrayList<>();
fields.add(new VectorField(name(), array, fieldType));
if (this.stored) {
fields.add(createStoredFieldForFloatVector(name(), array));
}
return fields;
}

/**
* Function returns a list of fields to be indexed when the vector is byte type.
*
* @param array array of bytes
* @param fieldType {@link FieldType}
* @return {@link List} of {@link Field}
*/
protected List<Field> getFieldsForByteVector(final byte[] array, final FieldType fieldType) {
final List<Field> fields = new ArrayList<>();
fields.add(new VectorField(name(), array, fieldType));
if (this.stored) {
fields.add(createStoredFieldForByteVector(name(), array));
}
return fields;
}

protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext)
throws IOException {

Expand All @@ -554,12 +587,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
}
final byte[] array = bytesArrayOptional.get();
spaceType.validateVector(array);
VectorField point = new VectorField(name(), array, fieldType);

context.doc().add(point);
if (this.stored) {
context.doc().add(createStoredFieldForByteVector(name(), array));
}
context.doc().addAll(getFieldsForByteVector(array, fieldType));
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

Expand All @@ -568,11 +596,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
}
final float[] array = floatsArrayOptional.get();
spaceType.validateVector(array);
VectorField point = new VectorField(name(), array, fieldType);
context.doc().add(point);
if (this.stored) {
context.doc().add(createStoredFieldForFloatVector(name(), array));
}
context.doc().addAll(getFieldsForFloatVector(array, fieldType));
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,23 @@

package org.opensearch.knn.index.mapper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NonNull;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.opensearch.common.Explicit;
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorField;
import org.opensearch.knn.index.util.KNNEngine;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector;
import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType;
Expand Down Expand Up @@ -77,54 +74,33 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper {
}

@Override
protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext)
throws IOException {

validateIfKNNPluginEnabled();
validateIfCircuitBreakerIsNotTriggered();

if (VectorDataType.BYTE == vectorDataType) {
Optional<byte[]> bytesArrayOptional = getBytesFromContext(context, dimension);
if (bytesArrayOptional.isEmpty()) {
return;
}
final byte[] array = bytesArrayOptional.get();
spaceType.validateVector(array);
KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType);

context.doc().add(point);
if (this.stored) {
context.doc().add(createStoredFieldForByteVector(name(), array));
}

if (hasDocValues && vectorFieldType != null) {
context.doc().add(new VectorField(name(), array, vectorFieldType));
}
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

if (floatsArrayOptional.isEmpty()) {
return;
}
final float[] array = floatsArrayOptional.get();
spaceType.validateVector(array);
KnnVectorField point = new KnnVectorField(name(), array, fieldType);

context.doc().add(point);
if (this.stored) {
context.doc().add(createStoredFieldForFloatVector(name(), array));
}

if (hasDocValues && vectorFieldType != null) {
context.doc().add(new VectorField(name(), array, vectorFieldType));
}
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
);
protected List<Field> getFieldsForFloatVector(final float[] array, final FieldType fieldType) {
final List<Field> fieldsToBeAdded = new ArrayList<>();
fieldsToBeAdded.add(new KnnVectorField(name(), array, fieldType));

if (hasDocValues && vectorFieldType != null) {
fieldsToBeAdded.add(new VectorField(name(), array, vectorFieldType));
}

if (this.stored) {
fieldsToBeAdded.add(createStoredFieldForFloatVector(name(), array));
}
return fieldsToBeAdded;
}

@Override
protected List<Field> getFieldsForByteVector(final byte[] array, final FieldType fieldType) {
final List<Field> fieldsToBeAdded = new ArrayList<>();
fieldsToBeAdded.add(new KnnByteVectorField(name(), array, fieldType));

context.path().remove();
if (hasDocValues && vectorFieldType != null) {
fieldsToBeAdded.add(new VectorField(name(), array, vectorFieldType));
}

if (this.stored) {
fieldsToBeAdded.add(createStoredFieldForByteVector(name(), array));
}
return fieldsToBeAdded;
}

@Override
Expand Down

0 comments on commit 623b610

Please sign in to comment.