Skip to content

Commit

Permalink
remote inference: add unit test for model and register model input (#…
Browse files Browse the repository at this point in the history
…1059)

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored and zane-neo committed Sep 1, 2023
1 parent c786142 commit da76cc8
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 18 deletions.
4 changes: 1 addition & 3 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ public MLModel(StreamInput input) throws IOException{
deployToAllNodes = input.readBoolean();
modelGroupId = input.readOptionalString();
if (input.readBoolean()) {
String connectorProtocol = input.readString();
connector = MLCommonsClassLoader.initConnector(connectorProtocol, new Object[]{connectorProtocol, input}, String.class, StreamInput.class);
connector = Connector.fromStream(input);
}
connectorId = input.readOptionalString();
}
Expand Down Expand Up @@ -263,7 +262,6 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(modelGroupId);
if (connector != null) {
out.writeBoolean(true);
out.writeString(connector.getProtocol());
connector.writeTo(out);
} else {
out.writeBoolean(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ private void parseFromStream(StreamInput input) throws IOException {
if (input.readBoolean()) {
this.access = input.readEnum(AccessMode.class);
}
if (input.readBoolean()) {
this.owner = new User(input);
}
}

@Override
Expand Down Expand Up @@ -235,6 +238,12 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
if (owner != null) {
out.writeBoolean(true);
owner.writeTo(out);
} else {
out.writeBoolean(false);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException {
this.deployModel = in.readBoolean();
this.modelNodeIds = in.readOptionalStringArray();
if (in.readBoolean()) {
String protocol = in.readString();
this.connector = MLCommonsClassLoader.initConnector(protocol, new Object[]{protocol, in}, String.class, StreamInput.class);
this.connector = Connector.fromStream(in);
}
this.connectorId = in.readOptionalString();
if (in.readBoolean()) {
Expand Down Expand Up @@ -184,7 +183,6 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalStringArray(modelNodeIds);
if (connector != null) {
out.writeBoolean(true);
out.writeString(connector.getProtocol());
connector.writeTo(out);
} else {
out.writeBoolean(false);
Expand Down
155 changes: 155 additions & 0 deletions common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common;

import org.junit.Assert;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.HttpConnectorTest;

import java.io.IOException;

import static org.junit.Assert.assertEquals;
import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;

public class RemoteModelTests {

@Test
public void toXContent_ConnectorId() throws IOException {
MLModel mlModel = MLModel.builder()
.algorithm(FunctionName.REMOTE)
.name("test_model_name")
.version("1.0.0")
.modelGroupId("test_group_id")
.description("test model")
.connectorId("test_connector_id")
.build();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\",\"algorithm\":\"REMOTE\"" +
",\"model_version\":\"1.0.0\",\"description\":\"test model\"," +
"\"connector_id\":\"test_connector_id\"}", mlModelContent);
}

@Test
public void toXContent_InternalConnector() throws IOException {
Connector connector = HttpConnectorTest.createHttpConnector();
MLModel mlModel = MLModel.builder()
.algorithm(FunctionName.REMOTE)
.name("test_model_name")
.version("1.0.0")
.modelGroupId("test_group_id")
.description("test model")
.connector(connector)
.build();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\",\"algorithm\":\"REMOTE\"," +
"\"model_version\":\"1.0.0\",\"description\":\"test model\",\"connector\":{\"name\":\"test_connector_name\"," +
"\"version\":\"1\",\"description\":\"this is a test connector\",\"protocol\":\"http\"," +
"\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," +
"\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," +
"\"headers\":{\"api_key\":\"${credential.key}\"}," +
"\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," +
"\"pre_process_function\":\"connector.pre_process.openai.embedding\"," +
"\"post_process_function\":\"connector.post_process.openai.embedding\"}]," +
"\"backend_roles\":[\"role1\",\"role2\"]," +
"\"access\":\"public\"}}", mlModelContent);
}

@Test
public void parse_ConnectorId() throws IOException {
MLModel mlModel = MLModel.builder()
.algorithm(FunctionName.REMOTE)
.name("test_model_name")
.version("1.0.0")
.modelGroupId("test_group_id")
.description("test model")
.connectorId("test_connector_id")
.build();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String jsonStr = TestHelper.xContentBuilderToString(builder);
XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr);
parser.nextToken();
MLModel parsedModel = MLModel.parse(parser, FunctionName.REMOTE.name());
Assert.assertNull(parsedModel.getConnector());
Assert.assertEquals(mlModel.getConnectorId(), parsedModel.getConnectorId());
}

@Test
public void parse_InternalConnector() throws IOException {
Connector connector = HttpConnectorTest.createHttpConnector();
MLModel mlModel = MLModel.builder()
.algorithm(FunctionName.REMOTE)
.name("test_model_name")
.version("1.0.0")
.modelGroupId("test_group_id")
.description("test model")
.connector(connector)
.build();

XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String jsonStr = TestHelper.xContentBuilderToString(builder);
XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr);
parser.nextToken();
MLModel parsedModel = MLModel.parse(parser, FunctionName.REMOTE.name());
Assert.assertEquals(mlModel.getConnector(), parsedModel.getConnector());
}


@Test
public void readInputStream_ConnectorId() throws IOException {
MLModel mlModel = MLModel.builder()
.algorithm(FunctionName.REMOTE)
.name("test_model_name")
.version("1.0.0")
.modelGroupId("test_group_id")
.description("test model")
.connectorId("test_connector_id")
.build();
readInputStream(mlModel);
}

@Test
public void readInputStream_InternalConnector() throws IOException {
Connector connector = HttpConnectorTest.createHttpConnector();
MLModel mlModel = MLModel.builder()
.algorithm(FunctionName.REMOTE)
.name("test_model_name")
.version("1.0.0")
.modelGroupId("test_group_id")
.description("test model")
.connector(connector)
.build();
readInputStream(mlModel);
}

public void readInputStream(MLModel mlModel) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
mlModel.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
MLModel parsedMLModel = new MLModel(streamInput);
assertEquals(mlModel.getName(), parsedMLModel.getName());
assertEquals(mlModel.getAlgorithm(), parsedMLModel.getAlgorithm());
assertEquals(mlModel.getVersion(), parsedMLModel.getVersion());
assertEquals(mlModel.getContent(), parsedMLModel.getContent());
assertEquals(mlModel.getUser(), parsedMLModel.getUser());
assertEquals(mlModel.getConnectorId(), parsedMLModel.getConnectorId());
assertEquals(mlModel.getConnector(), parsedMLModel.getConnector());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,19 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.TestHelper;
import org.opensearch.search.SearchModule;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -108,16 +118,65 @@ public void constructor_NoPredictAction() {
Assert.assertEquals("decrypted: ENCRYPTED: TEST_REGION", connector.getRegion());
}

@Test
public void constructor_Parser() throws IOException {
AwsConnector awsConnector = createAwsConnector();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
awsConnector.toXContent(builder, ToXContent.EMPTY_PARAMS);
String jsonStr = TestHelper.xContentBuilderToString(builder);

XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
Collections.emptyList()).getNamedXContents()), null, jsonStr);
parser.nextToken();

AwsConnector connector = new AwsConnector(awsConnector.getProtocol(), parser);
Assert.assertEquals(awsConnector, connector);
}

