From 1602aea9f5da01c9a6c82f1bb8822c721d9cc92b Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Mon, 10 Jul 2023 19:55:37 -0700 Subject: [PATCH] remote inference: add unit test for create connector request/response (#1067) * remote inference: add unit test for create connector request/response Signed-off-by: Yaliang Wu * fix failed UT Signed-off-by: Yaliang Wu * fix failed UT Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu --- .../connector/MLCreateConnectorInput.java | 33 +++++++-- .../MLCreateConnectorRequestTest.java | 72 +++++++++++++++++++ .../MLCreateConnectorResponseTest.java | 38 ++++++++++ .../TransportCreateConnectorAction.java | 2 +- .../TransportRegisterModelAction.java | 6 +- .../TransportCreateConnectorActionTests.java | 3 + .../TransportRegisterModelActionTests.java | 1 + 7 files changed, 144 insertions(+), 11 deletions(-) create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index d2d9dfd80e..dc7c6f0b0d 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -39,6 +39,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable { public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles"; public static final String OWNER_FIELD = "owner"; public static final String ACCESS_MODE_FIELD = "access_mode"; + public static final String DRY_RUN_FIELD = "dry_run"; public static final String DRY_RUN_CONNECTOR_NAME = "dryRunConnector"; @@ -52,6 +53,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable { private List backendRoles; private Boolean addAllBackendRoles; private AccessMode access; + private boolean dryRun = false; @Builder(toBuilder = true) public MLCreateConnectorInput(String name, @@ -63,8 +65,20 @@ public MLCreateConnectorInput(String name, List actions, List backendRoles, Boolean addAllBackendRoles, - AccessMode access + AccessMode access, + boolean dryRun ) { + if (!dryRun) { + if (name == null) { + throw new IllegalArgumentException("Connector name is null"); + } + if (version == null) { + throw new IllegalArgumentException("Connector version is null"); + } + if (protocol == null) { + throw new IllegalArgumentException("Connector protocol is null"); + } + } this.name = name; this.description = description; this.version = version; @@ -75,6 +89,7 @@ public MLCreateConnectorInput(String name, this.backendRoles = backendRoles; this.addAllBackendRoles = addAllBackendRoles; this.access = access; + this.dryRun = dryRun; } public static MLCreateConnectorInput parse(XContentParser parser) throws IOException { @@ -88,6 +103,7 @@ public static MLCreateConnectorInput parse(XContentParser parser) throws IOExcep List backendRoles = null; Boolean addAllBackendRoles = null; AccessMode access = null; + boolean dryRun = false; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -133,12 +149,15 @@ public static MLCreateConnectorInput parse(XContentParser parser) throws IOExcep case ACCESS_MODE_FIELD: access = AccessMode.from(parser.text()); break; + case DRY_RUN_FIELD: + dryRun = parser.booleanValue(); + break; default: parser.skipChildren(); break; } } - return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access); + return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun); } @Override @@ -181,7 +200,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public void writeTo(StreamOutput output) throws IOException { output.writeString(name); - output.writeString(description); + output.writeOptionalString(description); output.writeString(version); output.writeString(protocol); if (parameters != null) { @@ -211,20 +230,19 @@ public void writeTo(StreamOutput output) throws IOException { } else { output.writeBoolean(false); } - if (addAllBackendRoles != null) { - output.writeBoolean(addAllBackendRoles); - } + output.writeOptionalBoolean(addAllBackendRoles); if (access != null) { output.writeBoolean(true); output.writeEnum(access); } else { output.writeBoolean(false); } + output.writeBoolean(dryRun); } public MLCreateConnectorInput(StreamInput input) throws IOException { name = input.readString(); - description = input.readString(); + description = input.readOptionalString(); version = input.readString(); protocol = input.readString(); if (input.readBoolean()) { @@ -247,5 +265,6 @@ public MLCreateConnectorInput(StreamInput input) throws IOException { if (input.readBoolean()) { this.access = input.readEnum(AccessMode.class); } + dryRun = input.readBoolean(); } } \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTest.java new file mode 100644 index 0000000000..4fb636a6e7 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTest.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; + +import java.io.IOException; +import java.io.UncheckedIOException; + +public class MLCreateConnectorRequestTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Test + public void validate_nullInput() { + MLCreateConnectorRequest request = new MLCreateConnectorRequest((MLCreateConnectorInput)null); + ActionRequestValidationException exception = request.validate(); + Assert.assertTrue(exception.getMessage().contains("ML Connector input can't be null")); + } + + @Test + public void readFromStream() throws IOException { + MLCreateConnectorInput input = MLCreateConnectorInput.builder() + .name("test_connector") + .protocol("http") + .version("1") + .description("test") + .build(); + MLCreateConnectorRequest request = new MLCreateConnectorRequest(input); + BytesStreamOutput output = new BytesStreamOutput(); + request.writeTo(output); + MLCreateConnectorRequest request2 = new MLCreateConnectorRequest(output.bytes().streamInput()); + Assert.assertEquals("test_connector", request2.getMlCreateConnectorInput().getName()); + Assert.assertEquals("http", request2.getMlCreateConnectorInput().getProtocol()); + Assert.assertEquals("1", request2.getMlCreateConnectorInput().getVersion()); + Assert.assertEquals("test", request2.getMlCreateConnectorInput().getDescription()); + } + + @Test + public void fromActionRequest() { + MLCreateConnectorInput input = MLCreateConnectorInput.builder() + .name("test_connector") + .protocol("http") + .version("1") + .description("test") + .build(); + ActionRequest request = new MLCreateConnectorRequest(input); + MLCreateConnectorRequest request2 = MLCreateConnectorRequest.fromActionRequest(request); + Assert.assertEquals("test_connector", request2.getMlCreateConnectorInput().getName()); + Assert.assertEquals("http", request2.getMlCreateConnectorInput().getProtocol()); + Assert.assertEquals("1", request2.getMlCreateConnectorInput().getVersion()); + Assert.assertEquals("test", request2.getMlCreateConnectorInput().getDescription()); + } + + @Test + public void fromActionRequest_Exception() { + exceptionRule.expect(UncheckedIOException.class); + exceptionRule.expectMessage("Failed to parse ActionRequest into MLCreateConnectorRequest"); + ActionRequest request = new MLConnectorGetRequest("test_id", true); + MLCreateConnectorRequest.fromActionRequest(request); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTest.java new file mode 100644 index 0000000000..4a829d16f7 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTest.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.TestHelper; + +import java.io.IOException; + +public class MLCreateConnectorResponseTest { + + @Test + public void toXContent() throws IOException { + MLCreateConnectorResponse response = new MLCreateConnectorResponse("test_id"); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + Assert.assertEquals("{\"connector_id\":\"test_id\"}", content); + } + + @Test + public void readFromStream() throws IOException { + MLCreateConnectorResponse response = new MLCreateConnectorResponse("test_id"); + BytesStreamOutput output = new BytesStreamOutput(); + response.writeTo(output); + + MLCreateConnectorResponse response2 = new MLCreateConnectorResponse(output.bytes().streamInput()); + Assert.assertEquals("test_id", response2.getConnectorId()); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index aaa07ebedd..a0bf60788e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -84,7 +84,7 @@ public TransportCreateConnectorAction( protected void doExecute(Task task, ActionRequest request, ActionListener listener) { MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.fromActionRequest(request); MLCreateConnectorInput mlCreateConnectorInput = mlCreateConnectorRequest.getMlCreateConnectorInput(); - if (MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME.equals(mlCreateConnectorInput.getName())) { + if (mlCreateConnectorInput.isDryRun()) { MLCreateConnectorResponse response = new MLCreateConnectorResponse(MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME); listener.onResponse(response); return; diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 51645228ba..ce0cc0046f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -184,7 +184,7 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener< log.error(e.getMessage(), e); listener.onFailure(e); }); - MLCreateConnectorRequest mlCreateConnectorRequest = createConnectorRequest(); + MLCreateConnectorRequest mlCreateConnectorRequest = createDryRunConnectorRequest(); client.execute(MLCreateConnectorAction.INSTANCE, mlCreateConnectorRequest, dryRunResultListener); } } else { @@ -207,8 +207,8 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis } } - private MLCreateConnectorRequest createConnectorRequest() { - MLCreateConnectorInput createConnectorInput = MLCreateConnectorInput.builder().name("dryRunConnector").build(); + private MLCreateConnectorRequest createDryRunConnectorRequest() { + MLCreateConnectorInput createConnectorInput = MLCreateConnectorInput.builder().dryRun(true).build(); return new MLCreateConnectorRequest(createConnectorInput); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java index 58d1a3c7d7..5f07a0e472 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java @@ -141,6 +141,8 @@ public void setup() { Map credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret"); input = MLCreateConnectorInput .builder() + .name("test_name") + .version("1") .actions(actions) .parameters(parameters) .protocol(ConnectorProtocols.HTTP) @@ -430,6 +432,7 @@ public void test_execute_dryRun_connector_creation() { MLCreateConnectorInput mlCreateConnectorInput = mock(MLCreateConnectorInput.class); when(mlCreateConnectorInput.getName()).thenReturn(MLCreateConnectorInput.DRY_RUN_CONNECTOR_NAME); + when(mlCreateConnectorInput.isDryRun()).thenReturn(true); MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput); action.doExecute(task, request, actionListener); verify(actionListener).onResponse(any(MLCreateConnectorResponse.class)); diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index 9d73d708ab..55737ad3dc 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -409,6 +409,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_success() { MLRegisterModelInput input = mock(MLRegisterModelInput.class); when(request.getRegisterModelInput()).thenReturn(input); when(input.getModelName()).thenReturn("Test Model"); + when(input.getVersion()).thenReturn("1"); when(input.getModelGroupId()).thenReturn("modelGroupID"); when(input.getFunctionName()).thenReturn(FunctionName.REMOTE); Connector connector = mock(Connector.class);