Skip to content

Commit

Permalink
Remove expiration check from Device#isEnabled()
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-signal committed Jun 7, 2024
1 parent b376458 commit 2f55747
Show file tree
Hide file tree
Showing 25 changed files with 99 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@

/**
* This {@link WebsocketRefreshRequirementProvider} observes intra-request changes in {@link Account#isEnabled()} and
* {@link Device#isEnabled()}.
* {@link Device#hasMessageDeliveryChannel()}.
* <p>
* If a change in {@link Account#isEnabled()} or any associated {@link Device#isEnabled()} is observed, then any active
* If a change in {@link Account#isEnabled()} or any associated {@link Device#hasMessageDeliveryChannel()} is observed, then any active
* WebSocket connections for the account must be closed in order for clients to get a refreshed
* {@link io.dropwizard.auth.Auth} object with a current device list.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class ContainerRequestUtil {

private static Map<Byte, Boolean> buildDevicesEnabledMap(final Account account) {
return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::isEnabled));
return account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::hasMessageDeliveryChannel));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public static void verify(Optional<Account> requestAccount,

Optional<Device> targetDevice = targetAccount.get().getDevice(deviceId);

if (targetDevice.isPresent() && targetDevice.get().isEnabled()) {
if (targetDevice.isPresent() && targetDevice.get().hasMessageDeliveryChannel()) {
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,11 @@ public CompletableFuture<Response> setSignedKey(

private List<Device> parseDeviceId(String deviceId, Account account) {
if (deviceId.equals("*")) {
return account.getDevices().stream().filter(Device::isEnabled).toList();
return account.getDevices().stream().filter(Device::hasMessageDeliveryChannel).toList();
}
try {
byte id = Byte.parseByte(deviceId);
return account.getDevice(id).filter(Device::isEnabled).map(List::of).orElse(List.of());
return account.getDevice(id).filter(Device::hasMessageDeliveryChannel).map(List::of).orElse(List.of());
} catch (NumberFormatException e) {
throw new WebApplicationException(Response.status(422).build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ static Mono<GetPreKeysResponse> getPreKeys(final Account targetAccount,
: Flux.from(Mono.justOrEmpty(targetAccount.getDevice(targetDeviceId)));

return devices
.filter(Device::isEnabled)
.filter(Device::hasMessageDeliveryChannel)
.switchIfEmpty(Mono.error(Status.NOT_FOUND.asException()))
.flatMap(device -> Flux.merge(
Mono.fromFuture(() -> keysManager.takeEC(targetAccount.getIdentifier(identityType), device.getId())),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,14 @@ private boolean allEnabledDevicesHaveCapability(final Predicate<DeviceCapabiliti
requireNotStale();

return devices.stream()
.filter(Device::isEnabled)
.filter(Device::hasMessageDeliveryChannel)
.allMatch(device -> device.getCapabilities() != null && predicate.test(device.getCapabilities()));
}

public boolean isEnabled() {
requireNotStale();

return getPrimaryDevice().isEnabled();
return getPrimaryDevice().hasMessageDeliveryChannel();
}

public byte getNextDeviceId() {
Expand All @@ -327,7 +327,7 @@ public boolean hasEnabledLinkedDevice() {

return devices.stream()
.filter(d -> Device.PRIMARY_ID != d.getId())
.anyMatch(Device::isEnabled);
.anyMatch(Device::hasMessageDeliveryChannel);
}

public void setIdentityKey(final IdentityKey identityKey) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ private void setPniKeys(final Account account,

account.getDevices()
.stream()
.filter(Device::isEnabled)
.filter(Device::hasMessageDeliveryChannel)
.forEach(device -> device.setPhoneNumberIdentityRegistrationId(pniRegistrationIds.get(device.getId())));

account.setPhoneNumberIdentityKey(pniIdentityKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import java.time.Duration;
import java.util.List;
import java.util.OptionalInt;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -199,10 +198,8 @@ public void setCapabilities(DeviceCapabilities capabilities) {
this.capabilities = capabilities;
}

public boolean isEnabled() {
boolean hasChannel = fetchesMessages || StringUtils.isNotEmpty(getApnId()) || StringUtils.isNotEmpty(getGcmId());

return (id == PRIMARY_ID && hasChannel) || (id != PRIMARY_ID && hasChannel && !isExpired());
public boolean hasMessageDeliveryChannel() {
return fetchesMessages || StringUtils.isNotEmpty(getApnId()) || StringUtils.isNotEmpty(getGcmId());
}

public boolean isExpired() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ void unlinkLeastActiveDevice(final Account account, byte destinationDeviceId) th
// its messages) is unlinked
final Device deviceToDelete = account.getDevices()
.stream()
.filter(d -> !d.isPrimary() && !d.isEnabled())
.filter(d -> !d.isPrimary() && !d.hasMessageDeliveryChannel())
.min(Comparator.comparing(Device::getLastSeen))
.or(() ->
Flux.fromIterable(account.getDevices())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public static void validateCompleteDeviceList(final Account account,
final Set<Byte> excludedDeviceIds) throws MismatchedDevicesException {

final Set<Byte> accountDeviceIds = account.getDevices().stream()
.filter(Device::isEnabled)
.filter(Device::hasMessageDeliveryChannel)
.map(Device::getId)
.filter(deviceId -> !excludedDeviceIds.contains(deviceId))
.collect(Collectors.toSet());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ boolean deviceExpired(final Device device) {

@VisibleForTesting
boolean deviceNeedsUpdate(final Device device) {
return pushFeedbackIntervalElapsed(device) && (device.isEnabled() || device.getLastSeen() > device.getUninstalledFeedbackTimestamp());
return pushFeedbackIntervalElapsed(device) && (device.hasMessageDeliveryChannel() || device.getLastSeen() > device.getUninstalledFeedbackTimestamp());
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ void testAuthenticate() {
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId);
when(device.isEnabled()).thenReturn(true);
when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials);
when(credentials.verify(password)).thenReturn(true);
when(credentials.getVersion()).thenReturn(SaltedTokenHash.CURRENT_VERSION);
Expand Down Expand Up @@ -193,7 +193,7 @@ void testAuthenticateNonDefaultDevice() {
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId);
when(device.isEnabled()).thenReturn(true);
when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials);
when(credentials.verify(password)).thenReturn(true);
when(credentials.getVersion()).thenReturn(SaltedTokenHash.CURRENT_VERSION);
Expand Down Expand Up @@ -226,7 +226,7 @@ void testAuthenticateEnabled(
when(account.getDevice(deviceId)).thenReturn(Optional.of(authenticatedDevice));
when(account.isEnabled()).thenReturn(accountEnabled);
when(authenticatedDevice.getId()).thenReturn(deviceId);
when(authenticatedDevice.isEnabled()).thenReturn(deviceEnabled);
when(authenticatedDevice.hasMessageDeliveryChannel()).thenReturn(deviceEnabled);
when(authenticatedDevice.getAuthTokenHash()).thenReturn(credentials);
when(credentials.verify(password)).thenReturn(true);
when(credentials.getVersion()).thenReturn(SaltedTokenHash.CURRENT_VERSION);
Expand Down Expand Up @@ -262,7 +262,7 @@ void testAuthenticateV1() {
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId);
when(device.isEnabled()).thenReturn(true);
when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials);
when(credentials.verify(password)).thenReturn(true);
when(credentials.getVersion()).thenReturn(SaltedTokenHash.Version.V1);
Expand Down Expand Up @@ -299,7 +299,7 @@ void testAuthenticateDeviceNotFound() {
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId);
when(device.isEnabled()).thenReturn(true);
when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials);
when(credentials.verify(password)).thenReturn(true);
when(credentials.getVersion()).thenReturn(SaltedTokenHash.CURRENT_VERSION);
Expand Down Expand Up @@ -327,7 +327,7 @@ void testAuthenticateIncorrectPassword() {
when(account.getDevice(deviceId)).thenReturn(Optional.of(device));
when(account.isEnabled()).thenReturn(true);
when(device.getId()).thenReturn(deviceId);
when(device.isEnabled()).thenReturn(true);
when(device.hasMessageDeliveryChannel()).thenReturn(true);
when(device.getAuthTokenHash()).thenReturn(credentials);
when(credentials.verify(password)).thenReturn(true);
when(credentials.getVersion()).thenReturn(SaltedTokenHash.CURRENT_VERSION);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ void setup() {
void testDeviceEnabledChanged(final Map<Byte, Boolean> initialEnabled, final Map<Byte, Boolean> finalEnabled) {
assert initialEnabled.size() == finalEnabled.size();

assert account.getPrimaryDevice().isEnabled();
assert account.getPrimaryDevice().hasMessageDeliveryChannel();

initialEnabled.forEach((deviceId, enabled) ->
DevicesHelper.setEnabled(account.getDevice(deviceId).orElseThrow(), enabled));
Expand Down Expand Up @@ -177,7 +177,7 @@ static Stream<Arguments> testDeviceEnabledChanged() {

@Test
void testDeviceAdded() {
assert account.getPrimaryDevice().isEnabled();
assert account.getPrimaryDevice().hasMessageDeliveryChannel();

final int initialDeviceCount = account.getDevices().size();

Expand All @@ -204,7 +204,7 @@ void testDeviceAdded() {
@ParameterizedTest
@ValueSource(ints = {1, 2})
void testDeviceRemoved(final int removedDeviceCount) {
assert account.getPrimaryDevice().isEnabled();
assert account.getPrimaryDevice().hasMessageDeliveryChannel();

final List<Byte> initialDeviceIds = account.getDevices().stream().map(Device::getId).toList();

Expand Down Expand Up @@ -367,7 +367,7 @@ public String setAccountEnabled(@Auth TestPrincipal principal, @PathParam("enabl

DevicesHelper.setEnabled(device, enabled);

assert device.isEnabled() == enabled;
assert device.hasMessageDeliveryChannel() == enabled;

return String.format("Set account to %s", enabled);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void testBuildDevicesEnabled() {
.forEach(id -> {
final Device device = mock(Device.class);
when(device.getId()).thenReturn((byte) id);
when(device.isEnabled()).thenReturn(id != disabledDeviceId);
when(device.hasMessageDeliveryChannel()).thenReturn(id != disabledDeviceId);
devices.add(device);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,10 @@ void setup() {
when(sampleDevice3.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID2);
when(sampleDevice4.getRegistrationId()).thenReturn(SAMPLE_REGISTRATION_ID4);
when(sampleDevice.getPhoneNumberIdentityRegistrationId()).thenReturn(OptionalInt.of(SAMPLE_PNI_REGISTRATION_ID));
when(sampleDevice.isEnabled()).thenReturn(true);
when(sampleDevice2.isEnabled()).thenReturn(true);
when(sampleDevice3.isEnabled()).thenReturn(false);
when(sampleDevice4.isEnabled()).thenReturn(true);
when(sampleDevice.hasMessageDeliveryChannel()).thenReturn(true);
when(sampleDevice2.hasMessageDeliveryChannel()).thenReturn(true);
when(sampleDevice3.hasMessageDeliveryChannel()).thenReturn(false);
when(sampleDevice4.hasMessageDeliveryChannel()).thenReturn(true);
when(sampleDevice.getId()).thenReturn(sampleDeviceId);
when(sampleDevice2.getId()).thenReturn(sampleDevice2Id);
when(sampleDevice3.getId()).thenReturn(sampleDevice3Id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -211,13 +210,13 @@ class MessageControllerTest {
@BeforeEach
void setup() {
final List<Device> singleDeviceList = List.of(
generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, System.currentTimeMillis(), System.currentTimeMillis())
generateTestDevice(SINGLE_DEVICE_ID1, SINGLE_DEVICE_REG_ID1, SINGLE_DEVICE_PNI_REG_ID1, true)
);

final List<Device> multiDeviceList = List.of(
generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_PNI_REG_ID1, System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, MULTI_DEVICE_PNI_REG_ID2, System.currentTimeMillis(), System.currentTimeMillis()),
generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, MULTI_DEVICE_PNI_REG_ID3, System.currentTimeMillis(), System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31))
generateTestDevice(MULTI_DEVICE_ID1, MULTI_DEVICE_REG_ID1, MULTI_DEVICE_PNI_REG_ID1, true),
generateTestDevice(MULTI_DEVICE_ID2, MULTI_DEVICE_REG_ID2, MULTI_DEVICE_PNI_REG_ID2, true),
generateTestDevice(MULTI_DEVICE_ID3, MULTI_DEVICE_REG_ID3, MULTI_DEVICE_PNI_REG_ID3, false)
);

Account singleDeviceAccount = AccountsHelper.generateTestAccount(SINGLE_DEVICE_RECIPIENT, SINGLE_DEVICE_UUID, SINGLE_DEVICE_PNI, singleDeviceList, UNIDENTIFIED_ACCESS_BYTES);
Expand Down Expand Up @@ -260,14 +259,12 @@ void setup() {
}

private static Device generateTestDevice(final byte id, final int registrationId, final int pniRegistrationId,
final long createdAt, final long lastSeen) {
final boolean enabled) {
final Device device = new Device();
device.setId(id);
device.setRegistrationId(registrationId);
device.setPhoneNumberIdentityRegistrationId(pniRegistrationId);
device.setCreated(createdAt);
device.setLastSeen(lastSeen);
device.setGcmId("isgcm");
device.setFetchesMessages(enabled);

return device;
}
Expand Down Expand Up @@ -1125,8 +1122,7 @@ void testManyRecipientMessage() throws Exception {
IntStream.range(1, devicesPerRecipient + 1)
.mapToObj(
d -> generateTestDevice(
(byte) d, 100 + d, 10 * d, System.currentTimeMillis(),
System.currentTimeMillis()))
(byte) d, 100 + d, 10 * d, true))
.collect(Collectors.toList());
final UUID aci = new UUID(0L, (long) i);
final UUID pni = new UUID(1L, (long) i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ void getPreKeys(final org.signal.chat.common.IdentityType grpcIdentityType) {

final Device device = mock(Device.class);
when(device.getId()).thenReturn(deviceId);
when(device.isEnabled()).thenReturn(true);
when(device.hasMessageDeliveryChannel()).thenReturn(true);

devices.put(deviceId, device);
when(targetAccount.getDevice(deviceId)).thenReturn(Optional.of(device));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,35 +53,35 @@ class AccountTest {
@BeforeEach
void setup() {
when(oldPrimaryDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(366));
when(oldPrimaryDevice.isEnabled()).thenReturn(true);
when(oldPrimaryDevice.hasMessageDeliveryChannel()).thenReturn(true);
when(oldPrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);

when(recentPrimaryDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(1));
when(recentPrimaryDevice.isEnabled()).thenReturn(true);
when(recentPrimaryDevice.hasMessageDeliveryChannel()).thenReturn(true);
when(recentPrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);

when(agingSecondaryDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(31));
when(agingSecondaryDevice.isEnabled()).thenReturn(false);
when(agingSecondaryDevice.hasMessageDeliveryChannel()).thenReturn(false);
final byte deviceId2 = 2;
when(agingSecondaryDevice.getId()).thenReturn(deviceId2);

when(recentSecondaryDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(1));
when(recentSecondaryDevice.isEnabled()).thenReturn(true);
when(recentSecondaryDevice.hasMessageDeliveryChannel()).thenReturn(true);
when(recentSecondaryDevice.getId()).thenReturn(deviceId2);

when(oldSecondaryDevice.getLastSeen()).thenReturn(System.currentTimeMillis() - TimeUnit.DAYS.toMillis(366));
when(oldSecondaryDevice.isEnabled()).thenReturn(false);
when(oldSecondaryDevice.hasMessageDeliveryChannel()).thenReturn(false);
when(oldSecondaryDevice.getId()).thenReturn(deviceId2);

when(paymentActivationCapableDevice.getCapabilities()).thenReturn(
new DeviceCapabilities(true, true, true));
when(paymentActivationCapableDevice.isEnabled()).thenReturn(true);
when(paymentActivationCapableDevice.hasMessageDeliveryChannel()).thenReturn(true);
when(paymentActivationIncapableDevice.getCapabilities()).thenReturn(
new DeviceCapabilities(true, true, false));
when(paymentActivationIncapableDevice.isEnabled()).thenReturn(true);
when(paymentActivationIncapableDevice.hasMessageDeliveryChannel()).thenReturn(true);
when(paymentActivationIncapableExpiredDevice.getCapabilities()).thenReturn(
new DeviceCapabilities(true, true, false));
when(paymentActivationIncapableExpiredDevice.isEnabled()).thenReturn(false);
when(paymentActivationIncapableExpiredDevice.hasMessageDeliveryChannel()).thenReturn(false);

}

Expand All @@ -92,10 +92,10 @@ void testIsEnabled() {
final Device disabledPrimaryDevice = mock(Device.class);
final Device disabledLinkedDevice = mock(Device.class);

when(enabledPrimaryDevice.isEnabled()).thenReturn(true);
when(enabledLinkedDevice.isEnabled()).thenReturn(true);
when(disabledPrimaryDevice.isEnabled()).thenReturn(false);
when(disabledLinkedDevice.isEnabled()).thenReturn(false);
when(enabledPrimaryDevice.hasMessageDeliveryChannel()).thenReturn(true);
when(enabledLinkedDevice.hasMessageDeliveryChannel()).thenReturn(true);
when(disabledPrimaryDevice.hasMessageDeliveryChannel()).thenReturn(false);
when(disabledLinkedDevice.hasMessageDeliveryChannel()).thenReturn(false);

when(enabledPrimaryDevice.getId()).thenReturn(Device.PRIMARY_ID);
final byte deviceId2 = 2;
Expand Down Expand Up @@ -296,23 +296,23 @@ public void testHasEnabledLinkedDevice(final Account account, final boolean expe

static Stream<Arguments> testHasEnabledLinkedDevice() {
final Device enabledPrimary = mock(Device.class);
when(enabledPrimary.isEnabled()).thenReturn(true);
when(enabledPrimary.hasMessageDeliveryChannel()).thenReturn(true);
when(enabledPrimary.getId()).thenReturn(Device.PRIMARY_ID);

final Device disabledPrimary = mock(Device.class);
when(disabledPrimary.getId()).thenReturn(Device.PRIMARY_ID);

final byte linked1DeviceId = Device.PRIMARY_ID + 1;
final Device enabledLinked1 = mock(Device.class);
when(enabledLinked1.isEnabled()).thenReturn(true);
when(enabledLinked1.hasMessageDeliveryChannel()).thenReturn(true);
when(enabledLinked1.getId()).thenReturn(linked1DeviceId);

final Device disabledLinked1 = mock(Device.class);
when(disabledLinked1.getId()).thenReturn(linked1DeviceId);

final byte linked2DeviceId = Device.PRIMARY_ID + 2;
final Device enabledLinked2 = mock(Device.class);
when(enabledLinked2.isEnabled()).thenReturn(true);
when(enabledLinked2.hasMessageDeliveryChannel()).thenReturn(true);
when(enabledLinked2.getId()).thenReturn(linked2DeviceId);

final Device disabledLinked2 = mock(Device.class);
Expand Down
Loading

0 comments on commit 2f55747

Please sign in to comment.