Skip to content

Commit

Permalink
fix failed UT
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jul 11, 2023
1 parent 412049d commit 500add5
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable {
public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles";
public static final String OWNER_FIELD = "owner";
public static final String ACCESS_MODE_FIELD = "access_mode";
public static final String DRY_RUN_FIELD = "dry_run";

public static final String DRY_RUN_CONNECTOR_NAME = "dryRunConnector";

Expand All @@ -52,6 +53,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable {
private List<String> backendRoles;
private Boolean addAllBackendRoles;
private AccessMode access;
private boolean dryRun = false;

@Builder(toBuilder = true)
public MLCreateConnectorInput(String name,
Expand All @@ -63,16 +65,19 @@ public MLCreateConnectorInput(String name,
List<ConnectorAction> actions,
List<String> backendRoles,
Boolean addAllBackendRoles,
AccessMode access
AccessMode access,
boolean dryRun
) {
if (name == null) {
throw new IllegalArgumentException("Connector name is null");
}
if (version == null) {
throw new IllegalArgumentException("Connector version is null");
}
if (protocol == null) {
throw new IllegalArgumentException("Connector protocol is null");
if (!dryRun) {
if (name == null) {
throw new IllegalArgumentException("Connector name is null");
}
if (version == null) {
throw new IllegalArgumentException("Connector version is null");
}
if (protocol == null) {
throw new IllegalArgumentException("Connector protocol is null");
}
}
this.name = name;
this.description = description;
Expand All @@ -97,6 +102,7 @@ public static MLCreateConnectorInput parse(XContentParser parser) throws IOExcep
List<String> backendRoles = null;
Boolean addAllBackendRoles = null;
AccessMode access = null;
boolean dryRun = false;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -142,12 +148,15 @@ public static MLCreateConnectorInput parse(XContentParser parser) throws IOExcep
case ACCESS_MODE_FIELD:
access = AccessMode.from(parser.text());
break;
case DRY_RUN_FIELD:
dryRun = parser.booleanValue();
break;
default:
parser.skipChildren();
break;
}
}
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access);
return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun);
}

@Override
Expand Down Expand Up @@ -227,6 +236,7 @@ public void writeTo(StreamOutput output) throws IOException {
} else {
output.writeBoolean(false);
}
output.writeBoolean(dryRun);
}

public MLCreateConnectorInput(StreamInput input) throws IOException {
Expand Down Expand Up @@ -254,5 +264,6 @@ public MLCreateConnectorInput(StreamInput input) throws IOException {
if (input.readBoolean()) {
this.access = input.readEnum(AccessMode.class);
}
dryRun = input.readBoolean();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener<
log.error(e.getMessage(), e);
listener.onFailure(e);
});
MLCreateConnectorRequest mlCreateConnectorRequest = createConnectorRequest();
MLCreateConnectorRequest mlCreateConnectorRequest = createDryRunConnectorRequest();
client.execute(MLCreateConnectorAction.INSTANCE, mlCreateConnectorRequest, dryRunResultListener);
}
} else {
Expand All @@ -207,8 +207,8 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis
}
}

private MLCreateConnectorRequest createConnectorRequest() {
MLCreateConnectorInput createConnectorInput = MLCreateConnectorInput.builder().name("dryRunConnector").build();
private MLCreateConnectorRequest createDryRunConnectorRequest() {
MLCreateConnectorInput createConnectorInput = MLCreateConnectorInput.builder().dryRun(true).build();
return new MLCreateConnectorRequest(createConnectorInput);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ public void setup() {
Map<String, String> credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret");
input = MLCreateConnectorInput
.builder()
.name("test_name")
.version("1")
.actions(actions)
.parameters(parameters)
.protocol(ConnectorProtocols.HTTP)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ public void test_execute_registerRemoteModel_withInternalConnector_success() {
MLRegisterModelInput input = mock(MLRegisterModelInput.class);
when(request.getRegisterModelInput()).thenReturn(input);
when(input.getModelName()).thenReturn("Test Model");
when(input.getVersion()).thenReturn("1");
when(input.getModelGroupId()).thenReturn("modelGroupID");
when(input.getFunctionName()).thenReturn(FunctionName.REMOTE);
Connector connector = mock(Connector.class);
Expand Down

0 comments on commit 500add5

Please sign in to comment.