Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persist model definition in model metadata #1527

Merged
merged 10 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Optize Faiss Query With Filters: Reduce iteration and memory for id filter [#1402](https://github.com/opensearch-project/k-NN/pull/1402)
* Detect AVX2 Dynamically on the System [#1502](https://github.com/opensearch-project/k-NN/pull/1502)
* Validate zero vector when using cosine metric [#1501](https://github.com/opensearch-project/k-NN/pull/1501)
* Persist model definition in model metadata [#1527] (https://github.com/opensearch-project/k-NN/pull/1527)
### Bug Fixes
* Disable sdc table for HNSWPQ read-only indices [#1518](https://github.com/opensearch-project/k-NN/pull/1518)
* Switch SpaceType.INNERPRODUCT's vector similarity function to MAXIMUM_INNER_PRODUCT [#1532](https://github.com/opensearch-project/k-NN/pull/1532)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class KNNConstants {
public static final String MODEL_DESCRIPTION = "description";
public static final String MODEL_ERROR = "error";
public static final String MODEL_NODE_ASSIGNMENT = "training_node_assignment";
public static final String MODEL_METHOD_COMPONENT_CONTEXT = "model_definition";
public static final String PARAM_SIZE = "size";
public static final Integer SEARCH_MODEL_MIN_SIZE = 1;
public static final Integer SEARCH_MODEL_MAX_SIZE = 1000;
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@
public class IndexUtil {

public static final String MODEL_NODE_ASSIGNMENT_KEY = KNNConstants.MODEL_NODE_ASSIGNMENT;
public static final String MODEL_METHOD_COMPONENT_CONTEXT_KEY = KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT;

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT = Version.V_2_13_0;
private static final Map<String, Version> minimalRequiredVersionMap = new HashMap<String, Version>() {
{
put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED);
put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT);
put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT);
}
};

Expand Down
214 changes: 211 additions & 3 deletions src/main/java/org/opensearch/knn/index/MethodComponentContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,26 @@

package org.opensearch.knn.index;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.commons.lang.math.NumberUtils;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentFragment;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MapperParsingException;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.*;
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
import java.util.stream.Collectors;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.opensearch.knn.indices.ModelMetadata;

import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
Expand All @@ -41,6 +43,13 @@
@RequiredArgsConstructor
public class MethodComponentContext implements ToXContentFragment, Writeable {

// EMPTY method component context can only occur if a model originated on a cluster before 2.13.0 and the cluster is then upgraded to
// 2.13.0
public static final MethodComponentContext EMPTY = new MethodComponentContext("", Collections.emptyMap());

private static final String DELIMITER = ";";
private static final String DELIMITER_PLACEHOLDER = "$%$";

@Getter
private final String name;
private final Map<String, Object> parameters;
Expand Down Expand Up @@ -161,6 +170,15 @@
return builder;
}

public static MethodComponentContext fromXContent(XContentParser xContentParser) throws IOException {
// If it is a fresh parser, move to the first token
if (xContentParser.currentToken() == null) {
xContentParser.nextToken();
}
Map<String, Object> parsedMap = xContentParser.map();
return MethodComponentContext.parse(parsedMap);
}

@Override
public boolean equals(Object obj) {
if (this == obj) return true;
Expand Down Expand Up @@ -193,6 +211,196 @@
return parameters;
}

/**
*
* Provides a String representation of MethodComponentContext
* Sample return:
* {name=ivf;parameters=[nlist=4;type=fp16;encoder={name=sq;parameters=[nprobes=2;clip=false;]};]}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need { and } at beginning and end? can we not skip it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added those to help with the recursion if there is a nested MethodComponentContext object. Otherwise, there is no way to differentiate the object from a normal field.

*
* @return string representation
*/
public String toClusterStateString() {
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("{name=").append(name).append(DELIMITER);
stringBuilder.append("parameters=[");
if (Objects.nonNull(parameters)) {
for (Map.Entry<String, Object> entry : parameters.entrySet()) {
stringBuilder.append(entry.getKey()).append("=");
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
Object objectValue = entry.getValue();
String value;
if (objectValue instanceof MethodComponentContext) {
value = ((MethodComponentContext) objectValue).toClusterStateString();
} else {
value = entry.getValue().toString();
}
// Model Metadata uses a delimiter to split the input string in its fromString method
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
// https://github.com/opensearch-project/k-NN/blob/2.12/src/main/java/org/opensearch/knn/indices/ModelMetadata.java#L265
// If any of the values in the method component context contain this delimiter,
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
// then the method will not work correctly. Therefore, we replace the delimiter with an uncommon
// sequence that is very unlikely to appear in the value itself.
// https://github.com/opensearch-project/k-NN/issues/1337
value = value.replace(ModelMetadata.DELIMITER, DELIMITER_PLACEHOLDER);
stringBuilder.append(value).append(DELIMITER);
}
}
stringBuilder.append("]}");
return stringBuilder.toString();
}

/**
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
* This method converts a string created by the toClusterStateString() method of MethodComponentContext
* to a MethodComponentContext object.
*
* @param in a string representation of MethodComponentContext
* @return a MethodComponentContext object
*/
public static MethodComponentContext fromClusterStateString(String in) {
String stringToParse = unwrapString(in, '{', '}');

// Parse name from string
String[] nameAndParameters = stringToParse.split(DELIMITER, 2);
System.out.println("nameAndParameters: " + Arrays.toString(nameAndParameters));
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
checkExpectedArrayLength(nameAndParameters, 2);
String name = parseName(nameAndParameters[0]);
System.out.println("name: " + name);
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
String parametersString = nameAndParameters[1];
Map<String, Object> parameters = parseParameters(parametersString);
return new MethodComponentContext(name, parameters);
}

private static String parseName(String candidateNameString) {
// Expecting candidateNameString to look like "name=ivf"
checkStringNotEmpty(candidateNameString);
String[] nameKeyAndValue = candidateNameString.split("=");
checkStringMatches(nameKeyAndValue[0], "name");
if (nameKeyAndValue.length == 1) {
return "";
}
checkExpectedArrayLength(nameKeyAndValue, 2);
return nameKeyAndValue[1];
}

private static Map<String, Object> parseParameters(String candidateParameterString) {
checkStringNotEmpty(candidateParameterString);
String[] parametersKeyAndValue = candidateParameterString.split("=", 2);
checkStringMatches(parametersKeyAndValue[0], "parameters");
if (parametersKeyAndValue.length == 1) {
return Collections.emptyMap();

Check warning on line 288 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L288

Added line #L288 was not covered by tests
}
checkExpectedArrayLength(parametersKeyAndValue, 2);
return parseParametersValue(parametersKeyAndValue[1]);
}

private static Map<String, Object> parseParametersValue(String candidateParameterValueString) {
// Expected input is [nlist=4;type=fp16;encoder={name=sq;parameters=[nprobes=2;clip=false;]};]
System.out.println("parameter value: " + candidateParameterValueString);
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
checkStringNotEmpty(candidateParameterValueString);
candidateParameterValueString = unwrapString(candidateParameterValueString, '[', ']');
Map<String, Object> parameters = new HashMap<>();
while (!candidateParameterValueString.isEmpty()) {
String[] keyAndValueToParse = candidateParameterValueString.split("=", 2);
if (keyAndValueToParse.length == 1 && keyAndValueToParse[0].charAt(0) == ';') {
break;
}
String key = keyAndValueToParse[0];
ValueAndRestToParse parsed = parseParameterValueAndRestToParse(keyAndValueToParse[1]);
parameters.put(key, parsed.getValue());
candidateParameterValueString = parsed.getRestToParse();
}

return parameters;
}

private static ValueAndRestToParse parseParameterValueAndRestToParse(String candidateParameterValueAndRestToParse) {
if (candidateParameterValueAndRestToParse.charAt(0) == '{') {
int endOfNestedMap = findClosingPosition(candidateParameterValueAndRestToParse, '{', '}');
String nestedMethodContext = candidateParameterValueAndRestToParse.substring(0, endOfNestedMap + 1);
System.out.println("nested method context: " + nestedMethodContext);
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
Object nestedParse = fromClusterStateString(nestedMethodContext);
String restToParse = candidateParameterValueAndRestToParse.substring(endOfNestedMap + 1);
return new ValueAndRestToParse(nestedParse, restToParse);
}

String[] stringValueAndRestToParse = candidateParameterValueAndRestToParse.split(DELIMITER, 2);
String stringValue = stringValueAndRestToParse[0];
System.out.println("string value: " + stringValue);
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
Object value;
if (NumberUtils.isNumber(stringValue)) {
value = Integer.parseInt(stringValue);
} else if (stringValue.equals("true") || stringValue.equals("false")) {
value = Boolean.parseBoolean(stringValue);
} else {
stringValue = stringValue.replace(DELIMITER_PLACEHOLDER, ModelMetadata.DELIMITER);
value = stringValue;
}

return new ValueAndRestToParse(value, stringValueAndRestToParse[1]);
}

private static String unwrapString(String in, char expectedStart, char expectedEnd) {
if (in.length() < 2) {
throw new IllegalArgumentException("Invalid string.");

Check warning on line 342 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L342

Added line #L342 was not covered by tests
}

if (in.charAt(0) != expectedStart || in.charAt(in.length() - 1) != expectedEnd) {
throw new IllegalArgumentException("Invalid string." + in);

Check warning on line 346 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L346

Added line #L346 was not covered by tests
}
return in.substring(1, in.length() - 1);
}

private static int findClosingPosition(String in, char expectedStart, char expectedEnd) {
int nestedLevel = 0;
for (int i = 0; i < in.length(); i++) {
if (in.charAt(i) == expectedStart) {
nestedLevel++;
continue;
}

if (in.charAt(i) == expectedEnd) {
nestedLevel--;
}

if (nestedLevel == 0) {
return i;
}
}

throw new IllegalArgumentException("Invalid string. No end to the nesting");

Check warning on line 368 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L368

Added line #L368 was not covered by tests
}

private static void checkStringNotEmpty(String string) {
if (string.isEmpty()) {
// TODO: Come up with better exception
throw new RuntimeException("think of better exception");

Check warning on line 374 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L374

Added line #L374 was not covered by tests
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
}
}

private static void checkStringMatches(String string, String expected) {
if (!Objects.equals(string, expected)) {
// TODO: Come up with better exception
throw new RuntimeException("think of better exception");

Check warning on line 381 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L381

Added line #L381 was not covered by tests
}
}

private static void checkExpectedArrayLength(String[] array, int expectedLength) {
if (null == array) {
// TODO: Come up with better exception
throw new RuntimeException("think of better exception");

Check warning on line 388 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L388

Added line #L388 was not covered by tests
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
}

if (array.length != expectedLength) {
// TODO: Come up with better exception
throw new RuntimeException("not expected length");

Check warning on line 393 in src/main/java/org/opensearch/knn/index/MethodComponentContext.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/index/MethodComponentContext.java#L393

Added line #L393 was not covered by tests
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
}
}

@AllArgsConstructor
@Getter
private static class ValueAndRestToParse {
private final Object value;
private final String restToParse;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(this.name);
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.common.exception.DeleteModelWhenInTrainStateException;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.plugin.transport.DeleteModelResponse;
import org.opensearch.knn.plugin.transport.GetModelResponse;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction;
Expand Down Expand Up @@ -288,6 +292,13 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription());
put(KNNConstants.MODEL_ERROR, modelMetadata.getError());
put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment());

MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
if (!methodComponentContext.getName().isEmpty()) {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
builder = methodComponentContext.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject();
put(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT, builder.toString());
}
}
};

Expand Down
Loading
Loading