Skip to content

Commit

Permalink
Refine RegistrationController logic
Browse files Browse the repository at this point in the history
Local device transfer on iOS uses the `409` status code to prompt the
transfer UI. This needs to happen before sending a `423` and locking
an existing account, since the device transfer
includes the local device database verbatim.
  • Loading branch information
eager-signal committed Sep 25, 2023
1 parent f9fabbe commit 8d1135a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,18 @@ public AccountIdentityResponse register(
REREGISTRATION_IDLE_DAYS_DISTRIBUTION.record(timeSinceLastSeen.toDays());
});

if (existingAccount.isPresent()) {
registrationLockVerificationManager.verifyRegistrationLock(existingAccount.get(),
registrationRequest.accountAttributes().getRegistrationLock(),
userAgent, RegistrationLockVerificationManager.Flow.REGISTRATION, verificationType);
}

if (!registrationRequest.skipDeviceTransfer() && existingAccount.map(Account::isTransferSupported).orElse(false)) {
// If a device transfer is possible, clients must explicitly opt out of a transfer (i.e. after prompting the user)
// before we'll let them create a new account "from scratch"
throw new WebApplicationException(Response.status(409, "device transfer available").build());
}

if (existingAccount.isPresent()) {
registrationLockVerificationManager.verifyRegistrationLock(existingAccount.get(),
registrationRequest.accountAttributes().getRegistrationLock(),
userAgent, RegistrationLockVerificationManager.Flow.REGISTRATION, verificationType);
}

Account account = accounts.create(number, password, signalAgent, registrationRequest.accountAttributes(),
existingAccount.map(Account::getBadges).orElseGet(ArrayList::new));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Base64;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
Expand All @@ -45,9 +48,10 @@
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.junitpioneer.jupiter.cartesian.ArgumentSets;
import org.junitpioneer.jupiter.cartesian.CartesianTest;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
Expand Down Expand Up @@ -315,33 +319,58 @@ void recoveryPasswordManagerVerificationFalse() throws InterruptedException {
}
}

@ParameterizedTest
@EnumSource(RegistrationLockError.class)
void registrationLock(final RegistrationLockError error) throws Exception {
@CartesianTest
@CartesianTest.MethodFactory("registrationLockAndDeviceTransfer")
void registrationLockAndDeviceTransfer(
final boolean deviceTransferSupported,
@Nullable final RegistrationLockError error)
throws Exception {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));

when(accountsManager.getByE164(any())).thenReturn(Optional.of(mock(Account.class)));
final Account account = mock(Account.class);
when(accountsManager.getByE164(any())).thenReturn(Optional.of(account));
when(account.isTransferSupported()).thenReturn(deviceTransferSupported);

final Exception e = switch (error) {
final int expectedStatus;
if (deviceTransferSupported) {
expectedStatus = 409;
} else if (error != null) {
final Exception e = switch (error) {
case MISMATCH -> new WebApplicationException(error.getExpectedStatus());
case RATE_LIMITED -> new RateLimitExceededException(null, true);
};
doThrow(e)
.when(registrationLockVerificationManager).verifyRegistrationLock(any(), any(), any(), any(), any());
doThrow(e)
.when(registrationLockVerificationManager).verifyRegistrationLock(any(), any(), any(), any(), any());
expectedStatus = error.getExpectedStatus();
} else {
when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class));
expectedStatus = 200;
}

final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));
try (Response response = request.post(Entity.json(requestJson("sessionId")))) {
assertEquals(error.getExpectedStatus(), response.getStatus());
assertEquals(expectedStatus, response.getStatus());
}
}

@SuppressWarnings("unused")
static ArgumentSets registrationLockAndDeviceTransfer() {
final Set<RegistrationLockError> registrationLockErrors = new HashSet<>(EnumSet.allOf(RegistrationLockError.class));
registrationLockErrors.add(null);

return ArgumentSets.argumentsForFirstParameter(true, false)
.argumentsForNextParameter(registrationLockErrors);
}


@ParameterizedTest
@CsvSource({
"false, false, false, 200",
Expand Down

0 comments on commit 8d1135a

Please sign in to comment.