Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Persist model definition in model metadata #1548

Merged
merged 4 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Optize Faiss Query With Filters: Reduce iteration and memory for id filter [#1402](https://github.com/opensearch-project/k-NN/pull/1402)
* Detect AVX2 Dynamically on the System [#1502](https://github.com/opensearch-project/k-NN/pull/1502)
* Validate zero vector when using cosine metric [#1501](https://github.com/opensearch-project/k-NN/pull/1501)
* Persist model definition in model metadata [#1527] (https://github.com/opensearch-project/k-NN/pull/1527)
### Bug Fixes
* Disable sdc table for HNSWPQ read-only indices [#1518](https://github.com/opensearch-project/k-NN/pull/1518)
* Switch SpaceType.INNERPRODUCT's vector similarity function to MAXIMUM_INNER_PRODUCT [#1532](https://github.com/opensearch-project/k-NN/pull/1532)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class KNNConstants {
public static final String MODEL_DESCRIPTION = "description";
public static final String MODEL_ERROR = "error";
public static final String MODEL_NODE_ASSIGNMENT = "training_node_assignment";
public static final String MODEL_METHOD_COMPONENT_CONTEXT = "model_definition";
public static final String PARAM_SIZE = "size";
public static final Integer SEARCH_MODEL_MIN_SIZE = 1;
public static final Integer SEARCH_MODEL_MAX_SIZE = 1000;
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,18 @@
public class IndexUtil {

public static final String MODEL_NODE_ASSIGNMENT_KEY = KNNConstants.MODEL_NODE_ASSIGNMENT;
public static final String MODEL_METHOD_COMPONENT_CONTEXT_KEY = KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT;

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER = Version.V_2_4_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT = Version.V_2_13_0;
public static final Map<String, Version> minimalRequiredVersionMap = new HashMap<String, Version>() {
{
put("filter", MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER);
put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED);
put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT);
put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT);
}
};

Expand Down
202 changes: 202 additions & 0 deletions src/main/java/org/opensearch/knn/index/MethodComponentContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,29 @@

package org.opensearch.knn.index;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.commons.lang.math.NumberUtils;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentFragment;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MapperParsingException;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.opensearch.knn.indices.ModelMetadata;

import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
Expand All @@ -41,6 +46,13 @@
@RequiredArgsConstructor
public class MethodComponentContext implements ToXContentFragment, Writeable {

// EMPTY method component context can only occur if a model originated on a cluster before 2.13.0 and the cluster is then upgraded to
// 2.13.0
public static final MethodComponentContext EMPTY = new MethodComponentContext("", Collections.emptyMap());

private static final String DELIMITER = ";";
private static final String DELIMITER_PLACEHOLDER = "$%$";

@Getter
private final String name;
private final Map<String, Object> parameters;
Expand Down Expand Up @@ -161,6 +173,15 @@
return builder;
}

public static MethodComponentContext fromXContent(XContentParser xContentParser) throws IOException {
// If it is a fresh parser, move to the first token
if (xContentParser.currentToken() == null) {
xContentParser.nextToken();
}
Map<String, Object> parsedMap = xContentParser.map();
return MethodComponentContext.parse(parsedMap);
}

@Override
public boolean equals(Object obj) {
if (this == obj) return true;
Expand Down Expand Up @@ -193,6 +214,187 @@
return parameters;
}

/**
*
* Provides a String representation of MethodComponentContext
* Sample return:
* {name=ivf;parameters=[nlist=4;type=fp16;encoder={name=sq;parameters=[nprobes=2;clip=false;]};]}
*
* @return string representation
*/
public String toClusterStateString() {
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("{name=").append(name).append(DELIMITER);
stringBuilder.append("parameters=[");
if (Objects.nonNull(parameters)) {
for (Map.Entry<String, Object> entry : parameters.entrySet()) {
stringBuilder.append(entry.getKey()).append("=");
Object objectValue = entry.getValue();
String value;
if (objectValue instanceof MethodComponentContext) {
value = ((MethodComponentContext) objectValue).toClusterStateString();
} else {
value = entry.getValue().toString();
}
// Model Metadata uses a delimiter to split the input string in its fromString method
// https://github.com/opensearch-project/k-NN/blob/2.12/src/main/java/org/opensearch/knn/indices/ModelMetadata.java#L265
// If any of the values in the method component context contain this delimiter,
// then the method will not work correctly. Therefore, we replace the delimiter with an uncommon
// sequence that is very unlikely to appear in the value itself.
// https://github.com/opensearch-project/k-NN/issues/1337
value = value.replace(ModelMetadata.DELIMITER, DELIMITER_PLACEHOLDER);
stringBuilder.append(value).append(DELIMITER);
}
}
stringBuilder.append("]}");
return stringBuilder.toString();
}

/**
* This method converts a string created by the toClusterStateString() method of MethodComponentContext
* to a MethodComponentContext object.
*
* @param in a string representation of MethodComponentContext
* @return a MethodComponentContext object
*/
public static MethodComponentContext fromClusterStateString(String in) {
String stringToParse = unwrapString(in, '{', '}');

// Parse name from string
String[] nameAndParameters = stringToParse.split(DELIMITER, 2);
checkExpectedArrayLength(nameAndParameters, 2);
String name = parseName(nameAndParameters[0]);
String parametersString = nameAndParameters[1];
Map<String, Object> parameters = parseParameters(parametersString);
return new MethodComponentContext(name, parameters);
}

private static String parseName(String candidateNameString) {
// Expecting candidateNameString to look like "name=ivf"
checkStringNotEmpty(candidateNameString);
String[] nameKeyAndValue = candidateNameString.split("=");
checkStringMatches(nameKeyAndValue[0], "name");
if (nameKeyAndValue.length == 1) {
return "";
}
checkExpectedArrayLength(nameKeyAndValue, 2);
return nameKeyAndValue[1];
}

private static Map<String, Object> parseParameters(String candidateParameterString) {
checkStringNotEmpty(candidateParameterString);
String[] parametersKeyAndValue = candidateParameterString.split("=", 2);
checkStringMatches(parametersKeyAndValue[0], "parameters");
if (parametersKeyAndValue.length == 1) {
return Collections.emptyMap();

Check warning on line 289 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L289

Added line #L289 was not covered by tests
}
checkExpectedArrayLength(parametersKeyAndValue, 2);
return parseParametersValue(parametersKeyAndValue[1]);
}

private static Map<String, Object> parseParametersValue(String candidateParameterValueString) {
// Expected input is [nlist=4;type=fp16;encoder={name=sq;parameters=[nprobes=2;clip=false;]};]
checkStringNotEmpty(candidateParameterValueString);
candidateParameterValueString = unwrapString(candidateParameterValueString, '[', ']');
Map<String, Object> parameters = new HashMap<>();
while (!candidateParameterValueString.isEmpty()) {
String[] keyAndValueToParse = candidateParameterValueString.split("=", 2);
if (keyAndValueToParse.length == 1 && keyAndValueToParse[0].charAt(0) == ';') {
break;
}
String key = keyAndValueToParse[0];
ValueAndRestToParse parsed = parseParameterValueAndRestToParse(keyAndValueToParse[1]);
parameters.put(key, parsed.getValue());
candidateParameterValueString = parsed.getRestToParse();
}

return parameters;
}

private static ValueAndRestToParse parseParameterValueAndRestToParse(String candidateParameterValueAndRestToParse) {
if (candidateParameterValueAndRestToParse.charAt(0) == '{') {
int endOfNestedMap = findClosingPosition(candidateParameterValueAndRestToParse, '{', '}');
String nestedMethodContext = candidateParameterValueAndRestToParse.substring(0, endOfNestedMap + 1);
Object nestedParse = fromClusterStateString(nestedMethodContext);
String restToParse = candidateParameterValueAndRestToParse.substring(endOfNestedMap + 1);
return new ValueAndRestToParse(nestedParse, restToParse);
}

String[] stringValueAndRestToParse = candidateParameterValueAndRestToParse.split(DELIMITER, 2);
String stringValue = stringValueAndRestToParse[0];
Object value;
if (NumberUtils.isNumber(stringValue)) {
value = Integer.parseInt(stringValue);
} else if (stringValue.equals("true") || stringValue.equals("false")) {
value = Boolean.parseBoolean(stringValue);
} else {
stringValue = stringValue.replace(DELIMITER_PLACEHOLDER, ModelMetadata.DELIMITER);
value = stringValue;
}

return new ValueAndRestToParse(value, stringValueAndRestToParse[1]);
}

private static String unwrapString(String in, char expectedStart, char expectedEnd) {
if (in.length() < 2) {
throw new IllegalArgumentException("Invalid string.");

Check warning on line 340 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L340

Added line #L340 was not covered by tests
}

if (in.charAt(0) != expectedStart || in.charAt(in.length() - 1) != expectedEnd) {
throw new IllegalArgumentException("Invalid string." + in);

Check warning on line 344 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L344

Added line #L344 was not covered by tests
}
return in.substring(1, in.length() - 1);
}

private static int findClosingPosition(String in, char expectedStart, char expectedEnd) {
int nestedLevel = 0;
for (int i = 0; i < in.length(); i++) {
if (in.charAt(i) == expectedStart) {
nestedLevel++;
continue;
}

if (in.charAt(i) == expectedEnd) {
nestedLevel--;
}

if (nestedLevel == 0) {
return i;
}
}

throw new IllegalArgumentException("Invalid string. No end to the nesting");

Check warning on line 366 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L366

Added line #L366 was not covered by tests
}

private static void checkStringNotEmpty(String string) {
if (string.isEmpty()) {
throw new IllegalArgumentException("Unable to parse MethodComponentContext");

Check warning on line 371 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L371

Added line #L371 was not covered by tests
}
}

private static void checkStringMatches(String string, String expected) {
if (!Objects.equals(string, expected)) {
throw new IllegalArgumentException("Unexpected key in MethodComponentContext. Expected 'name' or 'parameters'");

Check warning on line 377 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L377

Added line #L377 was not covered by tests
}
}

private static void checkExpectedArrayLength(String[] array, int expectedLength) {
if (null == array) {
throw new IllegalArgumentException("Error parsing MethodComponentContext. Array is null.");

Check warning on line 383 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L383

Added line #L383 was not covered by tests
}

if (array.length != expectedLength) {
throw new IllegalArgumentException("Error parsing MethodComponentContext. Array is not expected length.");

Check warning on line 387 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L387

Added line #L387 was not covered by tests
}
}

@AllArgsConstructor
@Getter
private static class ValueAndRestToParse {
private final Object value;
private final String restToParse;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(this.name);
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.common.exception.DeleteModelWhenInTrainStateException;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.plugin.transport.DeleteModelResponse;
import org.opensearch.knn.plugin.transport.GetModelResponse;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction;
Expand Down Expand Up @@ -288,6 +292,13 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription());
put(KNNConstants.MODEL_ERROR, modelMetadata.getError());
put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment());

MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
if (!methodComponentContext.getName().isEmpty()) {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
builder = methodComponentContext.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject();
put(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT, builder.toString());
}
}
};

Expand Down
Loading
Loading