Skip to content

Commit

Permalink
Allow, but do not require, message delivery to devices without active…
Browse files Browse the repository at this point in the history
… delivery channels
  • Loading branch information
jon-signal committed Jun 25, 2024
1 parent f5ce34f commit d306caf
Show file tree
Hide file tree
Showing 24 changed files with 205 additions and 297 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
import org.whispersystems.textsecuregcm.util.Pair;

/**
* This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in {@link Account#isEnabled()} and
* This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in
* {@link Device#hasMessageDeliveryChannel()}.
* <p>
* If a change in {@link Account#isEnabled()} or any associated {@link Device#hasMessageDeliveryChannel()} is observed, then any active
* WebSocket connections for the account must be closed in order for clients to get a refreshed
* {@link io.dropwizard.auth.Auth} object with a current device list.
* If a change in any associated {@link Device#hasMessageDeliveryChannel()} is observed, then any active WebSocket
* connections for the account must be closed in order for clients to get a refreshed {@link io.dropwizard.auth.Auth}
* object with a current device list.
*
* @see AuthenticatedAccount
*/
Expand All @@ -48,9 +48,8 @@ public AuthEnablementRefreshRequirementProvider(final AccountsManager accountsMa
@Override
public void handleRequestFiltered(final RequestEvent requestEvent) {
if (requestEvent.getUriInfo().getMatchedResourceMethod().getInvocable().getHandlingMethod().getAnnotation(ChangesDeviceEnabledState.class) != null) {
// The authenticated principal, if any, will be available after filters have run.
// Now that the account is known, capture a snapshot of `isEnabled` for the account's devices before carrying out
// the request’s business logic.
// The authenticated principal, if any, will be available after filters have run. Now that the account is known,
// capture a snapshot of the account's devices before carrying out the request’s business logic.
ContainerRequestUtil.getAuthenticatedAccount(requestEvent.getContainerRequest()).ifPresent(account ->
setAccount(requestEvent.getContainerRequest(), account));
}
Expand All @@ -66,8 +65,8 @@ private static void setAccount(final ContainerRequest containerRequest, final Co

@Override
public List<Pair<UUID, Byte>> handleRequestFinished(final RequestEvent requestEvent) {
// Now that the request is finished, check whether `isEnabled` changed for any of the devices. If the value did
// change or if a devices was added or removed, all devices must disconnect and reauthenticate.
// Now that the request is finished, check whether `hasMessageDeliveryChannel` changed for any of the devices. If
// the value did change or if a devices was added or removed, all devices must disconnect and reauthenticate.
if (requestEvent.getContainerRequest().getProperty(DEVICES_ENABLED) != null) {

@SuppressWarnings("unchecked") final Map<Byte, Boolean> initialDevicesEnabled =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class OptionalAccess {

public static String ALL_DEVICES_SELECTOR = "*";

public static void verify(Optional<Account> requestAccount,
Optional<Anonymous> accessKey,
Optional<Account> targetAccount,
Expand All @@ -26,12 +28,12 @@ public static void verify(Optional<Account> requestAccount,
try {
verify(requestAccount, accessKey, targetAccount);

if (!deviceSelector.equals("*")) {
if (!ALL_DEVICES_SELECTOR.equals(deviceSelector)) {
byte deviceId = Byte.parseByte(deviceSelector);

Optional<Device> targetDevice = targetAccount.get().getDevice(deviceId);

if (targetDevice.isPresent() && targetDevice.get().hasMessageDeliveryChannel()) {
if (targetDevice.isPresent()) {
return;
}

Expand All @@ -48,11 +50,10 @@ public static void verify(Optional<Account> requestAccount,

public static void verify(Optional<Account> requestAccount,
Optional<Anonymous> accessKey,
Optional<Account> targetAccount)
{
Optional<Account> targetAccount) {
if (requestAccount.isPresent()) {
// Authenticated requests are never unauthorized; if the target exists and is enabled, return OK, otherwise throw not-found.
if (targetAccount.isPresent() && targetAccount.get().isEnabled()) {
// Authenticated requests are never unauthorized; if the target exists, return OK, otherwise throw not-found.
if (targetAccount.isPresent()) {
return;
} else {
throw new NotFoundException();
Expand All @@ -63,7 +64,7 @@ public static void verify(Optional<Account> requestAccount,
// has unrestricted unidentified access, callers need to supply a fake access key. Likewise, if
// the target account does not exist, we *also* report unauthorized here (*not* not-found,
// since that would provide a free exists check).
if (accessKey.isEmpty() || !targetAccount.map(Account::isEnabled).orElse(false)) {
if (accessKey.isEmpty() || targetAccount.isEmpty()) {
throw new NotAuthorizedException(Response.Status.UNAUTHORIZED);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ public CompletableFuture<Response> checkKeys(
@ApiResponse(responseCode = "200", description = "Indicates at least one prekey was available for at least one requested device.", useReturnTypeSchema = true)
@ApiResponse(responseCode = "400", description = "A group send endorsement and other authorization (account authentication or unidentified-access key) were both provided.")
@ApiResponse(responseCode = "401", description = "Account authentication check failed and unidentified-access key or group send endorsement token was not supplied or invalid.")
@ApiResponse(responseCode = "404", description = "Requested identity or device does not exist, is not active, or has no available prekeys.")
@ApiResponse(responseCode = "404", description = "Requested identity or device does not exist or device has no available prekeys.")
@ApiResponse(responseCode = "429", description = "Rate limit exceeded.", headers = @Header(
name = "Retry-After",
description = "If present, a positive integer indicating the number of seconds before a subsequent attempt could succeed"))
Expand Down Expand Up @@ -440,11 +440,11 @@ public CompletableFuture<Response> setSignedKey(

private List<Device> parseDeviceId(String deviceId, Account account) {
if (deviceId.equals("*")) {
return account.getDevices().stream().filter(Device::hasMessageDeliveryChannel).toList();
return account.getDevices();
}
try {
byte id = Byte.parseByte(deviceId);
return account.getDevice(id).filter(Device::hasMessageDeliveryChannel).map(List::of).orElse(List.of());
return account.getDevice(id).map(List::of).orElse(List.of());
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ public Response sendMessage(@ReadOnly @Auth Optional<AuthenticatedAccount> sourc
OptionalAccess.verify(source.map(AuthenticatedAccount::getAccount), accessKey, destination);
}

boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().hasEnabledLinkedDevice();
boolean needsSync = !isSyncMessage && source.isPresent() && source.get().getAccount().getDevices().size() > 1;

// We return 200 when stories are sent to a non-existent account. Since story sends bypass OptionalAccess.verify
// we leak information about whether a destination UUID exists if we return any other code (e.g. 404) from
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ static Mono<GetPreKeysResponse> getPreKeys(final Account targetAccount,
: Flux.from(Mono.justOrEmpty(targetAccount.getDevice(targetDeviceId)));

return devices
.filter(Device::hasMessageDeliveryChannel)
.switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()))
.flatMap(device -> Flux.merge(
Mono.fromFuture(() -> keysManager.takeEC(targetAccount.getIdentifier(identityType), device.getId())),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,6 @@ private boolean allDevicesHaveCapability(final Predicate<DeviceCapabilities> pre
.allMatch(device -> device.getCapabilities() != null && predicate.test(device.getCapabilities()));
}

public boolean isEnabled() {
requireNotStale();

return getPrimaryDevice().hasMessageDeliveryChannel();
}

public byte getNextDeviceId() {
requireNotStale();

Expand All @@ -325,14 +319,6 @@ public byte getNextDeviceId() {
return candidateId;
}

public boolean hasEnabledLinkedDevice() {
requireNotStale();

return devices.stream()
.filter(d -> Device.PRIMARY_ID != d.getId())
.anyMatch(Device::hasMessageDeliveryChannel);
}

public void setIdentityKey(final IdentityKey identityKey) {
requireNotStale();

Expand Down Expand Up @@ -503,12 +489,6 @@ public void setDiscoverableByPhoneNumber(final boolean discoverableByPhoneNumber
this.discoverableByPhoneNumber = discoverableByPhoneNumber;
}

public boolean shouldBeVisibleInDirectory() {
requireNotStale();

return isEnabled() && isDiscoverableByPhoneNumber();
}

public int getVersion() {
requireNotStale();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ public void changeNumber(final Account account,
.expressionAttributeValues(Map.of(
":number", numberAttr,
":data", accountDataAttributeValue(account),
":cds", AttributeValues.fromBool(account.shouldBeVisibleInDirectory()),
":cds", AttributeValues.fromBool(account.isDiscoverableByPhoneNumber()),
":pni", pniAttr,
":version", AttributeValues.fromInt(account.getVersion()),
":version_increment", AttributeValues.fromInt(1)))
Expand Down Expand Up @@ -924,7 +924,7 @@ static UpdateAccountSpec forAccount(

final Map<String, AttributeValue> attrValues = new HashMap<>(Map.of(
":data", accountDataAttributeValue(account),
":cds", AttributeValues.fromBool(account.shouldBeVisibleInDirectory()),
":cds", AttributeValues.fromBool(account.isDiscoverableByPhoneNumber()),
":version", AttributeValues.fromInt(account.getVersion()),
":version_increment", AttributeValues.fromInt(1)));

Expand Down Expand Up @@ -1359,7 +1359,7 @@ private TransactWriteItem buildAccountPut(
ATTR_PNI_UUID, pniUuidAttr,
ATTR_ACCOUNT_DATA, accountDataAttributeValue(account),
ATTR_VERSION, AttributeValues.fromInt(account.getVersion()),
ATTR_CANONICALLY_DISCOVERABLE, AttributeValues.fromBool(account.shouldBeVisibleInDirectory())));
ATTR_CANONICALLY_DISCOVERABLE, AttributeValues.fromBool(account.isDiscoverableByPhoneNumber())));

// Add the UAK if it's in the account
account.getUnidentifiedAccessKey()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,6 @@ private void setPniKeys(final Account account,

account.getDevices()
.stream()
.filter(Device::hasMessageDeliveryChannel)
.forEach(device -> device.setPhoneNumberIdentityRegistrationId(pniRegistrationIds.get(device.getId())));

account.setPhoneNumberIdentityKey(pniIdentityKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,19 @@ public static void validateCompleteDeviceList(final Account account,
final Set<Byte> excludedDeviceIds) throws MismatchedDevicesException {

final Set<Byte> accountDeviceIds = account.getDevices().stream()
.filter(Device::hasMessageDeliveryChannel)
.map(Device::getId)
.filter(deviceId -> !excludedDeviceIds.contains(deviceId))
.collect(Collectors.toSet());

final Set<Byte> missingDeviceIds = new HashSet<>(accountDeviceIds);
missingDeviceIds.removeAll(messageDeviceIds);

// Temporarily "excuse" missing devices if they're missing a message delivery channel as a transitional measure
missingDeviceIds.removeAll(account.getDevices().stream()
.filter(device -> !device.hasMessageDeliveryChannel())
.map(Device::getId)
.collect(Collectors.toSet()));

final Set<Byte> extraDeviceIds = new HashSet<>(messageDeviceIds);
extraDeviceIds.removeAll(accountDeviceIds);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ void testAuthenticate() {
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
when(account.getUuid()).thenReturn(uuid);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId);
when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials);
Expand Down Expand Up @@ -191,7 +190,6 @@ void testAuthenticateNonDefaultDevice() {
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
when(account.getUuid()).thenReturn(uuid);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId);
when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials);
Expand Down Expand Up @@ -224,7 +222,6 @@ void testAuthenticateEnabled(
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
when(account.getUuid()).thenReturn(uuid);
when(account.getDevice(deviceId)).thenReturn(Optional.of(authenticatedDevice));
when(account.isEnabled()).thenReturn(accountEnabled);
when(authenticatedDevice.getId()).thenReturn(deviceId);
when(authenticatedDevice.hasMessageDeliveryChannel()).thenReturn(deviceEnabled);
when(authenticatedDevice.getAuthTokenHash()).thenReturn(credentials);
Expand Down Expand Up @@ -260,7 +257,6 @@ void testAuthenticateV1() {
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
when(account.getUuid()).thenReturn(uuid);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId);
when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials);
Expand Down Expand Up @@ -297,7 +293,6 @@ void testAuthenticateDeviceNotFound() {
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
when(account.getUuid()).thenReturn(uuid);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId);
when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials);
Expand Down Expand Up @@ -325,7 +320,6 @@ void testAuthenticateIncorrectPassword() {
when(accountsManager.getByAccountIdentifier(uuid)).thenReturn(Optional.of(account));
when(account.getUuid()).thenReturn(uuid);
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId);
when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials);
Expand Down
Loading

0 comments on commit d306caf

Please sign in to comment.