Skip to content

Commit

Permalink
Treat the stand-alone signed pre-keys table as the source of truth fo…
Browse files Browse the repository at this point in the history
…r signed pre-keys
  • Loading branch information
jon-signal committed Dec 11, 2023
1 parent c7cc300 commit feb933b
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import org.whispersystems.textsecuregcm.entities.PreKeySignatureValidator;
import org.whispersystems.textsecuregcm.entities.SetKeysRequest;
import org.whispersystems.textsecuregcm.entities.SignedPreKey;
import org.whispersystems.textsecuregcm.experiment.Experiment;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.identity.ServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
Expand All @@ -67,7 +66,6 @@ public class KeysController {
private final RateLimiters rateLimiters;
private final KeysManager keys;
private final AccountsManager accounts;
private final Experiment compareSignedEcPreKeysExperiment = new Experiment("compareSignedEcPreKeys");

private static final CompletableFuture<?>[] EMPTY_FUTURE_ARRAY = new CompletableFuture[0];

Expand Down Expand Up @@ -237,37 +235,33 @@ public PreKeyResponse getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
final List<PreKeyResponseItem> responseItems = new ArrayList<>(devices.size());

final List<CompletableFuture<Void>> tasks = devices.stream().map(device -> {
final CompletableFuture<Optional<ECPreKey>> unsignedEcPreKeyFuture =
keys.takeEC(targetIdentifier.uuid(), device.getId());

ECSignedPreKey signedECPreKey = device.getSignedPreKey(targetIdentifier.identityType());
final CompletableFuture<Optional<ECSignedPreKey>> signedEcPreKeyFuture =
keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId());

final CompletableFuture<Optional<ECPreKey>> unsignedEcPreKeyFuture = keys.takeEC(targetIdentifier.uuid(),
device.getId());
final CompletableFuture<Optional<KEMSignedPreKey>> pqPreKeyFuture = returnPqKey
? keys.takePQ(targetIdentifier.uuid(), device.getId())
: CompletableFuture.completedFuture(Optional.empty());

return pqPreKeyFuture.thenCombine(unsignedEcPreKeyFuture,
(maybePqPreKey, maybeUnsignedEcPreKey) -> {
return CompletableFuture.allOf(unsignedEcPreKeyFuture, signedEcPreKeyFuture, pqPreKeyFuture)
.thenAccept(ignored -> {
final KEMSignedPreKey pqPreKey = pqPreKeyFuture.join().orElse(null);
final ECPreKey unsignedEcPreKey = unsignedEcPreKeyFuture.join().orElse(null);
final ECSignedPreKey signedEcPreKey = signedEcPreKeyFuture.join().orElse(null);

KEMSignedPreKey pqPreKey = pqPreKeyFuture.join().orElse(null);
ECPreKey unsignedECPreKey = unsignedEcPreKeyFuture.join().orElse(null);

compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(signedECPreKey),
keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId()));

if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) {
if (signedEcPreKey != null || unsignedEcPreKey != null || pqPreKey != null) {
final int registrationId = switch (targetIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId());
};

responseItems.add(
new PreKeyResponseItem(device.getId(), registrationId, signedECPreKey, unsignedECPreKey,
new PreKeyResponseItem(device.getId(), registrationId, signedEcPreKey, unsignedEcPreKey,
pqPreKey));
}

return null;
}).thenRun(Util.NOOP);
});
})
.toList();

Expand All @@ -278,6 +272,7 @@ public PreKeyResponse getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
if (responseItems.isEmpty()) {
throw new WebApplicationException(Response.Status.NOT_FOUND);
}

return new PreKeyResponse(identityKey, responseItems);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,41 +41,37 @@ static Mono<GetPreKeysResponse> getPreKeys(final Account targetAccount,
return devices
.filter(Device::isEnabled)
.switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()))
.flatMap(device -> {
final ECSignedPreKey ecSignedPreKey = device.getSignedPreKey(identityType);
.flatMap(device -> Flux.merge(
Mono.fromFuture(() -> keysManager.takeEC(targetAccount.getIdentifier(identityType), device.getId())),
Mono.fromFuture(() -> keysManager.getEcSignedPreKey(targetAccount.getIdentifier(identityType), device.getId())),
Mono.fromFuture(() -> keysManager.takePQ(targetAccount.getIdentifier(identityType), device.getId())))
.flatMap(Mono::justOrEmpty)
.reduce(GetPreKeysResponse.PreKeyBundle.newBuilder(), (builder, preKey) -> {
if (preKey instanceof ECPreKey ecPreKey) {
builder.setEcOneTimePreKey(EcPreKey.newBuilder()
.setKeyId(ecPreKey.keyId())
.setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey()))
.build());
} else if (preKey instanceof ECSignedPreKey ecSignedPreKey) {
builder.setEcSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(ecSignedPreKey.keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKey.signature()))
.build());
} else if (preKey instanceof KEMSignedPreKey kemSignedPreKey) {
builder.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
.setKeyId(kemSignedPreKey.keyId())
.setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey()))
.setSignature(ByteString.copyFrom(kemSignedPreKey.signature()))
.build());
} else {
throw new AssertionError("Unexpected pre-key type: " + preKey.getClass());
}

