From a10693e2b17344b04583b6e2f003bc8e2632cc8f Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Tue, 11 Jul 2023 12:09:22 -0700 Subject: [PATCH] PenTest fixes: error codes and update model group fix Signed-off-by: Bhavana Ramaram --- docs/model_access_control.md | 6 +-- .../DeleteModelGroupTransportAction.java | 13 +++-- .../TransportUpdateModelGroupAction.java | 16 +++--- .../models/GetModelTransportAction.java | 10 +++- .../action/tasks/GetTaskTransportAction.java | 4 +- .../opensearch/ml/model/MLModelManager.java | 3 +- .../ml/rest/RestMLPredictionAction.java | 28 +++++----- .../ml/rest/RestMLUpdateModelGroupAction.java | 5 +- .../DeleteModelTransportActionTests.java | 53 +++++-------------- .../ml/model/MLModelGroupManagerTests.java | 33 ++++-------- .../ml/rest/MLCommonsRestTestCase.java | 3 +- .../RestMLUpdateModelGroupActionTests.java | 6 +-- 12 files changed, 72 insertions(+), 108 deletions(-) diff --git a/docs/model_access_control.md b/docs/model_access_control.md index fb80b49e96..2fac5bca73 100644 --- a/docs/model_access_control.md +++ b/docs/model_access_control.md @@ -178,7 +178,7 @@ Updating a model group request is very similar to register model group request. ### Path and HTTP method ``` -PUT /_plugins/_ml/model_groups//_update +PUT /_plugins/_ml/model_groups/ ``` A user can make updates to a model group to which he/she has access which is determined by the access mode of the model group. @@ -196,7 +196,7 @@ For example, Sample request allowed by admin/owner ``` -PUT /_plugins/_ml/model_groups//_update +PUT /_plugins/_ml/model_groups/ { "name": "model_group_test", "description": "This is an example description", @@ -215,7 +215,7 @@ PUT /_plugins/_ml/model_groups//_update Sample update request allowed by any other user with access to model group. ``` -PUT /_plugins/_ml/model_groups//_update +PUT /_plugins/_ml/model_groups/ { "name": "model_group_test", "description": "This is an example description" diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java index 295afab26c..1e78aef1c4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java @@ -5,10 +5,9 @@ package org.opensearch.ml.action.model_group; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; -import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID; - +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.action.delete.DeleteRequest; @@ -35,9 +34,9 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; -import lombok.AccessLevel; -import lombok.experimental.FieldDefaults; -import lombok.extern.log4j.Log4j2; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID; @Log4j2 @FieldDefaults(level = AccessLevel.PRIVATE) diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index 49efb64b65..039d9cb4f8 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -14,6 +14,7 @@ import java.util.Map; import org.apache.commons.lang3.StringUtils; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; @@ -41,6 +42,7 @@ import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -101,7 +103,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { @@ -197,18 +199,18 @@ private void validateRequestForAccessControl(MLUpdateModelGroupInput input, User if (hasAccessControlChange(input)) { if (!modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user) && !modelAccessControlHelper.isAdmin(user)) { throw new IllegalArgumentException("Only owner or admin can update access control data."); - } else if (modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user) - && !modelAccessControlHelper.isAdmin(user) - && !modelAccessControlHelper.isOwnerStillHasPermission(user, mlModelGroup)) { - throw new IllegalArgumentException( - "You don’t have the specified backend role to update access control data. For more information, contact your administrator." - ); } } if (!modelAccessControlHelper.isAdmin(user) && !modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user) && !modelAccessControlHelper.isUserHasBackendRole(user, mlModelGroup)) { throw new IllegalArgumentException("You don't have permission to update this model group."); + } else if (modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user) + && !modelAccessControlHelper.isAdmin(user) + && !modelAccessControlHelper.isOwnerStillHasPermission(user, mlModelGroup)) { + throw new IllegalArgumentException( + "You don’t have the specified backend role to update access control data. For more information, contact your administrator." + ); } AccessMode accessMode = input.getModelAccessMode(); if ((AccessMode.PUBLIC == accessMode || AccessMode.PRIVATE == accessMode) diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java index d73c1c7bbf..05c4deaa86 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java @@ -11,6 +11,7 @@ import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; @@ -34,6 +35,7 @@ import org.opensearch.ml.common.transport.model.MLModelGetResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.rest.RestStatus; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -110,7 +112,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index 6fd84e372f..11e9c67d2c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; @@ -25,6 +26,7 @@ import org.opensearch.ml.common.transport.task.MLTaskGetAction; import org.opensearch.ml.common.transport.task.MLTaskGetRequest; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; +import org.opensearch.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -68,7 +70,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index b9208187fc..ab0ca575a3 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -63,6 +63,7 @@ import java.util.function.Supplier; import org.apache.logging.log4j.util.Strings; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.get.GetRequest; @@ -891,7 +892,7 @@ public void getModel(String modelId, String[] includes, String[] excludes, Actio listener.onFailure(e); } } else { - listener.onFailure(new MLResourceNotFoundException("Fail to find model")); + listener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND)); } }, e -> { listener.onFailure(e); })); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index df5b098576..18d3c3480e 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -5,17 +5,9 @@ package org.opensearch.ml.rest; -import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; -import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; -import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; -import static org.opensearch.ml.utils.RestActionUtils.getParameterId; - -import java.io.IOException; -import java.util.List; -import java.util.Locale; -import java.util.Optional; - +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import lombok.extern.log4j.Log4j2; import org.opensearch.action.ActionListener; import org.opensearch.client.node.NodeClient; import org.opensearch.core.xcontent.XContentParser; @@ -34,10 +26,16 @@ import org.opensearch.rest.RestStatus; import org.opensearch.rest.action.RestToXContentListener; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.Optional; -import lombok.extern.log4j.Log4j2; +import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; @Log4j2 public class RestMLPredictionAction extends BaseRestHandler { @@ -99,7 +97,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client }, e -> { log.error("Failed to get ML model", e); try { - channel.sendResponse(new BytesRestResponse(channel, RestStatus.BAD_REQUEST, e)); + channel.sendResponse(new BytesRestResponse(channel, RestStatus.NOT_FOUND, e)); } catch (IOException ex) { log.error("Failed to send error response", ex); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelGroupAction.java index 34ccca9c15..87f64e0c5b 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelGroupAction.java @@ -38,10 +38,7 @@ public String getName() { public List routes() { return ImmutableList .of( - new Route( - RestRequest.Method.PUT, - String.format(Locale.ROOT, "%s/model_groups/{%s}/_update", ML_BASE_URI, PARAMETER_MODEL_GROUP_ID) - ) + new Route(RestRequest.Method.PUT, String.format(Locale.ROOT, "%s/model_groups/{%s}", ML_BASE_URI, PARAMETER_MODEL_GROUP_ID)) ); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index 73bd333985..245c7a86e2 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -147,7 +147,7 @@ public void testDeleteModel_Success() throws IOException { return null; }).when(client).execute(any(), any(), any()); - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -172,16 +172,7 @@ public void testDeleteModel_Success_AlgorithmNotNull() throws IOException { return null; }).when(client).execute(any(), any(), any()); - MLModel mlModel = MLModel - .builder() - .modelId("test_id") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.TEXT_EMBEDDING) - .build(); - XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse getResponse = new GetResponse(getResult); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -213,17 +204,8 @@ public void test_Success_ModelGroupIDNotNull_LastModelOfGroup() throws IOExcepti return null; }).when(client).search(any(), isA(ActionListener.class)); - MLModel mlModel = MLModel - .builder() - .modelId("test_id") - .modelGroupId("modelGroupID") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.TEXT_EMBEDDING) - .build(); - XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse getResponse = new GetResponse(getResult); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID"); + doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -296,17 +278,8 @@ public void test_Failure_FailedToSearchLastModel() throws IOException { return null; }).when(client).search(any(), isA(ActionListener.class)); - MLModel mlModel = MLModel - .builder() - .modelId("test_id") - .modelGroupId("modelGroupID") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.TEXT_EMBEDDING) - .build(); - XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse getResponse = new GetResponse(getResult); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID"); + doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -320,7 +293,7 @@ public void test_Failure_FailedToSearchLastModel() throws IOException { } public void test_UserHasNoAccessException() throws IOException { - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID"); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -340,7 +313,7 @@ public void test_UserHasNoAccessException() throws IOException { } public void testDeleteModel_CheckModelState() throws IOException { - GetResponse getResponse = prepareMLModel(MLModelState.DEPLOYING); + GetResponse getResponse = prepareMLModel(MLModelState.DEPLOYING, null); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -383,7 +356,7 @@ public void testDeleteModel_ResourceNotFoundException() throws IOException { return null; }).when(client).execute(any(), any(), any()); - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -397,7 +370,7 @@ public void testDeleteModel_ResourceNotFoundException() throws IOException { } public void test_ValidationFailedException() throws IOException { - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -441,7 +414,7 @@ public void testDeleteModelChunks_Success() { } public void testDeleteModel_RuntimeException() throws IOException { - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -535,8 +508,8 @@ public void test_FailToDeleteAllModelChunks_SearchFailure() { assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + SEARCH_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage()); } - public GetResponse prepareMLModel(MLModelState mlModelState) throws IOException { - MLModel mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).build(); + public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID) throws IOException { + MLModel mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).modelGroupId(modelGroupID).build(); XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); BytesReference bytesReference = BytesReference.bytes(content); GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index 0bb2c10632..a17151f46b 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -5,14 +5,6 @@ package org.opensearch.ml.model; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.Arrays; -import java.util.List; - import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; @@ -38,6 +30,15 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import java.util.Arrays; +import java.util.List; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@Ignore public class MLModelGroupManagerTests extends OpenSearchTestCase { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -105,7 +106,6 @@ public void setup() { when(threadPool.getThreadContext()).thenReturn(threadContext); } - @Ignore public void test_SuccessAddAllBackendRolesTrue() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -116,7 +116,6 @@ public void test_SuccessAddAllBackendRolesTrue() { verify(actionListener).onResponse(argumentCaptor.capture()); } - @Ignore public void test_SuccessPublic() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -126,7 +125,6 @@ public void test_SuccessPublic() { verify(actionListener).onResponse(argumentCaptor.capture()); } - @Ignore public void test_ExceptionAllAccessFieldsNull() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -140,7 +138,6 @@ public void test_ExceptionAllAccessFieldsNull() { ); } - @Ignore public void test_ModelAccessModeNullAddAllBackendRolesTrue() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -151,7 +148,6 @@ public void test_ModelAccessModeNullAddAllBackendRolesTrue() { verify(actionListener).onResponse(argumentCaptor.capture()); } - @Ignore public void test_BackendRolesProvidedWithPublic() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -162,7 +158,6 @@ public void test_BackendRolesProvidedWithPublic() { assertEquals("You can specify backend roles only for a model group with the restricted access mode.", argumentCaptor.getValue().getMessage()); } - @Ignore public void test_BackendRolesProvidedWithPrivate() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -173,7 +168,6 @@ public void test_BackendRolesProvidedWithPrivate() { assertEquals("You can specify backend roles only for a model group with the restricted access mode.", argumentCaptor.getValue().getMessage()); } - @Ignore public void test_AdminSpecifiedAddAllBackendRolesForRestricted() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "admin|admin|all_access"); when(modelAccessControlHelper.isAdmin(any())).thenReturn(true); @@ -186,7 +180,6 @@ public void test_AdminSpecifiedAddAllBackendRolesForRestricted() { assertEquals("Admin users cannot add all backend roles to a model group.", argumentCaptor.getValue().getMessage()); } - @Ignore public void test_UserWithNoBackendRolesSpecifiedRestricted() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex||engineering,operations"); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -201,7 +194,6 @@ public void test_UserWithNoBackendRolesSpecifiedRestricted() { ); } - @Ignore public void test_UserSpecifiedRestrictedButNoBackendRolesFieldF() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -216,7 +208,6 @@ public void test_UserSpecifiedRestrictedButNoBackendRolesFieldF() { ); } - @Ignore public void test_RestrictedAndUserSpecifiedBothBackendRolesField() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -231,7 +222,6 @@ public void test_RestrictedAndUserSpecifiedBothBackendRolesField() { ); } - @Ignore public void test_RestrictedAndUserSpecifiedIncorrectBackendRoles() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -245,7 +235,6 @@ public void test_RestrictedAndUserSpecifiedIncorrectBackendRoles() { assertEquals("You don't have the backend roles specified.", argumentCaptor.getValue().getMessage()); } - @Ignore public void test_SuccessSecurityDisabledCluster() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false); @@ -255,7 +244,6 @@ public void test_SuccessSecurityDisabledCluster() { verify(actionListener).onResponse(argumentCaptor.capture()); } - @Ignore public void test_ExceptionSecurityDisabledCluster() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false); @@ -269,7 +257,6 @@ public void test_ExceptionSecurityDisabledCluster() { ); } - @Ignore public void test_ExceptionFailedToInitModelGroupIndex() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -279,7 +266,6 @@ public void test_ExceptionFailedToInitModelGroupIndex() { verify(actionListener).onFailure(argumentCaptor.capture()); } - @Ignore public void test_ExceptionFailedToIndexModelGroup() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false); doAnswer(invocation -> { @@ -295,7 +281,6 @@ public void test_ExceptionFailedToIndexModelGroup() { assertEquals("Index Not Found", argumentCaptor.getValue().getMessage()); } - @Ignore public void test_ExceptionInitModelGroupIndexIfAbsent() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index e2cbf33f15..13e40ed292 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -697,8 +697,7 @@ public void registerModelGroup(RestClient client, String input, Consumer> function) throws IOException { - Response response = TestHelper - .makeRequest(client, "PUT", "/_plugins/_ml/model_groups/" + modelGroupId + "/_update", null, input, null); + Response response = TestHelper.makeRequest(client, "PUT", "/_plugins/_ml/model_groups/" + modelGroupId, null, input, null); verifyResponse(function, response); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelGroupActionTests.java index b51a1d279e..4be74e6370 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelGroupActionTests.java @@ -91,7 +91,7 @@ public void testRoutes() { assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(0); assertEquals(RestRequest.Method.PUT, route.getMethod()); - assertEquals("/_plugins/_ml/model_groups/{model_group_id}/_update", route.getPath()); + assertEquals("/_plugins/_ml/model_groups/{model_group_id}", route.getPath()); } public void testUpdateModelGroupRequest() throws Exception { @@ -119,7 +119,7 @@ private RestRequest getRestRequest() { params.put("model_group_id", "test_modelGroupId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/model_groups/{model_group_id}/_update") + .withPath("/_plugins/_ml/model_groups/{model_group_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); @@ -132,7 +132,7 @@ private RestRequest getRestRequestWithEmptyContent() { params.put("model_group_id", "test_modelGroupId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/model_groups/{model_group_id}/_update") + .withPath("/_plugins/_ml/model_groups/{model_group_id}") .withParams(params) .withContent(new BytesArray(""), XContentType.JSON) .build();