Skip to content

Commit

Permalink
Convert some KeysController methods return CompletableFutures
Browse files Browse the repository at this point in the history
  • Loading branch information
eager-signal committed Aug 24, 2023
1 parent f181397 commit d338ba5
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.KeysManager;
import org.whispersystems.textsecuregcm.util.Util;

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@Path("/v2/keys")
Expand Down Expand Up @@ -95,7 +96,7 @@ public KeysController(RateLimiters rateLimiters, KeysManager keys, AccountsManag
description = "Gets the number of one-time prekeys uploaded for this device and still available")
@ApiResponse(responseCode = "200", description = "Body contains the number of available one-time prekeys for the device.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth,
public CompletableFuture<PreKeyCount> getStatus(@Auth final AuthenticatedAccount auth,
@QueryParam("identity") final Optional<String> identityType) {

final CompletableFuture<Integer> ecCountFuture =
Expand All @@ -104,7 +105,7 @@ public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth,
final CompletableFuture<Integer> pqCountFuture =
keys.getPqCount(getIdentifier(auth.getAccount(), identityType), auth.getAuthenticatedDevice().getId());

return new PreKeyCount(ecCountFuture.join(), pqCountFuture.join());
return ecCountFuture.thenCombine(pqCountFuture, PreKeyCount::new);
}

@PUT
Expand All @@ -120,7 +121,7 @@ public PreKeyCount getStatus(@Auth final AuthenticatedAccount auth,
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "403", description = "Attempt to change identity key from a non-primary device.")
@ApiResponse(responseCode = "422", description = "Invalid request format.")
public void setKeys(@Auth final DisabledPermittedAuthenticatedAccount disabledPermittedAuth,
public CompletableFuture<Void> setKeys(@Auth final DisabledPermittedAuthenticatedAccount disabledPermittedAuth,
@RequestBody @NotNull @Valid final PreKeyState preKeys,

@Parameter(allowEmptyValue=true)
Expand Down Expand Up @@ -187,10 +188,8 @@ public void setKeys(@Auth final DisabledPermittedAuthenticatedAccount disabledPe
});
}

keys.store(
getIdentifier(account, identityType), device.getId(),
preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getSignedPreKey(), preKeys.getPqLastResortPreKey())
.join();
return keys.store(getIdentifier(account, identityType), device.getId(),
preKeys.getPreKeys(), preKeys.getPqPreKeys(), preKeys.getSignedPreKey(), preKeys.getPqLastResortPreKey());
}

@GET
Expand Down Expand Up @@ -241,27 +240,45 @@ public PreKeyResponse getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
+ "." + deviceId);
}

List<Device> devices = parseDeviceId(deviceId, target);
List<PreKeyResponseItem> responseItems = new ArrayList<>(devices.size());
final List<Device> devices = parseDeviceId(deviceId, target);
final List<PreKeyResponseItem> responseItems = new ArrayList<>(devices.size());

for (Device device : devices) {
ECSignedPreKey signedECPreKey = device.getSignedPreKey(targetIdentifier.identityType());
final List<CompletableFuture<Void>> tasks = devices.stream().map(device -> {

ECPreKey unsignedECPreKey = keys.takeEC(targetIdentifier.uuid(), device.getId()).join().orElse(null);
KEMSignedPreKey pqPreKey = returnPqKey ? keys.takePQ(targetIdentifier.uuid(), device.getId()).join().orElse(null) : null;
ECSignedPreKey signedECPreKey = device.getSignedPreKey(targetIdentifier.identityType());

compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(signedECPreKey),
keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId()));
final CompletableFuture<Optional<ECPreKey>> unsignedEcPreKeyFuture = keys.takeEC(targetIdentifier.uuid(),
device.getId());
final CompletableFuture<Optional<KEMSignedPreKey>> pqPreKeyFuture = returnPqKey
? keys.takePQ(targetIdentifier.uuid(), device.getId())
: CompletableFuture.completedFuture(Optional.empty());

if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) {
final int registrationId = switch (targetIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId());
};
return pqPreKeyFuture.thenCombine(unsignedEcPreKeyFuture,
(maybePqPreKey, maybeUnsignedEcPreKey) -> {

responseItems.add(new PreKeyResponseItem(device.getId(), registrationId, signedECPreKey, unsignedECPreKey, pqPreKey));
}
}
KEMSignedPreKey pqPreKey = pqPreKeyFuture.join().orElse(null);
ECPreKey unsignedECPreKey = unsignedEcPreKeyFuture.join().orElse(null);

compareSignedEcPreKeysExperiment.compareFutureResult(Optional.ofNullable(signedECPreKey),
keys.getEcSignedPreKey(targetIdentifier.uuid(), device.getId()));

if (signedECPreKey != null || unsignedECPreKey != null || pqPreKey != null) {
final int registrationId = switch (targetIdentifier.identityType()) {
case ACI -> device.getRegistrationId();
case PNI -> device.getPhoneNumberIdentityRegistrationId().orElse(device.getRegistrationId());
};

responseItems.add(
new PreKeyResponseItem(device.getId(), registrationId, signedECPreKey, unsignedECPreKey,
pqPreKey));
}

return null;
}).thenRun(Util.NOOP);
})
.toList();

CompletableFuture.allOf(tasks.toArray(new CompletableFuture[0])).join();

final IdentityKey identityKey = target.getIdentityKey(targetIdentifier.identityType());

Expand All @@ -282,7 +299,7 @@ public PreKeyResponse getDeviceKeys(@Auth Optional<AuthenticatedAccount> auth,
@ApiResponse(responseCode = "200", description = "Indicates that new prekey was successfully stored.")
@ApiResponse(responseCode = "401", description = "Account authentication check failed.")
@ApiResponse(responseCode = "422", description = "Invalid request format.")
public void setSignedKey(@Auth final AuthenticatedAccount auth,
public CompletableFuture<Void> setSignedKey(@Auth final AuthenticatedAccount auth,
@Valid final ECSignedPreKey signedPreKey,
@QueryParam("identity") final Optional<String> identityType) {

Expand All @@ -296,7 +313,8 @@ public void setSignedKey(@Auth final AuthenticatedAccount auth,
}
});

keys.storeEcSignedPreKeys(getIdentifier(auth.getAccount(), identityType), Map.of(device.getId(), signedPreKey)).join();
return keys.storeEcSignedPreKeys(getIdentifier(auth.getAccount(), identityType),
Map.of(device.getId(), signedPreKey));
}

private static boolean usePhoneNumberIdentity(final Optional<String> identityType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import javax.ws.rs.client.Entity;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.glassfish.jersey.server.ServerProperties;
import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
Expand All @@ -62,6 +63,7 @@
import org.whispersystems.textsecuregcm.identity.PniServiceIdentifier;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.mappers.CompletionExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper;
import org.whispersystems.textsecuregcm.mappers.ServerRejectedExceptionMapper;
import org.whispersystems.textsecuregcm.storage.Account;
Expand Down Expand Up @@ -124,7 +126,9 @@ class KeysControllerTest {
private static final RateLimiter rateLimiter = mock(RateLimiter.class );

private static final ResourceExtension resources = ResourceExtension.builder()
.addProperty(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE)
.addProvider(AuthHelper.getAuthFilter())
.addProvider(CompletionExceptionMapper.class)
.addProvider(new PolymorphicAuthValueFactoryProvider.Binder<>(ImmutableSet.of(
AuthenticatedAccount.class, DisabledPermittedAuthenticatedAccount.class)))
.setTestContainerFactory(new GrizzlyWebTestContainerFactory())
Expand Down

0 comments on commit d338ba5

Please sign in to comment.