From d4f70dea6ee6e30ebe9b6b41a954eeaf291703b4 Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Wed, 3 Apr 2024 14:08:15 -0700 Subject: [PATCH] Resolve feedback Signed-off-by: Junqiu Lei --- src/main/java/org/opensearch/knn/index/SpaceType.java | 3 +++ .../opensearch/knn/index/query/KNNQueryBuilder.java | 10 +++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/SpaceType.java b/src/main/java/org/opensearch/knn/index/SpaceType.java index 5b679902d..240bfbe91 100644 --- a/src/main/java/org/opensearch/knn/index/SpaceType.java +++ b/src/main/java/org/opensearch/knn/index/SpaceType.java @@ -39,6 +39,9 @@ public VectorSimilarityFunction getVectorSimilarityFunction() { @Override public float scoreToDistanceTranslation(float score) { + if (score == 0) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "score cannot be 0 when space type is [%s]", getValue())); + } return 1 / score - 1; } }, diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 933365ab2..78ddb532d 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -98,7 +98,7 @@ public KNNQueryBuilder k(Integer k) { if (k == null) { throw new IllegalArgumentException("[" + NAME + "] requires k to be set"); } - validSingleQueryType(k, distance, score); + validateSingleQueryType(k, distance, score); if (k <= 0 || k > K_MAX) { throw new IllegalArgumentException("[" + NAME + "] requires 0 < k <= " + K_MAX); } @@ -115,7 +115,7 @@ public KNNQueryBuilder distance(Float distance) { if (distance == null) { throw new IllegalArgumentException("[" + NAME + "] requires distance to be set"); } - validSingleQueryType(k, distance, score); + validateSingleQueryType(k, distance, score); this.distance = distance; return this; } @@ -129,7 +129,7 @@ public KNNQueryBuilder score(Float score) { if (score == null) { throw new IllegalArgumentException("[" + NAME + "] requires score to be set"); } - validSingleQueryType(k, distance, score); + validateSingleQueryType(k, distance, score); if (score <= 0) { throw new IllegalArgumentException("[" + NAME + "] requires score greater than 0"); } @@ -295,7 +295,7 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - validSingleQueryType(k, distance, score); + validateSingleQueryType(k, distance, score); KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter) .ignoreUnmapped(ignoreUnmapped) @@ -542,7 +542,7 @@ public String getWriteableName() { return NAME; } - private static void validSingleQueryType(Integer k, Float distance, Float score) { + private static void validateSingleQueryType(Integer k, Float distance, Float score) { int countSetFields = 0; if (k != null && k != 0) {