final GetPreKeysResponse.PreKeyBundle.Builder preKeyBundleBuilder = GetPreKeysResponse.PreKeyBundle.newBuilder()
.setEcSignedPreKey(EcSignedPreKey.newBuilder()
.setKeyId(ecSignedPreKey.keyId())
.setPublicKey(ByteString.copyFrom(ecSignedPreKey.serializedPublicKey()))
.setSignature(ByteString.copyFrom(ecSignedPreKey.signature()))
.build());

return Flux.merge(
Mono.fromFuture(() -> keysManager.takeEC(targetAccount.getIdentifier(identityType), device.getId())),
Mono.fromFuture(() -> keysManager.takePQ(targetAccount.getIdentifier(identityType), device.getId())))
.flatMap(Mono::justOrEmpty)
.reduce(preKeyBundleBuilder, (builder, preKey) -> {
if (preKey instanceof ECPreKey ecPreKey) {
builder.setEcOneTimePreKey(EcPreKey.newBuilder()
.setKeyId(ecPreKey.keyId())
.setPublicKey(ByteString.copyFrom(ecPreKey.serializedPublicKey()))
.build());
} else if (preKey instanceof KEMSignedPreKey kemSignedPreKey) {
preKeyBundleBuilder.setKemOneTimePreKey(KemSignedPreKey.newBuilder()
.setKeyId(kemSignedPreKey.keyId())
.setPublicKey(ByteString.copyFrom(kemSignedPreKey.serializedPublicKey()))
.setSignature(ByteString.copyFrom(kemSignedPreKey.signature()))
.build());
} else {
throw new AssertionError("Unexpected pre-key type: " + preKey.getClass());
}

return builder;
})
// Cast device IDs to `int` to match data types in the response object’s protobuf definition
.map(builder -> Tuples.of((int) device.getId(), builder.build()));
return builder;
})
// Cast device IDs to `int` to match data types in the response object’s protobuf definition
.map(builder -> Tuples.of((int) device.getId(), builder.build())))
.collectMap(Tuple2::getT1, Tuple2::getT2)
.map(preKeyBundles -> GetPreKeysResponse.newBuilder()
.setIdentityKey(ByteString.copyFrom(targetAccount.getIdentityKey(identityType).serialize()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import java.util.List;
import java.util.OptionalInt;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -238,6 +239,10 @@ public void setPhoneNumberIdentityRegistrationId(final int phoneNumberIdentityRe
this.phoneNumberIdentityRegistrationId = phoneNumberIdentityRegistrationId;
}

/**
* @deprecated Please retrieve signed pre-keys via {@link KeysManager#getEcSignedPreKey(UUID, byte)} instead
*/
@Deprecated
public ECSignedPreKey getSignedPreKey(final IdentityType identityType) {
return switch (identityType) {
case ACI -> signedPreKey;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,6 @@ void setup() {
when(sampleDevice2.isEnabled()).thenReturn(true);
when(sampleDevice3.isEnabled()).thenReturn(false);
when(sampleDevice4.isEnabled()).thenReturn(true);
when(sampleDevice.getSignedPreKey(IdentityType.ACI)).thenReturn(SAMPLE_SIGNED_KEY);
when(sampleDevice2.getSignedPreKey(IdentityType.ACI)).thenReturn(SAMPLE_SIGNED_KEY2);
when(sampleDevice3.getSignedPreKey(IdentityType.ACI)).thenReturn(SAMPLE_SIGNED_KEY3);
when(sampleDevice4.getSignedPreKey(IdentityType.ACI)).thenReturn(null);
when(sampleDevice.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY);
when(sampleDevice2.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY2);
when(sampleDevice3.getSignedPreKey(IdentityType.PNI)).thenReturn(SAMPLE_SIGNED_PNI_KEY3);
when(sampleDevice4.getSignedPreKey(IdentityType.PNI)).thenReturn(null);
when(sampleDevice.getId()).thenReturn(sampleDeviceId);
when(sampleDevice2.getId()).thenReturn(sampleDevice2Id);
when(sampleDevice3.getId()).thenReturn(sampleDevice3Id);
Expand Down Expand Up @@ -260,6 +252,24 @@ void setup() {
when(KEYS.getEcSignedPreKey(any(), anyByte())).thenReturn(CompletableFuture.completedFuture(Optional.empty()));
when(KEYS.storeEcSignedPreKeys(any(), any())).thenReturn(CompletableFutureTestUtil.almostCompletedFuture(null));

when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDeviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY)));

when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDevice2Id))
.thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY2)));

when(KEYS.getEcSignedPreKey(EXISTS_UUID, sampleDevice3Id))
.thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_KEY3)));

when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDeviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY)));

when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDevice2Id))
.thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY2)));

