Skip to content

Commit

Permalink
Retire the returnPqKey flag when fetching pre-keys
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-signal committed Jan 23, 2024
1 parent 91b0c36 commit 595cc55
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -186,10 +185,6 @@ public PreKeyResponse getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
@Parameter(description="the device id of a single device to retrieve prekeys for, or `*` for all enabled devices")
@PathParam("device_id") String deviceId,

@Parameter(allowEmptyValue=true, description="whether to retrieve post-quantum prekeys")
@Schema(defaultValue="false")
@QueryParam("pq") boolean returnPqKey,

@HeaderParam(HttpHeaders.USER_AGENT) String userAgent)
throws RateLimitExceededException {

Expand Down Expand Up @@ -229,9 +224,8 @@ public PreKeyResponse getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
final CompletableFuture<Optional<ECSignedPreKey>> signedEcPreKeyFuture =
keysManager.getEcSignedPreKey(targetIdentifier.uuid(), device.getId());

final CompletableFuture<Optional<KEMSignedPreKey>> pqPreKeyFuture = returnPqKey
? keysManager.takePQ(targetIdentifier.uuid(), device.getId())
: CompletableFuture.completedFuture(Optional.empty());
final CompletableFuture<Optional<KEMSignedPreKey>> pqPreKeyFuture =
keysManager.takePQ(targetIdentifier.uuid(), device.getId());

return CompletableFuture.allOf(unsignedEcPreKeyFuture, signedEcPreKeyFuture, pqPreKeyFuture)
.thenAccept(ignored -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.UUID;
Expand Down Expand Up @@ -110,6 +109,7 @@ class KeysControllerTest {
private final KEMSignedPreKey SAMPLE_PQ_KEY = KeysHelper.signedKEMPreKey(2424, Curve.generateKeyPair());
private final KEMSignedPreKey SAMPLE_PQ_KEY2 = KeysHelper.signedKEMPreKey(6868, Curve.generateKeyPair());
private final KEMSignedPreKey SAMPLE_PQ_KEY3 = KeysHelper.signedKEMPreKey(1313, Curve.generateKeyPair());
private final KEMSignedPreKey SAMPLE_PQ_KEY4 = KeysHelper.signedKEMPreKey(7676, Curve.generateKeyPair());

private final KEMSignedPreKey SAMPLE_PQ_KEY_PNI = KeysHelper.signedKEMPreKey(8888, Curve.generateKeyPair());

Expand Down Expand Up @@ -349,11 +349,12 @@ void validSingleRequestTestV2() {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.ACI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY);
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(SAMPLE_SIGNED_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());

verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
Expand Down Expand Up @@ -415,11 +416,12 @@ void validSingleRequestByPhoneNumberIdentifierTestV2() {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI);
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_PNI_REGISTRATION_ID);
assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());

verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
Expand Down Expand Up @@ -459,11 +461,12 @@ void validSingleRequestByPhoneNumberIdentifierNoPniRegistrationIdTestV2() {
assertThat(result.getIdentityKey()).isEqualTo(existsAccount.getIdentityKey(IdentityType.PNI));
assertThat(result.getDevicesCount()).isEqualTo(1);
assertEquals(SAMPLE_KEY_PNI, result.getDevice(SAMPLE_DEVICE_ID).getPreKey());
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isNull();
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getPqPreKey()).isEqualTo(SAMPLE_PQ_KEY_PNI);
assertThat(result.getDevice(SAMPLE_DEVICE_ID).getRegistrationId()).isEqualTo(SAMPLE_REGISTRATION_ID);
assertEquals(SAMPLE_SIGNED_PNI_KEY, result.getDevice(SAMPLE_DEVICE_ID).getSignedPreKey());

verify(KEYS).takeEC(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_PNI, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_PNI, SAMPLE_DEVICE_ID);
verifyNoMoreInteractions(KEYS);
}
Expand Down Expand Up @@ -555,6 +558,15 @@ void validMultiRequestTestV2() {
when(KEYS.takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_KEY4)));

when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY)));
when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY2)));
when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID3)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY3)));
when(KEYS.takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4)).thenReturn(
CompletableFuture.completedFuture(Optional.of(SAMPLE_PQ_KEY4)));

PreKeyResponse results = resources.getJerseyTest()
.target(String.format("/v2/keys/%s/*", EXISTS_UUID))
.request()
Expand Down Expand Up @@ -597,6 +609,9 @@ void validMultiRequestTestV2() {
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).takeEC(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).takePQ(EXISTS_UUID, SAMPLE_DEVICE_ID4);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID2);
verify(KEYS).getEcSignedPreKey(EXISTS_UUID, SAMPLE_DEVICE_ID4);
Expand Down

0 comments on commit 595cc55

Please sign in to comment.