Skip to content

Commit

Permalink
Add validation for pq m parameter before training starts (opensearch-…
Browse files Browse the repository at this point in the history
…project#1713)

* Add validation for pq code count before training starts

Signed-off-by: Ryan Bogan <[email protected]>

* Add integration test

Signed-off-by: Ryan Bogan <[email protected]>

* Add unit tests

Signed-off-by: Ryan Bogan <[email protected]>

* Clean up code

Signed-off-by: Ryan Bogan <[email protected]>

* Remove unnecessary lines

Signed-off-by: Ryan Bogan <[email protected]>

* Add changelog entry

Signed-off-by: Ryan Bogan <[email protected]>

* Change framework to add validation with data

Signed-off-by: Ryan Bogan <[email protected]>

* Remove unused error message

Signed-off-by: Ryan Bogan <[email protected]>

* Add unit tests

Signed-off-by: Ryan Bogan <[email protected]>

* Change space type check name for readability

Signed-off-by: Ryan Bogan <[email protected]>

* Add javadocs

Signed-off-by: Ryan Bogan <[email protected]>

* Modify validation error wording and add json structure to tests

Signed-off-by: Ryan Bogan <[email protected]>

* Change TrainingDataSpec to VectorSpaceInfo

Signed-off-by: Ryan Bogan <[email protected]>

* Add unit tests

Signed-off-by: Ryan Bogan <[email protected]>

---------

Signed-off-by: Ryan Bogan <[email protected]>
  • Loading branch information
ryanbogan authored May 30, 2024
1 parent 7a88f40 commit 3701d19
Show file tree
Hide file tree
Showing 16 changed files with 599 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add KnnCircuitBreakerException and modify exception message [#1688](https://github.com/opensearch-project/k-NN/pull/1688)
* Add stats for radial search [#1684](https://github.com/opensearch-project/k-NN/pull/1684)
* Support script score when doc value is disabled and fix misusing DISI [#1696](https://github.com/opensearch-project/k-NN/pull/1696)
* Add validation for pq m parameter before training starts [#1713](https://github.com/opensearch-project/k-NN/pull/1713)
### Bug Fixes
* Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692)
* Update threshold value after new result is added [#1715](https://github.com/opensearch-project/k-NN/pull/1715)
Expand Down
41 changes: 39 additions & 2 deletions src/main/java/org/opensearch/knn/index/KNNMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import lombok.Getter;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.training.VectorSpaceInfo;

import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -41,7 +42,7 @@ public class KNNMethod {
* @param space to be checked
* @return true if the space is supported; false otherwise
*/
public boolean containsSpace(SpaceType space) {
public boolean isSpaceTypeSupported(SpaceType space) {
return spaces.contains(space);
}

Expand All @@ -53,7 +54,7 @@ public boolean containsSpace(SpaceType space) {
*/
public ValidationException validate(KNNMethodContext knnMethodContext) {
List<String> errorMessages = new ArrayList<>();
if (!containsSpace(knnMethodContext.getSpaceType())) {
if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) {
errorMessages.add(
String.format(
"\"%s\" configuration does not support space type: " + "\"%s\".",
Expand All @@ -77,6 +78,42 @@ public ValidationException validate(KNNMethodContext knnMethodContext) {
return validationException;
}

/**
* Validate that the configured KNNMethodContext is valid for this method, using additional data not present in the method context
*
* @param knnMethodContext to be validated
* @param vectorSpaceInfo additional data not present in the method context
* @return ValidationException produced by validation errors; null if no validations errors.
*/
public ValidationException validateWithData(KNNMethodContext knnMethodContext, VectorSpaceInfo vectorSpaceInfo) {
List<String> errorMessages = new ArrayList<>();
if (!isSpaceTypeSupported(knnMethodContext.getSpaceType())) {
errorMessages.add(
String.format(
"\"%s\" configuration does not support space type: " + "\"%s\".",
this.methodComponent.getName(),
knnMethodContext.getSpaceType().getValue()
)
);
}

ValidationException methodValidation = methodComponent.validateWithData(
knnMethodContext.getMethodComponentContext(),
vectorSpaceInfo
);
if (methodValidation != null) {
errorMessages.addAll(methodValidation.validationErrors());
}

if (errorMessages.isEmpty()) {
return null;
}

ValidationException validationException = new ValidationException();
validationException.addValidationErrors(errorMessages);
return validationException;
}

/**
* returns whether training is required or not
*
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/org/opensearch/knn/index/KNNMethodContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.stream.Collectors;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.opensearch.knn.training.VectorSpaceInfo;

import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
Expand Down Expand Up @@ -86,6 +87,16 @@ public ValidationException validate() {
return knnEngine.validateMethod(this);
}

/**
* This method uses the knnEngine to validate that the method is compatible with the engine, using additional data not present in the method context
*
* @param vectorSpaceInfo additional data not present in the method context
* @return ValidationException produced by validation errors; null if no validations errors.
*/
public ValidationException validateWithData(VectorSpaceInfo vectorSpaceInfo) {
return knnEngine.validateMethodWithData(this, vectorSpaceInfo);
}

/**
* This method returns whether training is requires or not from knnEngine
*
Expand Down
38 changes: 38 additions & 0 deletions src/main/java/org/opensearch/knn/index/MethodComponent.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.common.ValidationException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.util.IndexHyperParametersUtil;
import org.opensearch.knn.training.VectorSpaceInfo;

import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -102,6 +103,43 @@ public ValidationException validate(MethodComponentContext methodComponentContex
return validationException;
}

/**
* Validate that the methodComponentContext is a valid configuration for this methodComponent, using additional data not present in the method component context
*
* @param methodComponentContext to be validated
* @param vectorSpaceInfo additional data not present in the method component context
* @return ValidationException produced by validation errors; null if no validations errors.
*/
public ValidationException validateWithData(MethodComponentContext methodComponentContext, VectorSpaceInfo vectorSpaceInfo) {
Map<String, Object> providedParameters = methodComponentContext.getParameters();
List<String> errorMessages = new ArrayList<>();

if (providedParameters == null) {
return null;
}

ValidationException parameterValidation;
for (Map.Entry<String, Object> parameter : providedParameters.entrySet()) {
if (!parameters.containsKey(parameter.getKey())) {
errorMessages.add(String.format("Invalid parameter for method \"%s\".", getName()));
continue;
}

parameterValidation = parameters.get(parameter.getKey()).validateWithData(parameter.getValue(), vectorSpaceInfo);
if (parameterValidation != null) {
errorMessages.addAll(parameterValidation.validationErrors());
}
}

if (errorMessages.isEmpty()) {
return null;
}

ValidationException validationException = new ValidationException();
validationException.addValidationErrors(errorMessages);
return validationException;
}

/**
* gets requiresTraining value
*
Expand Down
149 changes: 148 additions & 1 deletion src/main/java/org/opensearch/knn/index/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
package org.opensearch.knn.index;

import org.opensearch.common.ValidationException;
import org.opensearch.knn.training.VectorSpaceInfo;

import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Predicate;

/**
Expand All @@ -26,6 +28,7 @@ public abstract class Parameter<T> {
private String name;
private T defaultValue;
protected Predicate<T> validator;
protected BiFunction<T, VectorSpaceInfo, Boolean> validatorWithData;

/**
* Constructor
Expand All @@ -38,6 +41,14 @@ public Parameter(String name, T defaultValue, Predicate<T> validator) {
this.name = name;
this.defaultValue = defaultValue;
this.validator = validator;
this.validatorWithData = null;
}

public Parameter(String name, T defaultValue, Predicate<T> validator, BiFunction<T, VectorSpaceInfo, Boolean> validatorWithData) {
this.name = name;
this.defaultValue = defaultValue;
this.validator = validator;
this.validatorWithData = validatorWithData;
}

/**
Expand Down Expand Up @@ -66,6 +77,15 @@ public T getDefaultValue() {
*/
public abstract ValidationException validate(Object value);

/**
* Check if the value passed in is valid, using additional data not present in the value
*
* @param value to be checked
* @param vectorSpaceInfo additional data not present in the value
* @return ValidationException produced by validation errors; null if no validations errors.
*/
public abstract ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo);

/**
* Boolean method parameter
*/
Expand All @@ -74,12 +94,23 @@ public BooleanParameter(String name, Boolean defaultValue, Predicate<Boolean> va
super(name, defaultValue, validator);
}

public BooleanParameter(
String name,
Boolean defaultValue,
Predicate<Boolean> validator,
BiFunction<Boolean, VectorSpaceInfo, Boolean> validatorWithData
) {
super(name, defaultValue, validator, validatorWithData);
}

@Override
public ValidationException validate(Object value) {
ValidationException validationException = null;
if (!(value instanceof Boolean)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("value not of type Boolean for Boolean parameter [%s].", getName()));
validationException.addValidationError(
String.format("value is not an instance of Boolean for Boolean parameter [%s].", getName())
);
return validationException;
}

Expand All @@ -89,6 +120,27 @@ public ValidationException validate(Object value) {
}
return validationException;
}

@Override
public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) {
ValidationException validationException = null;
if (!(value instanceof Boolean)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("value not of type Boolean for Boolean parameter [%s].", getName()));
return validationException;
}

if (validatorWithData == null) {
return null;
}

if (!validatorWithData.apply((Boolean) value, vectorSpaceInfo)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("parameter validation failed for Boolean parameter [%s].", getName()));
}

return validationException;
}
}

/**
Expand All @@ -99,6 +151,15 @@ public IntegerParameter(String name, Integer defaultValue, Predicate<Integer> va
super(name, defaultValue, validator);
}

public IntegerParameter(
String name,
Integer defaultValue,
Predicate<Integer> validator,
BiFunction<Integer, VectorSpaceInfo, Boolean> validatorWithData
) {
super(name, defaultValue, validator, validatorWithData);
}

@Override
public ValidationException validate(Object value) {
ValidationException validationException = null;
Expand All @@ -118,6 +179,29 @@ public ValidationException validate(Object value) {
}
return validationException;
}

@Override
public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) {
ValidationException validationException = null;
if (!(value instanceof Integer)) {
validationException = new ValidationException();
validationException.addValidationError(
String.format("value is not an instance of Integer for Integer parameter [%s].", getName())
);
return validationException;
}

if (validatorWithData == null) {
return null;
}

if (!validatorWithData.apply((Integer) value, vectorSpaceInfo)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("parameter validation failed for Integer parameter [%s].", getName()));
}

return validationException;
}
}

/**
Expand All @@ -136,6 +220,15 @@ public StringParameter(String name, String defaultValue, Predicate<String> valid
super(name, defaultValue, validator);
}

public StringParameter(
String name,
String defaultValue,
Predicate<String> validator,
BiFunction<String, VectorSpaceInfo, Boolean> validatorWithData
) {
super(name, defaultValue, validator, validatorWithData);
}

/**
* Check if the value passed in is valid
*
Expand All @@ -161,6 +254,29 @@ public ValidationException validate(Object value) {
}
return validationException;
}

@Override
public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) {
ValidationException validationException = null;
if (!(value instanceof String)) {
validationException = new ValidationException();
validationException.addValidationError(
String.format("value is not an instance of String for String parameter [%s].", getName())
);
return validationException;
}

if (validatorWithData == null) {
return null;
}

if (!validatorWithData.apply((String) value, vectorSpaceInfo)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("parameter validation failed for String parameter [%s].", getName()));
}

return validationException;
}
}

/**
Expand Down Expand Up @@ -190,6 +306,12 @@ public MethodComponentContextParameter(
}

return methodComponents.get(methodComponentContext.getName()).validate(methodComponentContext) == null;
}, (methodComponentContext, vectorSpaceInfo) -> {
if (!methodComponents.containsKey(methodComponentContext.getName())) {
return false;
}
return methodComponents.get(methodComponentContext.getName())
.validateWithData(methodComponentContext, vectorSpaceInfo) == null;
});
this.methodComponents = methodComponents;
}
Expand All @@ -216,6 +338,31 @@ public ValidationException validate(Object value) {
return validationException;
}

@Override
public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) {
ValidationException validationException = null;
if (!(value instanceof MethodComponentContext)) {
validationException = new ValidationException();
validationException.addValidationError(
String.format("value is not an instance of for MethodComponentContext parameter [%s].", getName())
);
return validationException;
}

if (validatorWithData == null) {
return null;
}

if (!validatorWithData.apply((MethodComponentContext) value, vectorSpaceInfo)) {
validationException = new ValidationException();
validationException.addValidationError(
String.format("parameter validation failed for MethodComponentContext parameter [%s].", getName())
);
}

return validationException;
}

/**
* Get method component by name
*
Expand Down
Loading

0 comments on commit 3701d19

Please sign in to comment.