Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support score type threshold in radial search #1589

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
* Support distance type radius search for Lucene engine [#1498](https://github.com/opensearch-project/k-NN/pull/1498)
* Support distance type radius search for Faiss engine [#1546](https://github.com/opensearch-project/k-NN/pull/1546)
* Support score type threshold in radial search [#1589](https://github.com/opensearch-project/k-NN/pull/1589)
### Enhancements
* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549)
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
Expand Down
90 changes: 76 additions & 14 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static final ParseField DISTANCE_FIELD = new ParseField("distance");
public static final ParseField SCORE_FIELD = new ParseField("score");
public static final int K_MAX = 10000;
/**
* The name for the knn query
Expand All @@ -64,6 +65,7 @@
private final float[] vector;
private int k = 0;
private Float distance = null;
private Float score = null;
private QueryBuilder filter;
private boolean ignoreUnmapped = false;

Expand Down Expand Up @@ -92,13 +94,14 @@
*
* @param k K nearest neighbours for the given vector
*/
public KNNQueryBuilder k(int k) {
public KNNQueryBuilder k(Integer k) {
if (k == null) {
throw new IllegalArgumentException("[" + NAME + "] requires k to be set");

Check warning on line 99 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L99

Added line #L99 was not covered by tests
}
validSingleQueryType(k, distance, score);
if (k <= 0 || k > K_MAX) {
throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX);
}
if (distance != null) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}
this.k = k;
return this;
}
Expand All @@ -112,13 +115,28 @@
if (distance == null) {
throw new IllegalArgumentException("[" + NAME + "] requires distance to be set");
}
if (k != 0) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}
validSingleQueryType(k, distance, score);
this.distance = distance;
return this;
}

/**
* Builder method for score
*
* @param score the score threshold for the nearest neighbours
*/
public KNNQueryBuilder score(Float score) {
if (score == null) {
throw new IllegalArgumentException("[" + NAME + "] requires score to be set");
}
validSingleQueryType(k, distance, score);
if (score <= 0) {
throw new IllegalArgumentException("[" + NAME + "] requires score greater than 0");
}
this.score = score;
return this;
}

