diff --git a/docs/model_access_control.md b/docs/model_access_control.md index fb80b49e96..2fac5bca73 100644 --- a/docs/model_access_control.md +++ b/docs/model_access_control.md @@ -178,7 +178,7 @@ Updating a model group request is very similar to register model group request. ### Path and HTTP method ``` -PUT /_plugins/_ml/model_groups//_update +PUT /_plugins/_ml/model_groups/ ``` A user can make updates to a model group to which he/she has access which is determined by the access mode of the model group. @@ -196,7 +196,7 @@ For example, Sample request allowed by admin/owner ``` -PUT /_plugins/_ml/model_groups//_update +PUT /_plugins/_ml/model_groups/ { "name": "model_group_test", "description": "This is an example description", @@ -215,7 +215,7 @@ PUT /_plugins/_ml/model_groups//_update Sample update request allowed by any other user with access to model group. ``` -PUT /_plugins/_ml/model_groups//_update +PUT /_plugins/_ml/model_groups/ { "name": "model_group_test", "description": "This is an example description" diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index 49efb64b65..039d9cb4f8 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -14,6 +14,7 @@ import java.util.Map; import org.apache.commons.lang3.StringUtils; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; @@ -41,6 +42,7 @@ import org.opensearch.ml.model.MLModelGroupManager; import org.opensearch.ml.utils.MLNodeUtils; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -101,7 +103,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { @@ -197,18 +199,18 @@ private void validateRequestForAccessControl(MLUpdateModelGroupInput input, User if (hasAccessControlChange(input)) { if (!modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user) && !modelAccessControlHelper.isAdmin(user)) { throw new IllegalArgumentException("Only owner or admin can update access control data."); - } else if (modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user) - && !modelAccessControlHelper.isAdmin(user) - && !modelAccessControlHelper.isOwnerStillHasPermission(user, mlModelGroup)) { - throw new IllegalArgumentException( - "You don’t have the specified backend role to update access control data. For more information, contact your administrator." - ); } } if (!modelAccessControlHelper.isAdmin(user) && !modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user) && !modelAccessControlHelper.isUserHasBackendRole(user, mlModelGroup)) { throw new IllegalArgumentException("You don't have permission to update this model group."); + } else if (modelAccessControlHelper.isOwner(mlModelGroup.getOwner(), user) + && !modelAccessControlHelper.isAdmin(user) + && !modelAccessControlHelper.isOwnerStillHasPermission(user, mlModelGroup)) { + throw new IllegalArgumentException( + "You don’t have the specified backend role to update access control data. For more information, contact your administrator." + ); } AccessMode accessMode = input.getModelAccessMode(); if ((AccessMode.PUBLIC == accessMode || AccessMode.PRIVATE == accessMode) diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java index d73c1c7bbf..05c4deaa86 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java @@ -11,6 +11,7 @@ import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; @@ -34,6 +35,7 @@ import org.opensearch.ml.common.transport.model.MLModelGetResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.rest.RestStatus; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -110,7 +112,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index 6fd84e372f..11e9c67d2c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; @@ -25,6 +26,7 @@ import org.opensearch.ml.common.transport.task.MLTaskGetAction; import org.opensearch.ml.common.transport.task.MLTaskGetRequest; import org.opensearch.ml.common.transport.task.MLTaskGetResponse; +import org.opensearch.rest.RestStatus; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -68,7 +70,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index b9208187fc..02d3889464 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -63,6 +63,7 @@ import java.util.function.Supplier; import org.apache.logging.log4j.util.Strings; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.get.GetRequest; @@ -891,7 +892,7 @@ public void getModel(String modelId, String[] includes, String[] excludes, Actio listener.onFailure(e); } } else { - listener.onFailure(new MLResourceNotFoundException("Fail to find model")); + listener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND)); } }, e -> { listener.onFailure(e); })); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index df5b098576..14f4d882d3 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -99,7 +99,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client }, e -> { log.error("Failed to get ML model", e); try { - channel.sendResponse(new BytesRestResponse(channel, RestStatus.BAD_REQUEST, e)); + channel.sendResponse(new BytesRestResponse(channel, RestStatus.NOT_FOUND, e)); } catch (IOException ex) { log.error("Failed to send error response", ex); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelGroupAction.java index 34ccca9c15..87f64e0c5b 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelGroupAction.java @@ -38,10 +38,7 @@ public String getName() { public List routes() { return ImmutableList .of( - new Route( - RestRequest.Method.PUT, - String.format(Locale.ROOT, "%s/model_groups/{%s}/_update", ML_BASE_URI, PARAMETER_MODEL_GROUP_ID) - ) + new Route(RestRequest.Method.PUT, String.format(Locale.ROOT, "%s/model_groups/{%s}", ML_BASE_URI, PARAMETER_MODEL_GROUP_ID)) ); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index 73bd333985..245c7a86e2 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -147,7 +147,7 @@ public void testDeleteModel_Success() throws IOException { return null; }).when(client).execute(any(), any(), any()); - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -172,16 +172,7 @@ public void testDeleteModel_Success_AlgorithmNotNull() throws IOException { return null; }).when(client).execute(any(), any(), any()); - MLModel mlModel = MLModel - .builder() - .modelId("test_id") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.TEXT_EMBEDDING) - .build(); - XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse getResponse = new GetResponse(getResult); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -213,17 +204,8 @@ public void test_Success_ModelGroupIDNotNull_LastModelOfGroup() throws IOExcepti return null; }).when(client).search(any(), isA(ActionListener.class)); - MLModel mlModel = MLModel - .builder() - .modelId("test_id") - .modelGroupId("modelGroupID") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.TEXT_EMBEDDING) - .build(); - XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse getResponse = new GetResponse(getResult); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID"); + doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -296,17 +278,8 @@ public void test_Failure_FailedToSearchLastModel() throws IOException { return null; }).when(client).search(any(), isA(ActionListener.class)); - MLModel mlModel = MLModel - .builder() - .modelId("test_id") - .modelGroupId("modelGroupID") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.TEXT_EMBEDDING) - .build(); - XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse getResponse = new GetResponse(getResult); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID"); + doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -320,7 +293,7 @@ public void test_Failure_FailedToSearchLastModel() throws IOException { } public void test_UserHasNoAccessException() throws IOException { - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID"); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -340,7 +313,7 @@ public void test_UserHasNoAccessException() throws IOException { } public void testDeleteModel_CheckModelState() throws IOException { - GetResponse getResponse = prepareMLModel(MLModelState.DEPLOYING); + GetResponse getResponse = prepareMLModel(MLModelState.DEPLOYING, null); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -383,7 +356,7 @@ public void testDeleteModel_ResourceNotFoundException() throws IOException { return null; }).when(client).execute(any(), any(), any()); - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -397,7 +370,7 @@ public void testDeleteModel_ResourceNotFoundException() throws IOException { } public void test_ValidationFailedException() throws IOException { - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -441,7 +414,7 @@ public void testDeleteModelChunks_Success() { } public void testDeleteModel_RuntimeException() throws IOException { - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED); + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); @@ -535,8 +508,8 @@ public void test_FailToDeleteAllModelChunks_SearchFailure() { assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + SEARCH_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage()); } - public GetResponse prepareMLModel(MLModelState mlModelState) throws IOException { - MLModel mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).build(); + public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID) throws IOException { + MLModel mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).modelGroupId(modelGroupID).build(); XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); BytesReference bytesReference = BytesReference.bytes(content); GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java index 0bb2c10632..54a78fc921 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelGroupManagerTests.java @@ -38,6 +38,7 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +@Ignore public class MLModelGroupManagerTests extends OpenSearchTestCase { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -105,7 +106,6 @@ public void setup() { when(threadPool.getThreadContext()).thenReturn(threadContext); } - @Ignore public void test_SuccessAddAllBackendRolesTrue() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -116,7 +116,6 @@ public void test_SuccessAddAllBackendRolesTrue() { verify(actionListener).onResponse(argumentCaptor.capture()); } - @Ignore public void test_SuccessPublic() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -126,7 +125,6 @@ public void test_SuccessPublic() { verify(actionListener).onResponse(argumentCaptor.capture()); } - @Ignore public void test_ExceptionAllAccessFieldsNull() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -140,7 +138,6 @@ public void test_ExceptionAllAccessFieldsNull() { ); } - @Ignore public void test_ModelAccessModeNullAddAllBackendRolesTrue() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -151,7 +148,6 @@ public void test_ModelAccessModeNullAddAllBackendRolesTrue() { verify(actionListener).onResponse(argumentCaptor.capture()); } - @Ignore public void test_BackendRolesProvidedWithPublic() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -162,7 +158,6 @@ public void test_BackendRolesProvidedWithPublic() { assertEquals("You can specify backend roles only for a model group with the restricted access mode.", argumentCaptor.getValue().getMessage()); } - @Ignore public void test_BackendRolesProvidedWithPrivate() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -173,7 +168,6 @@ public void test_BackendRolesProvidedWithPrivate() { assertEquals("You can specify backend roles only for a model group with the restricted access mode.", argumentCaptor.getValue().getMessage()); } - @Ignore public void test_AdminSpecifiedAddAllBackendRolesForRestricted() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "admin|admin|all_access"); when(modelAccessControlHelper.isAdmin(any())).thenReturn(true); @@ -186,7 +180,6 @@ public void test_AdminSpecifiedAddAllBackendRolesForRestricted() { assertEquals("Admin users cannot add all backend roles to a model group.", argumentCaptor.getValue().getMessage()); } - @Ignore public void test_UserWithNoBackendRolesSpecifiedRestricted() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex||engineering,operations"); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -201,7 +194,6 @@ public void test_UserWithNoBackendRolesSpecifiedRestricted() { ); } - @Ignore public void test_UserSpecifiedRestrictedButNoBackendRolesFieldF() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -216,7 +208,6 @@ public void test_UserSpecifiedRestrictedButNoBackendRolesFieldF() { ); } - @Ignore public void test_RestrictedAndUserSpecifiedBothBackendRolesField() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -231,7 +222,6 @@ public void test_RestrictedAndUserSpecifiedBothBackendRolesField() { ); } - @Ignore public void test_RestrictedAndUserSpecifiedIncorrectBackendRoles() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations"); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -245,7 +235,6 @@ public void test_RestrictedAndUserSpecifiedIncorrectBackendRoles() { assertEquals("You don't have the backend roles specified.", argumentCaptor.getValue().getMessage()); } - @Ignore public void test_SuccessSecurityDisabledCluster() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false); @@ -255,7 +244,6 @@ public void test_SuccessSecurityDisabledCluster() { verify(actionListener).onResponse(argumentCaptor.capture()); } - @Ignore public void test_ExceptionSecurityDisabledCluster() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false); @@ -269,7 +257,6 @@ public void test_ExceptionSecurityDisabledCluster() { ); } - @Ignore public void test_ExceptionFailedToInitModelGroupIndex() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(true); @@ -279,7 +266,6 @@ public void test_ExceptionFailedToInitModelGroupIndex() { verify(actionListener).onFailure(argumentCaptor.capture()); } - @Ignore public void test_ExceptionFailedToIndexModelGroup() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false); doAnswer(invocation -> { @@ -295,7 +281,6 @@ public void test_ExceptionFailedToIndexModelGroup() { assertEquals("Index Not Found", argumentCaptor.getValue().getMessage()); } - @Ignore public void test_ExceptionInitModelGroupIndexIfAbsent() { when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false); doAnswer(invocation -> { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 1e2ebefea3..07c87c65fd 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -483,7 +483,7 @@ public void testDeployModel_NullGetModelResponse() { assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); - assertEquals("Fail to find model", exception.getValue().getMessage()); + assertEquals("Failed to find model", exception.getValue().getMessage()); verify(mlStats) .createCounterStatIfAbsent( eq(FunctionName.TEXT_EMBEDDING), @@ -503,7 +503,7 @@ public void testDeployModel_GetModelResponse_NotExist() { assertFalse(modelManager.isModelRunningOnNode(modelId)); ArgumentCaptor exception = ArgumentCaptor.forClass(Exception.class); verify(listener).onFailure(exception.capture()); - assertEquals("Fail to find model", exception.getValue().getMessage()); + assertEquals("Failed to find model", exception.getValue().getMessage()); verify(mlStats) .createCounterStatIfAbsent( eq(FunctionName.TEXT_EMBEDDING), diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index e2cbf33f15..13e40ed292 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -697,8 +697,7 @@ public void registerModelGroup(RestClient client, String input, Consumer> function) throws IOException { - Response response = TestHelper - .makeRequest(client, "PUT", "/_plugins/_ml/model_groups/" + modelGroupId + "/_update", null, input, null); + Response response = TestHelper.makeRequest(client, "PUT", "/_plugins/_ml/model_groups/" + modelGroupId, null, input, null); verifyResponse(function, response); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelGroupActionTests.java index b51a1d279e..4be74e6370 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelGroupActionTests.java @@ -91,7 +91,7 @@ public void testRoutes() { assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(0); assertEquals(RestRequest.Method.PUT, route.getMethod()); - assertEquals("/_plugins/_ml/model_groups/{model_group_id}/_update", route.getPath()); + assertEquals("/_plugins/_ml/model_groups/{model_group_id}", route.getPath()); } public void testUpdateModelGroupRequest() throws Exception { @@ -119,7 +119,7 @@ private RestRequest getRestRequest() { params.put("model_group_id", "test_modelGroupId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/model_groups/{model_group_id}/_update") + .withPath("/_plugins/_ml/model_groups/{model_group_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); @@ -132,7 +132,7 @@ private RestRequest getRestRequestWithEmptyContent() { params.put("model_group_id", "test_modelGroupId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/model_groups/{model_group_id}/_update") + .withPath("/_plugins/_ml/model_groups/{model_group_id}") .withParams(params) .withContent(new BytesArray(""), XContentType.JSON) .build();