Skip to content

Commit

Permalink
Resolve PR feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jan 9, 2024
1 parent cf37f59 commit e500258
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ public static void initialize(ModelDao modelDao) {
}

private static float[] ObjectsToFloats(List<Object> objs) {
if (Objects.isNull(objs)) {
throw new IllegalArgumentException(String.format("[%s] requires 'vector' to be non-null", NAME));
if (Objects.isNull(objs) || objs.isEmpty()) {
throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be non-null and non-empty", NAME));
}
float[] vec = new float[objs.size()];
for (int i = 0; i < objs.size(); i++) {
if (!(objs.get(i) instanceof Number)) {
throw new IllegalArgumentException(String.format("[%s] requires 'vector' to be an array of numbers", NAME));
if ((objs.get(i) instanceof Number) == false) {
throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be an array of numbers", NAME));
}
vec[i] = ((Number) objs.get(i)).floatValue();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.plugins.SearchPlugin;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

Expand Down Expand Up @@ -133,7 +134,11 @@ public void testFromXContent_invalidQueryVectorType() throws Exception {
final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
knnClusterUtil.initialize(clusterService);

String[] invalidTypeQueryVector = { "a", "b", "c", "d" };
List<Object> invalidTypeQueryVector = new ArrayList<>();
invalidTypeQueryVector.add(1.5);
invalidTypeQueryVector.add(2.5);
invalidTypeQueryVector.add("a");
invalidTypeQueryVector.add(null);

XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
Expand All @@ -148,7 +153,7 @@ public void testFromXContent_invalidQueryVectorType() throws Exception {
IllegalArgumentException.class,
() -> KNNQueryBuilder.fromXContent(contentParser)
);
assertTrue(exception.getMessage().contains("[knn] requires 'vector' to be an array of numbers"));
assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be an array of numbers"));
}

public void testFromXContent_missingQueryVector() throws Exception {
Expand All @@ -157,19 +162,34 @@ public void testFromXContent_missingQueryVector() throws Exception {
final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
knnClusterUtil.initialize(clusterService);

XContentBuilder builder = XContentFactory.jsonBuilder();
builder.startObject();
builder.startObject(FIELD_NAME);
builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K);
builder.endObject();
builder.endObject();
XContentParser contentParser = createParser(builder);
contentParser.nextToken();
// Test without vector field
XContentBuilder builderWithoutVectorField = XContentFactory.jsonBuilder();
builderWithoutVectorField.startObject();
builderWithoutVectorField.startObject(FIELD_NAME);
builderWithoutVectorField.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K);
builderWithoutVectorField.endObject();
builderWithoutVectorField.endObject();
XContentParser contentParserWithoutVectorField = createParser(builderWithoutVectorField);
contentParserWithoutVectorField.nextToken();
IllegalArgumentException exception = expectThrows(
IllegalArgumentException.class,
() -> KNNQueryBuilder.fromXContent(contentParser)
() -> KNNQueryBuilder.fromXContent(contentParserWithoutVectorField)
);
assertTrue(exception.getMessage().contains("[knn] requires 'vector' to be non-null"));
assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty"));

// Test empty vector field
List<Object> emptyQueryVector = new ArrayList<>();
XContentBuilder builderWithEmptyVector = XContentFactory.jsonBuilder();
builderWithEmptyVector.startObject();
builderWithEmptyVector.startObject(FIELD_NAME);
builderWithEmptyVector.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), emptyQueryVector);
builderWithEmptyVector.field(KNNQueryBuilder.K_FIELD.getPreferredName(), K);
builderWithEmptyVector.endObject();
builderWithEmptyVector.endObject();
XContentParser contentParserWithEmptyVector = createParser(builderWithEmptyVector);
contentParserWithEmptyVector.nextToken();
exception = expectThrows(IllegalArgumentException.class, () -> KNNQueryBuilder.fromXContent(contentParserWithEmptyVector));
assertTrue(exception.getMessage().contains("[knn] field 'vector' requires to be non-null and non-empty"));
}

@Override
Expand Down

0 comments on commit e500258

Please sign in to comment.