Skip to content

Commit

Permalink
enable auto redeploy for hidden model
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Feb 13, 2024
1 parent 67e904e commit 80d3a75
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,29 @@ public class MLDeployModelRequest extends MLTaskRequest {
private String modelId;
private String[] modelNodeIds;
boolean async;
// This is to identify if the get request is initiated by user or not. During auto redeploy, we also perform deploy operation. This field is to distinguish between
// these two situations.
private final boolean isUserInitiatedDeployRequest;

@Builder
public MLDeployModelRequest(String modelId, String[] modelNodeIds, boolean async, boolean dispatchTask) {
public MLDeployModelRequest(String modelId, String[] modelNodeIds, boolean async, boolean dispatchTask, boolean isUserInitiatedDeployRequest) {
super(dispatchTask);
this.modelId = modelId;
this.modelNodeIds = modelNodeIds;
this.async = async;
this.isUserInitiatedDeployRequest = isUserInitiatedDeployRequest;
}

public MLDeployModelRequest(String modelId, boolean async) {
this(modelId, null, async, true);
this(modelId, null, async, true, true);
}

public MLDeployModelRequest(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.modelNodeIds = in.readOptionalStringArray();
this.async = in.readBoolean();
this.isUserInitiatedDeployRequest = in.readOptionalBoolean();
}

@Override
Expand All @@ -74,6 +79,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeOptionalStringArray(modelNodeIds);
out.writeBoolean(async);
out.writeOptionalBoolean(isUserInitiatedDeployRequest);
}

public static MLDeployModelRequest parse(XContentParser parser, String modelId) throws IOException {
Expand All @@ -96,7 +102,7 @@ public static MLDeployModelRequest parse(XContentParser parser, String modelId)
}
}
String[] nodeIds = nodeIdList == null ? null : nodeIdList.toArray(new String[0]);
return new MLDeployModelRequest(modelId, nodeIds, false, true);
return new MLDeployModelRequest(modelId, nodeIds, false, true, true);
}

public static MLDeployModelRequest fromActionRequest(ActionRequest actionRequest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ public TransportDeployModelAction(
protected void doExecute(Task task, ActionRequest request, ActionListener<MLDeployModelResponse> listener) {
MLDeployModelRequest deployModelRequest = MLDeployModelRequest.fromActionRequest(request);
String modelId = deployModelRequest.getModelId();
Boolean isUserInitiatedDeployRequest = deployModelRequest.isUserInitiatedDeployRequest();
User user = RestActionUtils.getUserContext(client);
boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client);
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };
Expand All @@ -144,7 +145,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
if (isHidden != null && isHidden) {
if (isSuperAdmin) {
if (isSuperAdmin || !isUserInitiatedDeployRequest) {
deployModel(deployModelRequest, mlModel, modelId, wrappedListener, listener);
} else {
wrappedListener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ private void triggerModelRedeploy(ModelAutoRedeployArrangement modelAutoRedeploy
ImmutableMap.of(MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD, Optional.ofNullable(autoRedeployRetryTimes).orElse(0) + 1)
);

MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, nodeIds, false, true);
MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, nodeIds, false, true, false);
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, listener);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ private void updateModelRegisterStateAsDone(
void deployModelAfterRegistering(MLRegisterModelInput registerModelInput, String modelId) {
String[] modelNodeIds = registerModelInput.getModelNodeIds();
log.debug("start deploying model after registering, modelId: {} on nodes: {}", modelId, Arrays.toString(modelNodeIds));
MLDeployModelRequest request = new MLDeployModelRequest(modelId, modelNodeIds, false, true);
MLDeployModelRequest request = new MLDeployModelRequest(modelId, modelNodeIds, false, true, true);
ActionListener<MLDeployModelResponse> listener = ActionListener
.wrap(r -> log.debug("model deployed, response {}", r), e -> log.error("Failed to deploy model", e));
client.execute(MLDeployModelAction.INSTANCE, request, listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ public void testDoExecute_no_permission_hidden_model() {
MLModel mlModel = mock(MLModel.class);
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
when(mlModel.getIsHidden()).thenReturn(true);
when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(true);
doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(3);
listener.onResponse(mlModel);
Expand Down

0 comments on commit 80d3a75

Please sign in to comment.