Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persist model definition in model metadata #1527

Merged
merged 10 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
* Optize Faiss Query With Filters: Reduce iteration and memory for id filter [#1402](https://github.com/opensearch-project/k-NN/pull/1402)
* Add MethodComponentContext to ModelMetadata [#1527] (https://github.com/opensearch-project/k-NN/pull/1527)
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
### Bug Fixes
* Disable sdc table for HNSWPQ read-only indices [#1518](https://github.com/opensearch-project/k-NN/pull/1518)
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class KNNConstants {
public static final String MODEL_DESCRIPTION = "description";
public static final String MODEL_ERROR = "error";
public static final String MODEL_NODE_ASSIGNMENT = "training_node_assignment";
public static final String MODEL_METHOD_COMPONENT_CONTEXT = "method_component_context";
public static final String PARAM_SIZE = "size";
public static final Integer SEARCH_MODEL_MIN_SIZE = 1;
public static final Integer SEARCH_MODEL_MAX_SIZE = 1000;
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@
public class IndexUtil {

public static final String MODEL_NODE_ASSIGNMENT_KEY = KNNConstants.MODEL_NODE_ASSIGNMENT;
public static final String MODEL_METHOD_COMPONENT_CONTEXT_KEY = KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT;

private static final Version MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED = Version.V_2_11_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT = Version.V_2_13_0;
private static final Map<String, Version> minimalRequiredVersionMap = new HashMap<String, Version>() {
{
put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED);
put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT);
put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT);
}
};

Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/index/KNNMethodContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ public static synchronized KNNMethodContext getDefault() {
@NonNull
private final MethodComponentContext methodComponentContext;

private static final String DELIMITER = ";";
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved

/**
* Constructor from stream.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.apache.commons.lang.math.NumberUtils;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -26,6 +27,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
Expand All @@ -41,6 +43,8 @@
@RequiredArgsConstructor
public class MethodComponentContext implements ToXContentFragment, Writeable {

public static final MethodComponentContext DEFAULT = new MethodComponentContext("", Collections.emptyMap());

@Getter
private final String name;
private final Map<String, Object> parameters;
Expand Down Expand Up @@ -193,6 +197,69 @@ public Map<String, Object> getParameters() {
return parameters;
}

@Override
public String toString() {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("{name=").append(name).append(";");
stringBuilder.append("parameters=[");
if (Objects.nonNull(parameters)) {
for (Map.Entry<String, Object> entry : parameters.entrySet()) {
stringBuilder.append(entry.getKey()).append("=");
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
String value = entry.getValue().toString();
value = value.replace(",", "$%$");
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
stringBuilder.append(value).append(";");
}
}
stringBuilder.append("]}");
return stringBuilder.toString();
}

public static MethodComponentContext fromString(String in) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method will need a few comments describing what it will do. Also, lets add java doc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a java doc with a small description. Let me know if I need to add more

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name this fromClusterStateString for consistency.

int index = 0;

String[] outerMethodComponentContextArray = in.split("\\{", -1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need "\"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The split method takes a regex as an input. { is a special character in regex so the \ escapes it to just split by {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid this by just confirming the first and last character are {} and then building a substring from then which is processed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The split still needs to be there for the recursion to work. Otherwise, I can't pass the whole substring of a nested MethodComponentContext because it will be split by ;


if (outerMethodComponentContextArray[index].isEmpty()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we split this into a few methods? maybe parse name and parse parameters?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll try to see if it works.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a parse parameters method. Didn't add a parse name method because it's only two lines

index++;
}

String[] innerMethodComponentContextArray = outerMethodComponentContextArray[index].split(";", -1);
index++;

String name = "";
Map<String, Object> parameters = new HashMap<>();
name = innerMethodComponentContextArray[0].substring(innerMethodComponentContextArray[0].indexOf("=") + 1);
if (innerMethodComponentContextArray.length > 2) {
for (int i = 1; i < innerMethodComponentContextArray.length; i++) {
String substring = innerMethodComponentContextArray[i];
if (i == 1) {
substring = substring.substring(substring.indexOf("=") + 2);
}
if (substring.charAt(0) == ']') {
break;
}
String key = substring.substring(0, substring.indexOf("="));
String stringValue = substring.substring(substring.indexOf("=") + 1);
Object value;
if (stringValue.isEmpty()) {
value = fromString(outerMethodComponentContextArray[index]);
} else if (NumberUtils.isNumber(stringValue)) {
value = Integer.parseInt(stringValue);
} else if (stringValue.equals("true") || stringValue.equals("false")) {
value = Boolean.parseBoolean(stringValue);
} else {
value = stringValue;
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
}

parameters.put(key, value);
}
} else {
parameters = Collections.emptyMap();
}

return new MethodComponentContext(name, parameters);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(this.name);
Expand Down
106 changes: 90 additions & 16 deletions src/main/java/org/opensearch/knn/indices/ModelMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,17 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR;
import static org.opensearch.knn.common.KNNConstants.MODEL_STATE;
import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP;
import static org.opensearch.knn.common.KNNConstants.MODEL_NODE_ASSIGNMENT;
import static org.opensearch.knn.common.KNNConstants.*;
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved

@Log4j2
public class ModelMetadata implements Writeable, ToXContentObject {
Expand All @@ -51,6 +46,7 @@
final private String timestamp;
final private String description;
final private String trainingNodeAssignment;
private MethodComponentContext methodComponentContext;
private String error;

/**
Expand All @@ -76,6 +72,12 @@
} else {
this.trainingNodeAssignment = "";
}

if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), IndexUtil.MODEL_METHOD_COMPONENT_CONTEXT_KEY)) {
this.methodComponentContext = new MethodComponentContext(in);
} else {
this.methodComponentContext = new MethodComponentContext("", Collections.emptyMap());

Check warning on line 79 in src/main/java/org/opensearch/knn/indices/ModelMetadata.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/indices/ModelMetadata.java#L79

Added line #L79 was not covered by tests
}
}

/**
Expand All @@ -88,6 +90,8 @@
* @param timestamp timevalue when model was created
* @param description information about the model
* @param error error message associated with model
* @param trainingNodeAssignment node assignment for the model
* @param methodComponentContext method component context associated with model
*/
public ModelMetadata(
KNNEngine knnEngine,
Expand All @@ -97,7 +101,8 @@
String timestamp,
String description,
String error,
String trainingNodeAssignment
String trainingNodeAssignment,
MethodComponentContext methodComponentContext
) {
this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null");
this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null");
Expand All @@ -118,6 +123,7 @@
this.description = Objects.requireNonNull(description, "description must not be null");
this.error = Objects.requireNonNull(error, "error must not be null");
this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null");
this.methodComponentContext = Objects.requireNonNull(methodComponentContext, "method context must not be null");
}

/**
Expand Down Expand Up @@ -192,6 +198,15 @@
return trainingNodeAssignment;
}

/**
* getter for model's method context
*
* @return knnMethodContext
*/
public MethodComponentContext getMethodComponentContext() {
return methodComponentContext;
}

/**
* setter for model's state
*
Expand Down Expand Up @@ -221,7 +236,8 @@
timestamp,
description,
error,
trainingNodeAssignment
trainingNodeAssignment,
methodComponentContext.toString()
);
}

Expand Down Expand Up @@ -268,17 +284,50 @@
// Because models can be created on older versions and the cluster can be upgraded after,
// we need to accept model metadata arrays both with and without the training node assignment.
if (modelMetadataArray.length == 7) {
log.debug("Model metadata array does not contain training node assignment. Assuming empty string.");
log.debug(
"Model metadata array does not contain training node assignment or method component context. Assuming empty string node assignment and empty method component context."
);
KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]);
SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]);
int dimension = Integer.parseInt(modelMetadataArray[2]);
ModelState modelState = ModelState.getModelState(modelMetadataArray[3]);
String timestamp = modelMetadataArray[4];
String description = modelMetadataArray[5];
String error = modelMetadataArray[6];
return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, "");
return new ModelMetadata(
knnEngine,
spaceType,
dimension,
modelState,
timestamp,
description,
error,
"",
new MethodComponentContext("", Collections.emptyMap())
);
} else if (modelMetadataArray.length == 8) {
log.debug("Model metadata contains training node assignment");
log.debug("Model metadata contains training node assignment. Assuming empty method component context.");
KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]);
SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]);
int dimension = Integer.parseInt(modelMetadataArray[2]);
ModelState modelState = ModelState.getModelState(modelMetadataArray[3]);
String timestamp = modelMetadataArray[4];
String description = modelMetadataArray[5];
String error = modelMetadataArray[6];
String trainingNodeAssignment = modelMetadataArray[7];
return new ModelMetadata(

Check warning on line 318 in src/main/java/org/opensearch/knn/indices/ModelMetadata.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/indices/ModelMetadata.java#L309-L318

Added lines #L309 - L318 were not covered by tests
knnEngine,
spaceType,
dimension,
modelState,
timestamp,
description,
error,
trainingNodeAssignment,
new MethodComponentContext("", Collections.emptyMap())

Check warning on line 327 in src/main/java/org/opensearch/knn/indices/ModelMetadata.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/knn/indices/ModelMetadata.java#L327

Added line #L327 was not covered by tests
);
} else if (modelMetadataArray.length == 9) {
log.debug("Model metadata contains training node assignment and method context");
KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]);
SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]);
int dimension = Integer.parseInt(modelMetadataArray[2]);
Expand All @@ -287,11 +336,22 @@
String description = modelMetadataArray[5];
String error = modelMetadataArray[6];
String trainingNodeAssignment = modelMetadataArray[7];
return new ModelMetadata(knnEngine, spaceType, dimension, modelState, timestamp, description, error, trainingNodeAssignment);
MethodComponentContext methodComponentContext = MethodComponentContext.fromString(modelMetadataArray[8]);
return new ModelMetadata(
knnEngine,
spaceType,
dimension,
modelState,
timestamp,
description,
error,
trainingNodeAssignment,
methodComponentContext
);
} else {
throw new IllegalArgumentException(
"Illegal format for model metadata. Must be of the form "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>\" or \"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>\"."
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>\" or \"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>\" or \"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>\"."
);
}
}
Expand Down Expand Up @@ -321,11 +381,16 @@
Object description = modelSourceMap.get(KNNConstants.MODEL_DESCRIPTION);
Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR);
Object trainingNodeAssignment = modelSourceMap.get(KNNConstants.MODEL_NODE_ASSIGNMENT);
Object methodComponentContext = modelSourceMap.get(MODEL_METHOD_COMPONENT_CONTEXT);

if (trainingNodeAssignment == null) {
trainingNodeAssignment = "";
}

if (methodComponentContext == null) {
methodComponentContext = new MethodComponentContext("", Collections.emptyMap());
}

ModelMetadata modelMetadata = new ModelMetadata(
KNNEngine.getEngine(objectToString(engine)),
SpaceType.getSpace(objectToString(space)),
Expand All @@ -334,7 +399,8 @@
objectToString(timestamp),
objectToString(description),
objectToString(error),
objectToString(trainingNodeAssignment)
objectToString(trainingNodeAssignment),
(MethodComponentContext) methodComponentContext
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to cast?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, otherwise it throws a parsing exception

);
return modelMetadata;
}
Expand All @@ -351,6 +417,9 @@
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) {
out.writeString(getNodeAssignment());
}
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_METHOD_COMPONENT_CONTEXT_KEY)) {
getMethodComponentContext().writeTo(out);
}
}

@Override
Expand All @@ -366,6 +435,11 @@
if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) {
builder.field(MODEL_NODE_ASSIGNMENT, getNodeAssignment());
}
if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(IndexUtil.MODEL_METHOD_COMPONENT_CONTEXT_KEY)) {
builder.field(MODEL_METHOD_COMPONENT_CONTEXT).startObject();
getMethodComponentContext().toXContent(builder, params);
builder.endObject();
}
return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ public TrainingJob(
ZonedDateTime.now(ZoneOffset.UTC).toString(),
description,
"",
nodeAssignment
nodeAssignment,
knnMethodContext.getMethodComponentContext()
),
null,
this.modelId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.io.IOException;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -62,7 +63,8 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException
ZonedDateTime.now(ZoneOffset.UTC).toString(),
"",
"",
"test-node"
"test-node",
new MethodComponentContext("", Collections.emptyMap())
);

Model model = new Model(modelMetadata, modelBlob, modelId);
Expand Down
Loading
Loading