/**
* Builder method for filter
*
Expand Down Expand Up @@ -163,6 +181,7 @@
this.filter = filter;
this.ignoreUnmapped = false;
this.distance = null;
this.score = null;
}

public static void initialize(ModelDao modelDao) {
Expand Down Expand Up @@ -200,6 +219,9 @@
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
distance = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
score = in.readOptionalFloat();
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
Expand All @@ -211,6 +233,7 @@
float boost = AbstractQueryBuilder.DEFAULT_BOOST;
Integer k = null;
Float distance = null;
Float score = null;
QueryBuilder filter = null;
String queryName = null;
String currentFieldName = null;
Expand Down Expand Up @@ -241,6 +264,8 @@
queryName = parser.text();
} else if (DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
distance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else if (SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
score = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -270,9 +295,7 @@
}
}

if ((k != null && distance != null) || (k == null && distance == null)) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
}
validSingleQueryType(k, distance, score);

KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter)
.ignoreUnmapped(ignoreUnmapped)
Expand All @@ -281,8 +304,10 @@

if (k != null) {
knnQueryBuilder.k(k);
} else {
} else if (distance != null) {
knnQueryBuilder.distance(distance);
} else if (score != null) {
knnQueryBuilder.score(score);
}

return knnQueryBuilder;
Expand All @@ -300,6 +325,9 @@
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(distance);
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(score);
}
}

/**
Expand All @@ -324,6 +352,10 @@
return this.distance;
}

public float getScore() {
return this.score;
}

public QueryBuilder getFilter() {
return this.filter;
}
Expand Down Expand Up @@ -358,6 +390,9 @@
if (ignoreUnmapped) {
builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), ignoreUnmapped);
}
if (score != null) {
builder.field(SCORE_FIELD.getPreferredName(), score);

Check warning on line 394 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L394

Added line #L394 was not covered by tests
}
printBoostAndQueryName(builder);
builder.endObject();
builder.endObject();
Expand Down Expand Up @@ -397,8 +432,8 @@
spaceType = knnMethodContext.getSpaceType();
}

// Currently, k-NN supports distance type radius search.
// We need transform distance radius to right type of engine required radius.
// Currently, k-NN supports distance and score types radial search
// We need transform distance/score to right type of engine required radius.
Float radius = null;
if (this.distance != null) {
if (this.distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
Expand All @@ -407,6 +442,13 @@
radius = knnEngine.distanceToRadialThreshold(this.distance, spaceType);
}

if (this.score != null) {
if (this.score > 1 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
throw new IllegalArgumentException("[" + NAME + "] requires score to be in the range (0, 1] for space type: " + spaceType);
}
radius = knnEngine.scoreToRadialThreshold(this.score, spaceType);
}

if (fieldDimension != vector.length) {
throw new IllegalArgumentException(
String.format("Query vector has invalid dimension: %d. Dimension should be: %d", vector.length, fieldDimension)
Expand Down Expand Up @@ -464,7 +506,7 @@
.build();
return RNNQueryFactory.create(createQueryRequest);
}
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance or score to be set");

Check warning on line 509 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L509

Added line #L509 was not covered by tests
}

private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
Expand Down Expand Up @@ -499,4 +541,24 @@
public String getWriteableName() {
return NAME;
}

private static void validSingleQueryType(Integer k, Float distance, Float score) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: valid -> validate

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

int countSetFields = 0;

if (k != null && k != 0) {
countSetFields++;
}
if (distance != null) {
countSetFields++;
}
if (score != null) {
countSetFields++;
}

if (countSetFields != 1) {
throw new IllegalArgumentException(

Check warning on line 559 in src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java#L559

Added line #L559 was not covered by tests
"[" + NAME + "] requires only one query type to be set, it can be either k, distance, or score"
);
}
}
}
27 changes: 25 additions & 2 deletions src/main/java/org/opensearch/knn/index/util/Faiss.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
* Implements NativeLibrary for the faiss native library
*/
class Faiss extends NativeLibrary {
Map<SpaceType, Function<Float, Float>> scoreTransform;

// TODO: Current version is not really current version. Instead, it encodes information in the file name
// about the compatibility version the file is created with. In the future, we should refactor this so that it
Expand All @@ -68,6 +69,14 @@ class Faiss extends NativeLibrary {
rawScore -> SpaceType.INNER_PRODUCT.scoreTranslation(-1 * rawScore)
);

// Map that transforms radial search score threshold to faiss required distance
private final static Map<SpaceType, Function<Float, Float>> SCORE_TO_DISTANCE_TRANSFORMATIONS = ImmutableMap.<
SpaceType,
Function<Float, Float>>builder()
.put(SpaceType.INNER_PRODUCT, score -> score > 1 ? 1 - score : 1 / score - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets add some java doc here why this conversion make sense.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.put(SpaceType.L2, score -> 1 / score - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for L2 spacetype can we put this translation in SpaceType enum class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes updated.

.build();

// Define encoders supported by faiss
private final static MethodComponentContext ENCODER_DEFAULT = new MethodComponentContext(
KNNConstants.ENCODER_FLAT,
Expand Down Expand Up @@ -301,7 +310,13 @@ class Faiss extends NativeLibrary {
).addSpaces(SpaceType.L2, SpaceType.INNER_PRODUCT).build()
);

final static Faiss INSTANCE = new Faiss(METHODS, SCORE_TRANSLATIONS, CURRENT_VERSION, KNNConstants.FAISS_EXTENSION);
final static Faiss INSTANCE = new Faiss(
METHODS,
SCORE_TRANSLATIONS,
CURRENT_VERSION,
KNNConstants.FAISS_EXTENSION,
SCORE_TO_DISTANCE_TRANSFORMATIONS
);

/**
* Constructor for Faiss
Expand All @@ -315,9 +330,11 @@ private Faiss(
Map<String, KNNMethod> methods,
Map<SpaceType, Function<Float, Float>> scoreTranslation,
String currentVersion,
String extension
String extension,
Map<SpaceType, Function<Float, Float>> scoreTransform
) {
super(methods, scoreTranslation, currentVersion, extension);
this.scoreTransform = scoreTransform;
}

@Override
Expand All @@ -326,6 +343,12 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return distance;
}

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
// Faiss engine uses distance as is and need transformation
return this.scoreTransform.get(spaceType).apply(score);
}

/**
* MethodAsMap builder is used to create the map that will be passed to the jni to create the faiss index.
* Faiss's index factory takes an "index description" that it uses to build the index. In this description,
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/KNNEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return knnLibrary.distanceToRadialThreshold(distance, spaceType);
}

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
return knnLibrary.scoreToRadialThreshold(score, spaceType);
}

@Override
public ValidationException validateMethod(KNNMethodContext knnMethodContext) {
return knnLibrary.validateMethod(knnMethodContext);
Expand Down
10 changes: 10 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/KNNLibrary.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ public interface KNNLibrary {
*/
Float distanceToRadialThreshold(Float distance, SpaceType spaceType);

/**
* Translate the score threshold input from end user to the engine's threshold.
*
* @param score score threshold input from end user
* @param spaceType spaceType used to compute the threshold
*
* @return transformed score for the library
*/
Float scoreToRadialThreshold(Float score, SpaceType spaceType);

/**
* Validate the knnMethodContext for the given library. A ValidationException should be thrown if the method is
* deemed invalid.
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/Lucene.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public class Lucene extends JVMLibrary {
Function<Float, Float>>builder()
.put(SpaceType.COSINESIMIL, distance -> (2 - distance) / 2)
.put(SpaceType.L2, distance -> 1 / (1 + distance))
.put(SpaceType.INNER_PRODUCT, distance -> distance <= 0 ? 1 / (1 - distance) : distance + 1)
.build();

final static Lucene INSTANCE = new Lucene(METHODS, Version.LATEST.toString(), DISTANCE_TRANSLATIONS);
Expand Down Expand Up @@ -93,6 +94,12 @@ public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return this.distanceTransform.get(spaceType).apply(distance);
}

@Override
public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
// Lucene engine uses distance as is and does not need transformation
return score;
}

@Override
public List<String> mmapFileExtensions() {
return List.of("vec", "vex");
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/Nmslib.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,8 @@
public Float distanceToRadialThreshold(Float distance, SpaceType spaceType) {
return distance;
}

public Float scoreToRadialThreshold(Float score, SpaceType spaceType) {
return score;

Check warning on line 78 in src/main/java/org/opensearch/knn/index/util/Nmslib.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/util/Nmslib.java#L78

Added line #L78 was not covered by tests
}
}
Loading
Loading