From 3445fa125938846e704d67b52ef4ac5a0d2bf6b9 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Mon, 10 Jul 2023 16:20:31 -0700 Subject: [PATCH] add UTs for register model via local file class Signed-off-by: Bhavana Ramaram --- plugin/build.gradle | 3 +- ...TransportRegisterModelMetaActionTests.java | 37 ++++++++++++++++--- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/plugin/build.gradle b/plugin/build.gradle index 868ee9d52b..558cb0f599 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -296,7 +296,8 @@ List jacocoExclusions = [ 'org.opensearch.ml.rest.RestMLCreateConnectorAction', 'org.opensearch.ml.action.connector.SearchConnectorTransportAction', 'org.opensearch.ml.model.MLModelGroupManager', - 'org.opensearch.ml.helper.ModelAccessControlHelper' + 'org.opensearch.ml.helper.ModelAccessControlHelper', + 'org.opensearch.ml.action.models.DeleteModelTransportAction.2' ] jacocoTestCoverageVerification { diff --git a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java index 0b1d4e17da..c1009ee222 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/upload_chunk/TransportRegisterModelMetaActionTests.java @@ -104,12 +104,39 @@ public void testTransportRegisterModelMetaActionConstructor() { public void testTransportRegisterModelMetaActionDoExecute() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); - MLRegisterModelMetaRequest actionRequest = prepareRequest(); + MLRegisterModelMetaRequest actionRequest = prepareRequest("modelGroupID"); action.doExecute(task, actionRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelMetaResponse.class); verify(actionListener).onResponse(argumentCaptor.capture()); } + public void testDoExecute_successWithCreateModelGroup() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse("modelGroupID"); + return null; + }).when(mlModelGroupManager).createModelGroup(any(), any()); + + MLRegisterModelMetaRequest actionRequest = prepareRequest(null); + action.doExecute(task, actionRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelMetaResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + } + + public void testDoExecute_failureWithCreateModelGroup() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new Exception("Failed to create Model Group")); + return null; + }).when(mlModelGroupManager).createModelGroup(any(), any()); + + MLRegisterModelMetaRequest actionRequest = prepareRequest(null); + action.doExecute(task, actionRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to create Model Group", argumentCaptor.getValue().getMessage()); + } + public void testDoExecute_userHasNoAccessException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -119,7 +146,7 @@ public void testDoExecute_userHasNoAccessException() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); - MLRegisterModelMetaRequest actionRequest = prepareRequest(); + MLRegisterModelMetaRequest actionRequest = prepareRequest("modelGroupID"); action.doExecute(task, actionRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -135,18 +162,18 @@ public void test_ValidationFailedException() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); - MLRegisterModelMetaRequest actionRequest = prepareRequest(); + MLRegisterModelMetaRequest actionRequest = prepareRequest("modelGroupID"); action.doExecute(task, actionRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage()); } - private MLRegisterModelMetaRequest prepareRequest() { + private MLRegisterModelMetaRequest prepareRequest(String modelGroupID) { MLRegisterModelMetaInput input = MLRegisterModelMetaInput .builder() .name("Model Name") - .modelGroupId("1") + .modelGroupId(modelGroupID) .description("Custom Model Test") .modelFormat(MLModelFormat.TORCH_SCRIPT) .functionName(FunctionName.BATCH_RCF)