diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java index 7336200bb..0ef2eb5e6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java @@ -130,18 +130,18 @@ public AccountIdentityResponse register( REREGISTRATION_IDLE_DAYS_DISTRIBUTION.record(timeSinceLastSeen.toDays()); }); - if (existingAccount.isPresent()) { - registrationLockVerificationManager.verifyRegistrationLock(existingAccount.get(), - registrationRequest.accountAttributes().getRegistrationLock(), - userAgent, RegistrationLockVerificationManager.Flow.REGISTRATION, verificationType); - } - if (!registrationRequest.skipDeviceTransfer() && existingAccount.map(Account::isTransferSupported).orElse(false)) { // If a device transfer is possible, clients must explicitly opt out of a transfer (i.e. after prompting the user) // before we'll let them create a new account "from scratch" throw new WebApplicationException(Response.status(409, "device transfer available").build()); } + if (existingAccount.isPresent()) { + registrationLockVerificationManager.verifyRegistrationLock(existingAccount.get(), + registrationRequest.accountAttributes().getRegistrationLock(), + userAgent, RegistrationLockVerificationManager.Flow.REGISTRATION, verificationType); + } + Account account = accounts.create(number, password, signalAgent, registrationRequest.accountAttributes(), existingAccount.map(Account::getBadges).orElseGet(ArrayList::new)); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java index 552aa36dd..4fee853d5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java @@ -23,9 +23,12 @@ import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Base64; +import java.util.EnumSet; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.function.Consumer; @@ -45,9 +48,10 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.CsvSource; -import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; +import org.junitpioneer.jupiter.cartesian.ArgumentSets; +import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; @@ -315,33 +319,58 @@ void recoveryPasswordManagerVerificationFalse() throws InterruptedException { } } - @ParameterizedTest - @EnumSource(RegistrationLockError.class) - void registrationLock(final RegistrationLockError error) throws Exception { + @CartesianTest + @CartesianTest.MethodFactory("registrationLockAndDeviceTransfer") + void registrationLockAndDeviceTransfer( + final boolean deviceTransferSupported, + @Nullable final RegistrationLockError error) + throws Exception { when(registrationServiceClient.getSession(any(), any())) .thenReturn( CompletableFuture.completedFuture( Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null, SESSION_EXPIRATION_SECONDS)))); - when(accountsManager.getByE164(any())).thenReturn(Optional.of(mock(Account.class))); + final Account account = mock(Account.class); + when(accountsManager.getByE164(any())).thenReturn(Optional.of(account)); + when(account.isTransferSupported()).thenReturn(deviceTransferSupported); - final Exception e = switch (error) { + final int expectedStatus; + if (deviceTransferSupported) { + expectedStatus = 409; + } else if (error != null) { + final Exception e = switch (error) { case MISMATCH -> new WebApplicationException(error.getExpectedStatus()); case RATE_LIMITED -> new RateLimitExceededException(null, true); }; - doThrow(e) - .when(registrationLockVerificationManager).verifyRegistrationLock(any(), any(), any(), any(), any()); + doThrow(e) + .when(registrationLockVerificationManager).verifyRegistrationLock(any(), any(), any(), any(), any()); + expectedStatus = error.getExpectedStatus(); + } else { + when(accountsManager.create(any(), any(), any(), any(), any())) + .thenReturn(mock(Account.class)); + expectedStatus = 200; + } final Invocation.Builder request = resources.getJerseyTest() .target("/v1/registration") .request() .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); try (Response response = request.post(Entity.json(requestJson("sessionId")))) { - assertEquals(error.getExpectedStatus(), response.getStatus()); + assertEquals(expectedStatus, response.getStatus()); } } + @SuppressWarnings("unused") + static ArgumentSets registrationLockAndDeviceTransfer() { + final Set registrationLockErrors = new HashSet<>(EnumSet.allOf(RegistrationLockError.class)); + registrationLockErrors.add(null); + + return ArgumentSets.argumentsForFirstParameter(true, false) + .argumentsForNextParameter(registrationLockErrors); + } + + @ParameterizedTest @CsvSource({ "false, false, false, 200",