Skip to content

Commit

Permalink
Resolve feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jun 18, 2024
1 parent 688b70d commit 26ec968
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 103 deletions.
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/index/VectorQueryType.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,8 @@ public KNNCounter getQueryWithFilterStatCounter() {
public abstract KNNCounter getQueryStatCounter();

public abstract KNNCounter getQueryWithFilterStatCounter();

public boolean isRadialSearch() {
return this == MAX_DISTANCE || this == MIN_SCORE;
}
}
70 changes: 41 additions & 29 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.parser.MethodParametersParser;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.index.util.QueryContext;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
Expand Down Expand Up @@ -95,6 +96,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private QueryBuilder filter;
@Getter
private boolean ignoreUnmapped;
@Getter
private VectorQueryType vectorQueryType;

/**
* Constructs a new query with the given field name and vector
Expand Down Expand Up @@ -189,13 +192,22 @@ public Builder boost(float boost) {
}

public KNNQueryBuilder build() {
validate();
VectorQueryType vectorQueryType = validate();
int k = this.k == null ? 0 : this.k;
return new KNNQueryBuilder(fieldName, vector, k, maxDistance, minScore, methodParameters, filter, ignoreUnmapped).boost(boost)
.queryName(queryName);
}

private void validate() {
return new KNNQueryBuilder(
fieldName,
vector,
k,
maxDistance,
minScore,
methodParameters,
filter,
ignoreUnmapped,
vectorQueryType
).boost(boost).queryName(queryName);
}

private VectorQueryType validate() throws IllegalArgumentException {
if (Strings.isNullOrEmpty(fieldName)) {
throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME));
}
Expand Down Expand Up @@ -243,6 +255,7 @@ private void validate() {
if (filter != null) {
vectorQueryType.getQueryWithFilterStatCounter().increment();
}
return vectorQueryType;
}
}

Expand Down Expand Up @@ -518,6 +531,28 @@ protected Query doToQuery(QueryShardContext context) {
methodComponentContext = knnMethodContext.getMethodComponentContext();
}

final String method = methodComponentContext != null ? methodComponentContext.getName() : null;
if (StringUtils.isNotBlank(method)) {
final EngineSpecificMethodContext engineSpecificMethodContext = knnEngine.getMethodContext(method);
KNNEngine finalKnnEngine = knnEngine;
QueryContext methodContext = () -> (vectorQueryType.isRadialSearch() && KNNEngine.LUCENE.equals(finalKnnEngine)) == false;
ValidationException validationException = validateParameters(
engineSpecificMethodContext.supportedMethodParameters(methodContext),
(Map<String, Object>) methodParameters
);
if (validationException != null) {
throw new IllegalArgumentException(
String.format(
"Parameters not valid for [%s]:[%s]:[%s] combination: [%s]",
knnEngine,
method,
vectorQueryType.getQueryTypeName(),
validationException.getMessage()
)
);
}
}

// Currently, k-NN supports distance and score types radial search
// We need transform distance/score to right type of engine required radius.
Float radius = null;
Expand All @@ -539,29 +574,6 @@ protected Query doToQuery(QueryShardContext context) {
radius = knnEngine.scoreToRadialThreshold(this.minScore, spaceType);
}

final String method = methodComponentContext != null ? methodComponentContext.getName() : null;
if (StringUtils.isNotBlank(method)) {
final EngineSpecificMethodContext engineSpecificMethodContext = knnEngine.getMethodContext(method);
Float finalRadius = radius;
EngineSpecificMethodContext.Context methodContext = () -> finalRadius != null;
ValidationException validationException = validateParameters(
engineSpecificMethodContext.supportedMethodParameters(methodContext),
(Map<String, Object>) methodParameters,
knnEngine,
methodContext
);
if (validationException != null) {
throw new IllegalArgumentException(
String.format(
"Parameters not valid for [%s]:[%s] combination: [%s]",
knnEngine,
method,
validationException.getMessage()
)
);
}
}

if (fieldDimension != vector.length) {
throw new IllegalArgumentException(
String.format("Query vector has invalid dimension: %d. Dimension should be: %d", vector.length, fieldDimension)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public final class DefaultHnswContext implements EngineSpecificMethodContext {
.build();

@Override
public Map<String, Parameter<?>> supportedMethodParameters(Context ctx) {
public Map<String, Parameter<?>> supportedMethodParameters(QueryContext ctx) {
return supportedMethodParameters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@
*/
public interface EngineSpecificMethodContext {

Map<String, Parameter<?>> supportedMethodParameters(Context ctx);
Map<String, Parameter<?>> supportedMethodParameters(QueryContext ctx);

EngineSpecificMethodContext EMPTY = ctx -> Collections.emptyMap();

interface Context {
boolean isRadialSearch();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ public class LuceneHNSWContext implements EngineSpecificMethodContext {
.build();

@Override
public Map<String, Parameter<?>> supportedMethodParameters(Context ctx) {
if (ctx.isRadialSearch()) {
// return empty map if radial search is true
return Collections.emptyMap();
public Map<String, Parameter<?>> supportedMethodParameters(QueryContext ctx) {
if (ctx.is_ef_search_parameter_supported()) {
// Return the supported method parameters for non-radial cases
return supportedMethodParameters;
}
// Return the supported method parameters for non-radial cases
return supportedMethodParameters;
// return empty map if radial search is true
return Collections.emptyMap();
}
}
19 changes: 19 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/QueryContext.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.util;

/**
* Context interface for query-specific information.
*/
public interface QueryContext {
boolean is_ef_search_parameter_supported();
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,89 +14,35 @@
import org.opensearch.common.Nullable;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.index.Parameter;
import org.opensearch.knn.index.util.EngineSpecificMethodContext;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;

public final class ParameterValidator {

/**
* A function which validates request parameters.
* @param validParameters A set of valid parameters that can be requestParameters can be validated against
* @param requestParameters parameters from the request
* @return
*/
@Nullable
public static ValidationException validateParameters(
final Map<String, Parameter<?>> validParameters,
final Map<String, Object> requestParameters
) {
return validateParameters(validParameters, requestParameters, null, null);
}

/**
* A function which validates request parameters.
* @param validParameters A set of valid parameters that can be requestParameters can be validated against
* @param requestParameters parameters from the request
* @param knnEngine The KNN engine
* @param context The engine specific method context
*/
@Nullable
public static ValidationException validateParameters(
final Map<String, Parameter<?>> validParameters,
final Map<String, Object> requestParameters,
final KNNEngine knnEngine,
final EngineSpecificMethodContext.Context context
) {
validateNonNullParameters(validParameters);

if (requestParameters == null || requestParameters.isEmpty()) {
return null;
}

List<String> errorMessages = new ArrayList<>();
Set<String> checkedParameters = new HashSet<>();

checkEngineSpecificErrors(knnEngine, context, errorMessages, checkedParameters);
validateRequestParameters(validParameters, requestParameters, errorMessages, checkedParameters);

return buildValidationException(errorMessages);
}

private static void validateNonNullParameters(Map<String, Parameter<?>> validParameters) {
if (validParameters == null) {
throw new IllegalArgumentException("validParameters cannot be null");
}
}

private static void checkEngineSpecificErrors(
KNNEngine knnEngine,
EngineSpecificMethodContext.Context context,
List<String> errorMessages,
Set<String> checkedParameters
) {
if (KNNEngine.LUCENE.equals(knnEngine) && context != null && context.isRadialSearch()) {
errorMessages.add("ef_search is not supported for Lucene engine radial search");
checkedParameters.add(METHOD_PARAMETER_EF_SEARCH);
if (requestParameters == null || requestParameters.isEmpty()) {
return null;
}
}

private static void validateRequestParameters(
Map<String, Parameter<?>> validParameters,
Map<String, Object> requestParameters,
List<String> errorMessages,
Set<String> checkedParameters
) {
final List<String> errorMessages = new ArrayList<>();
for (Map.Entry<String, Object> parameter : requestParameters.entrySet()) {
if (checkedParameters.contains(parameter.getKey())) {
continue;
}
if (validParameters.containsKey(parameter.getKey())) {
final ValidationException parameterValidation = validParameters.get(parameter.getKey()).validate(parameter.getValue());
if (parameterValidation != null) {
Expand All @@ -106,10 +52,7 @@ private static void validateRequestParameters(
errorMessages.add("Unknown parameter '" + parameter.getKey() + "' found");
}
}
}

@Nullable
private static ValidationException buildValidationException(List<String> errorMessages) {
if (errorMessages.isEmpty()) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public ValidationException validate(KNNMethodContext knnMethodContext) {

public void testEngineSpecificMethods() {
String methodName1 = "test-method-1";
EngineSpecificMethodContext.Context engineSpecificMethodContext = () -> false;
QueryContext engineSpecificMethodContext = () -> false;
EngineSpecificMethodContext context = ctx -> ImmutableMap.of(
"myparameter",
new Parameter.BooleanParameter("myparameter", null, value -> true)
Expand Down

0 comments on commit 26ec968

Please sign in to comment.