Skip to content

Commit

Permalink
add UTs for register model via local file class
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Jul 10, 2023
1 parent 4f49ec1 commit 3445fa1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
3 changes: 2 additions & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.rest.RestMLCreateConnectorAction',
'org.opensearch.ml.action.connector.SearchConnectorTransportAction',
'org.opensearch.ml.model.MLModelGroupManager',
'org.opensearch.ml.helper.ModelAccessControlHelper'
'org.opensearch.ml.helper.ModelAccessControlHelper',
'org.opensearch.ml.action.models.DeleteModelTransportAction.2'
]

jacocoTestCoverageVerification {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,39 @@ public void testTransportRegisterModelMetaActionConstructor() {
public void testTransportRegisterModelMetaActionDoExecute() {
threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations");

MLRegisterModelMetaRequest actionRequest = prepareRequest();
MLRegisterModelMetaRequest actionRequest = prepareRequest("modelGroupID");
action.doExecute(task, actionRequest, actionListener);
ArgumentCaptor<MLRegisterModelMetaResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelMetaResponse.class);
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void testDoExecute_successWithCreateModelGroup() {
doAnswer(invocation -> {
ActionListener<String> listener = invocation.getArgument(1);
listener.onResponse("modelGroupID");
return null;
}).when(mlModelGroupManager).createModelGroup(any(), any());

MLRegisterModelMetaRequest actionRequest = prepareRequest(null);
action.doExecute(task, actionRequest, actionListener);
ArgumentCaptor<MLRegisterModelMetaResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelMetaResponse.class);
verify(actionListener).onResponse(argumentCaptor.capture());
}

public void testDoExecute_failureWithCreateModelGroup() {
doAnswer(invocation -> {
ActionListener<String> listener = invocation.getArgument(1);
listener.onFailure(new Exception("Failed to create Model Group"));
return null;
}).when(mlModelGroupManager).createModelGroup(any(), any());

MLRegisterModelMetaRequest actionRequest = prepareRequest(null);
action.doExecute(task, actionRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Failed to create Model Group", argumentCaptor.getValue().getMessage());
}

public void testDoExecute_userHasNoAccessException() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(3);
Expand All @@ -119,7 +146,7 @@ public void testDoExecute_userHasNoAccessException() {

threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations");

MLRegisterModelMetaRequest actionRequest = prepareRequest();
MLRegisterModelMetaRequest actionRequest = prepareRequest("modelGroupID");
action.doExecute(task, actionRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
Expand All @@ -135,18 +162,18 @@ public void test_ValidationFailedException() {

threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alex|IT,HR|engineering,operations");

MLRegisterModelMetaRequest actionRequest = prepareRequest();
MLRegisterModelMetaRequest actionRequest = prepareRequest("modelGroupID");
action.doExecute(task, actionRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage());
}

private MLRegisterModelMetaRequest prepareRequest() {
private MLRegisterModelMetaRequest prepareRequest(String modelGroupID) {
MLRegisterModelMetaInput input = MLRegisterModelMetaInput
.builder()
.name("Model Name")
.modelGroupId("1")
.modelGroupId(modelGroupID)
.description("Custom Model Test")
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.functionName(FunctionName.BATCH_RCF)
Expand Down

0 comments on commit 3445fa1

Please sign in to comment.