@Test
public void constructor() {
AwsConnector connector = createAwsConnector();
Map<String, String> parameters = new HashMap<>();
parameters.put("input", "test input value");
parameters.put(SERVICE_NAME_FIELD, "test_service");
parameters.put(REGION_FIELD, "us-west-2");
parameters.put("endpoint", "test.com");

Map<String, String> credential = new HashMap<>();
credential.put(ACCESS_KEY_FIELD, "test_access_key");
credential.put(SECRET_KEY_FIELD, "test_secret_key");
credential.put(SESSION_TOKEN_FIELD, "test_session_token");
String url = "https://${parameters.endpoint}/model1";

AwsConnector connector = createAwsConnector(parameters, credential, url);
connector.encrypt(encryptFunction);
connector.decrypt(decryptFunction);
Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken());
Assert.assertEquals("test_service", connector.getServiceName());
Assert.assertEquals("us-west-2", connector.getRegion());
Assert.assertEquals("https://test.com/model1", connector.getPredictEndpoint(parameters));
}

@Test
public void constructor_NoParameter() {
Map<String, String> credential = new HashMap<>();
credential.put(ACCESS_KEY_FIELD, "test_access_key");
credential.put(SECRET_KEY_FIELD, "test_secret_key");
credential.put(SESSION_TOKEN_FIELD, "test_session_token");
credential.put(SERVICE_NAME_FIELD, "test_service");
credential.put(REGION_FIELD, "us-west-2");

String url = "https://test.com";
AwsConnector connector = createAwsConnector(null, credential, url);
connector.encrypt(encryptFunction);
connector.decrypt(decryptFunction);
Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken());
Assert.assertEquals("decrypted: ENCRYPTED: TEST_SERVICE", connector.getServiceName());
Assert.assertEquals("decrypted: ENCRYPTED: US-WEST-2", connector.getRegion());
Assert.assertEquals("https://test.com", connector.getPredictEndpoint(null));
}

