From ac0ee512c3e43c6f58defbfea7d01878e6b16939 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Mon, 18 Mar 2024 15:51:43 -0500 Subject: [PATCH] SQFP16 Range Validation for Faiss IVF Models (#1557) * SQFP16 Range Validation for Faiss IVF Models Signed-off-by: Naveen Tatikonda * Address Review Comments Signed-off-by: Naveen Tatikonda --------- Signed-off-by: Naveen Tatikonda --- CHANGELOG.md | 1 + .../org/opensearch/knn/bwc/FaissSQIT.java | 71 +++++++ .../index/mapper/KNNVectorFieldMapper.java | 59 +++--- .../knn/index/mapper/LuceneFieldMapper.java | 6 +- .../knn/index/mapper/ModelFieldMapper.java | 2 +- .../org/opensearch/knn/index/FaissIT.java | 179 ++++++++++++++++++ .../mapper/KNNVectorFieldMapperTests.java | 34 +++- 7 files changed, 319 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index df79dccc4..fdfedc126 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Persist model definition in model metadata [#1527] (https://github.com/opensearch-project/k-NN/pull/1527) * Added Inner Product Space type support for Lucene Engine [#1551](https://github.com/opensearch-project/k-NN/pull/1551) * Add Range Validation for Faiss SQFP16 [#1493](https://github.com/opensearch-project/k-NN/pull/1493) +* SQFP16 Range Validation for Faiss IVF Models [#1557](https://github.com/opensearch-project/k-NN/pull/1557) ### Bug Fixes * Disable sdc table for HNSWPQ read-only indices [#1518](https://github.com/opensearch-project/k-NN/pull/1518) * Switch SpaceType.INNERPRODUCT's vector similarity function to MAXIMUM_INNER_PRODUCT [#1532](https://github.com/opensearch-project/k-NN/pull/1532) diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java index 9516c2d26..b6f1697bb 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java @@ -40,6 +40,7 @@ import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; @@ -276,6 +277,76 @@ public void testIVFSQFP16_onUpgradeWhenIndexedAndQueried_thenSucceed() throws Ex } } + public void testIVFSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16Range_thenSucceed() throws Exception { + if (!isRunningAgainstOldCluster()) { + int dimension = 2; + + // Add training data + createBasicKnnIndex(TRAIN_INDEX, TRAIN_TEST_FIELD, dimension); + int trainingDataCount = 200; + bulkIngestRandomVectors(TRAIN_INDEX, TRAIN_TEST_FIELD, trainingDataCount, dimension); + + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, "l2") + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 1) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_SQ) + .startObject(PARAMETERS) + .field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) + .field(FAISS_SQ_CLIP, true) + .endObject() + .endObject() + .endObject() + .endObject(); + Map method = xContentBuilderToMap(builder); + + trainModel(TEST_MODEL, TRAIN_INDEX, TRAIN_TEST_FIELD, dimension, method, "faiss ivf sqfp16 test description"); + + // Make sure training succeeds after 30 seconds + assertTrainingSucceeds(TEST_MODEL, 30, 1000); + + // Create knn index from model + String indexMapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(TEST_FIELD) + .field("type", "knn_vector") + .field(MODEL_ID, TEST_MODEL) + .endObject() + .endObject() + .endObject() + .toString(); + + createKnnIndex(testIndex, getKNNDefaultIndexSettings(), indexMapping); + Float[] vector1 = { -65523.76f, 65504.2f }; + Float[] vector2 = { -270.85f, 65514.2f }; + Float[] vector3 = { -150.9f, 65504.0f }; + Float[] vector4 = { -20.89f, 100000000.0f }; + addKnnDoc(testIndex, "1", TEST_FIELD, vector1); + addKnnDoc(testIndex, "2", TEST_FIELD, vector2); + addKnnDoc(testIndex, "3", TEST_FIELD, vector3); + addKnnDoc(testIndex, "4", TEST_FIELD, vector4); + + float[] queryVector = { -10.5f, 25.48f }; + int k = 4; + Response searchResponse = searchKNNIndex(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, k), k); + List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), TEST_FIELD); + assertEquals(k, results.size()); + for (int i = 0; i < k; i++) { + assertEquals(k - i, Integer.parseInt(results.get(i).getDocId())); + } + + deleteKNNIndex(testIndex); + deleteKNNIndex(TRAIN_INDEX); + deleteModel(TEST_MODEL); + validateGraphEviction(); + } + } + private void validateGraphEviction() throws Exception { // Search every 5 seconds 14 times to confirm graph gets evicted int intervals = 14; diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index bab81c578..a36a4222b 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -64,7 +64,6 @@ import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; -import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; @@ -530,10 +529,23 @@ protected String contentType() { @Override protected void parseCreateField(ParseContext context) throws IOException { - parseCreateField(context, fieldType().getDimension(), fieldType().getSpaceType()); + parseCreateField( + context, + fieldType().getDimension(), + fieldType().getSpaceType(), + getMethodComponentContext(fieldType().getKnnMethodContext()) + ); } - protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType) throws IOException { + private MethodComponentContext getMethodComponentContext(KNNMethodContext knnMethodContext) { + if (Objects.isNull(knnMethodContext)) { + return null; + } + return knnMethodContext.getMethodComponentContext(); + } + + protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext) + throws IOException { validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); @@ -551,7 +563,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s context.doc().add(point); addStoredFieldForVectorField(context, fieldType, name(), point.toString()); } else if (VectorDataType.FLOAT == vectorDataType) { - Optional floatsArrayOptional = getFloatsFromContext(context, dimension); + Optional floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext); if (floatsArrayOptional.isEmpty()) { return; @@ -571,34 +583,35 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s } // Verify mapping and return true if it is a "faiss" Index using "sq" encoder of type "fp16" - protected boolean isFaissSQfp16(KNNMethodContext knnMethodContext) { - - // KNNMethodContext shouldn't be null - if (Objects.isNull(knnMethodContext)) { + protected boolean isFaissSQfp16(MethodComponentContext methodComponentContext) { + if (Objects.isNull(methodComponentContext)) { return false; } - // engine should be faiss - if (!FAISS_NAME.equals(knnMethodContext.getKnnEngine().getName())) { + if (methodComponentContext.getParameters().size() == 0) { return false; } - // Should have Method Component Parameters - if (knnMethodContext.getMethodComponentContext().getParameters().size() == 0) { - return false; - } - Map methodComponentParams = knnMethodContext.getMethodComponentContext().getParameters(); + Map methodComponentParams = methodComponentContext.getParameters(); // The method component parameters should have an encoder if (!methodComponentParams.containsKey(METHOD_ENCODER_PARAMETER)) { return false; } - MethodComponentContext methodComponentContext = (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER); + // Validate if the object is of type MethodComponentContext before casting it later + if (!(methodComponentParams.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) { + return false; + } + + MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER); // returns true if encoder name is "sq" and type is "fp16" - return ENCODER_SQ.equals(methodComponentContext.getName()) - && FAISS_SQ_ENCODER_FP16.equals(methodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16)); + return ENCODER_SQ.equals(encoderMethodComponentContext.getName()) + && FAISS_SQ_ENCODER_FP16.equals( + encoderMethodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) + ); + } // Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index @@ -659,21 +672,19 @@ Optional getBytesFromContext(ParseContext context, int dimension) throws return Optional.of(array); } - Optional getFloatsFromContext(ParseContext context, int dimension) throws IOException { + Optional getFloatsFromContext(ParseContext context, int dimension, MethodComponentContext methodComponentContext) + throws IOException { context.path().add(simpleName()); // Returns an optional array of float values where each value in the vector is parsed as a float and validated // if it is a finite number and within the fp16 range of [-65504 to 65504] by default if Faiss encoder is SQ and type is 'fp16'. // If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be // clipped to FP16 range. - boolean isFaissSQfp16Flag = isFaissSQfp16(fieldType().getKnnMethodContext()); + boolean isFaissSQfp16Flag = isFaissSQfp16(methodComponentContext); boolean clipVectorValueToFP16RangeFlag = false; if (isFaissSQfp16Flag) { clipVectorValueToFP16RangeFlag = isFaissSQClipToFP16RangeEnabled( - (MethodComponentContext) fieldType().getKnnMethodContext() - .getMethodComponentContext() - .getParameters() - .get(METHOD_ENCODER_PARAMETER) + (MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER) ); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 81c7216bf..185ab3dc4 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -18,6 +18,7 @@ import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; @@ -75,7 +76,8 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { } @Override - protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType) throws IOException { + protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext) + throws IOException { validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); @@ -96,7 +98,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s context.doc().add(new VectorField(name(), array, vectorFieldType)); } } else if (VectorDataType.FLOAT == vectorDataType) { - Optional floatsArrayOptional = getFloatsFromContext(context, dimension); + Optional floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext); if (floatsArrayOptional.isEmpty()) { return; diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 2367d7422..ce92d2967 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -61,6 +61,6 @@ protected void parseCreateField(ParseContext context) throws IOException { ); } - parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType()); + parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext()); } } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 1c4638a45..3fafae9ba 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -568,6 +568,185 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then validateGraphEviction(); } + @SneakyThrows + public void testIVFSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() { + String modelId = "test-model-ivf-sqfp16"; + int dimension = 128; + + String trainingIndexName = "train-index-ivf-sqfp16"; + String trainingFieldName = "train-field-ivf-sqfp16"; + + // Add training data + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + int trainingDataCount = 200; + bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); + + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, "l2") + .startObject(PARAMETERS) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_SQ) + .startObject(PARAMETERS) + .field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) + .endObject() + .endObject() + .endObject() + .endObject(); + Map method = xContentBuilderToMap(builder); + + trainModel(modelId, trainingIndexName, trainingFieldName, dimension, method, "faiss ivf sqfp16 test description"); + + // Make sure training succeeds after 30 seconds + assertTrainingSucceeds(modelId, 30, 1000); + + // Create knn index from model + String fieldName = "test-field-name-ivf-sqfp16"; + String indexName = "test-index-name-ivf-sqfp16"; + String indexMapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field(MODEL_ID, modelId) + .endObject() + .endObject() + .endObject() + .toString(); + + createKnnIndex(indexName, getKNNDefaultIndexSettings(), indexMapping); + Float[] vector = { -10.76f, 65504.2f }; + + ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(indexName, "1", fieldName, vector)); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", + ENCODER_SQ, + FAISS_SQ_ENCODER_FP16, + FP16_MIN_VALUE, + FP16_MAX_VALUE + ) + ) + ); + + Float[] vector1 = { -65506.84f, 12.56f }; + + ResponseException ex1 = expectThrows(ResponseException.class, () -> addKnnDoc(indexName, "2", fieldName, vector1)); + assertTrue( + ex1.getMessage() + .contains( + String.format( + Locale.ROOT, + "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", + ENCODER_SQ, + FAISS_SQ_ENCODER_FP16, + FP16_MIN_VALUE, + FP16_MAX_VALUE + ) + ) + ); + + Float[] vector2 = { -65526.4567f, 65526.4567f }; + + ResponseException ex2 = expectThrows(ResponseException.class, () -> addKnnDoc(indexName, "3", fieldName, vector2)); + assertTrue( + ex2.getMessage() + .contains( + String.format( + Locale.ROOT, + "encoder name is set as [%s] and type is set as [%s] in index mapping. But, KNN vector values are not within in the FP16 range [%f, %f]", + ENCODER_SQ, + FAISS_SQ_ENCODER_FP16, + FP16_MIN_VALUE, + FP16_MAX_VALUE + ) + ) + ); + deleteKNNIndex(indexName); + deleteKNNIndex(trainingIndexName); + deleteModel(modelId); + } + + @SneakyThrows + public void testIVFSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_thenSucceed() { + String modelId = "test-model-ivf-sqfp16"; + int dimension = 2; + + String trainingIndexName = "train-index-ivf-sqfp16"; + String trainingFieldName = "train-field-ivf-sqfp16"; + + // Add training data + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + int trainingDataCount = 200; + bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); + + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, "l2") + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 1) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_SQ) + .startObject(PARAMETERS) + .field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) + .field(FAISS_SQ_CLIP, true) + .endObject() + .endObject() + .endObject() + .endObject(); + Map method = xContentBuilderToMap(builder); + + trainModel(modelId, trainingIndexName, trainingFieldName, dimension, method, "faiss ivf sqfp16 test description"); + + // Make sure training succeeds after 30 seconds + assertTrainingSucceeds(modelId, 30, 1000); + + // Create knn index from model + String fieldName = "test-field-name-ivf-sqfp16"; + String indexName = "test-index-name-ivf-sqfp16"; + String indexMapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field(MODEL_ID, modelId) + .endObject() + .endObject() + .endObject() + .toString(); + + createKnnIndex(indexName, getKNNDefaultIndexSettings(), indexMapping); + Float[] vector1 = { -65523.76f, 65504.2f }; + Float[] vector2 = { -270.85f, 65514.2f }; + Float[] vector3 = { -150.9f, 65504.0f }; + Float[] vector4 = { -20.89f, 100000000.0f }; + addKnnDoc(indexName, "1", fieldName, vector1); + addKnnDoc(indexName, "2", fieldName, vector2); + addKnnDoc(indexName, "3", fieldName, vector3); + addKnnDoc(indexName, "4", fieldName, vector4); + + float[] queryVector = { -10.5f, 25.48f }; + int k = 4; + Response searchResponse = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, queryVector, k), k); + List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); + assertEquals(k, results.size()); + for (int i = 0; i < k; i++) { + assertEquals(k - i, Integer.parseInt(results.get(i).getDocId())); + } + + deleteKNNIndex(indexName); + deleteKNNIndex(trainingIndexName); + deleteModel(modelId); + validateGraphEviction(); + } + @SneakyThrows public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed() { String indexName = "test-index"; diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 4d3bcd62e..a9b65878f 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -739,10 +739,16 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { when(parseContext.path()).thenReturn(contentPath); LuceneFieldMapper luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); - doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); + doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper) + .getFloatsFromContext(parseContext, TEST_DIMENSION, new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())); doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType); + luceneFieldMapper.parseCreateField( + parseContext, + TEST_DIMENSION, + luceneFieldMapper.fieldType().spaceType, + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + ); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnVectorField List fields = document.getFields(); @@ -776,11 +782,17 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { inputBuilder.hasDocValues(false); luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); - doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); + doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper) + .getFloatsFromContext(parseContext, TEST_DIMENSION, new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())); doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType); + luceneFieldMapper.parseCreateField( + parseContext, + TEST_DIMENSION, + luceneFieldMapper.fieldType().spaceType, + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + ); // Document should have 1 field: one for KnnVectorField fields = document.getFields(); @@ -809,7 +821,12 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType); + luceneFieldMapper.parseCreateField( + parseContext, + TEST_DIMENSION, + luceneFieldMapper.fieldType().spaceType, + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + ); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnByteVectorField List fields = document.getFields(); @@ -846,7 +863,12 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType); + luceneFieldMapper.parseCreateField( + parseContext, + TEST_DIMENSION, + luceneFieldMapper.fieldType().spaceType, + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + ); // Document should have 1 field: one for KnnByteVectorField fields = document.getFields();