Skip to content

Commit

Permalink
addressing comments
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Feb 16, 2024
1 parent 80d3a75 commit 81f619b
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ public class MLDeployModelRequest extends MLTaskRequest {
private String modelId;
private String[] modelNodeIds;
boolean async;
// This is to identify if the get request is initiated by user or not. During auto redeploy, we also perform deploy operation. This field is to distinguish between
// these two situations.
// This is to identify if the get request is initiated by user or not. During auto redeploy also, we perform deploy operation.
// This field is mainly to distinguish between these two situations.
private final boolean isUserInitiatedDeployRequest;

@Builder
Expand All @@ -51,6 +51,9 @@ public MLDeployModelRequest(String modelId, String[] modelNodeIds, boolean async
this.isUserInitiatedDeployRequest = isUserInitiatedDeployRequest;
}

// In this constructor, isUserInitiatedDeployRequest to always set to true. So, it can be used only when
// deploy request is coming directly from the user. DO NOT use this when the
// deploy call is from the code or system initiated.
public MLDeployModelRequest(String modelId, boolean async) {
this(modelId, null, async, true, true);
}
Expand All @@ -60,7 +63,7 @@ public MLDeployModelRequest(StreamInput in) throws IOException {
this.modelId = in.readString();
this.modelNodeIds = in.readOptionalStringArray();
this.async = in.readBoolean();
this.isUserInitiatedDeployRequest = in.readOptionalBoolean();
this.isUserInitiatedDeployRequest = in.readBoolean();
}

@Override
Expand All @@ -79,7 +82,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeOptionalStringArray(modelNodeIds);
out.writeBoolean(async);
out.writeOptionalBoolean(isUserInitiatedDeployRequest);
out.writeBoolean(isUserInitiatedDeployRequest);
}

public static MLDeployModelRequest parse(XContentParser parser, String modelId) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
if (functionName == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
if (isHidden != null && isHidden) {
if (isSuperAdmin || !isUserInitiatedDeployRequest) {
if (!isUserInitiatedDeployRequest) {
deployModel(deployModelRequest, mlModel, modelId, wrappedListener, listener);
} else if (isHidden != null && isHidden) {
if (isSuperAdmin) {
deployModel(deployModelRequest, mlModel, modelId, wrappedListener, listener);
} else {
wrappedListener
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ public void setup() {
return null;
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());

when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(true);

when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true);

MLStat mlStat = mock(MLStat.class);
Expand Down Expand Up @@ -218,6 +220,30 @@ public void testDoExecute_success() {
verify(deployModelResponseListener).onResponse(any(MLDeployModelResponse.class));
}

public void testDoExecute_success_not_userInitiatedRequest() {
MLModel mlModel = mock(MLModel.class);
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(3);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(anyString(), isNull(), any(String[].class), Mockito.isA(ActionListener.class));

when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(false);

IndexResponse indexResponse = mock(IndexResponse.class);
when(indexResponse.getId()).thenReturn("mockIndexId");
doAnswer(invocation -> {
ActionListener<IndexResponse> listener = invocation.getArgument(1);
listener.onResponse(indexResponse);
return null;
}).when(mlTaskManager).createMLTask(any(MLTask.class), Mockito.isA(ActionListener.class));

ActionListener<MLDeployModelResponse> deployModelResponseListener = mock(ActionListener.class);
transportDeployModelAction.doExecute(mock(Task.class), mlDeployModelRequest, deployModelResponseListener);
verify(deployModelResponseListener).onResponse(any(MLDeployModelResponse.class));
}

public void testDoExecute_success_hidden_model() {
transportDeployModelAction = spy(
new TransportDeployModelAction(
Expand Down Expand Up @@ -286,7 +312,6 @@ public void testDoExecute_no_permission_hidden_model() {
MLModel mlModel = mock(MLModel.class);
when(mlModel.getAlgorithm()).thenReturn(FunctionName.ANOMALY_LOCALIZATION);
when(mlModel.getIsHidden()).thenReturn(true);
when(mlDeployModelRequest.isUserInitiatedDeployRequest()).thenReturn(true);
doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(3);
listener.onResponse(mlModel);
Expand Down

0 comments on commit 81f619b

Please sign in to comment.