diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d9d35ef0..1f0acfbaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +* Add stored fields for knn_vector type [#1630](https://github.com/opensearch-project/k-NN/pull/1630) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 9f7d52205..ad1195fe2 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -43,7 +43,7 @@ public float[] getValue() { throw new IllegalStateException(errorMessage); } try { - return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue()); + return vectorDataType.getVectorFromBytesRef(binaryDocValues.binaryValue()); } catch (IOException e) { throw ExceptionsHelper.convertToOpenSearchException(e); } diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index 23b374e9d..98b767f8d 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -37,7 +37,7 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio } @Override - public float[] getVectorFromDocValues(BytesRef binaryValue) { + public float[] getVectorFromBytesRef(BytesRef binaryValue) { float[] vector = new float[binaryValue.length]; int i = 0; int j = binaryValue.offset; @@ -56,7 +56,7 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio } @Override - public float[] getVectorFromDocValues(BytesRef binaryValue) { + public float[] getVectorFromBytesRef(BytesRef binaryValue) { ByteArrayInputStream byteStream = new ByteArrayInputStream(binaryValue.bytes, binaryValue.offset, binaryValue.length); final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream); return vectorSerializer.byteToFloatArray(byteStream); @@ -81,12 +81,12 @@ public float[] getVectorFromDocValues(BytesRef binaryValue) { public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction); /** - * Deserializes float vector from doc values binary value. + * Deserializes float vector from BytesRef. * - * @param binaryValue Binary Value of DocValues + * @param binaryValue Binary Value * @return float vector deserialized from binary value */ - public abstract float[] getVectorFromDocValues(BytesRef binaryValue); + public abstract float[] getVectorFromBytesRef(BytesRef binaryValue); /** * Validates if given VectorDataType is in the list of supported data types. diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index a36a4222b..b314a5dae 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -11,6 +11,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.function.Supplier; import lombok.Getter; @@ -20,6 +21,7 @@ import org.apache.lucene.index.IndexOptions; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.util.BytesRef; import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.common.Nullable; @@ -52,16 +54,6 @@ import org.opensearch.search.aggregations.support.CoreValuesSourceType; import org.opensearch.search.lookup.SearchLookup; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.function.Supplier; - import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; @@ -74,19 +66,17 @@ import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension; import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting; /** - * Field Mapper for KNN vector type. - * - * Extends ParametrizedFieldMapper in order to easily configure mapping parameters. - * - * Implementations of this class define what needs to be stored in Lucene's fieldType. This allows us to have - * alternative mappings for the same field type. + * Field Mapper for KNN vector type. Implementations of this class define what needs to be stored in Lucene's fieldType. + * This allows us to have alternative mappings for the same field type. */ @Log4j2 public abstract class KNNVectorFieldMapper extends ParametrizedFieldMapper { @@ -109,8 +99,8 @@ private static KNNVectorFieldMapper toType(FieldMapper in) { public static class Builder extends ParametrizedFieldMapper.Builder { protected Boolean ignoreMalformed; - protected final Parameter stored = Parameter.boolParam("store", false, m -> toType(m).stored, false); - protected final Parameter hasDocValues = Parameter.boolParam("doc_values", false, m -> toType(m).hasDocValues, true); + protected final Parameter stored = Parameter.storeParam(m -> toType(m).stored, false); + protected final Parameter hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, true); protected final Parameter dimension = new Parameter<>(KNNConstants.DIMENSION, false, () -> -1, (n, c, o) -> { if (o == null) { throw new IllegalArgumentException("Dimension cannot be null"); @@ -483,6 +473,11 @@ public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, S failIfNoDocValues(); return new KNNVectorIndexFieldData.Builder(name(), CoreValuesSourceType.BYTES, this.vectorDataType); } + + @Override + public Object valueForDisplay(Object value) { + return deserializeStoredVector((BytesRef) value, vectorDataType); + } } protected Explicit ignoreMalformed; @@ -561,7 +556,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s VectorField point = new VectorField(name(), array, fieldType); context.doc().add(point); - addStoredFieldForVectorField(context, fieldType, name(), point.toString()); + if (this.stored) { + context.doc().add(createStoredFieldForByteVector(name(), array)); + } } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext); @@ -572,7 +569,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s spaceType.validateVector(array); VectorField point = new VectorField(name(), array, fieldType); context.doc().add(point); - addStoredFieldForVectorField(context, fieldType, name(), point.toString()); + if (this.stored) { + context.doc().add(createStoredFieldForFloatVector(name(), array)); + } } else { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) @@ -735,11 +734,6 @@ Optional getFloatsFromContext(ParseContext context, int dimension, Meth return Optional.of(array); } - @Override - protected boolean docValuesByDefault() { - return true; - } - @Override public ParametrizedFieldMapper.Builder getMergeBuilder() { return new KNNVectorFieldMapper.Builder(simpleName(), modelDao, indexCreatedVersion).init(this); diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 283d35f00..9b1578a45 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -16,11 +16,13 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.util.BytesRef; import org.opensearch.index.mapper.ParametrizedFieldMapper; -import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.util.KNNEngine; +import java.util.Arrays; import java.util.Locale; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; @@ -43,7 +45,6 @@ public class KNNVectorFieldMapperUtil { */ public static void validateFP16VectorValue(float value) { validateFloatVectorValue(value); - if (value < FP16_MIN_VALUE || value > FP16_MAX_VALUE) { throw new IllegalArgumentException( String.format( @@ -135,14 +136,39 @@ public static FieldType buildDocValuesFieldType(KNNEngine knnEngine) { return field; } - public static void addStoredFieldForVectorField( - ParseContext context, - FieldType fieldType, - String mapperName, - String vectorFieldAsString - ) { - if (fieldType.stored()) { - context.doc().add(new StoredField(mapperName, vectorFieldAsString)); + /** + * Creates a stored field for a byte vector + * + * @param name field name + * @param vector vector to be added to stored field + */ + public static StoredField createStoredFieldForByteVector(String name, byte[] vector) { + return new StoredField(name, vector); + } + + /** + * Creates a stored field for a float vector + * + * @param name field name + * @param vector vector to be added to stored field + */ + public static StoredField createStoredFieldForFloatVector(String name, float[] vector) { + return new StoredField(name, KNNVectorSerializerFactory.getDefaultSerializer().floatToByteArray(vector)); + } + + /** + * @param storedVector Vector representation in bytes + * @param vectorDataType type of vector + * @return either int[] or float[] of corresponding vector + */ + public static Object deserializeStoredVector(BytesRef storedVector, VectorDataType vectorDataType) { + if (VectorDataType.BYTE == vectorDataType) { + byte[] bytes = storedVector.bytes; + int[] byteAsIntArray = new int[bytes.length]; + Arrays.setAll(byteAsIntArray, i -> bytes[i]); + return byteAsIntArray; } + + return vectorDataType.getVectorFromBytesRef(storedVector); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 185ab3dc4..618d77a32 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -25,7 +25,8 @@ import org.opensearch.knn.index.util.KNNEngine; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType; /** @@ -92,7 +93,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType); context.doc().add(point); - addStoredFieldForVectorField(context, fieldType, name(), point.toString()); + if (this.stored) { + context.doc().add(createStoredFieldForByteVector(name(), array)); + } if (hasDocValues && vectorFieldType != null) { context.doc().add(new VectorField(name(), array, vectorFieldType)); @@ -108,7 +111,9 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s KnnVectorField point = new KnnVectorField(name(), array, fieldType); context.doc().add(point); - addStoredFieldForVectorField(context, fieldType, name(), point.toString()); + if (this.stored) { + context.doc().add(createStoredFieldForFloatVector(name(), array)); + } if (hasDocValues && vectorFieldType != null) { context.doc().add(new VectorField(name(), array, vectorFieldType)); diff --git a/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java b/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java index b559b5760..5380dae90 100644 --- a/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java +++ b/src/test/java/org/opensearch/knn/index/AdvancedFilteringUseCasesIT.java @@ -53,8 +53,6 @@ public class AdvancedFilteringUseCasesIT extends KNNRestTestCase { private static final String FIELD_NAME_VECTOR = "test_vector"; - private static final String PROPERTIES_FIELD = "properties"; - private static final String FILTER_FIELD = "filter"; private static final String TERM_FIELD = "term"; diff --git a/src/test/java/org/opensearch/knn/index/KNNMapperSearcherIT.java b/src/test/java/org/opensearch/knn/index/KNNMapperSearcherIT.java index b2477fa5d..a50691e4b 100644 --- a/src/test/java/org/opensearch/knn/index/KNNMapperSearcherIT.java +++ b/src/test/java/org/opensearch/knn/index/KNNMapperSearcherIT.java @@ -5,20 +5,34 @@ package org.opensearch.knn.index; +import lombok.SneakyThrows; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.apache.http.util.EntityUtils; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.client.Response; import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.util.KNNEngine; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.QUERY; +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; + public class KNNMapperSearcherIT extends KNNRestTestCase { - private static final Logger logger = LogManager.getLogger(KNNMapperSearcherIT.class); + + private static final String INDEX_NAME = "test_index"; + private static final String FIELD_NAME = "test_vector"; /** * Test Data set @@ -239,4 +253,166 @@ public void testLargeK() throws Exception { List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); assertEquals(results.size(), 4); } + + /** + * Request: + * { + * "stored_fields": ["test_vector"], + * "query": { + * "match_all": {} + * } + * } + * + * Example Response: + * { + * "took":248, + * "timed_out":false, + * "_shards":{ + * "total":1, + * "successful":1, + * "skipped":0, + * "failed":0 + * }, + * "hits":{ + * "total":{ + * "value":1, + * "relation":"eq" + * }, + * "max_score":1.0, + * "hits":[ + * { + * "_index":"test_index", + * "_id":"1", + * "_score":1.0, + * "fields":{"test_vector":[[-128,0,1,127]]} + * } + * ] + * } + * } + */ + @SneakyThrows + public void testStoredFields_whenByteDataType_thenSucceed() { + // Create index with stored field and confirm that we can properly retrieve it + int[] testVector = new int[] { -128, 0, 1, 127 }; + String expectedResponse = String.format("\"fields\":{\"%s\":[[-128,0,1,127]]}}", FIELD_NAME); + createKnnIndex( + INDEX_NAME, + createVectorMapping(testVector.length, KNNEngine.LUCENE.getName(), VectorDataType.BYTE.getValue(), true) + ); + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, testVector); + + final XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field(STORED_QUERY_FIELD, List.of(FIELD_NAME)); + builder.startObject(QUERY); + builder.startObject(MATCH_ALL_QUERY_FIELD); + builder.endObject(); + builder.endObject(); + builder.endObject(); + + String response = EntityUtils.toString(performSearch(INDEX_NAME, builder.toString()).getEntity()); + assertTrue(response.contains(expectedResponse)); + + deleteKNNIndex(INDEX_NAME); + } + + /** + * Request: + * { + * "stored_fields": ["test_vector"], + * "query": { + * "match_all": {} + * } + * } + * + * Example Response: + * { + * "took":248, + * "timed_out":false, + * "_shards":{ + * "total":1, + * "successful":1, + * "skipped":0, + * "failed":0 + * }, + * "hits":{ + * "total":{ + * "value":1, + * "relation":"eq" + * }, + * "max_score":1.0, + * "hits":[ + * { + * "_index":"test_index", + * "_id":"1", + * "_score":1.0, + * "fields":{"test_vector":[[-100.0,100.0,0.0,1.0]]} + * } + * ] + * } + * } + */ + @SneakyThrows + public void testStoredFields_whenFloatDataType_thenSucceed() { + List enginesToTest = List.of(KNNEngine.NMSLIB, KNNEngine.FAISS, KNNEngine.LUCENE); + float[] testVector = new float[] { -100.0f, 100.0f, 0f, 1f }; + String expectedResponse = String.format("\"fields\":{\"%s\":[[-100.0,100.0,0.0,1.0]]}}", FIELD_NAME); + for (KNNEngine knnEngine : enginesToTest) { + createKnnIndex(INDEX_NAME, createVectorMapping(testVector.length, knnEngine.getName(), VectorDataType.FLOAT.getValue(), true)); + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, testVector); + + final XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.field(STORED_QUERY_FIELD, List.of(FIELD_NAME)); + builder.startObject(QUERY); + builder.startObject(MATCH_ALL_QUERY_FIELD); + builder.endObject(); + builder.endObject(); + builder.endObject(); + + String response = EntityUtils.toString(performSearch(INDEX_NAME, builder.toString()).getEntity()); + assertTrue(response.contains(expectedResponse)); + + deleteKNNIndex(INDEX_NAME); + } + } + + /** + * Mapping + * { + * "properties": { + * "test_vector": { + * "type": "knn_vector", + * "dimension": {dimension}, + * "data_type": "{type}", + * "stored": true + * "method": { + * "name": "hnsw", + * "engine": "{engine}" + * } + * } + * } + * } + */ + @SneakyThrows + private String createVectorMapping(final int dimension, final String engine, final String dataType, final boolean isStored) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, dimension) + .field(VECTOR_DATA_TYPE_FIELD, dataType) + .field(STORE_FIELD, isStored) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, engine) + .endObject() + .endObject() + .endObject() + .endObject(); + + return builder.toString(); + } + } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java new file mode 100644 index 000000000..3fa9f2363 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.mapper; + +import org.apache.lucene.document.StoredField; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; + +import java.io.ByteArrayInputStream; +import java.util.Arrays; + +public class KNNVectorFieldMapperUtilTests extends KNNTestCase { + + private static final String TEST_FIELD_NAME = "test_field_name"; + private static final byte[] TEST_BYTE_VECTOR = new byte[] { -128, 0, 1, 127 }; + private static final float[] TEST_FLOAT_VECTOR = new float[] { -100.0f, 100.0f, 0f, 1f }; + + public void testStoredFields_whenVectorIsByteType_thenSucceed() { + StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForByteVector(TEST_FIELD_NAME, TEST_BYTE_VECTOR); + assertEquals(TEST_FIELD_NAME, storedField.name()); + assertEquals(TEST_BYTE_VECTOR, storedField.binaryValue().bytes); + Object vector = KNNVectorFieldMapperUtil.deserializeStoredVector(storedField.binaryValue(), VectorDataType.BYTE); + assertTrue(vector instanceof int[]); + int[] byteAsIntArray = new int[TEST_BYTE_VECTOR.length]; + Arrays.setAll(byteAsIntArray, i -> TEST_BYTE_VECTOR[i]); + assertArrayEquals(byteAsIntArray, (int[]) vector); + } + + public void testStoredFields_whenVectorIsFloatType_thenSucceed() { + StoredField storedField = KNNVectorFieldMapperUtil.createStoredFieldForFloatVector(TEST_FIELD_NAME, TEST_FLOAT_VECTOR); + assertEquals(TEST_FIELD_NAME, storedField.name()); + byte[] bytes = storedField.binaryValue().bytes; + ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bytes, 0, bytes.length); + assertArrayEquals( + TEST_FLOAT_VECTOR, + KNNVectorSerializerFactory.getDefaultSerializer().byteToFloatArray(byteArrayInputStream), + 0.001f + ); + + Object vector = KNNVectorFieldMapperUtil.deserializeStoredVector(storedField.binaryValue(), VectorDataType.FLOAT); + assertTrue(vector instanceof float[]); + assertArrayEquals(TEST_FLOAT_VECTOR, (float[]) vector, 0.001f); + } +} diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 6897091af..d1fb8703d 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -107,6 +107,10 @@ public class KNNRestTestCase extends ODFERestTestCase { public static final String INDEX_NAME = "test_index"; public static final String FIELD_NAME = "test_field"; + public static final String PROPERTIES_FIELD = "properties"; + public static final String STORE_FIELD = "store"; + public static final String STORED_QUERY_FIELD = "stored_fields"; + public static final String MATCH_ALL_QUERY_FIELD = "match_all"; private static final String DOCUMENT_FIELD_SOURCE = "_source"; private static final String DOCUMENT_FIELD_FOUND = "found"; protected static final int DELAY_MILLI_SEC = 1000; @@ -474,7 +478,7 @@ protected void forceMergeKnnIndex(String index, int maxSegments) throws Exceptio /** * Add a single KNN Doc to an index */ - protected void addKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException { + protected void addKnnDoc(String index, String docId, String fieldName, T vector) throws IOException { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, vector).endObject();