From 4372e150e72f8bbd6f069097d010825724bfe5ec Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 23 Apr 2024 14:02:23 -0700 Subject: [PATCH] Enable script score to work with model based indices Signed-off-by: Ryan Bogan --- .../org/opensearch/knn/plugin/KNNPlugin.java | 2 + .../knn/plugin/script/KNNScoringSpace.java | 11 +++-- .../plugin/script/KNNScoringSpaceUtil.java | 48 +++++++++++++++++++ .../script/KNNScoringSpaceUtilTests.java | 45 +++++++++++++++++ 4 files changed, 101 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 2e5a55092..bc17e80e7 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -33,6 +33,7 @@ import org.opensearch.knn.plugin.rest.RestTrainModelHandler; import org.opensearch.knn.plugin.rest.RestClearCacheHandler; import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; +import org.opensearch.knn.plugin.script.KNNScoringSpaceUtil; import org.opensearch.knn.plugin.stats.KNNStats; import org.opensearch.knn.plugin.transport.DeleteModelAction; import org.opensearch.knn.plugin.transport.DeleteModelTransportAction; @@ -204,6 +205,7 @@ public Collection createComponents( TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client); KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); + KNNScoringSpaceUtil.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java index 5a8cdb036..3ba8bce63 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpace.java @@ -28,6 +28,7 @@ import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToLong; public interface KNNScoringSpace { + /** * Return the correct scoring script for a given query. The scoring script * @@ -60,7 +61,7 @@ public L2(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v)); @@ -96,7 +97,7 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); SpaceType.COSINESIMIL.validateVector(processedQuery); @@ -191,7 +192,7 @@ public L1(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v)); @@ -226,7 +227,7 @@ public LInf(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v)); @@ -263,7 +264,7 @@ public InnerProd(Object query, MappedFieldType fieldType) { this.processedQuery = parseToFloatArray( query, - ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(), + KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType), ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType() ); this.scoringMethod = (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)); diff --git a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java index c482413fb..888184e54 100644 --- a/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java +++ b/src/main/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtil.java @@ -8,6 +8,9 @@ import java.util.List; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.MappedFieldType; @@ -21,6 +24,12 @@ public class KNNScoringSpaceUtil { + private static ModelDao modelDao; + + public static void initialize(ModelDao modelDao) { + KNNScoringSpaceUtil.modelDao = modelDao; + } + /** * Check if the passed in fieldType is of type NumberFieldType with numericType being Long * @@ -137,4 +146,43 @@ public static float getVectorMagnitudeSquared(float[] inputVector) { } return normInputVector; } + + /** + * Get the expected dimensions from a specified knn vector field type. + * + * If the field is model-based, get dimensions from model metadata. + * @param knnVectorFieldType knn vector field type + * @return expected dimensions + */ + public static int getExpectedDimensions(KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) { + int expectedDimensions = knnVectorFieldType.getDimension(); + // Value will be -1 when a model-based index is used. In this case, retrieve expected dimensions from model metadata. + if (expectedDimensions == -1) { + ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); + expectedDimensions = modelMetadata.getDimension(); + } + return expectedDimensions; + } + + /** + * Returns the model metadata for a specified knn vector field + * + * @param knnVectorField knn vector field + * @return the model metadata from knnVectorField + */ + private static ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) { + String modelId = knnVectorField.getModelId(); + + if (modelId == null) { + throw new IllegalArgumentException( + String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName()) + ); + } + + ModelMetadata modelMetadata = modelDao.getMetadata(modelId); + if (!ModelUtil.isModelCreated(modelMetadata)) { + throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId)); + } + return modelMetadata; + } } diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java index b5bc4b95f..1497e3e17 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java @@ -6,10 +6,15 @@ package org.opensearch.knn.plugin.script; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; import java.math.BigInteger; import java.util.ArrayList; @@ -75,4 +80,44 @@ public void testParseKNNVectorQuery() { String invalidObject = "invalidObject"; expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3, VectorDataType.FLOAT)); } + + public void testGetExpectedDimensions() { + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(knnVectorFieldType.getDimension()).thenReturn(3); + + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); + String modelId = "test-model"; + when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); + + ModelDao modelDao = mock(ModelDao.class); + ModelMetadata modelMetadata = mock(ModelMetadata.class); + when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getDimension()).thenReturn(4); + when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); + + KNNScoringSpaceUtil.initialize(modelDao); + + assertEquals(3, KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldType)); + assertEquals(4, KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)); + + when(modelMetadata.getState()).thenReturn(ModelState.TRAINING); + + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased) + ); + assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage()); + + when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null); + KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); + MethodComponentContext methodComponentContext = mock(MethodComponentContext.class); + String fieldName = "test-field"; + when(methodComponentContext.getName()).thenReturn(fieldName); + when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext); + when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext); + + e = expectThrows(IllegalArgumentException.class, () -> KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)); + assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage()); + } }