when(KEYS.getEcSignedPreKey(EXISTS_PNI, sampleDevice3Id))
.thenReturn(CompletableFuture.completedFuture(Optional.of(SAMPLE_SIGNED_PNI_KEY3)));

when(KEYS.takeEC(EXISTS_UUID, sampleDeviceId)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY)));
when(KEYS.takePQ(EXISTS_UUID, sampleDeviceId)).thenReturn(
Expand All @@ -272,9 +282,13 @@ void setup() {
when(KEYS.getEcCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5));
when(KEYS.getPqCount(AuthHelper.VALID_UUID, sampleDeviceId)).thenReturn(CompletableFuture.completedFuture(5));

when(AuthHelper.VALID_DEVICE.getSignedPreKey(IdentityType.ACI)).thenReturn(VALID_DEVICE_SIGNED_KEY);
when(AuthHelper.VALID_DEVICE.getSignedPreKey(IdentityType.PNI)).thenReturn(VALID_DEVICE_PNI_SIGNED_KEY);
when(AuthHelper.VALID_ACCOUNT.getIdentityKey(IdentityType.ACI)).thenReturn(null);

when(KEYS.getEcSignedPreKey(AuthHelper.VALID_UUID, AuthHelper.VALID_DEVICE.getId()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(VALID_DEVICE_SIGNED_KEY)));

when(KEYS.getEcSignedPreKey(AuthHelper.VALID_PNI, AuthHelper.VALID_DEVICE.getId()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(VALID_DEVICE_PNI_SIGNED_KEY)));
}

@AfterEach
Expand Down Expand Up @@ -365,8 +379,7 @@ void validSingleRequestTestV2() {
assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());

verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
Expand All @@ -389,8 +402,7 @@ void validSingleRequestPqTestNoPqKeysV2() {
assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());

verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
Expand All @@ -412,8 +424,7 @@ void validSingleRequestPqTestV2() {
assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());

verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
Expand All @@ -434,8 +445,7 @@ void validSingleRequestByPhoneNumberIdentifierTestV2() {
assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());

verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID);
Expand All @@ -456,8 +466,7 @@ void validSingleRequestPqByPhoneNumberIdentifierTestV2() {
assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI);
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());

verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID);
Expand All @@ -480,8 +489,7 @@ void validSingleRequestByPhoneNumberIdentifierNoPniRegistrationIdTestV2() {
assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.PNI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());

verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID);
Expand Down Expand Up @@ -516,8 +524,7 @@ void testUnidentifiedRequest() {
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertEquals(SAMPLE_PQ_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey());
assertEquals(existsAccount.getDevice(SAMPLE_DEVICE_ID).get().getSignedPreKey(IdentityType.ACI),
result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());
assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());

verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,14 @@ void getPreKeys() {
final ECSignedPreKey ecSignedPreKey = KeysHelper.signedECPreKey(2, identityKeyPair);
final KEMSignedPreKey kemSignedPreKey = KeysHelper.signedKEMPreKey(3, identityKeyPair);

when(keysManager.takeEC(identifier, Device.PRIMARY_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(ecPreKey)));
when(keysManager.takePQ(identifier, Device.PRIMARY_ID)).thenReturn(CompletableFuture.completedFuture(Optional.of(kemSignedPreKey)));
when(targetDevice.getSignedPreKey(IdentityType.ACI)).thenReturn(ecSignedPreKey);
when(keysManager.takeEC(identifier, Device.PRIMARY_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(ecPreKey)));

when(keysManager.takePQ(identifier, Device.PRIMARY_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(kemSignedPreKey)));

when(keysManager.getEcSignedPreKey(identifier, Device.PRIMARY_ID))
.thenReturn(CompletableFuture.completedFuture(Optional.of(ecSignedPreKey)));

final GetPreKeysResponse response = unauthenticatedServiceStub().getPreKeys(GetPreKeysAnonymousRequest.newBuilder()
.setUnidentifiedAccessKey(ByteString.copyFrom(unidentifiedAccessKey))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,6 @@ void getPreKeys(final org.signal.chat.common.IdentityType grpcIdentityType) {
final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
when(device.isEnabled()).thenReturn(true);
when(device.getSignedPreKey(identityType)).thenReturn(ecSignedPreKeys.get(deviceId));

devices.put(deviceId, device);
when(targetAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
Expand All @@ -505,6 +504,9 @@ void getPreKeys(final org.signal.chat.common.IdentityType grpcIdentityType) {
ecOneTimePreKeys.forEach((deviceId, preKey) -> when(keysManager.takeEC(identifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(preKey))));

ecSignedPreKeys.forEach((deviceId, preKey) -> when(keysManager.getEcSignedPreKey(identifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(preKey))));

kemPreKeys.forEach((deviceId, preKey) -> when(keysManager.takePQ(identifier, deviceId))
.thenReturn(CompletableFuture.completedFuture(Optional.of(preKey))));

Expand Down

0 comments on commit feb933b

Please sign in to comment.