Skip to content

Commit

Permalink
Enable script score to work with model based indices
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Bogan <[email protected]>
  • Loading branch information
ryanbogan committed Apr 23, 2024
1 parent dc8eb6b commit 4372e15
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.opensearch.knn.plugin.rest.RestTrainModelHandler;
import org.opensearch.knn.plugin.rest.RestClearCacheHandler;
import org.opensearch.knn.plugin.script.KNNScoringScriptEngine;
import org.opensearch.knn.plugin.script.KNNScoringSpaceUtil;
import org.opensearch.knn.plugin.stats.KNNStats;
import org.opensearch.knn.plugin.transport.DeleteModelAction;
import org.opensearch.knn.plugin.transport.DeleteModelTransportAction;
Expand Down Expand Up @@ -204,6 +205,7 @@ public Collection<Object> createComponents(
TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client);
KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNScoringSpaceUtil.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance());
TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import static org.opensearch.knn.plugin.script.KNNScoringSpaceUtil.parseToLong;

public interface KNNScoringSpace {

/**
* Return the correct scoring script for a given query. The scoring script
*
Expand Down Expand Up @@ -60,7 +61,7 @@ public L2(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l2Squared(q, v));
Expand Down Expand Up @@ -96,7 +97,7 @@ public CosineSimilarity(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
SpaceType.COSINESIMIL.validateVector(processedQuery);
Expand Down Expand Up @@ -191,7 +192,7 @@ public L1(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1Norm(q, v));
Expand Down Expand Up @@ -226,7 +227,7 @@ public LInf(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.lInfNorm(q, v));
Expand Down Expand Up @@ -263,7 +264,7 @@ public InnerProd(Object query, MappedFieldType fieldType) {

this.processedQuery = parseToFloatArray(
query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension(),
KNNScoringSpaceUtil.getExpectedDimensions((KNNVectorFieldMapper.KNNVectorFieldType) fieldType),
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getVectorDataType()
);
this.scoringMethod = (float[] q, float[] v) -> KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import java.util.List;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.index.mapper.BinaryFieldMapper;
import org.opensearch.index.mapper.MappedFieldType;
Expand All @@ -21,6 +24,12 @@

public class KNNScoringSpaceUtil {

private static ModelDao modelDao;

public static void initialize(ModelDao modelDao) {
KNNScoringSpaceUtil.modelDao = modelDao;
}

/**
* Check if the passed in fieldType is of type NumberFieldType with numericType being Long
*
Expand Down Expand Up @@ -137,4 +146,43 @@ public static float getVectorMagnitudeSquared(float[] inputVector) {
}
return normInputVector;
}

/**
* Get the expected dimensions from a specified knn vector field type.
*
* If the field is model-based, get dimensions from model metadata.
* @param knnVectorFieldType knn vector field type
* @return expected dimensions
*/
public static int getExpectedDimensions(KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) {
int expectedDimensions = knnVectorFieldType.getDimension();
// Value will be -1 when a model-based index is used. In this case, retrieve expected dimensions from model metadata.
if (expectedDimensions == -1) {
ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType);
expectedDimensions = modelMetadata.getDimension();
}
return expectedDimensions;
}

/**
* Returns the model metadata for a specified knn vector field
*
* @param knnVectorField knn vector field
* @return the model metadata from knnVectorField
*/
private static ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
String modelId = knnVectorField.getModelId();

if (modelId == null) {
throw new IllegalArgumentException(
String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName())
);
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId));
}
return modelMetadata;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
package org.opensearch.knn.plugin.script;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.index.mapper.BinaryFieldMapper;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;

import java.math.BigInteger;
import java.util.ArrayList;
Expand Down Expand Up @@ -75,4 +80,44 @@ public void testParseKNNVectorQuery() {
String invalidObject = "invalidObject";
expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3, VectorDataType.FLOAT));
}

public void testGetExpectedDimensions() {
KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldType.getDimension()).thenReturn(3);

KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1);
String modelId = "test-model";
when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId);

ModelDao modelDao = mock(ModelDao.class);
ModelMetadata modelMetadata = mock(ModelMetadata.class);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
when(modelMetadata.getDimension()).thenReturn(4);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);

KNNScoringSpaceUtil.initialize(modelDao);

assertEquals(3, KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldType));
assertEquals(4, KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased));

when(modelMetadata.getState()).thenReturn(ModelState.TRAINING);

IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased)
);
assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage());

when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null);
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
MethodComponentContext methodComponentContext = mock(MethodComponentContext.class);
String fieldName = "test-field";
when(methodComponentContext.getName()).thenReturn(fieldName);
when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext);
when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext);

e = expectThrows(IllegalArgumentException.class, () -> KNNScoringSpaceUtil.getExpectedDimensions(knnVectorFieldTypeModelBased));
assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage());
}
}

0 comments on commit 4372e15

Please sign in to comment.