Skip to content

Commit

Permalink
Add an optional parameter to require atomic account creation
Browse files Browse the repository at this point in the history
By default, if a registration request has no optional fields for atomic
account creation set, the request will proceed non-atomically. If a
client sets the `atomic` field, now such a request would be rejected.
  • Loading branch information
ravi-signal committed Jul 5, 2023
1 parent b593d49 commit fedeef4
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public static TestUser newRegisteredUser(final String number) {

// register account
final RegistrationRequest registrationRequest = new RegistrationRequest(
null, registrationPassword, accountAttributes, true,
null, registrationPassword, accountAttributes, true, false,
Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());

final AccountIdentityResponse registrationResponse = apiPost("/v1/registration", registrationRequest)
Expand Down Expand Up @@ -113,6 +113,7 @@ public static TestUser newRegisteredUserAtomic(final String number) {
registrationPassword,
accountAttributes,
true,
true,
Optional.of(new IdentityKey(aciIdentityKeyPair.getPublicKey())),
Optional.of(new IdentityKey(pniIdentityKeyPair.getPublicKey())),
Optional.of(generateSignedECPreKey(1, aciIdentityKeyPair)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ the calling device and the device associated with the existing account (if any).
""")
boolean skipDeviceTransfer,

@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
If true, indicates that this is a request for "atomic" registration. If any properties
needed for atomic account creation are not present, the request will fail. If false,
atomic account creation can still occur, but only if all required fields are present.
""")
boolean requireAtomic,

@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, description = """
The ACI-associated identity key for the account, encoded as a base64 string. If
provided, an account will be created "atomically," and all other properties needed for
Expand Down Expand Up @@ -78,6 +85,7 @@ public RegistrationRequest(@JsonProperty("sessionId") String sessionId,
@JsonProperty("recoveryPassword") byte[] recoveryPassword,
@JsonProperty("accountAttributes") AccountAttributes accountAttributes,
@JsonProperty("skipDeviceTransfer") boolean skipDeviceTransfer,
@JsonProperty("requireAtomic") boolean requireAtomic,
@JsonProperty("aciIdentityKey") Optional<IdentityKey> aciIdentityKey,
@JsonProperty("pniIdentityKey") Optional<IdentityKey> pniIdentityKey,
@JsonProperty("aciSignedPreKey") Optional<@Valid ECSignedPreKey> aciSignedPreKey,
Expand All @@ -90,7 +98,7 @@ public RegistrationRequest(@JsonProperty("sessionId") String sessionId,
// This may seem a little verbose, but at the time of writing, Jackson struggles with `@JsonUnwrapped` members in
// records, and this is a workaround. Please see
// https://github.com/FasterXML/jackson-databind/issues/3726#issuecomment-1525396869 for additional context.
this(sessionId, recoveryPassword, accountAttributes, skipDeviceTransfer, aciIdentityKey, pniIdentityKey,
this(sessionId, recoveryPassword, accountAttributes, skipDeviceTransfer, requireAtomic, aciIdentityKey, pniIdentityKey,
new DeviceActivationRequest(aciSignedPreKey, pniSignedPreKey, aciPqLastResortPreKey, pniPqLastResortPreKey, apnToken, gcmToken));
}

Expand Down Expand Up @@ -122,7 +130,7 @@ && deviceActivationRequest().pniSignedPreKey().isEmpty()
&& deviceActivationRequest().aciPqLastResortPreKey().isEmpty()
&& deviceActivationRequest().pniPqLastResortPreKey().isEmpty();

return supportsAtomicAccountCreation() || hasNoAtomicAccountCreationParameters;
return supportsAtomicAccountCreation() || (!requireAtomic() && hasNoAtomicAccountCreationParameters);
}

public boolean supportsAtomicAccountCreation() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.signal.libsignal.protocol.IdentityKey;
import org.signal.libsignal.protocol.ecc.Curve;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
Expand Down Expand Up @@ -117,10 +118,10 @@ void setUp() {

@Test
public void testRegistrationRequest() throws Exception {
assertFalse(new RegistrationRequest("", new byte[0], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertFalse(new RegistrationRequest("some", new byte[32], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertTrue(new RegistrationRequest("", new byte[32], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertTrue(new RegistrationRequest("some", new byte[0], new AccountAttributes(), true, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertFalse(new RegistrationRequest("", new byte[0], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertFalse(new RegistrationRequest("some", new byte[32], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertTrue(new RegistrationRequest("", new byte[32], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
assertTrue(new RegistrationRequest("some", new byte[0], new AccountAttributes(), true, false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()).isValid());
}

@Test
Expand Down Expand Up @@ -447,6 +448,7 @@ static Stream<Arguments> atomicAccountCreationConflictingChannel() {
new byte[0],
fetchesMessagesAccountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
Expand All @@ -461,6 +463,7 @@ static Stream<Arguments> atomicAccountCreationConflictingChannel() {
new byte[0],
fetchesMessagesAccountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
Expand All @@ -475,6 +478,7 @@ static Stream<Arguments> atomicAccountCreationConflictingChannel() {
new byte[0],
pushAccountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
Expand Down Expand Up @@ -533,6 +537,7 @@ static Stream<Arguments> atomicAccountCreationPartialSignedPreKeys() {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
Expand All @@ -547,6 +552,7 @@ static Stream<Arguments> atomicAccountCreationPartialSignedPreKeys() {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
Optional.empty(),
Expand All @@ -561,6 +567,7 @@ static Stream<Arguments> atomicAccountCreationPartialSignedPreKeys() {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
Expand All @@ -575,6 +582,7 @@ static Stream<Arguments> atomicAccountCreationPartialSignedPreKeys() {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
Expand All @@ -589,6 +597,7 @@ static Stream<Arguments> atomicAccountCreationPartialSignedPreKeys() {
new byte[0],
accountAttributes,
true,
false,
Optional.empty(),
pniIdentityKey,
aciSignedPreKey,
Expand All @@ -603,6 +612,7 @@ static Stream<Arguments> atomicAccountCreationPartialSignedPreKeys() {
new byte[0],
accountAttributes,
true,
false,
aciIdentityKey,
Optional.empty(),
aciSignedPreKey,
Expand Down Expand Up @@ -686,6 +696,43 @@ void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest,
() -> verify(device, never()).setGcmId(any()));
}

@ParameterizedTest
@ValueSource(booleans = {false, true})
void nonAtomicAccountCreationWithNoAtomicFields(boolean requireAtomic) throws InterruptedException {
when(registrationServiceClient.getSession(any(), any()))
.thenReturn(
CompletableFuture.completedFuture(
Optional.of(new RegistrationServiceSession(new byte[16], NUMBER, true, null, null, null,
SESSION_EXPIRATION_SECONDS))));

final Invocation.Builder request = resources.getJerseyTest()
.target("/v1/registration")
.request()
.header(HttpHeaders.AUTHORIZATION, AuthHelper.getProvisioningAuthHeader(NUMBER, PASSWORD));

when(accountsManager.create(any(), any(), any(), any(), any()))
.thenReturn(mock(Account.class));

RegistrationRequest reg = new RegistrationRequest("session-id",
new byte[0],
new AccountAttributes(true, 1, "test", null, true, new Device.DeviceCapabilities()),
true,
requireAtomic,
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
Optional.empty());

try (final Response response = request.post(Entity.json(reg))) {
int expected = requireAtomic ? 422 : 200;
assertEquals(expected, response.getStatus());
}
}

private static Stream<Arguments> atomicAccountCreationSuccess() {
final Optional<IdentityKey> aciIdentityKey;
final Optional<IdentityKey> pniIdentityKey;
Expand Down Expand Up @@ -715,75 +762,105 @@ private static Stream<Arguments> atomicAccountCreationSuccess() {
final String apnsVoipToken = "apns-voip-token";
final String gcmToken = "gcm-token";

return Stream.of(
// Fetches messages; no push tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
fetchesMessagesAccountAttributes,
true,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty()),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.empty(),
Optional.empty(),
Optional.empty()),

// Has APNs tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.of(apnsToken),
Optional.of(apnsVoipToken),
Optional.empty()),

// Fetches messages; no push tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.of(new GcmRegistrationId(gcmToken))),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.empty(),
Optional.empty(),
Optional.of(gcmToken)));
return Stream.of(false, true)
// try with and without strict atomic checking
.flatMap(requireAtomic ->
Stream.of(
// Fetches messages; no push tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
fetchesMessagesAccountAttributes,
true,
requireAtomic,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.empty()),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.empty(),
Optional.empty(),
Optional.empty()),

// Has APNs tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
requireAtomic,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.of(apnsToken),
Optional.of(apnsVoipToken),
Optional.empty()),

// requires the request to be atomic
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
requireAtomic,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.of(new ApnRegistrationId(apnsToken, apnsVoipToken)),
Optional.empty()),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.of(apnsToken),
Optional.of(apnsVoipToken),
Optional.empty()),

// Fetches messages; no push tokens
Arguments.of(new RegistrationRequest("session-id",
new byte[0],
pushAccountAttributes,
true,
requireAtomic,
aciIdentityKey,
pniIdentityKey,
aciSignedPreKey,
pniSignedPreKey,
aciPqLastResortPreKey,
pniPqLastResortPreKey,
Optional.empty(),
Optional.of(new GcmRegistrationId(gcmToken))),
aciIdentityKey.get(),
pniIdentityKey.get(),
aciSignedPreKey.get(),
pniSignedPreKey.get(),
aciPqLastResortPreKey.get(),
pniPqLastResortPreKey.get(),
Optional.empty(),
Optional.empty(),
Optional.of(gcmToken))));
}

/**
Expand Down

0 comments on commit fedeef4

Please sign in to comment.