From dacd51078870d766488f4e77478d4e4bc28b682a Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Tue, 11 Jul 2023 20:50:03 -0700 Subject: [PATCH] Add more UT for remote inference classes (#1077) --- common/build.gradle | 4 +- .../connector/MLCreateConnectorInput.java | 2 +- .../MLConnectorDeleteRequestTests.java | 100 ++++++++ .../connector/MLConnectorGetRequestTests.java | 98 ++++++++ .../MLConnectorGetResponseTests.java | 107 +++++++++ .../MLCreateConnectorInputTests.java | 217 ++++++++++++++++++ .../MLCreateConnectorRequestTest.java | 72 ------ .../MLCreateConnectorRequestTests.java | 136 +++++++++++ ...va => MLCreateConnectorResponseTests.java} | 2 +- .../model/MLModelDeleteRequestTest.java | 20 ++ .../model/MLModelGetRequestTest.java | 17 ++ .../model/MLModelGetResponseTest.java | 40 +++- .../MLRegisterModelGroupRequestTest.java | 13 +- 13 files changed, 742 insertions(+), 86 deletions(-) create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java delete mode 100644 common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java rename common/src/test/java/org/opensearch/ml/common/transport/connector/{MLCreateConnectorResponseTest.java => MLCreateConnectorResponseTests.java} (96%) diff --git a/common/build.gradle b/common/build.gradle index 6e235683a7..024540091d 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -40,11 +40,11 @@ jacocoTestCoverageVerification { rule { limit { counter = 'LINE' - minimum = 0.6 //TODO: add more test to meet the coverage bar 0.9 + minimum = 0.8 //TODO: add more test to meet the coverage bar 0.9 } limit { counter = 'BRANCH' - minimum = 0.5 //TODO: add more test to meet the coverage bar 0.9 + minimum = 0.7 //TODO: add more test to meet the coverage bar 0.9 } } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index 17e32e1c84..1652509bf7 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -227,7 +227,7 @@ public void writeTo(StreamOutput output) throws IOException { } if (!CollectionUtils.isEmpty(backendRoles)) { output.writeBoolean(true); - output.writeOptionalStringCollection(backendRoles); + output.writeStringCollection(backendRoles); } else { output.writeBoolean(false); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java new file mode 100644 index 0000000000..27bb438599 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +public class MLConnectorDeleteRequestTests { + private String connectorId; + + @Before + public void setUp() { + connectorId = "test-connector-id"; + } + + @Test + public void writeTo_Success() throws IOException { + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() + .connectorId(connectorId).build(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlConnectorDeleteRequest.writeTo(bytesStreamOutput); + MLConnectorDeleteRequest parsedConnector = new MLConnectorDeleteRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(parsedConnector.getConnectorId(), connectorId); + } + + @Test + public void valid_Exception_NullConnectorId() { + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder().build(); + ActionRequestValidationException exception = mlConnectorDeleteRequest.validate(); + assertEquals("Validation Failed: 1: ML connector id can't be null;", exception.getMessage()); + } + + @Test + public void validate_Success() { + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() + .connectorId(connectorId).build(); + ActionRequestValidationException actionRequestValidationException = mlConnectorDeleteRequest.validate(); + assertNull(actionRequestValidationException); + } + + @Test + public void fromActionRequest_Success() { + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() + .connectorId(connectorId).build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlConnectorDeleteRequest.writeTo(out); + } + }; + MLConnectorDeleteRequest parsedConnector = MLConnectorDeleteRequest.fromActionRequest(actionRequest); + assertNotSame(parsedConnector, mlConnectorDeleteRequest); + assertEquals(parsedConnector.getConnectorId(), connectorId); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLConnectorDeleteRequest.fromActionRequest(actionRequest); + } + + @Test + public void fromActionRequestWithConnectorDeleteRequest_Success() { + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() + .connectorId(connectorId).build(); + MLConnectorDeleteRequest mlConnectorDeleteRequestFromActionRequest = MLConnectorDeleteRequest.fromActionRequest(mlConnectorDeleteRequest); + assertSame(mlConnectorDeleteRequest, mlConnectorDeleteRequestFromActionRequest); + assertEquals(mlConnectorDeleteRequest.getConnectorId(), mlConnectorDeleteRequestFromActionRequest.getConnectorId()); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java new file mode 100644 index 0000000000..0113aca7d6 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.ml.common.transport.connector; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamOutput; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +public class MLConnectorGetRequestTests { + private String connectorId; + + @Before + public void setUp() { + connectorId = "test-connector-id"; + } + + @Test + public void writeTo_Success() throws IOException { + MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(connectorId).build(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlConnectorGetRequest.writeTo(bytesStreamOutput); + MLConnectorGetRequest parsedConnector = new MLConnectorGetRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(connectorId, parsedConnector.getConnectorId()); + } + + @Test + public void fromActionRequest_Success() { + MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(connectorId).build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlConnectorGetRequest.writeTo(out); + } + }; + MLConnectorGetRequest mlConnectorGetRequestFromActionRequest = MLConnectorGetRequest.fromActionRequest(actionRequest); + assertNotSame(mlConnectorGetRequest, mlConnectorGetRequestFromActionRequest); + assertEquals(mlConnectorGetRequest.getConnectorId(), mlConnectorGetRequestFromActionRequest.getConnectorId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLConnectorGetRequest.fromActionRequest(actionRequest); + } + + @Test + public void fromActionRequestWithMLConnectorGetRequest_Success() { + MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(connectorId).build(); + MLConnectorGetRequest mlConnectorGetRequestFromActionRequest = MLConnectorGetRequest.fromActionRequest(mlConnectorGetRequest); + assertSame(mlConnectorGetRequest, mlConnectorGetRequestFromActionRequest); + assertEquals(mlConnectorGetRequest.getConnectorId(), mlConnectorGetRequestFromActionRequest.getConnectorId()); + } + + @Test + public void validate_Exception_NullConnctorId() { + MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().build(); + ActionRequestValidationException actionRequestValidationException = mlConnectorGetRequest.validate(); + assertEquals("Validation Failed: 1: ML connector id can't be null;", actionRequestValidationException.getMessage()); + } + + @Test + public void validate_Success() { + MLConnectorGetRequest mlConnectorGetRequest = MLConnectorGetRequest.builder().connectorId(connectorId).build(); + ActionRequestValidationException actionRequestValidationException = mlConnectorGetRequest.validate(); + assertNull(actionRequestValidationException); + } +} + diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java new file mode 100644 index 0000000000..69a19255e8 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionResponse; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.HttpConnectorTest; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; + +public class MLConnectorGetResponseTests { + Connector connector; + + @Before + public void setUp() { + connector = HttpConnectorTest.createHttpConnector(); + } + + @Test + public void writeTo_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + MLConnectorGetResponse response = MLConnectorGetResponse.builder().mlConnector(connector).build(); + response.writeTo(bytesStreamOutput); + MLConnectorGetResponse parsedResponse = new MLConnectorGetResponse(bytesStreamOutput.bytes().streamInput()); + assertNotEquals(response, parsedResponse); + assertNotSame(response.mlConnector, parsedResponse.mlConnector); + assertEquals(response.mlConnector, parsedResponse.mlConnector); + assertEquals(response.mlConnector.getName(), parsedResponse.mlConnector.getName()); + assertEquals(response.mlConnector.getAccess(), parsedResponse.mlConnector.getAccess()); + assertEquals(response.mlConnector.getProtocol(), parsedResponse.mlConnector.getProtocol()); + assertEquals(response.mlConnector.getDecryptedHeaders(), parsedResponse.mlConnector.getDecryptedHeaders()); + assertEquals(response.mlConnector.getBackendRoles(), parsedResponse.mlConnector.getBackendRoles()); + assertEquals(response.mlConnector.getActions(), parsedResponse.mlConnector.getActions()); + assertEquals(response.mlConnector.getParameters(), parsedResponse.mlConnector.getParameters()); + } + + @Test + public void toXContentTest() throws IOException { + MLConnectorGetResponse mlConnectorGetResponse = MLConnectorGetResponse.builder().mlConnector(connector).build(); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + mlConnectorGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = Strings.toString(builder); + assertEquals("{\"name\":\"test_connector_name\"," + + "\"version\":\"1\",\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"access\":\"public\"}", jsonStr); + } + + @Test + public void fromActionResponseWithMLConnectorGetResponse_Success() { + MLConnectorGetResponse mlConnectorGetResponse = MLConnectorGetResponse.builder().mlConnector(connector).build(); + MLConnectorGetResponse mlConnectorGetResponseFromActionResponse = MLConnectorGetResponse.fromActionResponse(mlConnectorGetResponse); + assertSame(mlConnectorGetResponse, mlConnectorGetResponseFromActionResponse); + assertEquals(mlConnectorGetResponse.mlConnector, mlConnectorGetResponseFromActionResponse.mlConnector); + } + + @Test + public void fromActionResponse_Success() { + MLConnectorGetResponse mlConnectorGetResponse = MLConnectorGetResponse.builder().mlConnector(connector).build(); + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + mlConnectorGetResponse.writeTo(out); + } + }; + MLConnectorGetResponse mlConnectorGetResponseFromActionResponse = MLConnectorGetResponse.fromActionResponse(actionResponse); + assertNotSame(mlConnectorGetResponse, mlConnectorGetResponseFromActionResponse); + assertEquals(mlConnectorGetResponse.mlConnector, mlConnectorGetResponseFromActionResponse.mlConnector); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionResponse_IOException() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLConnectorGetResponse.fromActionResponse(actionResponse); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java new file mode 100644 index 0000000000..c037da1529 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -0,0 +1,217 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.MLPostProcessFunction; +import org.opensearch.ml.common.connector.MLPreProcessFunction; +import org.opensearch.search.SearchModule; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +public class MLCreateConnectorInputTests { + private MLCreateConnectorInput mlCreateConnectorInput; + private MLCreateConnectorInput mlCreateDryRunConnectorInput; + + @Rule + public final ExpectedException exceptionRule = ExpectedException.none(); + private final String expectedInputStr = "{\"name\":\"test_connector_name\"," + + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," + + "\"access_mode\":\"PUBLIC\"}"; + + @Before + public void setUp(){ + ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String method = "POST"; + String url = "https://test.com"; + Map headers = new HashMap<>(); + headers.put("api_key", "${credential.key}"); + String mlCreateConnectorRequestBody = "{\"input\": \"${parameters.input}\"}"; + String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; + String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; + ConnectorAction action = new ConnectorAction(actionType, method, url, headers, mlCreateConnectorRequestBody, preProcessFunction, postProcessFunction); + + mlCreateConnectorInput = MLCreateConnectorInput.builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of(action)) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + + mlCreateDryRunConnectorInput = MLCreateConnectorInput.builder() + .dryRun(true) + .build(); + } + + @Test + public void constructorMLCreateConnectorInput_NullName() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Connector name is null"); + MLCreateConnectorInput.builder() + .name(null) + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + } + + @Test + public void constructorMLCreateConnectorInput_NullVersion() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Connector version is null"); + MLCreateConnectorInput.builder() + .name("test_connector_name") + .description("this is a test connector") + .version(null) + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + } + + @Test + public void constructorMLCreateConnectorInput_NullProtocol() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Connector protocol is null"); + MLCreateConnectorInput.builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol(null) + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + } + + @Test + public void testToXContent_FullFields() throws Exception { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + mlCreateConnectorInput.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = Strings.toString(builder); + assertEquals(expectedInputStr, jsonStr); + } + + @Test + public void testToXContent_NullFields() throws Exception { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + mlCreateDryRunConnectorInput.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = Strings.toString(builder); + assertEquals("{}", jsonStr); + } + + @Test + public void testParse() throws Exception { + testParseFromJsonString(expectedInputStr, parsedInput -> { + assertEquals("test_connector_name", parsedInput.getName()); + }); + } + + @Test + public void testParseWithDryRun() throws Exception { + String expectedInputStrWithDryRun = "{\"dry_run\":true}"; + testParseFromJsonString(expectedInputStrWithDryRun, parsedInput -> { + assertNull(parsedInput.getName()); + assertTrue(parsedInput.isDryRun()); + }); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(mlCreateConnectorInput, parsedInput -> assertEquals(mlCreateConnectorInput.getName(), parsedInput.getName())); + } + + @Test + public void readInputStream_SuccessWithNullFields() throws IOException { + MLCreateConnectorInput mlCreateMinimalConnectorInput = MLCreateConnectorInput.builder() + .name("test_connector_name") + .version("1") + .protocol("http") + .build(); + readInputStream(mlCreateMinimalConnectorInput, parsedInput -> { + assertEquals(mlCreateMinimalConnectorInput.getName(), parsedInput.getName()); + assertNull(parsedInput.getActions()); + }); + } + + private void testParseFromJsonString(String expectedInputString, Consumer verify) throws Exception { + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputString); + parser.nextToken(); + MLCreateConnectorInput parsedInput = MLCreateConnectorInput.parse(parser); + verify.accept(parsedInput); + } + + private void readInputStream(MLCreateConnectorInput input, Consumer verify) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + input.writeTo(bytesStreamOutput); + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLCreateConnectorInput parsedInput = new MLCreateConnectorInput(streamInput); + verify.accept(parsedInput); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTest.java deleted file mode 100644 index 4fb636a6e7..0000000000 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTest.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.connector; - -import org.junit.Assert; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.common.io.stream.BytesStreamOutput; - -import java.io.IOException; -import java.io.UncheckedIOException; - -public class MLCreateConnectorRequestTest { - - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); - - @Test - public void validate_nullInput() { - MLCreateConnectorRequest request = new MLCreateConnectorRequest((MLCreateConnectorInput)null); - ActionRequestValidationException exception = request.validate(); - Assert.assertTrue(exception.getMessage().contains("ML Connector input can't be null")); - } - - @Test - public void readFromStream() throws IOException { - MLCreateConnectorInput input = MLCreateConnectorInput.builder() - .name("test_connector") - .protocol("http") - .version("1") - .description("test") - .build(); - MLCreateConnectorRequest request = new MLCreateConnectorRequest(input); - BytesStreamOutput output = new BytesStreamOutput(); - request.writeTo(output); - MLCreateConnectorRequest request2 = new MLCreateConnectorRequest(output.bytes().streamInput()); - Assert.assertEquals("test_connector", request2.getMlCreateConnectorInput().getName()); - Assert.assertEquals("http", request2.getMlCreateConnectorInput().getProtocol()); - Assert.assertEquals("1", request2.getMlCreateConnectorInput().getVersion()); - Assert.assertEquals("test", request2.getMlCreateConnectorInput().getDescription()); - } - - @Test - public void fromActionRequest() { - MLCreateConnectorInput input = MLCreateConnectorInput.builder() - .name("test_connector") - .protocol("http") - .version("1") - .description("test") - .build(); - ActionRequest request = new MLCreateConnectorRequest(input); - MLCreateConnectorRequest request2 = MLCreateConnectorRequest.fromActionRequest(request); - Assert.assertEquals("test_connector", request2.getMlCreateConnectorInput().getName()); - Assert.assertEquals("http", request2.getMlCreateConnectorInput().getProtocol()); - Assert.assertEquals("1", request2.getMlCreateConnectorInput().getVersion()); - Assert.assertEquals("test", request2.getMlCreateConnectorInput().getDescription()); - } - - @Test - public void fromActionRequest_Exception() { - exceptionRule.expect(UncheckedIOException.class); - exceptionRule.expectMessage("Failed to parse ActionRequest into MLCreateConnectorRequest"); - ActionRequest request = new MLConnectorGetRequest("test_id", true); - MLCreateConnectorRequest.fromActionRequest(request); - } -} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java new file mode 100644 index 0000000000..6fe82d7f2b --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.MLPostProcessFunction; +import org.opensearch.ml.common.connector.MLPreProcessFunction; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +public class MLCreateConnectorRequestTests { + private MLCreateConnectorInput mlCreateConnectorInput; + + @Before + public void setUp(){ + ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String method = "POST"; + String url = "https://test.com"; + Map headers = new HashMap<>(); + headers.put("api_key", "${credential.key}"); + String mlCreateConnectorRequestBody = "{\"input\": \"${parameters.input}\"}"; + String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; + String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; + ConnectorAction action = new ConnectorAction(actionType, method, url, headers, mlCreateConnectorRequestBody, preProcessFunction, postProcessFunction); + + mlCreateConnectorInput = MLCreateConnectorInput.builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of(action)) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + } + + @Test + public void writeTo_Success() throws IOException { + MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder().mlCreateConnectorInput(mlCreateConnectorInput).build(); + BytesStreamOutput output = new BytesStreamOutput(); + mlCreateConnectorRequest.writeTo(output); + MLCreateConnectorRequest parsedRequest = new MLCreateConnectorRequest(output.bytes().streamInput()); + assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getName(), parsedRequest.getMlCreateConnectorInput().getName()); + assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getAccess(), parsedRequest.getMlCreateConnectorInput().getAccess()); + assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getProtocol(), parsedRequest.getMlCreateConnectorInput().getProtocol()); + assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getBackendRoles(), parsedRequest.getMlCreateConnectorInput().getBackendRoles()); + assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getActions(), parsedRequest.getMlCreateConnectorInput().getActions()); + assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getParameters(), parsedRequest.getMlCreateConnectorInput().getParameters()); + } + + @Test + public void validate_Success() { + MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder() + .mlCreateConnectorInput(mlCreateConnectorInput) + .build(); + + assertNull(mlCreateConnectorRequest.validate()); + } + + @Test + public void validate_Exception_NullMLRegisterModelGroupInput() { + MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder() + .build(); + ActionRequestValidationException exception = mlCreateConnectorRequest.validate(); + assertEquals("Validation Failed: 1: ML Connector input can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequest_Success_WithMLRegisterModelRequest() { + MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder() + .mlCreateConnectorInput(mlCreateConnectorInput) + .build(); + assertSame(MLCreateConnectorRequest.fromActionRequest(mlCreateConnectorRequest), mlCreateConnectorRequest); + } + + @Test + public void fromActionRequest_Success_WithNonMLRegisterModelRequest() { + MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder() + .mlCreateConnectorInput(mlCreateConnectorInput) + .build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlCreateConnectorRequest.writeTo(out); + } + }; + MLCreateConnectorRequest result = MLCreateConnectorRequest.fromActionRequest(actionRequest); + assertNotSame(result, mlCreateConnectorRequest); + assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getName(), result.getMlCreateConnectorInput().getName()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLCreateConnectorRequest.fromActionRequest(actionRequest); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java similarity index 96% rename from common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTest.java rename to common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java index 4a829d16f7..8d58047980 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java @@ -15,7 +15,7 @@ import java.io.IOException; -public class MLCreateConnectorResponseTest { +public class MLCreateConnectorResponseTests { @Test public void toXContent() throws IOException { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java index 21081cd48a..533b96ecdf 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java @@ -17,6 +17,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; public class MLModelDeleteRequestTest { private String modelId; @@ -36,6 +38,14 @@ public void writeTo_Success() throws IOException { assertEquals(parsedModel.getModelId(), modelId); } + @Test + public void validate_Success() { + MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder() + .modelId(modelId).build(); + ActionRequestValidationException actionRequestValidationException = mlModelDeleteRequest.validate(); + assertNull(actionRequestValidationException); + } + @Test public void validate_Exception_NullModelId() { MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().build(); @@ -79,4 +89,14 @@ public void writeTo(StreamOutput out) throws IOException { }; MLModelDeleteRequest.fromActionRequest(actionRequest); } + + + @Test + public void fromActionRequestWithModelDeleteRequest_Success() { + MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder() + .modelId(modelId).build(); + MLModelDeleteRequest mlModelDeleteRequestFromActionRequest = MLModelDeleteRequest.fromActionRequest(mlModelDeleteRequest); + assertSame(mlModelDeleteRequest, mlModelDeleteRequestFromActionRequest); + assertEquals(mlModelDeleteRequest.getModelId(), mlModelDeleteRequestFromActionRequest.getModelId()); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java index f7d0c679c5..97f784d868 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java @@ -17,6 +17,8 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; public class MLModelGetRequestTest { private String modelId; @@ -79,4 +81,19 @@ public void writeTo(StreamOutput out) throws IOException { }; MLModelGetRequest.fromActionRequest(actionRequest); } + + @Test + public void validate_Success() { + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build(); + ActionRequestValidationException actionRequestValidationException = mlModelGetRequest.validate(); + assertNull(actionRequestValidationException); + } + + @Test + public void fromActionRequestWithMLModelGetRequest_Success() { + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build(); + MLModelGetRequest mlModelGetRequestFromActionRequest = MLModelGetRequest.fromActionRequest(mlModelGetRequest); + assertSame(mlModelGetRequest, mlModelGetRequestFromActionRequest); + assertEquals(mlModelGetRequest.getModelId(), mlModelGetRequestFromActionRequest.getModelId()); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java index fc957f0dbe..81a25ce8d8 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java @@ -7,11 +7,11 @@ import org.junit.Before; import org.junit.Test; -import org.opensearch.core.common.Strings; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -20,6 +20,7 @@ import org.opensearch.ml.common.model.MLModelState; import java.io.IOException; +import java.io.UncheckedIOException; import static org.junit.Assert.*; @@ -67,4 +68,37 @@ public void toXContentTest() throws IOException { "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null},\"model_state\":\"TRAINED\"}", jsonStr); } + + @Test + public void fromActionResponseWithMLModelGetResponse_Success() { + MLModelGetResponse mlModelGetResponse = MLModelGetResponse.builder().mlModel(mlModel).build(); + MLModelGetResponse mlModelGetResponseFromActionResponse = MLModelGetResponse.fromActionResponse(mlModelGetResponse); + assertSame(mlModelGetResponse, mlModelGetResponseFromActionResponse); + assertEquals(mlModelGetResponse.mlModel, mlModelGetResponseFromActionResponse.mlModel); + } + + @Test + public void fromActionResponse_Success() { + MLModelGetResponse mlModelGetResponse = MLModelGetResponse.builder().mlModel(mlModel).build(); + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + mlModelGetResponse.writeTo(out); + } + }; + MLModelGetResponse mlModelGetResponseFromActionResponse = MLModelGetResponse.fromActionResponse(actionResponse); + assertNotSame(mlModelGetResponse, mlModelGetResponseFromActionResponse); + assertNotEquals(mlModelGetResponse.mlModel, mlModelGetResponseFromActionResponse.mlModel); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionResponse_IOException() { + ActionResponse actionResponse = new ActionResponse() { + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLModelGetResponse.fromActionResponse(actionResponse); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java index cf948cc1d9..8e27325e47 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java @@ -35,18 +35,17 @@ public void setUp(){ @Test public void writeTo_Success() throws IOException { - MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder() .registerModelGroupInput(mlRegisterModelGroupInput) .build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); - request = new MLRegisterModelGroupRequest(bytesStreamOutput.bytes().streamInput()); - assertEquals("name", request.getRegisterModelGroupInput().getName()); - assertEquals("description", request.getRegisterModelGroupInput().getDescription()); - assertEquals("IT", request.getRegisterModelGroupInput().getBackendRoles().get(0)); - assertEquals(AccessMode.RESTRICTED, request.getRegisterModelGroupInput().getModelAccessMode()); - assertEquals(true, request.getRegisterModelGroupInput().getIsAddAllBackendRoles()); + MLRegisterModelGroupRequest parsedRequest = new MLRegisterModelGroupRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(request.getRegisterModelGroupInput().getName(), parsedRequest.getRegisterModelGroupInput().getName()); + assertEquals(request.getRegisterModelGroupInput().getDescription(), parsedRequest.getRegisterModelGroupInput().getDescription()); + assertEquals(request.getRegisterModelGroupInput().getBackendRoles().get(0), parsedRequest.getRegisterModelGroupInput().getBackendRoles().get(0)); + assertEquals(request.getRegisterModelGroupInput().getModelAccessMode(), parsedRequest.getRegisterModelGroupInput().getModelAccessMode()); + assertEquals(request.getRegisterModelGroupInput().getIsAddAllBackendRoles() ,parsedRequest.getRegisterModelGroupInput().getIsAddAllBackendRoles()); } @Test