diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index 63fcc49b8..11a2f4bb4 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -64,6 +64,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.KeysManager; +import org.whispersystems.textsecuregcm.util.Util; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @Path("/v2/keys") @@ -95,7 +96,7 @@ public KeysController(RateLimiters rateLimiters, KeysManager keys, AccountsManag description = "Gets the number of one-time prekeys uploaded for this device and still available") @ApiResponse(responseCode = "200", description = "Body contains the number of available one-time prekeys for the device.", useReturnTypeSchema = true) @ApiResponse(responseCode = "401", description = "Account authentication check failed.") - public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth, + public CompletableFuture getStatus(@Auth final AuthenticatedAccount auth, @QueryParam("identity") final Optional identityType) { final CompletableFuture ecCountFuture = @@ -104,7 +105,7 @@ public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth, final CompletableFuture pqCountFuture = keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId()); - return new PreKeyCount(ecCountFuture.join(), pqCountFuture.join()); + return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new); } @PUT @@ -120,7 +121,7 @@ public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth, @ApiResponse(responseCode = "401", description = "Account authentication check failed.") @ApiResponse(responseCode = "403", description = "Attempt to change identity key from a non-primary device.") @ApiResponse(responseCode = "422", description = "Invalid request format.") - public void setKeys(@Auth final DisabledPermittedAuthenticatedAccount disabledPermittedAuth, + public CompletableFuture setKeys(@Auth final DisabledPermittedAuthenticatedAccount disabledPermittedAuth, @RequestBody @NotNull @Valid final PreKeyState preKeys, @Parameter(allowEmptyValue=true) @@ -187,10 +188,8 @@ public void setKeys(@Auth final DisabledPermittedAuthenticatedAccount disabledPe }); } - keys.store( - getIdentifier(account, identityType), device.getId(), - preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getSignedPreKey(), preKeys.getPqLastResortPreKey()) - .join(); + return keys.store(getIdentifier(account, identityType), device.getId(), + preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getSignedPreKey(), preKeys.getPqLastResortPreKey()); } @GET @@ -241,27 +240,45 @@ public PreKeyResponse getDeviceKeys(@Auth Optional auth, + "." + deviceId); } - List devices = parseDeviceId(deviceId, target); - List responseItems = new ArrayList<>(devices.size()); + final List devices = parseDeviceId(deviceId, target); + final List responseItems = new ArrayList<>(devices.size()); - for (Device device : devices) { - ECSignedPreKey signedECPreKey = device.getSignedPreKey(targetIdentifier.identityType()); + final List> tasks = devices.stream().map(device -> { - ECPreKey unsignedECPreKey = keys.takeEC(targetIdentifier.uuid(), device.getId()).join().orElse(null); - KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(targetIdentifier.uuid(), device.getId()).join().orElse(null) : null; + ECSignedPreKey signedECPreKey = device.getSignedPreKey(targetIdentifier.identityType()); - compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(signedECPreKey), - keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId())); + final CompletableFuture> unsignedEcPreKeyFuture = keys.takeEC(targetIdentifier.uuid(), + device.getId()); + final CompletableFuture> pqPreKeyFuture = returnPqKey + ? keys.takePQ(targetIdentifier.uuid(), device.getId()) + : CompletableFuture.completedFuture(Optional.empty()); - if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) { - final int registrationId = switch (targetIdentifier.identityType()) { - case ACI -> device.getRegistrationId(); - case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()); - }; + return pqPreKeyFuture.thenCombine(unsignedEcPreKeyFuture, + (maybePqPreKey, maybeUnsignedEcPreKey) -> { - responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedECPreKey, unsignedECPreKey, pqPreKey)); - } - } + KEMSignedPreKey pqPreKey = pqPreKeyFuture.join().orElse(null); + ECPreKey unsignedECPreKey = unsignedEcPreKeyFuture.join().orElse(null); + + compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(signedECPreKey), + keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId())); + + if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) { + final int registrationId = switch (targetIdentifier.identityType()) { + case ACI -> device.getRegistrationId(); + case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId()); + }; + + responseItems.add( + new PreKeyResponseItem(device.getId(), registrationId, signedECPreKey, unsignedECPreKey, + pqPreKey)); + } + + return null; + }).thenRun(Util.NOOP); + }) + .toList(); + + CompletableFuture.allOf(tasks.toArray(new CompletableFuture[0])).join(); final IdentityKey identityKey = target.getIdentityKey(targetIdentifier.identityType()); @@ -282,7 +299,7 @@ public PreKeyResponse getDeviceKeys(@Auth Optional auth, @ApiResponse(responseCode = "200", description = "Indicates that new prekey was successfully stored.") @ApiResponse(responseCode = "401", description = "Account authentication check failed.") @ApiResponse(responseCode = "422", description = "Invalid request format.") - public void setSignedKey(@Auth final AuthenticatedAccount auth, + public CompletableFuture setSignedKey(@Auth final AuthenticatedAccount auth, @Valid final ECSignedPreKey signedPreKey, @QueryParam("identity") final Optional identityType) { @@ -296,7 +313,8 @@ public void setSignedKey(@Auth final AuthenticatedAccount auth, } }); - keys.storeEcSignedPreKeys(getIdentifier(auth.getAccount(), identityType), Map.of(device.getId(), signedPreKey)).join(); + return keys.storeEcSignedPreKeys(getIdentifier(auth.getAccount(), identityType), + Map.of(device.getId(), signedPreKey)); } private static boolean usePhoneNumberIdentity(final Optional identityType) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java index c230c8243..d05df2ce5 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/KeysControllerTest.java @@ -38,6 +38,7 @@ import javax.ws.rs.client.Entity; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import org.glassfish.jersey.server.ServerProperties; import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -62,6 +63,7 @@ import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper; import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper; import org.whispersystems.textsecuregcm.storage.Account; @@ -124,7 +126,9 @@ class KeysControllerTest { private static final RateLimiter rateLimiter = mock(RateLimiter.class ); private static final ResourceExtension resources = ResourceExtension.builder() + .addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE) .addProvider(AuthHelper.getAuthFilter()) + .addProvider(CompletionExceptionMapper.class) .addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of( AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class))) .setTestContainerFactory(new GrizzlyWebTestContainerFactory())