From 80d3a75018b22d6001919207e43ef51a9b6a56d3 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Tue, 13 Feb 2024 17:07:37 -0600 Subject: [PATCH] enable auto redeploy for hidden model Signed-off-by: Bhavana Ramaram --- .../transport/deploy/MLDeployModelRequest.java | 12 +++++++++--- .../ml/action/deploy/TransportDeployModelAction.java | 3 ++- .../ml/autoredeploy/MLModelAutoReDeployer.java | 2 +- .../java/org/opensearch/ml/model/MLModelManager.java | 2 +- .../deploy/TransportDeployModelActionTests.java | 1 + 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java index b0ad113d95..8df4d278a7 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java @@ -38,17 +38,21 @@ public class MLDeployModelRequest extends MLTaskRequest { private String modelId; private String[] modelNodeIds; boolean async; + // This is to identify if the get request is initiated by user or not. During auto redeploy, we also perform deploy operation. This field is to distinguish between + // these two situations. + private final boolean isUserInitiatedDeployRequest; @Builder - public MLDeployModelRequest(String modelId, String[] modelNodeIds, boolean async, boolean dispatchTask) { + public MLDeployModelRequest(String modelId, String[] modelNodeIds, boolean async, boolean dispatchTask, boolean isUserInitiatedDeployRequest) { super(dispatchTask); this.modelId = modelId; this.modelNodeIds = modelNodeIds; this.async = async; + this.isUserInitiatedDeployRequest = isUserInitiatedDeployRequest; } public MLDeployModelRequest(String modelId, boolean async) { - this(modelId, null, async, true); + this(modelId, null, async, true, true); } public MLDeployModelRequest(StreamInput in) throws IOException { @@ -56,6 +60,7 @@ public MLDeployModelRequest(StreamInput in) throws IOException { this.modelId = in.readString(); this.modelNodeIds = in.readOptionalStringArray(); this.async = in.readBoolean(); + this.isUserInitiatedDeployRequest = in.readOptionalBoolean(); } @Override @@ -74,6 +79,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); out.writeOptionalStringArray(modelNodeIds); out.writeBoolean(async); + out.writeOptionalBoolean(isUserInitiatedDeployRequest); } public static MLDeployModelRequest parse(XContentParser parser, String modelId) throws IOException { @@ -96,7 +102,7 @@ public static MLDeployModelRequest parse(XContentParser parser, String modelId) } } String[] nodeIds = nodeIdList == null ? null : nodeIdList.toArray(new String[0]); - return new MLDeployModelRequest(modelId, nodeIds, false, true); + return new MLDeployModelRequest(modelId, nodeIds, false, true, true); } public static MLDeployModelRequest fromActionRequest(ActionRequest actionRequest) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 8d1c4f706e..aa561248d1 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -131,6 +131,7 @@ public TransportDeployModelAction( protected void doExecute(Task task, ActionRequest request, ActionListener listener) { MLDeployModelRequest deployModelRequest = MLDeployModelRequest.fromActionRequest(request); String modelId = deployModelRequest.getModelId(); + Boolean isUserInitiatedDeployRequest = deployModelRequest.isUserInitiatedDeployRequest(); User user = RestActionUtils.getUserContext(client); boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client); String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; @@ -144,7 +145,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener = ActionListener .wrap(r -> log.debug("model deployed, response {}", r), e -> log.error("Failed to deploy model", e)); client.execute(MLDeployModelAction.INSTANCE, request, listener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index 02629c392a..8fe76842b7 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -286,6 +286,7 @@ public void testDoExecute_no_permission_hidden_model() { MLModel mlModel = mock(MLModel.class); when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION); when(mlModel.getIsHidden()).thenReturn(true); + when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(true); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(mlModel);