@Test
Expand All @@ -128,17 +187,6 @@ public void cloneConnector() {
}

private AwsConnector createAwsConnector() {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "POST";
String url = "https://test.com";
Map<String, String> headers = new HashMap<>();
headers.put("api_key", "${credential.key}");
String requestBody = "{\"input\": \"${parameters.input}\"}";
String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING;

ConnectorAction action = new ConnectorAction(actionType, method, url, headers, requestBody, preProcessFunction, postProcessFunction);

Map<String, String> parameters = new HashMap<>();
parameters.put("input", "test input value");
parameters.put(SERVICE_NAME_FIELD, "test_service");
Expand All @@ -148,6 +196,20 @@ private AwsConnector createAwsConnector() {
credential.put(ACCESS_KEY_FIELD, "test_access_key");
credential.put(SECRET_KEY_FIELD, "test_secret_key");
credential.put(SESSION_TOKEN_FIELD, "test_session_token");
String url = "https://test.com";
return createAwsConnector(parameters, credential, url);
}

private AwsConnector createAwsConnector(Map<String, String> parameters, Map<String, String> credential, String url) {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "POST";
Map<String, String> headers = new HashMap<>();
headers.put("api_key", "${credential.key}");
String requestBody = "{\"input\": \"${parameters.input}\"}";
String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING;

ConnectorAction action = new ConnectorAction(actionType, method, url, headers, requestBody, preProcessFunction, postProcessFunction);

AwsConnector connector = AwsConnector.awsConnectorBuilder()
.name("test_connector_name")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.connector.HttpConnectorTest;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
Expand Down Expand Up @@ -229,6 +231,42 @@ public void readInputStream_SuccessWithNullFields() throws IOException {
});
}

@Test
public void readInputStream_WithConnectorId() throws IOException {
String connectorId = "test_connector_id";
input = MLRegisterModelInput.builder()
.functionName(FunctionName.REMOTE)
.modelName(modelName)
.description("test model input")
.version(version)
.modelGroupId(modelGroupId)
.connectorId(connectorId)
.build();
readInputStream(input, parsedInput -> {
assertNull(parsedInput.getModelConfig());
assertNull(parsedInput.getModelFormat());
assertNull(parsedInput.getConnector());
assertEquals(connectorId, parsedInput.getConnectorId());
});
}

@Test
public void readInputStream_WithInternalConnector() throws IOException {
HttpConnector connector = HttpConnectorTest.createHttpConnector();
input = MLRegisterModelInput.builder()
.functionName(FunctionName.REMOTE)
.modelName(modelName)
.description("test model input")
.version(version)
.modelGroupId(modelGroupId)
.connector(connector)
.build();
readInputStream(input, parsedInput -> {
assertNull(parsedInput.getModelConfig());
assertNull(parsedInput.getModelFormat());
assertEquals(input.getConnector(), parsedInput.getConnector());
});
}

private void readInputStream(MLRegisterModelInput input, Consumer<MLRegisterModelInput> verify) throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
Expand Down

0 comments on commit da76cc8

Please sign in to comment.