diff --git a/integration-tests/src/main/java/org/signal/integration/Operations.java b/integration-tests/src/main/java/org/signal/integration/Operations.java index f603f2fda..88c037bea 100644 --- a/integration-tests/src/main/java/org/signal/integration/Operations.java +++ b/integration-tests/src/main/java/org/signal/integration/Operations.java @@ -77,7 +77,7 @@ public static TestUser newRegisteredUser(final String number) { // register account final RegistrationRequest registrationRequest = new RegistrationRequest( - null, registrationPassword, accountAttributes, true, + null, registrationPassword, accountAttributes, true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); final AccountIdentityResponse registrationResponse = apiPost("/v1/registration", registrationRequest) @@ -113,6 +113,7 @@ public static TestUser newRegisteredUserAtomic(final String number) { registrationPassword, accountAttributes, true, + true, Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey())), Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey())), Optional.of(generateSignedECPreKey(1, aciIdentityKeyPair)), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java index 19300d686..68cc6b006 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/RegistrationRequest.java @@ -50,6 +50,13 @@ the calling device and the device associated with the existing account (if any). """) boolean skipDeviceTransfer, + @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ + If true, indicates that this is a request for "atomic" registration. If any properties + needed for atomic account creation are not present, the request will fail. If false, + atomic account creation can still occur, but only if all required fields are present. + """) + boolean requireAtomic, + @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """ The ACI-associated identity key for the account, encoded as a base64 string. If provided, an account will be created "atomically," and all other properties needed for @@ -78,6 +85,7 @@ public RegistrationRequest(@JsonProperty("sessionId") String sessionId, @JsonProperty("recoveryPassword") byte[] recoveryPassword, @JsonProperty("accountAttributes") AccountAttributes accountAttributes, @JsonProperty("skipDeviceTransfer") boolean skipDeviceTransfer, + @JsonProperty("requireAtomic") boolean requireAtomic, @JsonProperty("aciIdentityKey") Optional aciIdentityKey, @JsonProperty("pniIdentityKey") Optional pniIdentityKey, @JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey, @@ -90,7 +98,7 @@ public RegistrationRequest(@JsonProperty("sessionId") String sessionId, // This may seem a little verbose, but at the time of writing, Jackson struggles with `@JsonUnwrapped` members in // records, and this is a workaround. Please see // https://github.com/FasterXML/jackson-databind/issues/3726#issuecomment-1525396869 for additional context. - this(sessionId, recoveryPassword, accountAttributes, skipDeviceTransfer, aciIdentityKey, pniIdentityKey, + this(sessionId, recoveryPassword, accountAttributes, skipDeviceTransfer, requireAtomic, aciIdentityKey, pniIdentityKey, new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnToken, gcmToken)); } @@ -122,7 +130,7 @@ && deviceActivationRequest().pniSignedPreKey().isEmpty() && deviceActivationRequest().aciPqLastResortPreKey().isEmpty() && deviceActivationRequest().pniPqLastResortPreKey().isEmpty(); - return supportsAtomicAccountCreation() || hasNoAtomicAccountCreationParameters; + return supportsAtomicAccountCreation() || (!requireAtomic() && hasNoAtomicAccountCreationParameters); } public boolean supportsAtomicAccountCreation() { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java index 9597596cc..d775533f8 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java @@ -47,6 +47,7 @@ import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.EnumSource; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import org.signal.libsignal.protocol.IdentityKey; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECKeyPair; @@ -117,10 +118,10 @@ void setUp() { @Test public void testRegistrationRequest() throws Exception { - assertFalse(new RegistrationRequest("", new byte[0], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid()); - assertFalse(new RegistrationRequest("some", new byte[32], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid()); - assertTrue(new RegistrationRequest("", new byte[32], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid()); - assertTrue(new RegistrationRequest("some", new byte[0], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid()); + assertFalse(new RegistrationRequest("", new byte[0], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid()); + assertFalse(new RegistrationRequest("some", new byte[32], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid()); + assertTrue(new RegistrationRequest("", new byte[32], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid()); + assertTrue(new RegistrationRequest("some", new byte[0], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid()); } @Test @@ -447,6 +448,7 @@ static Stream atomicAccountCreationConflictingChannel() { new byte[0], fetchesMessagesAccountAttributes, true, + false, aciIdentityKey, pniIdentityKey, aciSignedPreKey, @@ -461,6 +463,7 @@ static Stream atomicAccountCreationConflictingChannel() { new byte[0], fetchesMessagesAccountAttributes, true, + false, aciIdentityKey, pniIdentityKey, aciSignedPreKey, @@ -475,6 +478,7 @@ static Stream atomicAccountCreationConflictingChannel() { new byte[0], pushAccountAttributes, true, + false, aciIdentityKey, pniIdentityKey, aciSignedPreKey, @@ -533,6 +537,7 @@ static Stream atomicAccountCreationPartialSignedPreKeys() { new byte[0], accountAttributes, true, + false, aciIdentityKey, pniIdentityKey, aciSignedPreKey, @@ -547,6 +552,7 @@ static Stream atomicAccountCreationPartialSignedPreKeys() { new byte[0], accountAttributes, true, + false, aciIdentityKey, pniIdentityKey, Optional.empty(), @@ -561,6 +567,7 @@ static Stream atomicAccountCreationPartialSignedPreKeys() { new byte[0], accountAttributes, true, + false, aciIdentityKey, pniIdentityKey, aciSignedPreKey, @@ -575,6 +582,7 @@ static Stream atomicAccountCreationPartialSignedPreKeys() { new byte[0], accountAttributes, true, + false, aciIdentityKey, pniIdentityKey, aciSignedPreKey, @@ -589,6 +597,7 @@ static Stream atomicAccountCreationPartialSignedPreKeys() { new byte[0], accountAttributes, true, + false, Optional.empty(), pniIdentityKey, aciSignedPreKey, @@ -603,6 +612,7 @@ static Stream atomicAccountCreationPartialSignedPreKeys() { new byte[0], accountAttributes, true, + false, aciIdentityKey, Optional.empty(), aciSignedPreKey, @@ -686,6 +696,43 @@ void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest, () -> verify(device, never()).setGcmId(any())); } + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void nonAtomicAccountCreationWithNoAtomicFields(boolean requireAtomic) throws InterruptedException { + when(registrationServiceClient.getSession(any(), any())) + .thenReturn( + CompletableFuture.completedFuture( + Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null, + SESSION_EXPIRATION_SECONDS)))); + + final Invocation.Builder request = resources.getJerseyTest() + .target("/v1/registration") + .request() + .header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD)); + + when(accountsManager.create(any(), any(), any(), any(), any())) + .thenReturn(mock(Account.class)); + + RegistrationRequest reg = new RegistrationRequest("session-id", + new byte[0], + new AccountAttributes(true, 1, "test", null, true, new Device.DeviceCapabilities()), + true, + requireAtomic, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + + try (final Response response = request.post(Entity.json(reg))) { + int expected = requireAtomic ? 422 : 200; + assertEquals(expected, response.getStatus()); + } + } + private static Stream atomicAccountCreationSuccess() { final Optional aciIdentityKey; final Optional pniIdentityKey; @@ -715,75 +762,105 @@ private static Stream atomicAccountCreationSuccess() { final String apnsVoipToken = "apns-voip-token"; final String gcmToken = "gcm-token"; - return Stream.of( - // Fetches messages; no push tokens - Arguments.of(new RegistrationRequest("session-id", - new byte[0], - fetchesMessagesAccountAttributes, - true, - aciIdentityKey, - pniIdentityKey, - aciSignedPreKey, - pniSignedPreKey, - aciPqLastResortPreKey, - pniPqLastResortPreKey, - Optional.empty(), - Optional.empty()), - aciIdentityKey.get(), - pniIdentityKey.get(), - aciSignedPreKey.get(), - pniSignedPreKey.get(), - aciPqLastResortPreKey.get(), - pniPqLastResortPreKey.get(), - Optional.empty(), - Optional.empty(), - Optional.empty()), - - // Has APNs tokens - Arguments.of(new RegistrationRequest("session-id", - new byte[0], - pushAccountAttributes, - true, - aciIdentityKey, - pniIdentityKey, - aciSignedPreKey, - pniSignedPreKey, - aciPqLastResortPreKey, - pniPqLastResortPreKey, - Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), - Optional.empty()), - aciIdentityKey.get(), - pniIdentityKey.get(), - aciSignedPreKey.get(), - pniSignedPreKey.get(), - aciPqLastResortPreKey.get(), - pniPqLastResortPreKey.get(), - Optional.of(apnsToken), - Optional.of(apnsVoipToken), - Optional.empty()), - - // Fetches messages; no push tokens - Arguments.of(new RegistrationRequest("session-id", - new byte[0], - pushAccountAttributes, - true, - aciIdentityKey, - pniIdentityKey, - aciSignedPreKey, - pniSignedPreKey, - aciPqLastResortPreKey, - pniPqLastResortPreKey, - Optional.empty(), - Optional.of(new GcmRegistrationId(gcmToken))), - aciIdentityKey.get(), - pniIdentityKey.get(), - aciSignedPreKey.get(), - pniSignedPreKey.get(), - aciPqLastResortPreKey.get(), - pniPqLastResortPreKey.get(), - Optional.empty(), - Optional.empty(), - Optional.of(gcmToken))); + return Stream.of(false, true) + // try with and without strict atomic checking + .flatMap(requireAtomic -> + Stream.of( + // Fetches messages; no push tokens + Arguments.of(new RegistrationRequest("session-id", + new byte[0], + fetchesMessagesAccountAttributes, + true, + requireAtomic, + aciIdentityKey, + pniIdentityKey, + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey, + Optional.empty(), + Optional.empty()), + aciIdentityKey.get(), + pniIdentityKey.get(), + aciSignedPreKey.get(), + pniSignedPreKey.get(), + aciPqLastResortPreKey.get(), + pniPqLastResortPreKey.get(), + Optional.empty(), + Optional.empty(), + Optional.empty()), + + // Has APNs tokens + Arguments.of(new RegistrationRequest("session-id", + new byte[0], + pushAccountAttributes, + true, + requireAtomic, + aciIdentityKey, + pniIdentityKey, + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey, + Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), + Optional.empty()), + aciIdentityKey.get(), + pniIdentityKey.get(), + aciSignedPreKey.get(), + pniSignedPreKey.get(), + aciPqLastResortPreKey.get(), + pniPqLastResortPreKey.get(), + Optional.of(apnsToken), + Optional.of(apnsVoipToken), + Optional.empty()), + + // requires the request to be atomic + Arguments.of(new RegistrationRequest("session-id", + new byte[0], + pushAccountAttributes, + true, + requireAtomic, + aciIdentityKey, + pniIdentityKey, + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey, + Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)), + Optional.empty()), + aciIdentityKey.get(), + pniIdentityKey.get(), + aciSignedPreKey.get(), + pniSignedPreKey.get(), + aciPqLastResortPreKey.get(), + pniPqLastResortPreKey.get(), + Optional.of(apnsToken), + Optional.of(apnsVoipToken), + Optional.empty()), + + // Fetches messages; no push tokens + Arguments.of(new RegistrationRequest("session-id", + new byte[0], + pushAccountAttributes, + true, + requireAtomic, + aciIdentityKey, + pniIdentityKey, + aciSignedPreKey, + pniSignedPreKey, + aciPqLastResortPreKey, + pniPqLastResortPreKey, + Optional.empty(), + Optional.of(new GcmRegistrationId(gcmToken))), + aciIdentityKey.get(), + pniIdentityKey.get(), + aciSignedPreKey.get(), + pniSignedPreKey.get(), + aciPqLastResortPreKey.get(), + pniPqLastResortPreKey.get(), + Optional.empty(), + Optional.empty(), + Optional.of(gcmToken)))); } /**