diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index 760d7db9d9..1282e6008e 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -208,8 +208,7 @@ public MLModel(StreamInput input) throws IOException{ deployToAllNodes = input.readBoolean(); modelGroupId = input.readOptionalString(); if (input.readBoolean()) { - String connectorProtocol = input.readString(); - connector = MLCommonsClassLoader.initConnector(connectorProtocol, new Object[]{connectorProtocol, input}, String.class, StreamInput.class); + connector = Connector.fromStream(input); } connectorId = input.readOptionalString(); } @@ -263,7 +262,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(modelGroupId); if (connector != null) { out.writeBoolean(true); - out.writeString(connector.getProtocol()); connector.writeTo(out); } else { out.writeBoolean(false); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index e00c682cfe..29586bc856 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -199,6 +199,9 @@ private void parseFromStream(StreamInput input) throws IOException { if (input.readBoolean()) { this.access = input.readEnum(AccessMode.class); } + if (input.readBoolean()) { + this.owner = new User(input); + } } @Override @@ -235,6 +238,12 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + if (owner != null) { + out.writeBoolean(true); + owner.writeTo(out); + } else { + out.writeBoolean(false); + } } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index 67adc82b17..bcc9ff2da6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -145,8 +145,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException { this.deployModel = in.readBoolean(); this.modelNodeIds = in.readOptionalStringArray(); if (in.readBoolean()) { - String protocol = in.readString(); - this.connector = MLCommonsClassLoader.initConnector(protocol, new Object[]{protocol, in}, String.class, StreamInput.class); + this.connector = Connector.fromStream(in); } this.connectorId = in.readOptionalString(); if (in.readBoolean()) { @@ -183,7 +182,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalStringArray(modelNodeIds); if (connector != null) { out.writeBoolean(true); - out.writeString(connector.getProtocol()); connector.writeTo(out); } else { out.writeBoolean(false); diff --git a/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java b/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java new file mode 100644 index 0000000000..bcbcc149c9 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java @@ -0,0 +1,155 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.HttpConnectorTest; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +public class RemoteModelTests { + + @Test + public void toXContent_ConnectorId() throws IOException { + MLModel mlModel = MLModel.builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connectorId("test_connector_id") + .build(); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlModel.toXContent(builder, EMPTY_PARAMS); + String mlModelContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\",\"algorithm\":\"REMOTE\"" + + ",\"model_version\":\"1.0.0\",\"description\":\"test model\"," + + "\"connector_id\":\"test_connector_id\"}", mlModelContent); + } + + @Test + public void toXContent_InternalConnector() throws IOException { + Connector connector = HttpConnectorTest.createHttpConnector(); + MLModel mlModel = MLModel.builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connector(connector) + .build(); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlModel.toXContent(builder, EMPTY_PARAMS); + String mlModelContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\",\"algorithm\":\"REMOTE\"," + + "\"model_version\":\"1.0.0\",\"description\":\"test model\",\"connector\":{\"name\":\"test_connector_name\"," + + "\"version\":\"1\",\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"access\":\"public\"}}", mlModelContent); + } + + @Test + public void parse_ConnectorId() throws IOException { + MLModel mlModel = MLModel.builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connectorId("test_connector_id") + .build(); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlModel.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MLModel parsedModel = MLModel.parse(parser, FunctionName.REMOTE.name()); + Assert.assertNull(parsedModel.getConnector()); + Assert.assertEquals(mlModel.getConnectorId(), parsedModel.getConnectorId()); + } + + @Test + public void parse_InternalConnector() throws IOException { + Connector connector = HttpConnectorTest.createHttpConnector(); + MLModel mlModel = MLModel.builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connector(connector) + .build(); + + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlModel.toXContent(builder, EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + parser.nextToken(); + MLModel parsedModel = MLModel.parse(parser, FunctionName.REMOTE.name()); + Assert.assertEquals(mlModel.getConnector(), parsedModel.getConnector()); + } + + + @Test + public void readInputStream_ConnectorId() throws IOException { + MLModel mlModel = MLModel.builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connectorId("test_connector_id") + .build(); + readInputStream(mlModel); + } + + @Test + public void readInputStream_InternalConnector() throws IOException { + Connector connector = HttpConnectorTest.createHttpConnector(); + MLModel mlModel = MLModel.builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connector(connector) + .build(); + readInputStream(mlModel); + } + + public void readInputStream(MLModel mlModel) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlModel.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLModel parsedMLModel = new MLModel(streamInput); + assertEquals(mlModel.getName(), parsedMLModel.getName()); + assertEquals(mlModel.getAlgorithm(), parsedMLModel.getAlgorithm()); + assertEquals(mlModel.getVersion(), parsedMLModel.getVersion()); + assertEquals(mlModel.getContent(), parsedMLModel.getContent()); + assertEquals(mlModel.getUser(), parsedMLModel.getUser()); + assertEquals(mlModel.getConnectorId(), parsedMLModel.getConnectorId()); + assertEquals(mlModel.getConnector(), parsedMLModel.getConnector()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java index 0efb3e42ff..a242c213ea 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java @@ -10,9 +10,19 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchModule; +import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.Locale; import java.util.Map; @@ -108,9 +118,36 @@ public void constructor_NoPredictAction() { Assert.assertEquals("decrypted: ENCRYPTED: TEST_REGION", connector.getRegion()); } + @Test + public void constructor_Parser() throws IOException { + AwsConnector awsConnector = createAwsConnector(); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + awsConnector.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonStr = TestHelper.xContentBuilderToString(builder); + + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + + AwsConnector connector = new AwsConnector(awsConnector.getProtocol(), parser); + Assert.assertEquals(awsConnector, connector); + } + @Test public void constructor() { - AwsConnector connector = createAwsConnector(); + Map parameters = new HashMap<>(); + parameters.put("input", "test input value"); + parameters.put(SERVICE_NAME_FIELD, "test_service"); + parameters.put(REGION_FIELD, "us-west-2"); + parameters.put("endpoint", "test.com"); + + Map credential = new HashMap<>(); + credential.put(ACCESS_KEY_FIELD, "test_access_key"); + credential.put(SECRET_KEY_FIELD, "test_secret_key"); + credential.put(SESSION_TOKEN_FIELD, "test_session_token"); + String url = "https://${parameters.endpoint}/model1"; + + AwsConnector connector = createAwsConnector(parameters, credential, url); connector.encrypt(encryptFunction); connector.decrypt(decryptFunction); Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey()); @@ -118,6 +155,28 @@ public void constructor() { Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken()); Assert.assertEquals("test_service", connector.getServiceName()); Assert.assertEquals("us-west-2", connector.getRegion()); + Assert.assertEquals("https://test.com/model1", connector.getPredictEndpoint(parameters)); + } + + @Test + public void constructor_NoParameter() { + Map credential = new HashMap<>(); + credential.put(ACCESS_KEY_FIELD, "test_access_key"); + credential.put(SECRET_KEY_FIELD, "test_secret_key"); + credential.put(SESSION_TOKEN_FIELD, "test_session_token"); + credential.put(SERVICE_NAME_FIELD, "test_service"); + credential.put(REGION_FIELD, "us-west-2"); + + String url = "https://test.com"; + AwsConnector connector = createAwsConnector(null, credential, url); + connector.encrypt(encryptFunction); + connector.decrypt(decryptFunction); + Assert.assertEquals("decrypted: ENCRYPTED: TEST_ACCESS_KEY", connector.getAccessKey()); + Assert.assertEquals("decrypted: ENCRYPTED: TEST_SECRET_KEY", connector.getSecretKey()); + Assert.assertEquals("decrypted: ENCRYPTED: TEST_SESSION_TOKEN", connector.getSessionToken()); + Assert.assertEquals("decrypted: ENCRYPTED: TEST_SERVICE", connector.getServiceName()); + Assert.assertEquals("decrypted: ENCRYPTED: US-WEST-2", connector.getRegion()); + Assert.assertEquals("https://test.com", connector.getPredictEndpoint(null)); } @Test @@ -128,17 +187,6 @@ public void cloneConnector() { } private AwsConnector createAwsConnector() { - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "POST"; - String url = "https://test.com"; - Map headers = new HashMap<>(); - headers.put("api_key", "${credential.key}"); - String requestBody = "{\"input\": \"${parameters.input}\"}"; - String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; - String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; - - ConnectorAction action = new ConnectorAction(actionType, method, url, headers, requestBody, preProcessFunction, postProcessFunction); - Map parameters = new HashMap<>(); parameters.put("input", "test input value"); parameters.put(SERVICE_NAME_FIELD, "test_service"); @@ -148,6 +196,20 @@ private AwsConnector createAwsConnector() { credential.put(ACCESS_KEY_FIELD, "test_access_key"); credential.put(SECRET_KEY_FIELD, "test_secret_key"); credential.put(SESSION_TOKEN_FIELD, "test_session_token"); + String url = "https://test.com"; + return createAwsConnector(parameters, credential, url); + } + + private AwsConnector createAwsConnector(Map parameters, Map credential, String url) { + ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String method = "POST"; + Map headers = new HashMap<>(); + headers.put("api_key", "${credential.key}"); + String requestBody = "{\"input\": \"${parameters.input}\"}"; + String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; + String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; + + ConnectorAction action = new ConnectorAction(actionType, method, url, headers, requestBody, preProcessFunction, postProcessFunction); AwsConnector connector = AwsConnector.awsConnectorBuilder() .name("test_connector_name") diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java index cb7b61ca50..4c42f8361f 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java @@ -19,6 +19,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.connector.HttpConnectorTest; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; @@ -230,6 +232,42 @@ public void readInputStream_SuccessWithNullFields() throws IOException { }); } + @Test + public void readInputStream_WithConnectorId() throws IOException { + String connectorId = "test_connector_id"; + input = MLRegisterModelInput.builder() + .functionName(FunctionName.REMOTE) + .modelName(modelName) + .description("test model input") + .version(version) + .modelGroupId(modelGroupId) + .connectorId(connectorId) + .build(); + readInputStream(input, parsedInput -> { + assertNull(parsedInput.getModelConfig()); + assertNull(parsedInput.getModelFormat()); + assertNull(parsedInput.getConnector()); + assertEquals(connectorId, parsedInput.getConnectorId()); + }); + } + + @Test + public void readInputStream_WithInternalConnector() throws IOException { + HttpConnector connector = HttpConnectorTest.createHttpConnector(); + input = MLRegisterModelInput.builder() + .functionName(FunctionName.REMOTE) + .modelName(modelName) + .description("test model input") + .version(version) + .modelGroupId(modelGroupId) + .connector(connector) + .build(); + readInputStream(input, parsedInput -> { + assertNull(parsedInput.getModelConfig()); + assertNull(parsedInput.getModelFormat()); + assertEquals(input.getConnector(), parsedInput.getConnector()); + }); + } private void readInputStream(MLRegisterModelInput input, Consumer verify) throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();