From 2b08742c0af449fcd62640f69b9a9bfa8949714b Mon Sep 17 00:00:00 2001 From: Jon Chambers <63609320+jon-signal@users.noreply.github.com> Date: Tue, 6 Jun 2023 17:08:26 -0400 Subject: [PATCH] Create separate key stores for different kinds of pre-keys --- LICENSE | 2 +- .../textsecuregcm/WhisperServerService.java | 9 +- .../configuration/DynamoDbTables.java | 20 +- .../controllers/DeviceController.java | 6 +- .../controllers/KeysController.java | 6 +- .../controllers/RegistrationController.java | 12 +- .../entities/ChangeNumberRequest.java | 2 +- .../entities/ChangePhoneNumberRequest.java | 2 +- ...eNumberIdentityKeyDistributionRequest.java | 2 +- .../storage/AccountsManager.java | 28 +- .../textsecuregcm/storage/Keys.java | 417 ------------------ .../textsecuregcm/storage/KeysManager.java | 111 +++++ .../storage/RepeatedUseSignedPreKeyStore.java | 228 ++++++++++ .../storage/SingleUseECPreKeyStore.java | 36 ++ .../storage/SingleUseKEMPreKeyStore.java | 38 ++ .../storage/SingleUsePreKeyStore.java | 312 +++++++++++++ .../workers/AssignUsernameCommand.java | 9 +- .../workers/CommandDependencies.java | 11 +- .../workers/UnlinkDeviceCommand.java | 2 +- .../RegistrationControllerTest.java | 10 +- ...ntsManagerChangeNumberIntegrationTest.java | 4 +- ...ConcurrentModificationIntegrationTest.java | 4 +- .../storage/AccountsManagerTest.java | 56 +-- ...ccountsManagerUsernameIntegrationTest.java | 2 +- .../textsecuregcm/storage/AccountsTest.java | 2 +- .../storage/DynamoDbExtensionSchema.java | 28 +- .../storage/KeysManagerTest.java | 257 +++++++++++ .../textsecuregcm/storage/KeysTest.java | 314 ------------- .../RepeatedUseSignedPreKeyStoreTest.java | 149 +++++++ .../storage/SingleUseECPreKeyStoreTest.java | 35 ++ .../storage/SingleUseKEMPreKeyStoreTest.java | 39 ++ .../storage/SingleUsePreKeyStoreTest.java | 155 +++++++ .../controllers/DeviceControllerTest.java | 17 +- .../tests/controllers/KeysControllerTest.java | 4 +- 34 files changed, 1482 insertions(+), 847 deletions(-) delete mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java delete mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStoreTest.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java diff --git a/LICENSE b/LICENSE index be3f7b28e..33fd343a7 100644 --- a/LICENSE +++ b/LICENSE @@ -296,7 +296,7 @@ commercial, industrial or non-consumer uses, unless such uses represent the only significant mode of use of the product. "Installation Information" for a User Product means any methods, -procedures, authorization keys, or other information required to install +procedures, authorization keysManager, or other information required to install and execute modified versions of a covered work in that User Product from a modified version of its Corresponding Source. The information must suffice to ensure that the continued functioning of the modified object diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index ab7dee05c..62b809492 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -176,7 +176,7 @@ import org.whispersystems.textsecuregcm.storage.DeletedAccounts; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; import org.whispersystems.textsecuregcm.storage.IssuedReceiptsManager; -import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.MessagePersister; import org.whispersystems.textsecuregcm.storage.MessagesCache; import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; @@ -345,10 +345,11 @@ public void run(WhisperServerConfiguration config, Environment environment) thro config.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName()); Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient, config.getDynamoDbTables().getProfiles().getTableName()); - Keys keys = new Keys(dynamoDbClient, + KeysManager keys = new KeysManager( + dynamoDbAsyncClient, config.getDynamoDbTables().getEcKeys().getTableName(), - config.getDynamoDbTables().getPqKeys().getTableName(), - config.getDynamoDbTables().getPqLastResortKeys().getTableName()); + config.getDynamoDbTables().getKemKeys().getTableName(), + config.getDynamoDbTables().getKemLastResortKeys().getTableName()); MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient, config.getDynamoDbTables().getMessages().getTableName(), config.getDynamoDbTables().getMessages().getExpiration(), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbTables.java b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbTables.java index eeaede95b..48b8ffe91 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbTables.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/configuration/DynamoDbTables.java @@ -51,8 +51,8 @@ public Duration getExpiration() { private final Table deletedAccountsLock; private final IssuedReceiptsTableConfiguration issuedReceipts; private final Table ecKeys; - private final Table pqKeys; - private final Table pqLastResortKeys; + private final Table kemKeys; + private final Table kemLastResortKeys; private final TableWithExpiration messages; private final Table pendingAccounts; private final Table pendingDevices; @@ -72,8 +72,8 @@ public DynamoDbTables( @JsonProperty("deletedAccountsLock") final Table deletedAccountsLock, @JsonProperty("issuedReceipts") final IssuedReceiptsTableConfiguration issuedReceipts, @JsonProperty("ecKeys") final Table ecKeys, - @JsonProperty("pqKeys") final Table pqKeys, - @JsonProperty("pqLastResortKeys") final Table pqLastResortKeys, + @JsonProperty("pqKeys") final Table kemKeys, + @JsonProperty("pqLastResortKeys") final Table kemLastResortKeys, @JsonProperty("messages") final TableWithExpiration messages, @JsonProperty("pendingAccounts") final Table pendingAccounts, @JsonProperty("pendingDevices") final Table pendingDevices, @@ -92,8 +92,8 @@ public DynamoDbTables( this.deletedAccountsLock = deletedAccountsLock; this.issuedReceipts = issuedReceipts; this.ecKeys = ecKeys; - this.pqKeys = pqKeys; - this.pqLastResortKeys = pqLastResortKeys; + this.kemKeys = kemKeys; + this.kemLastResortKeys = kemLastResortKeys; this.messages = messages; this.pendingAccounts = pendingAccounts; this.pendingDevices = pendingDevices; @@ -140,14 +140,14 @@ public Table getEcKeys() { @NotNull @Valid - public Table getPqKeys() { - return pqKeys; + public Table getKemKeys() { + return kemKeys; } @NotNull @Valid - public Table getPqLastResortKeys() { - return pqLastResortKeys; + public Table getKemLastResortKeys() { + return kemLastResortKeys; } @NotNull diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index efb3db938..5abec485a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -51,7 +51,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; -import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager; import org.whispersystems.textsecuregcm.util.Pair; @@ -67,14 +67,14 @@ public class DeviceController { private final StoredVerificationCodeManager pendingDevices; private final AccountsManager accounts; private final MessagesManager messages; - private final Keys keys; + private final KeysManager keys; private final RateLimiters rateLimiters; private final Map maxDeviceConfiguration; public DeviceController(StoredVerificationCodeManager pendingDevices, AccountsManager accounts, MessagesManager messages, - Keys keys, + KeysManager keys, RateLimiters rateLimiters, Map maxDeviceConfiguration) { this.pendingDevices = pendingDevices; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java index b1fe45376..f4ed5eb0b 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/KeysController.java @@ -53,7 +53,7 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.storage.KeysManager; @SuppressWarnings("OptionalUsedAsFieldOrParameterType") @Path("/v2/keys") @@ -61,7 +61,7 @@ public class KeysController { private final RateLimiters rateLimiters; - private final Keys keys; + private final KeysManager keys; private final AccountsManager accounts; private static final String IDENTITY_KEY_CHANGE_COUNTER_NAME = name(KeysController.class, "identityKeyChange"); @@ -70,7 +70,7 @@ public class KeysController { private static final String IDENTITY_TYPE_TAG_NAME = "identityType"; private static final String HAS_IDENTITY_KEY_TAG_NAME = "hasIdentityKey"; - public KeysController(RateLimiters rateLimiters, Keys keys, AccountsManager accounts) { + public KeysController(RateLimiters rateLimiters, KeysManager keys, AccountsManager accounts) { this.rateLimiters = rateLimiters; this.keys = keys; this.accounts = accounts; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java index 3ddf998d0..87f55863a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/RegistrationController.java @@ -48,7 +48,7 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.util.HeaderUtils; import org.whispersystems.textsecuregcm.util.Util; @@ -74,18 +74,18 @@ public class RegistrationController { private final AccountsManager accounts; private final PhoneVerificationTokenManager phoneVerificationTokenManager; private final RegistrationLockVerificationManager registrationLockVerificationManager; - private final Keys keys; + private final KeysManager keysManager; private final RateLimiters rateLimiters; public RegistrationController(final AccountsManager accounts, final PhoneVerificationTokenManager phoneVerificationTokenManager, final RegistrationLockVerificationManager registrationLockVerificationManager, - final Keys keys, + final KeysManager keysManager, final RateLimiters rateLimiters) { this.accounts = accounts; this.phoneVerificationTokenManager = phoneVerificationTokenManager; this.registrationLockVerificationManager = registrationLockVerificationManager; - this.keys = keys; + this.keysManager = keysManager; this.rateLimiters = rateLimiters; } @@ -176,8 +176,8 @@ public AccountIdentityResponse register( registrationRequest.deviceActivationRequest().gcmToken().ifPresent(gcmRegistrationId -> device.setGcmId(gcmRegistrationId.gcmRegistrationId())); - keys.storePqLastResort(a.getUuid(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey().get())); - keys.storePqLastResort(a.getPhoneNumberIdentifier(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey().get())); + keysManager.storePqLastResort(a.getUuid(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().aciPqLastResortPreKey().get())); + keysManager.storePqLastResort(a.getPhoneNumberIdentifier(), Map.of(Device.MASTER_ID, registrationRequest.deviceActivationRequest().pniPqLastResortPreKey().get())); }); } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java index cc7c2c046..c54bfa9b8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangeNumberRequest.java @@ -43,7 +43,7 @@ public record ChangeNumberRequest( @NotEmpty byte[] pniIdentityKey, @Schema(description=""" - A list of synchronization messages to send to companion devices to supply the private keys + A list of synchronization messages to send to companion devices to supply the private keysManager associated with the new identity key and their new prekeys. Exactly one message must be supplied for each enabled device other than the sending (primary) device.""") @NotNull @Valid List<@NotNull @Valid IncomingMessage> deviceMessages, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java index d82412f85..d8f771c1d 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/ChangePhoneNumberRequest.java @@ -36,7 +36,7 @@ public record ChangePhoneNumberRequest( @Nullable byte[] pniIdentityKey, @Schema(description=""" - A list of synchronization messages to send to companion devices to supply the private keys + A list of synchronization messages to send to companion devices to supply the private keysManager associated with the new identity key and their new prekeys. Exactly one message must be supplied for each enabled device other than the sending (primary) device.""") @Nullable List deviceMessages, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java index 3c5f13e2a..58d048bcb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/PhoneNumberIdentityKeyDistributionRequest.java @@ -30,7 +30,7 @@ public record PhoneNumberIdentityKeyDistributionRequest( @NotNull @Valid @Schema(description=""" - A list of synchronization messages to send to companion devices to supply the private keys + A list of synchronization messages to send to companion devices to supply the private keysManager associated with the new identity key and their new prekeys. Exactly one message must be supplied for each enabled device other than the sending (primary) device.""") List<@NotNull @Valid IncomingMessage> deviceMessages, diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java index c37cc6ea0..e41eef878 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/AccountsManager.java @@ -90,7 +90,7 @@ public class AccountsManager { private final FaultTolerantRedisCluster cacheCluster; private final AccountLockManager accountLockManager; private final DeletedAccounts deletedAccounts; - private final Keys keys; + private final KeysManager keysManager; private final MessagesManager messagesManager; private final ProfilesManager profilesManager; private final StoredVerificationCodeManager pendingAccounts; @@ -134,7 +134,7 @@ public AccountsManager(final Accounts accounts, final FaultTolerantRedisCluster cacheCluster, final AccountLockManager accountLockManager, final DeletedAccounts deletedAccounts, - final Keys keys, + final KeysManager keysManager, final MessagesManager messagesManager, final ProfilesManager profilesManager, final StoredVerificationCodeManager pendingAccounts, @@ -150,7 +150,7 @@ public AccountsManager(final Accounts accounts, this.cacheCluster = cacheCluster; this.accountLockManager = accountLockManager; this.deletedAccounts = deletedAccounts; - this.keys = keys; + this.keysManager = keysManager; this.messagesManager = messagesManager; this.profilesManager = profilesManager; this.pendingAccounts = pendingAccounts; @@ -223,8 +223,8 @@ public Account create(final String number, // account and need to clear out messages and keys that may have been stored for the old account. if (!originalUuid.equals(actualUuid)) { messagesManager.clear(actualUuid); - keys.delete(actualUuid); - keys.delete(account.getPhoneNumberIdentifier()); + keysManager.delete(actualUuid); + keysManager.delete(account.getPhoneNumberIdentifier()); profilesManager.deleteAll(actualUuid); clientPresenceManager.disconnectAllPresencesForUuid(actualUuid); } @@ -315,13 +315,13 @@ public Account changeNumber(final Account account, updatedAccount.set(numberChangedAccount); - keys.delete(phoneNumberIdentifier); - keys.delete(originalPhoneNumberIdentifier); + keysManager.delete(phoneNumberIdentifier); + keysManager.delete(originalPhoneNumberIdentifier); if (pniPqLastResortPreKeys != null) { - keys.storePqLastResort( + keysManager.storePqLastResort( phoneNumberIdentifier, - keys.getPqEnabledDevices(uuid).stream().collect( + keysManager.getPqEnabledDevices(uuid).stream().collect( Collectors.toMap( Function.identity(), pniPqLastResortPreKeys::get))); @@ -356,10 +356,10 @@ public Account updatePniKeys(final Account account, final UUID pni = account.getPhoneNumberIdentifier(); final Account updatedAccount = update(account, a -> { return setPniKeys(a, pniIdentityKey, pniSignedPreKeys, pniRegistrationIds); }); - final List pqEnabledDeviceIDs = keys.getPqEnabledDevices(pni); - keys.delete(pni); + final List pqEnabledDeviceIDs = keysManager.getPqEnabledDevices(pni); + keysManager.delete(pni); if (pniPqLastResortPreKeys != null) { - keys.storePqLastResort(pni, pqEnabledDeviceIDs.stream().collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get))); + keysManager.storePqLastResort(pni, pqEnabledDeviceIDs.stream().collect(Collectors.toMap(Function.identity(), pniPqLastResortPreKeys::get))); } return updatedAccount; @@ -740,8 +740,8 @@ private void delete(final Account account) { account.getUuid()); profilesManager.deleteAll(account.getUuid()); - keys.delete(account.getUuid()); - keys.delete(account.getPhoneNumberIdentifier()); + keysManager.delete(account.getUuid()); + keysManager.delete(account.getPhoneNumberIdentifier()); messagesManager.clear(account.getUuid()); messagesManager.clear(account.getPhoneNumberIdentifier()); registrationRecoveryPasswordsManager.removeForNumber(account.getNumber()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java deleted file mode 100644 index 43fbcf69b..000000000 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Keys.java +++ /dev/null @@ -1,417 +0,0 @@ -/* - * Copyright 2021 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Multimap; -import com.google.common.collect.MultimapBuilder; -import com.google.common.collect.Multimaps; -import io.micrometer.core.instrument.DistributionSummary; -import io.micrometer.core.instrument.Metrics; -import io.micrometer.core.instrument.Timer; -import io.micrometer.core.instrument.Counter; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Base64; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; -import java.util.function.Function; -import java.util.stream.Collectors; -import javax.annotation.Nullable; - -import org.apache.commons.lang3.StringUtils; -import org.whispersystems.textsecuregcm.entities.PreKey; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; -import org.whispersystems.textsecuregcm.util.AttributeValues; -import software.amazon.awssdk.services.dynamodb.DynamoDbClient; -import software.amazon.awssdk.services.dynamodb.model.AttributeValue; -import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; -import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse; -import software.amazon.awssdk.services.dynamodb.model.DeleteRequest; -import software.amazon.awssdk.services.dynamodb.model.PutRequest; -import software.amazon.awssdk.services.dynamodb.model.QueryRequest; -import software.amazon.awssdk.services.dynamodb.model.QueryResponse; -import software.amazon.awssdk.services.dynamodb.model.ReturnValue; -import software.amazon.awssdk.services.dynamodb.model.Select; -import software.amazon.awssdk.services.dynamodb.model.WriteRequest; - -import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; - -public class Keys extends AbstractDynamoDbStore { - - private final String ecTableName; - private final String pqTableName; - private final String pqLastResortTableName; - - static final String KEY_ACCOUNT_UUID = "U"; - static final String KEY_DEVICE_ID_KEY_ID = "DK"; - static final String KEY_PUBLIC_KEY = "P"; - static final String KEY_SIGNATURE = "S"; - - private static final Timer STORE_KEYS_TIMER = Metrics.timer(name(Keys.class, "storeKeys")); - private static final Timer TAKE_KEY_FOR_DEVICE_TIMER = Metrics.timer(name(Keys.class, "takeKeyForDevice")); - private static final Timer GET_KEY_COUNT_TIMER = Metrics.timer(name(Keys.class, "getKeyCount")); - private static final Timer DELETE_KEYS_FOR_DEVICE_TIMER = Metrics.timer(name(Keys.class, "deleteKeysForDevice")); - private static final Timer DELETE_KEYS_FOR_ACCOUNT_TIMER = Metrics.timer(name(Keys.class, "deleteKeysForAccount")); - private static final DistributionSummary CONTESTED_KEY_DISTRIBUTION = Metrics.summary(name(Keys.class, "contestedKeys")); - private static final DistributionSummary KEY_COUNT_DISTRIBUTION = Metrics.summary(name(Keys.class, "keyCount")); - private static final Counter KEYS_EMPTY_TAKE_COUNTER = Metrics.counter(name(Keys.class, "takeKeyEmpty")); - private static final Counter TOO_MANY_LAST_RESORT_KEYS_COUNTER = Metrics.counter(name(Keys.class, "tooManyLastResortKeys")); - private static final Counter PARSE_BYTES_FROM_STRING_COUNTER = Metrics.counter(name(Keys.class, "parseByteArray"), "format", "string"); - private static final Counter READ_BYTES_FROM_BYTE_ARRAY_COUNTER = Metrics.counter(name(Keys.class, "parseByteArray"), "format", "bytes"); - - public Keys( - final DynamoDbClient dynamoDB, - final String ecTableName, - final String pqTableName, - final String pqLastResortTableName) { - super(dynamoDB); - this.ecTableName = ecTableName; - this.pqTableName = pqTableName; - this.pqLastResortTableName = pqLastResortTableName; - } - - public void store(final UUID identifier, final long deviceId, final List keys) { - store(identifier, deviceId, keys, null, null); - } - - public void store( - final UUID identifier, final long deviceId, - @Nullable final List ecKeys, - @Nullable final List pqKeys, - @Nullable final SignedPreKey pqLastResortKey) { - Multimap keys = MultimapBuilder.hashKeys().arrayListValues().build(); - List tablesToClear = new ArrayList<>(); - - if (ecKeys != null && !ecKeys.isEmpty()) { - keys.putAll(ecTableName, ecKeys); - tablesToClear.add(ecTableName); - } - if (pqKeys != null && !pqKeys.isEmpty()) { - keys.putAll(pqTableName, pqKeys); - tablesToClear.add(pqTableName); - } - if (pqLastResortKey != null) { - keys.put(pqLastResortTableName, pqLastResortKey); - tablesToClear.add(pqLastResortTableName); - } - - STORE_KEYS_TIMER.record(() -> { - delete(tablesToClear, identifier, deviceId); - - writeInBatches( - keys.entries(), - batch -> { - Multimap writes = batch.stream() - .collect( - Multimaps.toMultimap( - Map.Entry::getKey, - entry -> WriteRequest.builder() - .putRequest(PutRequest.builder() - .item(getItemFromPreKey(identifier, deviceId, entry.getValue())) - .build()) - .build(), - MultimapBuilder.hashKeys().arrayListValues()::build)); - executeTableWriteItemsUntilComplete(writes.asMap()); - }); - }); - } - - public void storePqLastResort(final UUID identifier, final Map keys) { - final AttributeValue partitionKey = getPartitionKey(identifier); - final QueryRequest queryRequest = QueryRequest.builder() - .tableName(pqLastResortTableName) - .keyConditionExpression("#uuid = :uuid") - .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID)) - .expressionAttributeValues(Map.of(":uuid", partitionKey)) - .projectionExpression(KEY_DEVICE_ID_KEY_ID) - .consistentRead(true) - .build(); - - final List writes = new ArrayList<>(2 * keys.size()); - final Map> newItems = keys.entrySet().stream() - .collect(Collectors.toMap(Map.Entry::getKey, e -> getItemFromPreKey(identifier, e.getKey(), e.getValue()))); - - for (final Map item : db().query(queryRequest).items()) { - final AttributeValue oldSortKey = item.get(KEY_DEVICE_ID_KEY_ID); - final Long oldDeviceId = oldSortKey.b().asByteBuffer().getLong(); - if (newItems.containsKey(oldDeviceId)) { - final Map replacement = newItems.get(oldDeviceId); - if (!replacement.get(KEY_DEVICE_ID_KEY_ID).equals(oldSortKey)) { - writes.add(WriteRequest.builder() - .deleteRequest(DeleteRequest.builder() - .key(Map.of( - KEY_ACCOUNT_UUID, partitionKey, - KEY_DEVICE_ID_KEY_ID, oldSortKey)) - .build()) - .build()); - } - } - } - - newItems.forEach((unusedKey, item) -> - writes.add(WriteRequest.builder().putRequest(PutRequest.builder().item(item).build()).build())); - - executeTableWriteItemsUntilComplete(Map.of(pqLastResortTableName, writes)); - } - - public Optional takeEC(final UUID identifier, final long deviceId) { - return take(ecTableName, identifier, deviceId); - } - - public Optional takePQ(final UUID identifier, final long deviceId) { - return take(pqTableName, identifier, deviceId) - .or(() -> getLastResort(identifier, deviceId)) - .map(pk -> (SignedPreKey) pk); - } - - private Optional take(final String tableName, final UUID identifier, final long deviceId) { - return TAKE_KEY_FOR_DEVICE_TIMER.record(() -> { - final AttributeValue partitionKey = getPartitionKey(identifier); - QueryRequest queryRequest = QueryRequest.builder() - .tableName(tableName) - .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") - .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) - .expressionAttributeValues(Map.of( - ":uuid", partitionKey, - ":sortprefix", getSortKeyPrefix(deviceId))) - .projectionExpression(KEY_DEVICE_ID_KEY_ID) - .consistentRead(false) - .build(); - - int contestedKeys = 0; - - try { - QueryResponse response = db().query(queryRequest); - for (Map candidate : response.items()) { - DeleteItemRequest deleteItemRequest = DeleteItemRequest.builder() - .tableName(tableName) - .key(Map.of( - KEY_ACCOUNT_UUID, partitionKey, - KEY_DEVICE_ID_KEY_ID, candidate.get(KEY_DEVICE_ID_KEY_ID))) - .returnValues(ReturnValue.ALL_OLD) - .build(); - DeleteItemResponse deleteItemResponse = db().deleteItem(deleteItemRequest); - if (deleteItemResponse.hasAttributes()) { - return Optional.of(getPreKeyFromItem(deleteItemResponse.attributes())); - } - - contestedKeys++; - } - - KEYS_EMPTY_TAKE_COUNTER.increment(); - return Optional.empty(); - } finally { - CONTESTED_KEY_DISTRIBUTION.record(contestedKeys); - } - }); - } - - @VisibleForTesting - Optional getLastResort(final UUID identifier, final long deviceId) { - final AttributeValue partitionKey = getPartitionKey(identifier); - QueryRequest queryRequest = QueryRequest.builder() - .tableName(pqLastResortTableName) - .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") - .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) - .expressionAttributeValues(Map.of( - ":uuid", partitionKey, - ":sortprefix", getSortKeyPrefix(deviceId))) - .consistentRead(false) - .select(Select.ALL_ATTRIBUTES) - .build(); - - QueryResponse response = db().query(queryRequest); - if (response.count() > 1) { - TOO_MANY_LAST_RESORT_KEYS_COUNTER.increment(); - } - return response.items().stream().findFirst().map(this::getPreKeyFromItem); - } - - public List getPqEnabledDevices(final UUID identifier) { - final AttributeValue partitionKey = getPartitionKey(identifier); - final QueryRequest queryRequest = QueryRequest.builder() - .tableName(pqLastResortTableName) - .keyConditionExpression("#uuid = :uuid") - .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID)) - .expressionAttributeValues(Map.of(":uuid", partitionKey)) - .projectionExpression(KEY_DEVICE_ID_KEY_ID) - .consistentRead(false) - .build(); - - final QueryResponse response = db().query(queryRequest); - return response.items().stream() - .map(item -> item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong()) - .toList(); - } - - public int getEcCount(final UUID identifier, final long deviceId) { - return getCount(ecTableName, identifier, deviceId); - } - - public int getPqCount(final UUID identifier, final long deviceId) { - return getCount(pqTableName, identifier, deviceId); - } - - private int getCount(final String tableName, final UUID identifier, final long deviceId) { - return GET_KEY_COUNT_TIMER.record(() -> { - QueryRequest queryRequest = QueryRequest.builder() - .tableName(tableName) - .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") - .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) - .expressionAttributeValues(Map.of( - ":uuid", getPartitionKey(identifier), - ":sortprefix", getSortKeyPrefix(deviceId))) - .select(Select.COUNT) - .consistentRead(false) - .build(); - - int keyCount = 0; - // This is very confusing, but does appear to be the intended behavior. See: - // - // - https://github.com/aws/aws-sdk-java/issues/693 - // - https://github.com/aws/aws-sdk-java/issues/915 - // - https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Query.html#Query.Count - for (final QueryResponse page : db().queryPaginator(queryRequest)) { - keyCount += page.count(); - } - KEY_COUNT_DISTRIBUTION.record(keyCount); - return keyCount; - }); - } - - public void delete(final UUID accountUuid) { - DELETE_KEYS_FOR_ACCOUNT_TIMER.record(() -> { - final QueryRequest queryRequest = QueryRequest.builder() - .keyConditionExpression("#uuid = :uuid") - .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID)) - .expressionAttributeValues(Map.of( - ":uuid", getPartitionKey(accountUuid))) - .projectionExpression(KEY_DEVICE_ID_KEY_ID) - .consistentRead(true) - .build(); - - deleteItemsForAccountMatchingQuery(List.of(ecTableName, pqTableName, pqLastResortTableName), accountUuid, queryRequest); - }); - } - - public void delete(final UUID accountUuid, final long deviceId) { - delete(List.of(ecTableName, pqTableName, pqLastResortTableName), accountUuid, deviceId); - } - - private void delete(final List tableNames, final UUID accountUuid, final long deviceId) { - DELETE_KEYS_FOR_DEVICE_TIMER.record(() -> { - final QueryRequest queryRequest = QueryRequest.builder() - .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") - .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) - .expressionAttributeValues(Map.of( - ":uuid", getPartitionKey(accountUuid), - ":sortprefix", getSortKeyPrefix(deviceId))) - .projectionExpression(KEY_DEVICE_ID_KEY_ID) - .consistentRead(true) - .build(); - - deleteItemsForAccountMatchingQuery(tableNames, accountUuid, queryRequest); - }); - } - - private void deleteItemsForAccountMatchingQuery(final List tableNames, final UUID accountUuid, final QueryRequest querySpec) { - final AttributeValue partitionKey = getPartitionKey(accountUuid); - - Multimap> itemStream = tableNames.stream() - .collect( - Multimaps.flatteningToMultimap( - Function.identity(), - tableName -> - db().query(querySpec.toBuilder().tableName(tableName).build()) - .items() - .stream(), - MultimapBuilder.hashKeys(tableNames.size()).arrayListValues()::build)); - - writeInBatches( - itemStream.entries(), - batch -> { - Multimap deletes = batch.stream() - .collect(Multimaps.toMultimap( - Map.Entry>::getKey, - entry -> WriteRequest.builder() - .deleteRequest(DeleteRequest.builder() - .key(Map.of( - KEY_ACCOUNT_UUID, partitionKey, - KEY_DEVICE_ID_KEY_ID, entry.getValue().get(KEY_DEVICE_ID_KEY_ID))) - .build()) - .build(), - MultimapBuilder.hashKeys(tableNames.size()).arrayListValues()::build)); - executeTableWriteItemsUntilComplete(deletes.asMap()); - }); - } - - private static AttributeValue getPartitionKey(final UUID accountUuid) { - return AttributeValues.fromUUID(accountUuid); - } - - private static AttributeValue getSortKey(final long deviceId, final long keyId) { - final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]); - byteBuffer.putLong(deviceId); - byteBuffer.putLong(keyId); - return AttributeValues.fromByteBuffer(byteBuffer.flip()); - } - - @VisibleForTesting - static AttributeValue getSortKeyPrefix(final long deviceId) { - final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[8]); - byteBuffer.putLong(deviceId); - return AttributeValues.fromByteBuffer(byteBuffer.flip()); - } - - private Map getItemFromPreKey(final UUID accountUuid, final long deviceId, final PreKey preKey) { - if (preKey instanceof final SignedPreKey spk) { - return Map.of( - KEY_ACCOUNT_UUID, getPartitionKey(accountUuid), - KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, spk.getKeyId()), - KEY_PUBLIC_KEY, AttributeValues.fromByteArray(spk.getPublicKey()), - KEY_SIGNATURE, AttributeValues.fromByteArray(spk.getSignature())); - } - return Map.of( - KEY_ACCOUNT_UUID, getPartitionKey(accountUuid), - KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId()), - KEY_PUBLIC_KEY, AttributeValues.fromByteArray(preKey.getPublicKey())); - } - - private PreKey getPreKeyFromItem(Map item) { - final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8); - final byte[] publicKey = extractByteArray(item.get(KEY_PUBLIC_KEY)); - - if (item.containsKey(KEY_SIGNATURE)) { - // All PQ prekeys are signed, and therefore have this attribute. Signed EC prekeys are stored - // in the Accounts table, so EC prekeys retrieved by this class are never SignedPreKeys. - return new SignedPreKey(keyId, publicKey, extractByteArray(item.get(KEY_SIGNATURE))); - } - return new PreKey(keyId, publicKey); - } - - /** - * Extracts a byte array from an {@link AttributeValue} that may be either a byte array or a base64-encoded string. - * - * @param attributeValue the {@code AttributeValue} from which to extract a byte array - * - * @return the byte array represented by the given {@code AttributeValue} - */ - @VisibleForTesting - static byte[] extractByteArray(final AttributeValue attributeValue) { - if (attributeValue.b() != null) { - READ_BYTES_FROM_BYTE_ARRAY_COUNTER.increment(); - return attributeValue.b().asByteArray(); - } else if (StringUtils.isNotBlank(attributeValue.s())) { - PARSE_BYTES_FROM_STRING_COUNTER.increment(); - return Base64.getDecoder().decode(attributeValue.s()); - } - - throw new IllegalArgumentException("Attribute value has neither a byte array nor a string value"); - } -} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java new file mode 100644 index 000000000..eac1dc5f7 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/KeysManager.java @@ -0,0 +1,111 @@ +/* + * Copyright 2021 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import com.google.common.annotations.VisibleForTesting; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import javax.annotation.Nullable; +import org.whispersystems.textsecuregcm.entities.PreKey; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; + +public class KeysManager { + + private final SingleUseECPreKeyStore ecPreKeys; + private final SingleUseKEMPreKeyStore pqPreKeys; + private final RepeatedUseSignedPreKeyStore pqLastResortKeys; + + public KeysManager( + final DynamoDbAsyncClient dynamoDbAsyncClient, + final String ecTableName, + final String pqTableName, + final String pqLastResortTableName) { + this.ecPreKeys = new SingleUseECPreKeyStore(dynamoDbAsyncClient, ecTableName); + this.pqPreKeys = new SingleUseKEMPreKeyStore(dynamoDbAsyncClient, pqTableName); + this.pqLastResortKeys = new RepeatedUseSignedPreKeyStore(dynamoDbAsyncClient, pqLastResortTableName); + } + + public void store(final UUID identifier, final long deviceId, final List keys) { + store(identifier, deviceId, keys, null, null); + } + + public void store( + final UUID identifier, final long deviceId, + @Nullable final List ecKeys, + @Nullable final List pqKeys, + @Nullable final SignedPreKey pqLastResortKey) { + + final List> storeFutures = new ArrayList<>(); + + if (ecKeys != null && !ecKeys.isEmpty()) { + storeFutures.add(ecPreKeys.store(identifier, deviceId, ecKeys)); + } + + if (pqKeys != null && !pqKeys.isEmpty()) { + storeFutures.add(pqPreKeys.store(identifier, deviceId, pqKeys)); + } + + if (pqLastResortKey != null) { + storeFutures.add(pqLastResortKeys.store(identifier, deviceId, pqLastResortKey)); + } + + CompletableFuture.allOf(storeFutures.toArray(new CompletableFuture[0])).join(); + } + + public void storePqLastResort(final UUID identifier, final Map keys) { + pqLastResortKeys.store(identifier, keys).join(); + } + + public Optional takeEC(final UUID identifier, final long deviceId) { + return ecPreKeys.take(identifier, deviceId).join(); + } + + public Optional takePQ(final UUID identifier, final long deviceId) { + return pqPreKeys.take(identifier, deviceId) + .thenCompose(maybeSingleUsePreKey -> maybeSingleUsePreKey + .map(singleUsePreKey -> CompletableFuture.completedFuture(maybeSingleUsePreKey)) + .orElseGet(() -> pqLastResortKeys.find(identifier, deviceId))).join(); + } + + @VisibleForTesting + Optional getLastResort(final UUID identifier, final long deviceId) { + return pqLastResortKeys.find(identifier, deviceId).join() + .map(signedPreKey -> signedPreKey); + } + + public List getPqEnabledDevices(final UUID identifier) { + return pqLastResortKeys.getDeviceIdsWithKeys(identifier).collectList().block(); + } + + public int getEcCount(final UUID identifier, final long deviceId) { + return ecPreKeys.getCount(identifier, deviceId).join(); + } + + public int getPqCount(final UUID identifier, final long deviceId) { + return pqPreKeys.getCount(identifier, deviceId).join(); + } + + public void delete(final UUID accountUuid) { + CompletableFuture.allOf( + ecPreKeys.delete(accountUuid), + pqPreKeys.delete(accountUuid), + pqLastResortKeys.delete(accountUuid)) + .join(); + } + + public void delete(final UUID accountUuid, final long deviceId) { + CompletableFuture.allOf( + ecPreKeys.delete(accountUuid, deviceId), + pqPreKeys.delete(accountUuid, deviceId), + pqLastResortKeys.delete(accountUuid, deviceId)) + .join(); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java new file mode 100644 index 000000000..275eba4e5 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStore.java @@ -0,0 +1,228 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.metrics.MetricsUtil; +import org.whispersystems.textsecuregcm.util.AttributeValues; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; +import software.amazon.awssdk.services.dynamodb.model.GetItemRequest; +import software.amazon.awssdk.services.dynamodb.model.Put; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.QueryRequest; +import software.amazon.awssdk.services.dynamodb.model.TransactWriteItem; +import software.amazon.awssdk.services.dynamodb.model.TransactWriteItemsRequest; + +/** + * A repeated-use signed pre-key store manages storage for pre-keys that may be used more than once. Generally, these + * are considered "last resort" keys and should only be used when a device's supply of single-use pre-keys has been + * exhausted. + *

+ * Each {@link Account} may have one or more {@link Device devices}. Each "active" (i.e. those that have completed + * provisioning and are capable of sending and receiving messages) must have exactly one "last resort" pre-key. + */ +public class RepeatedUseSignedPreKeyStore { + + private final DynamoDbAsyncClient dynamoDbAsyncClient; + private final String tableName; + + static final String KEY_ACCOUNT_UUID = "U"; + static final String KEY_DEVICE_ID = "D"; + static final String ATTR_KEY_ID = "I"; + static final String ATTR_PUBLIC_KEY = "P"; + static final String ATTR_SIGNATURE = "S"; + + private static final Timer STORE_SINGLE_KEY_TIMER = + Metrics.timer(MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "storeSingleKey")); + + private static final Timer STORE_KEY_BATCH_TIMER = + Metrics.timer(MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "storeKeyBatch")); + + private static final Timer DELETE_FOR_DEVICE_TIMER = + Metrics.timer(MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "deleteForDevice")); + + private static final Timer DELETE_FOR_ACCOUNT_TIMER = + Metrics.timer(MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "deleteForAccount")); + + private static final String FIND_KEY_TIMER_NAME = MetricsUtil.name(RepeatedUseSignedPreKeyStore.class, "findKey"); + private static final String KEY_PRESENT_TAG_NAME = "keyPresent"; + + public RepeatedUseSignedPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { + this.dynamoDbAsyncClient = dynamoDbAsyncClient; + this.tableName = tableName; + } + + /** + * Stores a repeated-use pre-key for a specific device, displacing any previously-stored repeated-use pre-key for that + * device. + * + * @param identifier the identifier for the account/identity with which the target device is associated + * @param deviceId the identifier for the device within the given account/identity + * @param signedPreKey the key to store for the target device + * + * @return a future that completes once the key has been stored + */ + public CompletableFuture store(final UUID identifier, final long deviceId, final SignedPreKey signedPreKey) { + final Timer.Sample sample = Timer.start(); + + return dynamoDbAsyncClient.putItem(PutItemRequest.builder() + .tableName(tableName) + .item(getItemFromPreKey(identifier, deviceId, signedPreKey)) + .build()) + .thenRun(() -> sample.stop(STORE_SINGLE_KEY_TIMER)); + } + + /** + * Stores repeated-use pre-keys for a collection of devices associated with a single account/identity, displacing any + * previously-stored repeated-use pre-keys for the targeted devices. Note that this method is transactional; either + * all keys will be stored or none will. + * + * @param identifier the identifier for the account/identity with which the target devices are associated + * @param signedPreKeysByDeviceId a map of device identifiers to pre-keys + * + * @return a future that completes once all keys have been stored + */ + public CompletableFuture store(final UUID identifier, final Map signedPreKeysByDeviceId) { + final Timer.Sample sample = Timer.start(); + + return dynamoDbAsyncClient.transactWriteItems(TransactWriteItemsRequest.builder() + .transactItems(signedPreKeysByDeviceId.entrySet().stream() + .map(entry -> { + final long deviceId = entry.getKey(); + final SignedPreKey signedPreKey = entry.getValue(); + + return TransactWriteItem.builder() + .put(Put.builder() + .tableName(tableName) + .item(getItemFromPreKey(identifier, deviceId, signedPreKey)) + .build()) + .build(); + }) + .toList()) + .build()) + .thenRun(() -> sample.stop(STORE_KEY_BATCH_TIMER)); + } + + /** + * Finds a repeated-use pre-key for a specific device. + * + * @param identifier the identifier for the account/identity with which the target device is associated + * @param deviceId the identifier for the device within the given account/identity + * + * @return a future that yields an optional signed pre-key if one is available for the target device or empty if no + * key could be found for the target device + */ + public CompletableFuture> find(final UUID identifier, final long deviceId) { + final Timer.Sample sample = Timer.start(); + + final CompletableFuture> findFuture = dynamoDbAsyncClient.getItem(GetItemRequest.builder() + .tableName(tableName) + .key(getPrimaryKey(identifier, deviceId)) + .consistentRead(true) + .build()) + .thenApply(response -> response.hasItem() ? Optional.of(getPreKeyFromItem(response.item())) : Optional.empty()); + + findFuture.whenComplete((maybeSignedPreKey, throwable) -> + sample.stop(Metrics.timer(FIND_KEY_TIMER_NAME, KEY_PRESENT_TAG_NAME, String.valueOf(maybeSignedPreKey != null && maybeSignedPreKey.isPresent())))); + + return findFuture; + } + + /** + * Clears all repeated-use pre-keys associated with the given account/identity. + * + * @param identifier the identifier for the account/identity for which to clear repeated-use pre-keys + * + * @return a future that completes once repeated-use pre-keys have been cleared from all devices associated with the + * target account/identity + */ + public CompletableFuture delete(final UUID identifier) { + final Timer.Sample sample = Timer.start(); + + return getDeviceIdsWithKeys(identifier) + .map(deviceId -> DeleteItemRequest.builder() + .tableName(tableName) + .key(getPrimaryKey(identifier, deviceId)) + .build()) + .flatMap(deleteItemRequest -> Mono.fromFuture(dynamoDbAsyncClient.deleteItem(deleteItemRequest))) + // Idiom: wait for everything to finish, but discard the results + .reduce(0, (a, b) -> 0) + .toFuture() + .thenRun(() -> sample.stop(DELETE_FOR_ACCOUNT_TIMER)); + } + + /** + * Removes the repeated-use pre-key associated with a specific device. + * + * @param identifier the identifier for the account/identity with which the target device is associated + * @param deviceId the identifier for the device within the given account/identity + * + * @return a future that completes once the repeated-use pre-key has been removed from the target device + */ + public CompletableFuture delete(final UUID identifier, final long deviceId) { + final Timer.Sample sample = Timer.start(); + + return dynamoDbAsyncClient.deleteItem(DeleteItemRequest.builder() + .tableName(tableName) + .key(getPrimaryKey(identifier, deviceId)) + .build()) + .thenRun(() -> sample.stop(DELETE_FOR_DEVICE_TIMER)); + } + + public Flux getDeviceIdsWithKeys(final UUID identifier) { + return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() + .tableName(tableName) + .keyConditionExpression("#uuid = :uuid") + .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID)) + .expressionAttributeValues(Map.of( + ":uuid", getPartitionKey(identifier))) + .projectionExpression(KEY_DEVICE_ID) + .consistentRead(true) + .build()) + .items()) + .map(item -> Long.parseLong(item.get(KEY_DEVICE_ID).n())); + } + + private static Map getPrimaryKey(final UUID identifier, final long deviceId) { + return Map.of( + KEY_ACCOUNT_UUID, getPartitionKey(identifier), + KEY_DEVICE_ID, getSortKey(deviceId)); + } + + private static AttributeValue getPartitionKey(final UUID accountUuid) { + return AttributeValues.fromUUID(accountUuid); + } + + private static AttributeValue getSortKey(final long deviceId) { + return AttributeValues.fromLong(deviceId); + } + + private static Map getItemFromPreKey(final UUID accountUuid, final long deviceId, final SignedPreKey signedPreKey) { + return Map.of( + KEY_ACCOUNT_UUID, getPartitionKey(accountUuid), + KEY_DEVICE_ID, getSortKey(deviceId), + ATTR_KEY_ID, AttributeValues.fromLong(signedPreKey.getKeyId()), + ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.getPublicKey()), + ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.getSignature())); + } + + private static SignedPreKey getPreKeyFromItem(final Map item) { + return new SignedPreKey( + Long.parseLong(item.get(ATTR_KEY_ID).n()), + item.get(ATTR_PUBLIC_KEY).b().asByteArray(), + item.get(ATTR_SIGNATURE).b().asByteArray()); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java new file mode 100644 index 000000000..82ec333a9 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java @@ -0,0 +1,36 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import org.whispersystems.textsecuregcm.entities.PreKey; +import org.whispersystems.textsecuregcm.util.AttributeValues; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import java.util.Map; +import java.util.UUID; + +public class SingleUseECPreKeyStore extends SingleUsePreKeyStore { + + protected SingleUseECPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { + super(dynamoDbAsyncClient, tableName); + } + + @Override + protected Map getItemFromPreKey(final UUID identifier, final long deviceId, final PreKey preKey) { + return Map.of( + KEY_ACCOUNT_UUID, getPartitionKey(identifier), + KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.getKeyId()), + ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(preKey.getPublicKey())); + } + + @Override + protected PreKey getPreKeyFromItem(final Map item) { + final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8); + final byte[] publicKey = extractByteArray(item.get(ATTR_PUBLIC_KEY)); + + return new PreKey(keyId, publicKey); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java new file mode 100644 index 000000000..a35235823 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStore.java @@ -0,0 +1,38 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.util.AttributeValues; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import java.util.Map; +import java.util.UUID; + +public class SingleUseKEMPreKeyStore extends SingleUsePreKeyStore { + + protected SingleUseKEMPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { + super(dynamoDbAsyncClient, tableName); + } + + @Override + protected Map getItemFromPreKey(final UUID identifier, final long deviceId, final SignedPreKey signedPreKey) { + return Map.of( + KEY_ACCOUNT_UUID, getPartitionKey(identifier), + KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, signedPreKey.getKeyId()), + ATTR_PUBLIC_KEY, AttributeValues.fromByteArray(signedPreKey.getPublicKey()), + ATTR_SIGNATURE, AttributeValues.fromByteArray(signedPreKey.getSignature())); + } + + @Override + protected SignedPreKey getPreKeyFromItem(final Map item) { + final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8); + final byte[] publicKey = extractByteArray(item.get(ATTR_PUBLIC_KEY)); + final byte[] signature = extractByteArray(item.get(ATTR_SIGNATURE)); + + return new SignedPreKey(keyId, publicKey, signature); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java new file mode 100644 index 000000000..1a5864f38 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java @@ -0,0 +1,312 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; + +import com.google.common.annotations.VisibleForTesting; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.DistributionSummary; +import io.micrometer.core.instrument.Metrics; +import io.micrometer.core.instrument.Timer; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.commons.lang3.StringUtils; +import org.whispersystems.textsecuregcm.entities.PreKey; +import org.whispersystems.textsecuregcm.util.AttributeValues; +import org.whispersystems.textsecuregcm.util.Util; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; +import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse; +import software.amazon.awssdk.services.dynamodb.model.PutItemRequest; +import software.amazon.awssdk.services.dynamodb.model.QueryRequest; +import software.amazon.awssdk.services.dynamodb.model.QueryResponse; +import software.amazon.awssdk.services.dynamodb.model.ReturnValue; +import software.amazon.awssdk.services.dynamodb.model.Select; + +/** + * A single-use pre-key store stores single-use pre-keys of a specific type. Keys returned by a single-use pre-key + * store's {@link #take(UUID, long)} method are guaranteed to be returned exactly once, and repeated calls will never + * yield the same key. + *

+ * Each {@link Account} may have one or more {@link Device devices}. Clients should regularly check their + * supply of single-use pre-keys (see {@link #getCount(UUID, long)}) and upload new keys when their supply runs low. In + * the event that a party wants to begin a session with a device that has no single-use pre-keys remaining, that party + * may fall back to using the device's repeated-use ("last-resort") signed pre-key instead. + */ +public abstract class SingleUsePreKeyStore { + + private final DynamoDbAsyncClient dynamoDbAsyncClient; + private final String tableName; + + private final Timer storeKeyTimer = Metrics.timer(name(getClass(), "storeKey")); + private final Timer storeKeyBatchTimer = Metrics.timer(name(getClass(), "storeKeyBatch")); + private final Timer getKeyCountTimer = Metrics.timer(name(getClass(), "getCount")); + private final Timer deleteForDeviceTimer = Metrics.timer(name(getClass(), "deleteForDevice")); + private final Timer deleteForAccountTimer = Metrics.timer(name(getClass(), "deleteForAccount")); + + final DistributionSummary keysConsideredForTakeDistributionSummary = DistributionSummary + .builder(name(getClass(), "keysConsideredForTake")) + .publishPercentiles(0.5, 0.75, 0.95, 0.99, 0.999) + .distributionStatisticExpiry(Duration.ofMinutes(10)) + .register(Metrics.globalRegistry); + + final DistributionSummary availableKeyCountDistributionSummary = DistributionSummary + .builder(name(getClass(), "availableKeyCount")) + .publishPercentiles(0.5, 0.75, 0.95, 0.99, 0.999) + .distributionStatisticExpiry(Duration.ofMinutes(10)) + .register(Metrics.globalRegistry); + + private final String takeKeyTimerName = name(getClass(), "takeKey"); + private static final String KEY_PRESENT_TAG_NAME = "keyPresent"; + + private final Counter parseBytesFromStringCounter = Metrics.counter(name(getClass(), "parseByteArray"), "format", "string"); + private final Counter readBytesFromByteArrayCounter = Metrics.counter(name(getClass(), "parseByteArray"), "format", "bytes"); + + static final String KEY_ACCOUNT_UUID = "U"; + static final String KEY_DEVICE_ID_KEY_ID = "DK"; + static final String ATTR_PUBLIC_KEY = "P"; + static final String ATTR_SIGNATURE = "S"; + + protected SingleUsePreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { + this.dynamoDbAsyncClient = dynamoDbAsyncClient; + this.tableName = tableName; + } + + /** + * Stores a batch of single-use pre-keys for a specific device. All previously-stored keys for the device are cleared + * before storing new keys. + * + * @param identifier the identifier for the account/identity with which the target device is associated + * @param deviceId the identifier for the device within the given account/identity + * @param preKeys a collection of single-use pre-keys to store for the target device + * + * @return a future that completes when all previously-stored keys have been removed and the given collection of + * pre-keys has been stored in its place + */ + public CompletableFuture store(final UUID identifier, final long deviceId, final List preKeys) { + final Timer.Sample sample = Timer.start(); + + return delete(identifier, deviceId) + .thenCompose(ignored -> CompletableFuture.allOf(preKeys.stream() + .map(preKey -> store(identifier, deviceId, preKey)) + .toList() + .toArray(new CompletableFuture[0]))) + .thenRun(() -> sample.stop(storeKeyBatchTimer)); + } + + private CompletableFuture store(final UUID identifier, final long deviceId, final K preKey) { + final Timer.Sample sample = Timer.start(); + + return dynamoDbAsyncClient.putItem(PutItemRequest.builder() + .tableName(tableName) + .item(getItemFromPreKey(identifier, deviceId, preKey)) + .build()) + .thenRun(() -> sample.stop(storeKeyTimer)); + } + + /** + * Attempts to retrieve a single-use pre-key for a specific device. Keys may only be returned by this method at most + * once; once the key is returned, it is removed from the key store and subsequent calls to this method will never + * return the same key. + * + * @param identifier the identifier for the account/identity with which the target device is associated + * @param deviceId the identifier for the device within the given account/identity + * + * @return a future that yields a single-use pre-key if one is available or empty if no single-use pre-keys are + * available for the target device + */ + public CompletableFuture> take(final UUID identifier, final long deviceId) { + final Timer.Sample sample = Timer.start(); + final AttributeValue partitionKey = getPartitionKey(identifier); + final AtomicInteger keysConsidered = new AtomicInteger(0); + + return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() + .tableName(tableName) + .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") + .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) + .expressionAttributeValues(Map.of( + ":uuid", partitionKey, + ":sortprefix", getSortKeyPrefix(deviceId))) + .projectionExpression(KEY_DEVICE_ID_KEY_ID) + .consistentRead(false) + .build()) + .items()) + .map(item -> DeleteItemRequest.builder() + .tableName(tableName) + .key(Map.of( + KEY_ACCOUNT_UUID, partitionKey, + KEY_DEVICE_ID_KEY_ID, item.get(KEY_DEVICE_ID_KEY_ID))) + .returnValues(ReturnValue.ALL_OLD) + .build()) + .flatMap(deleteItemRequest -> Mono.fromFuture(dynamoDbAsyncClient.deleteItem(deleteItemRequest)), 1) + .doOnNext(deleteItemResponse -> keysConsidered.incrementAndGet()) + .filter(DeleteItemResponse::hasAttributes) + .next() + .map(deleteItemResponse -> getPreKeyFromItem(deleteItemResponse.attributes())) + .toFuture() + .thenApply(Optional::ofNullable) + .whenComplete((maybeKey, throwable) -> { + sample.stop(Metrics.timer(takeKeyTimerName, KEY_PRESENT_TAG_NAME, String.valueOf(maybeKey != null && maybeKey.isPresent()))); + keysConsideredForTakeDistributionSummary.record(keysConsidered.get()); + }); + } + + /** + * Estimates the number of single-use pre-keys available for a given device. + + * @param identifier the identifier for the account/identity with which the target device is associated + * @param deviceId the identifier for the device within the given account/identity + + * @return a future that yields the approximate number of single-use pre-keys currently available for the target + * device + */ + public CompletableFuture getCount(final UUID identifier, final long deviceId) { + final Timer.Sample sample = Timer.start(); + + // Getting an accurate count from DynamoDB can be very confusing. See: + // + // - https://github.com/aws/aws-sdk-java/issues/693 + // - https://github.com/aws/aws-sdk-java/issues/915 + // - https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Query.html#Query.Count + return Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() + .tableName(tableName) + .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") + .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) + .expressionAttributeValues(Map.of( + ":uuid", getPartitionKey(identifier), + ":sortprefix", getSortKeyPrefix(deviceId))) + .select(Select.COUNT) + .consistentRead(false) + .build())) + .map(QueryResponse::count) + .reduce(0, Integer::sum) + .toFuture() + .whenComplete((keyCount, throwable) -> { + sample.stop(getKeyCountTimer); + + if (throwable == null && keyCount != null) { + availableKeyCountDistributionSummary.record(keyCount); + } + }); + } + + /** + * Removes all single-use pre-keys for all devices associated with the given account/identity. + * + * @param identifier the identifier for the account/identity for which to remove single-use pre-keys + * + * @return a future that completes when all single-use pre-keys have been removed for all devices associated with the + * given account/identity + */ + public CompletableFuture delete(final UUID identifier) { + final Timer.Sample sample = Timer.start(); + + return deleteItems(getPartitionKey(identifier), Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() + .tableName(tableName) + .keyConditionExpression("#uuid = :uuid") + .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID)) + .expressionAttributeValues(Map.of(":uuid", getPartitionKey(identifier))) + .projectionExpression(KEY_DEVICE_ID_KEY_ID) + .consistentRead(true) + .build()) + .items())) + .thenRun(() -> sample.stop(deleteForAccountTimer)); + } + + /** + * Removes all single-use pre-keys for a specific device. + * + * @param identifier the identifier for the account/identity with which the target device is associated + * @param deviceId the identifier for the device within the given account/identity + + * @return a future that completes when all single-use pre-keys have been removed for the target device + */ + public CompletableFuture delete(final UUID identifier, final long deviceId) { + final Timer.Sample sample = Timer.start(); + + return deleteItems(getPartitionKey(identifier), Flux.from(dynamoDbAsyncClient.queryPaginator(QueryRequest.builder() + .tableName(tableName) + .keyConditionExpression("#uuid = :uuid AND begins_with (#sort, :sortprefix)") + .expressionAttributeNames(Map.of("#uuid", KEY_ACCOUNT_UUID, "#sort", KEY_DEVICE_ID_KEY_ID)) + .expressionAttributeValues(Map.of( + ":uuid", getPartitionKey(identifier), + ":sortprefix", getSortKeyPrefix(deviceId))) + .projectionExpression(KEY_DEVICE_ID_KEY_ID) + .consistentRead(true) + .build()) + .items())) + .thenRun(() -> sample.stop(deleteForDeviceTimer)); + } + + private CompletableFuture deleteItems(final AttributeValue partitionKey, final Flux> items) { + return items + .map(item -> DeleteItemRequest.builder() + .tableName(tableName) + .key(Map.of( + KEY_ACCOUNT_UUID, partitionKey, + KEY_DEVICE_ID_KEY_ID, item.get(KEY_DEVICE_ID_KEY_ID) + )) + .build()) + .flatMap(deleteItemRequest -> Mono.fromFuture(dynamoDbAsyncClient.deleteItem(deleteItemRequest))) + // Idiom: wait for everything to finish, but discard the results + .reduce(0, (a, b) -> 0) + .toFuture() + .thenRun(Util.NOOP); + } + + protected static AttributeValue getPartitionKey(final UUID accountUuid) { + return AttributeValues.fromUUID(accountUuid); + } + + protected static AttributeValue getSortKey(final long deviceId, final long keyId) { + final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]); + byteBuffer.putLong(deviceId); + byteBuffer.putLong(keyId); + return AttributeValues.fromByteBuffer(byteBuffer.flip()); + } + + private static AttributeValue getSortKeyPrefix(final long deviceId) { + final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[8]); + byteBuffer.putLong(deviceId); + return AttributeValues.fromByteBuffer(byteBuffer.flip()); + } + + protected abstract Map getItemFromPreKey(final UUID identifier, final long deviceId, + final K preKey); + + protected abstract K getPreKeyFromItem(final Map item); + + /** + * Extracts a byte array from an {@link AttributeValue} that may be either a byte array or a base64-encoded string. + * + * @param attributeValue the {@code AttributeValue} from which to extract a byte array + * + * @return the byte array represented by the given {@code AttributeValue} + */ + @VisibleForTesting + byte[] extractByteArray(final AttributeValue attributeValue) { + if (attributeValue.b() != null) { + readBytesFromByteArrayCounter.increment(); + return attributeValue.b().asByteArray(); + } else if (StringUtils.isNotBlank(attributeValue.s())) { + parseBytesFromStringCounter.increment(); + return Base64.getDecoder().decode(attributeValue.s()); + } + + throw new IllegalArgumentException("Attribute value has neither a byte array nor a string value"); + } +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java index 490f99ef6..5495a1411 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/AssignUsernameCommand.java @@ -42,7 +42,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.DeletedAccounts; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; -import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.MessagesCache; import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesManager; @@ -171,10 +171,11 @@ protected void run(Environment environment, Namespace namespace, configuration.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName()); Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient, configuration.getDynamoDbTables().getProfiles().getTableName()); - Keys keys = new Keys(dynamoDbClient, + KeysManager keys = new KeysManager( + dynamoDbAsyncClient, configuration.getDynamoDbTables().getEcKeys().getTableName(), - configuration.getDynamoDbTables().getPqKeys().getTableName(), - configuration.getDynamoDbTables().getPqLastResortKeys().getTableName()); + configuration.getDynamoDbTables().getKemKeys().getTableName(), + configuration.getDynamoDbTables().getKemLastResortKeys().getTableName()); MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient, configuration.getDynamoDbTables().getMessages().getTableName(), configuration.getDynamoDbTables().getMessages().getExpiration(), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java index d03969da8..86bb6c4c3 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/CommandDependencies.java @@ -36,7 +36,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.DeletedAccounts; import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; -import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.MessagesCache; import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesManager; @@ -66,7 +66,7 @@ record CommandDependencies( MessagesManager messagesManager, StoredVerificationCodeManager pendingAccountsManager, ClientPresenceManager clientPresenceManager, - Keys keys, + KeysManager keysManager, FaultTolerantRedisCluster cacheCluster, ClientResources redisClusterClientResources) { @@ -153,10 +153,11 @@ static CommandDependencies build( configuration.getDynamoDbTables().getPhoneNumberIdentifiers().getTableName()); Profiles profiles = new Profiles(dynamoDbClient, dynamoDbAsyncClient, configuration.getDynamoDbTables().getProfiles().getTableName()); - Keys keys = new Keys(dynamoDbClient, + KeysManager keys = new KeysManager( + dynamoDbAsyncClient, configuration.getDynamoDbTables().getEcKeys().getTableName(), - configuration.getDynamoDbTables().getPqKeys().getTableName(), - configuration.getDynamoDbTables().getPqLastResortKeys().getTableName()); + configuration.getDynamoDbTables().getKemKeys().getTableName(), + configuration.getDynamoDbTables().getKemLastResortKeys().getTableName()); MessagesDynamoDb messagesDynamoDb = new MessagesDynamoDb(dynamoDbClient, dynamoDbAsyncClient, configuration.getDynamoDbTables().getMessages().getTableName(), configuration.getDynamoDbTables().getMessages().getExpiration(), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDeviceCommand.java b/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDeviceCommand.java index 2f61251b6..9c2451c44 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDeviceCommand.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/workers/UnlinkDeviceCommand.java @@ -65,7 +65,7 @@ protected void run(final Environment environment, final Namespace namespace, account = deps.accountsManager().update(account, a -> a.removeDevice(deviceId)); System.out.format("Removing keys for device %s::%d\n", aci, deviceId); - deps.keys().delete(account.getUuid(), deviceId); + deps.keysManager().delete(account.getUuid(), deviceId); System.out.format("Clearing additional messages for %s::%d\n", aci, deviceId); deps.messagesManager().clear(account.getUuid(), deviceId); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java index e7d2ad175..8f9e53707 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/RegistrationControllerTest.java @@ -67,7 +67,7 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.RegistrationRecoveryPasswordsManager; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; @@ -90,7 +90,7 @@ class RegistrationControllerTest { RegistrationLockVerificationManager.class); private final RegistrationRecoveryPasswordsManager registrationRecoveryPasswordsManager = mock( RegistrationRecoveryPasswordsManager.class); - private final Keys keys = mock(Keys.class); + private final KeysManager keysManager = mock(KeysManager.class); private final RateLimiters rateLimiters = mock(RateLimiters.class); private final RateLimiter registrationLimiter = mock(RateLimiter.class); @@ -105,7 +105,7 @@ class RegistrationControllerTest { .addResource( new RegistrationController(accountsManager, new PhoneVerificationTokenManager(registrationServiceClient, registrationRecoveryPasswordsManager), - registrationLockVerificationManager, keys, rateLimiters)) + registrationLockVerificationManager, keysManager, rateLimiters)) .build(); @BeforeEach @@ -669,8 +669,8 @@ void atomicAccountCreationSuccess(final RegistrationRequest registrationRequest, verify(device).setSignedPreKey(expectedAciSignedPreKey); verify(device).setPhoneNumberIdentitySignedPreKey(expectedPniSignedPreKey); - verify(keys).storePqLastResort(accountIdentifier, Map.of(Device.MASTER_ID, expectedAciPqLastResortPreKey)); - verify(keys).storePqLastResort(phoneNumberIdentifier, Map.of(Device.MASTER_ID, expectedPniPqLastResortPreKey)); + verify(keysManager).storePqLastResort(accountIdentifier, Map.of(Device.MASTER_ID, expectedAciPqLastResortPreKey)); + verify(keysManager).storePqLastResort(phoneNumberIdentifier, Map.of(Device.MASTER_ID, expectedPniPqLastResortPreKey)); expectedApnsToken.ifPresentOrElse(expectedToken -> verify(device).setApnId(expectedToken), () -> verify(device, never()).setApnId(any())); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java index 9747f819d..01c94b996 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerChangeNumberIntegrationTest.java @@ -101,7 +101,9 @@ void setup() throws InterruptedException { accounts, phoneNumberIdentifiers, CACHE_CLUSTER_EXTENSION.getRedisCluster(), - accountLockManager, deletedAccounts, mock(Keys.class), + accountLockManager, + deletedAccounts, + mock(KeysManager.class), mock(MessagesManager.class), mock(ProfilesManager.class), mock(StoredVerificationCodeManager.class), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java index eb8b5efa3..7a105ffd3 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerConcurrentModificationIntegrationTest.java @@ -112,7 +112,9 @@ void setup() throws InterruptedException { accounts, phoneNumberIdentifiers, RedisClusterHelper.builder().stringCommands(commands).build(), - accountLockManager, deletedAccounts, mock(Keys.class), + accountLockManager, + deletedAccounts, + mock(KeysManager.class), mock(MessagesManager.class), mock(ProfilesManager.class), mock(StoredVerificationCodeManager.class), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java index 99288b621..66665c3a2 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerTest.java @@ -71,7 +71,7 @@ class AccountsManagerTest { private Accounts accounts; private DeletedAccounts deletedAccounts; - private Keys keys; + private KeysManager keysManager; private MessagesManager messagesManager; private ProfilesManager profilesManager; private ClientPresenceManager clientPresenceManager; @@ -94,7 +94,7 @@ class AccountsManagerTest { void setup() throws InterruptedException { accounts = mock(Accounts.class); deletedAccounts = mock(DeletedAccounts.class); - keys = mock(Keys.class); + keysManager = mock(KeysManager.class); messagesManager = mock(MessagesManager.class); profilesManager = mock(ProfilesManager.class); clientPresenceManager = mock(ClientPresenceManager.class); @@ -157,7 +157,7 @@ void setup() throws InterruptedException { RedisClusterHelper.builder().stringCommands(commands).build(), accountLockManager, deletedAccounts, - keys, + keysManager, messagesManager, profilesManager, mock(StoredVerificationCodeManager.class), @@ -542,7 +542,7 @@ void testCreateFreshAccount() throws InterruptedException { accountsManager.create(e164, "password", null, attributes, new ArrayList<>()); verify(accounts).create(argThat(account -> e164.equals(account.getNumber()))); - verifyNoInteractions(keys); + verifyNoInteractions(keysManager); verifyNoInteractions(messagesManager); verifyNoInteractions(profilesManager); } @@ -565,8 +565,8 @@ void testReregisterAccount() throws InterruptedException { verify(accounts) .create(argThat(account -> e164.equals(account.getNumber()) && existingUuid.equals(account.getUuid()))); - verify(keys).delete(existingUuid); - verify(keys).delete(phoneNumberIdentifiersByE164.get(e164)); + verify(keysManager).delete(existingUuid); + verify(keysManager).delete(phoneNumberIdentifiersByE164.get(e164)); verify(messagesManager).clear(existingUuid); verify(profilesManager).deleteAll(existingUuid); verify(clientPresenceManager).disconnectAllPresencesForUuid(existingUuid); @@ -585,7 +585,7 @@ void testCreateAccountRecentlyDeleted() throws InterruptedException { verify(accounts).create( argThat(account -> e164.equals(account.getNumber()) && recentlyDeletedUuid.equals(account.getUuid()))); - verifyNoInteractions(keys); + verifyNoInteractions(keysManager); verifyNoInteractions(messagesManager); verifyNoInteractions(profilesManager); } @@ -646,8 +646,8 @@ void testChangePhoneNumber() throws InterruptedException, MismatchedDevicesExcep assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber)); - verify(keys).delete(originalPni); - verify(keys).delete(phoneNumberIdentifiersByE164.get(targetNumber)); + verify(keysManager).delete(originalPni); + verify(keysManager).delete(phoneNumberIdentifiersByE164.get(targetNumber)); } @Test @@ -659,7 +659,7 @@ void testChangePhoneNumberSameNumber() throws InterruptedException, MismatchedDe assertEquals(number, account.getNumber()); verify(deletedAccounts, never()).put(any(), any()); - verify(keys, never()).delete(any()); + verify(keysManager, never()).delete(any()); } @Test @@ -674,7 +674,7 @@ void testChangePhoneNumberSameNumberWithPniData() { verify(accounts, never()).update(any()); verifyNoInteractions(deletedAccounts); - verifyNoInteractions(keys); + verifyNoInteractions(keysManager); } @Test @@ -697,11 +697,11 @@ void testChangePhoneNumberExistingAccount() throws InterruptedException, Mismatc assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber)); final UUID newPni = phoneNumberIdentifiersByE164.get(targetNumber); - verify(keys).delete(existingAccountUuid); - verify(keys).delete(originalPni); - verify(keys, atLeastOnce()).delete(targetPni); - verify(keys).delete(newPni); - verifyNoMoreInteractions(keys); + verify(keysManager).delete(existingAccountUuid); + verify(keysManager).delete(originalPni); + verify(keysManager, atLeastOnce()).delete(targetPni); + verify(keysManager).delete(newPni); + verifyNoMoreInteractions(keysManager); } @Test @@ -723,7 +723,7 @@ void testChangePhoneNumberWithPqKeysExistingAccount() throws InterruptedExceptio final Account existingAccount = AccountsHelper.generateTestAccount(targetNumber, existingAccountUuid, targetPni, new ArrayList<>(), new byte[16]); when(accounts.getByE164(targetNumber)).thenReturn(Optional.of(existingAccount)); - when(keys.getPqEnabledDevices(uuid)).thenReturn(List.of(1L)); + when(keysManager.getPqEnabledDevices(uuid)).thenReturn(List.of(1L)); final List devices = List.of(DevicesHelper.createDevice(1L, 0L, 101), DevicesHelper.createDevice(2L, 0L, 102)); final Account account = AccountsHelper.generateTestAccount(originalNumber, uuid, originalPni, devices, new byte[16]); @@ -735,13 +735,13 @@ void testChangePhoneNumberWithPqKeysExistingAccount() throws InterruptedExceptio assertTrue(phoneNumberIdentifiersByE164.containsKey(targetNumber)); final UUID newPni = phoneNumberIdentifiersByE164.get(targetNumber); - verify(keys).delete(existingAccountUuid); - verify(keys, atLeastOnce()).delete(targetPni); - verify(keys).delete(newPni); - verify(keys).delete(originalPni); - verify(keys).getPqEnabledDevices(uuid); - verify(keys).storePqLastResort(eq(newPni), eq(Map.of(1L, newSignedPqKeys.get(1L)))); - verifyNoMoreInteractions(keys); + verify(keysManager).delete(existingAccountUuid); + verify(keysManager, atLeastOnce()).delete(targetPni); + verify(keysManager).delete(newPni); + verify(keysManager).delete(originalPni); + verify(keysManager).getPqEnabledDevices(uuid); + verify(keysManager).storePqLastResort(eq(newPni), eq(Map.of(1L, newSignedPqKeys.get(1L)))); + verifyNoMoreInteractions(keysManager); } @Test @@ -792,7 +792,7 @@ void testPniUpdate() throws MismatchedDevicesException { verify(accounts).update(any()); verifyNoInteractions(deletedAccounts); - verify(keys).delete(oldPni); + verify(keysManager).delete(oldPni); } @Test @@ -813,7 +813,7 @@ void testPniPqUpdate() throws MismatchedDevicesException { UUID oldUuid = account.getUuid(); UUID oldPni = account.getPhoneNumberIdentifier(); - when(keys.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L)); + when(keysManager.getPqEnabledDevices(oldPni)).thenReturn(List.of(1L)); Map oldSignedPreKeys = account.getDevices().stream().collect(Collectors.toMap(Device::getId, Device::getSignedPreKey)); @@ -839,10 +839,10 @@ void testPniPqUpdate() throws MismatchedDevicesException { verify(accounts).update(any()); verifyNoInteractions(deletedAccounts); - verify(keys).delete(oldPni); + verify(keysManager).delete(oldPni); // only the pq key for the already-pq-enabled device should be saved - verify(keys).storePqLastResort(eq(oldPni), eq(Map.of(1L, newSignedPqKeys.get(1L)))); + verify(keysManager).storePqLastResort(eq(oldPni), eq(Map.of(1L, newSignedPqKeys.get(1L)))); } @Test diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java index c6817b611..05e94ca4a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsManagerUsernameIntegrationTest.java @@ -116,7 +116,7 @@ private void buildAccountsManager(final int initialWidth, int discriminatorMaxWi CACHE_CLUSTER_EXTENSION.getRedisCluster(), accountLockManager, deletedAccounts, - mock(Keys.class), + mock(KeysManager.class), mock(MessagesManager.class), mock(ProfilesManager.class), mock(StoredVerificationCodeManager.class), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java index 101ff7f05..45b0d0094 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/AccountsTest.java @@ -156,7 +156,7 @@ public void testUsernameLinksViaAccountsManager() throws Exception { mock(FaultTolerantRedisCluster.class), mock(AccountLockManager.class), mock(DeletedAccounts.class), - mock(Keys.class), + mock(KeysManager.class), mock(MessagesManager.class), mock(ProfilesManager.class), mock(StoredVerificationCodeManager.class), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java index 7024fe994..3bde90801 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/DynamoDbExtensionSchema.java @@ -88,44 +88,44 @@ public enum Tables implements DynamoDbExtension.TableSchema { List.of(), List.of()), EC_KEYS("keys_test", - Keys.KEY_ACCOUNT_UUID, - Keys.KEY_DEVICE_ID_KEY_ID, + SingleUsePreKeyStore.KEY_ACCOUNT_UUID, + SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID, List.of( AttributeDefinition.builder() - .attributeName(Keys.KEY_ACCOUNT_UUID) + .attributeName(SingleUsePreKeyStore.KEY_ACCOUNT_UUID) .attributeType(ScalarAttributeType.B) .build(), AttributeDefinition.builder() - .attributeName(Keys.KEY_DEVICE_ID_KEY_ID) + .attributeName(SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID) .attributeType(ScalarAttributeType.B) .build()), List.of(), List.of()), PQ_KEYS("pq_keys_test", - Keys.KEY_ACCOUNT_UUID, - Keys.KEY_DEVICE_ID_KEY_ID, + SingleUsePreKeyStore.KEY_ACCOUNT_UUID, + SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID, List.of( AttributeDefinition.builder() - .attributeName(Keys.KEY_ACCOUNT_UUID) + .attributeName(SingleUsePreKeyStore.KEY_ACCOUNT_UUID) .attributeType(ScalarAttributeType.B) .build(), AttributeDefinition.builder() - .attributeName(Keys.KEY_DEVICE_ID_KEY_ID) + .attributeName(SingleUsePreKeyStore.KEY_DEVICE_ID_KEY_ID) .attributeType(ScalarAttributeType.B) .build()), List.of(), List.of()), - PQ_LAST_RESORT_KEYS("pq_last_resort_keys_test", - Keys.KEY_ACCOUNT_UUID, - Keys.KEY_DEVICE_ID_KEY_ID, + REPEATED_USE_SIGNED_PRE_KEYS("repeated_use_signed_pre_keys_test", + RepeatedUseSignedPreKeyStore.KEY_ACCOUNT_UUID, + RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID, List.of( AttributeDefinition.builder() - .attributeName(Keys.KEY_ACCOUNT_UUID) + .attributeName(RepeatedUseSignedPreKeyStore.KEY_ACCOUNT_UUID) .attributeType(ScalarAttributeType.B) .build(), AttributeDefinition.builder() - .attributeName(Keys.KEY_DEVICE_ID_KEY_ID) - .attributeType(ScalarAttributeType.B) + .attributeName(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID) + .attributeType(ScalarAttributeType.N) .build()), List.of(), List.of()), diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java new file mode 100644 index 000000000..83b19a07f --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysManagerTest.java @@ -0,0 +1,257 @@ +/* + * Copyright 2021-2022 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertIterableEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.security.SecureRandom; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.entities.PreKey; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; +import org.whispersystems.textsecuregcm.tests.util.KeysHelper; + +class KeysManagerTest { + + private KeysManager keysManager; + + @RegisterExtension + static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension( + Tables.EC_KEYS, Tables.PQ_KEYS, Tables.REPEATED_USE_SIGNED_PRE_KEYS); + + private static final UUID ACCOUNT_UUID = UUID.randomUUID(); + private static final long DEVICE_ID = 1L; + + @BeforeEach + void setup() { + keysManager = new KeysManager( + DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), + Tables.EC_KEYS.tableName(), + Tables.PQ_KEYS.tableName(), + Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName()); + } + + @Test + void testStore() { + assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), + "Initial pre-key count for an account should be zero"); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), + "Initial pre-key count for an account should be zero"); + assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent(), + "Initial last-resort pre-key for an account should be missing"); + + keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1))); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + + keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1))); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), + "Repeatedly storing same key should have no effect"); + + keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestSignedPreKey(1)), null); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), + "Uploading new PQ prekeys should have no effect on EC prekeys"); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + + keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, null, generateTestSignedPreKey(1001)); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), + "Uploading new PQ last-resort prekey should have no effect on EC prekeys"); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), + "Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys"); + assertEquals(1001, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId()); + + keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), + "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), + "Uploading new EC prekeys should have no effect on PQ prekeys"); + + keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestSignedPreKey(2)), null); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), + "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), + "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); + + keysManager.store(ACCOUNT_UUID, DEVICE_ID, + List.of(generateTestPreKey(4), generateTestPreKey(5)), + List.of(generateTestSignedPreKey(6), generateTestSignedPreKey(7)), + generateTestSignedPreKey(1002)); + assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID), + "Inserting multiple new keys should overwrite all prior keys for the given account/device"); + assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID), + "Inserting multiple new keys should overwrite all prior keys for the given account/device"); + assertEquals(1002, keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId(), + "Uploading new last-resort key should overwrite prior last-resort key for the account/device"); + } + + @Test + void testTakeAccountAndDeviceId() { + assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID)); + + final PreKey preKey = generateTestPreKey(1); + + keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2))); + final Optional takenKey = keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID); + assertEquals(Optional.of(preKey), takenKey); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + } + + @Test + void testTakePQ() { + assertEquals(Optional.empty(), keysManager.takeEC(ACCOUNT_UUID, DEVICE_ID)); + + final SignedPreKey preKey1 = generateTestSignedPreKey(1); + final SignedPreKey preKey2 = generateTestSignedPreKey(2); + final SignedPreKey preKeyLast = generateTestSignedPreKey(1001); + + keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast); + + assertEquals(Optional.of(preKey1), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + + assertEquals(Optional.of(preKey2), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + + assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + + assertEquals(Optional.of(preKeyLast), keysManager.takePQ(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + } + + @Test + void testGetCount() { + assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + + keysManager.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestSignedPreKey(1)), null); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + } + + @Test + void testDeleteByAccount() { + keysManager.store(ACCOUNT_UUID, DEVICE_ID, + List.of(generateTestPreKey(1), generateTestPreKey(2)), + List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)), + generateTestSignedPreKey(5)); + + keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, + List.of(generateTestPreKey(6)), + List.of(generateTestSignedPreKey(7)), + generateTestSignedPreKey(8)); + + assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); + + keysManager.delete(ACCOUNT_UUID); + + assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); + assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); + } + + @Test + void testDeleteByAccountAndDevice() { + keysManager.store(ACCOUNT_UUID, DEVICE_ID, + List.of(generateTestPreKey(1), generateTestPreKey(2)), + List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)), + generateTestSignedPreKey(5)); + + keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, + List.of(generateTestPreKey(6)), + List.of(generateTestSignedPreKey(7)), + generateTestSignedPreKey(8)); + + assertEquals(2, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(2, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); + + keysManager.delete(ACCOUNT_UUID, DEVICE_ID); + + assertEquals(0, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID)); + assertEquals(0, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID)); + assertFalse(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); + assertEquals(1, keysManager.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertEquals(1, keysManager.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); + assertTrue(keysManager.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); + } + + @Test + void testStorePqLastResort() { + assertEquals(0, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size()); + + final ECKeyPair identityKeyPair = Curve.generateKeyPair(); + + keysManager.storePqLastResort( + ACCOUNT_UUID, + Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair))); + assertEquals(2, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size()); + assertEquals(1L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId()); + assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId()); + assertFalse(keysManager.getLastResort(ACCOUNT_UUID, 3L).isPresent()); + + keysManager.storePqLastResort( + ACCOUNT_UUID, + Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair))); + assertEquals(3, keysManager.getPqEnabledDevices(ACCOUNT_UUID).size(), "storing new last-resort keys should not create duplicates"); + assertEquals(3L, keysManager.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId(), "storing new last-resort keys should overwrite old ones"); + assertEquals(2L, keysManager.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId(), "storing new last-resort keys should leave untouched ones alone"); + assertEquals(4L, keysManager.getLastResort(ACCOUNT_UUID, 3L).get().getKeyId(), "storing new last-resort keys should overwrite old ones"); + } + + @Test + void testGetPqEnabledDevices() { + final ECKeyPair identityKeyPair = Curve.generateKeyPair(); + + keysManager.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null); + keysManager.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair)); + keysManager.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), KeysHelper.signedKEMPreKey(4, identityKeyPair)); + keysManager.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null); + assertIterableEquals( + Set.of(DEVICE_ID + 1, DEVICE_ID + 2), + Set.copyOf(keysManager.getPqEnabledDevices(ACCOUNT_UUID))); + } + + private static PreKey generateTestPreKey(final long keyId) { + final byte[] key = new byte[32]; + new SecureRandom().nextBytes(key); + + return new PreKey(keyId, key); + } + + private static SignedPreKey generateTestSignedPreKey(final long keyId) { + final byte[] key = new byte[32]; + final byte[] signature = new byte[32]; + + final SecureRandom secureRandom = new SecureRandom(); + secureRandom.nextBytes(key); + secureRandom.nextBytes(signature); + + return new SignedPreKey(keyId, key, signature); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java deleted file mode 100644 index 79282732b..000000000 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/KeysTest.java +++ /dev/null @@ -1,314 +0,0 @@ -/* - * Copyright 2021-2022 Signal Messenger, LLC - * SPDX-License-Identifier: AGPL-3.0-only - */ - -package org.whispersystems.textsecuregcm.storage; - -import java.security.SecureRandom; -import java.util.Base64; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.UUID; -import java.util.stream.Stream; - -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.signal.libsignal.protocol.ecc.Curve; -import org.signal.libsignal.protocol.ecc.ECKeyPair; -import org.whispersystems.textsecuregcm.entities.PreKey; -import org.whispersystems.textsecuregcm.entities.SignedPreKey; -import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; -import org.whispersystems.textsecuregcm.tests.util.KeysHelper; -import org.whispersystems.textsecuregcm.util.AttributeValues; -import software.amazon.awssdk.core.SdkBytes; -import software.amazon.awssdk.services.dynamodb.model.AttributeValue; -import software.amazon.awssdk.services.dynamodb.model.QueryRequest; -import software.amazon.awssdk.services.dynamodb.model.QueryResponse; -import software.amazon.awssdk.services.dynamodb.model.Select; - -import static org.junit.jupiter.api.Assertions.*; - -class KeysTest { - - private Keys keys; - - @RegisterExtension - static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension( - Tables.EC_KEYS, Tables.PQ_KEYS, Tables.PQ_LAST_RESORT_KEYS); - - private static final UUID ACCOUNT_UUID = UUID.randomUUID(); - private static final long DEVICE_ID = 1L; - - @BeforeEach - void setup() { - keys = new Keys( - DYNAMO_DB_EXTENSION.getDynamoDbClient(), - Tables.EC_KEYS.tableName(), - Tables.PQ_KEYS.tableName(), - Tables.PQ_LAST_RESORT_KEYS.tableName()); - } - - @Test - void testStore() { - assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), - "Initial pre-key count for an account should be zero"); - assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID), - "Initial pre-key count for an account should be zero"); - assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent(), - "Initial last-resort pre-key for an account should be missing"); - - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1))); - assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1))); - assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), - "Repeatedly storing same key should have no effect"); - - keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(generateTestSignedPreKey(1)), null); - assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), - "Uploading new PQ prekeys should have no effect on EC prekeys"); - assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - - keys.store(ACCOUNT_UUID, DEVICE_ID, null, null, generateTestSignedPreKey(1001)); - assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), - "Uploading new PQ last-resort prekey should have no effect on EC prekeys"); - assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID), - "Uploading new PQ last-resort prekey should have no effect on one-time PQ prekeys"); - assertEquals(1001, keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId()); - - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(2)), null, null); - assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), - "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); - assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID), - "Uploading new EC prekeys should have no effect on PQ prekeys"); - - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(3)), List.of(generateTestSignedPreKey(2)), null); - assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), - "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); - assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID), - "Inserting a new key should overwrite all prior keys of the same type for the given account/device"); - - keys.store(ACCOUNT_UUID, DEVICE_ID, - List.of(generateTestPreKey(4), generateTestPreKey(5)), - List.of(generateTestSignedPreKey(6), generateTestSignedPreKey(7)), - generateTestSignedPreKey(1002)); - assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID), - "Inserting multiple new keys should overwrite all prior keys for the given account/device"); - assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID), - "Inserting multiple new keys should overwrite all prior keys for the given account/device"); - assertEquals(1002, keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).get().getKeyId(), - "Uploading new last-resort key should overwrite prior last-resort key for the account/device"); - } - - @Test - void testTakeAccountAndDeviceId() { - assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID)); - - final PreKey preKey = generateTestPreKey(1); - - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(preKey, generateTestPreKey(2))); - final Optional takenKey = keys.takeEC(ACCOUNT_UUID, DEVICE_ID); - assertEquals(Optional.of(preKey), takenKey); - assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - } - - @Test - void testTakePQ() { - assertEquals(Optional.empty(), keys.takeEC(ACCOUNT_UUID, DEVICE_ID)); - - final SignedPreKey preKey1 = generateTestSignedPreKey(1); - final SignedPreKey preKey2 = generateTestSignedPreKey(2); - final SignedPreKey preKeyLast = generateTestSignedPreKey(1001); - - keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(preKey1, preKey2), preKeyLast); - - assertEquals(Optional.of(preKey1), keys.takePQ(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - - assertEquals(Optional.of(preKey2), keys.takePQ(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - - assertEquals(Optional.of(preKeyLast), keys.takePQ(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - - assertEquals(Optional.of(preKeyLast), keys.takePQ(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - } - - @Test - void testGetCount() { - assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - - keys.store(ACCOUNT_UUID, DEVICE_ID, List.of(generateTestPreKey(1)), List.of(generateTestSignedPreKey(1)), null); - assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - } - - @Test - void testDeleteByAccount() { - keys.store(ACCOUNT_UUID, DEVICE_ID, - List.of(generateTestPreKey(1), generateTestPreKey(2)), - List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)), - generateTestSignedPreKey(5)); - - keys.store(ACCOUNT_UUID, DEVICE_ID + 1, - List.of(generateTestPreKey(6)), - List.of(generateTestSignedPreKey(7)), - generateTestSignedPreKey(8)); - - assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); - assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); - assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); - assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); - - keys.delete(ACCOUNT_UUID); - - assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); - assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); - assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); - assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); - } - - @Test - void testDeleteByAccountAndDevice() { - keys.store(ACCOUNT_UUID, DEVICE_ID, - List.of(generateTestPreKey(1), generateTestPreKey(2)), - List.of(generateTestSignedPreKey(3), generateTestSignedPreKey(4)), - generateTestSignedPreKey(5)); - - keys.store(ACCOUNT_UUID, DEVICE_ID + 1, - List.of(generateTestPreKey(6)), - List.of(generateTestSignedPreKey(7)), - generateTestSignedPreKey(8)); - - assertEquals(2, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(2, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); - assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); - assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); - assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); - - keys.delete(ACCOUNT_UUID, DEVICE_ID); - - assertEquals(0, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID)); - assertEquals(0, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID)); - assertFalse(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID).isPresent()); - assertEquals(1, keys.getEcCount(ACCOUNT_UUID, DEVICE_ID + 1)); - assertEquals(1, keys.getPqCount(ACCOUNT_UUID, DEVICE_ID + 1)); - assertTrue(keys.getLastResort(ACCOUNT_UUID, DEVICE_ID + 1).isPresent()); - } - - @Test - void testStorePqLastResort() { - assertEquals(0, getLastResortCount(ACCOUNT_UUID)); - - final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - - keys.storePqLastResort( - ACCOUNT_UUID, - Map.of(1L, KeysHelper.signedKEMPreKey(1, identityKeyPair), 2L, KeysHelper.signedKEMPreKey(2, identityKeyPair))); - assertEquals(2, getLastResortCount(ACCOUNT_UUID)); - assertEquals(1L, keys.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId()); - assertEquals(2L, keys.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId()); - assertFalse(keys.getLastResort(ACCOUNT_UUID, 3L).isPresent()); - - keys.storePqLastResort( - ACCOUNT_UUID, - Map.of(1L, KeysHelper.signedKEMPreKey(3, identityKeyPair), 3L, KeysHelper.signedKEMPreKey(4, identityKeyPair))); - assertEquals(3, getLastResortCount(ACCOUNT_UUID), "storing new last-resort keys should not create duplicates"); - assertEquals(3L, keys.getLastResort(ACCOUNT_UUID, 1L).get().getKeyId(), "storing new last-resort keys should overwrite old ones"); - assertEquals(2L, keys.getLastResort(ACCOUNT_UUID, 2L).get().getKeyId(), "storing new last-resort keys should leave untouched ones alone"); - assertEquals(4L, keys.getLastResort(ACCOUNT_UUID, 3L).get().getKeyId(), "storing new last-resort keys should overwrite old ones"); - } - - private int getLastResortCount(UUID uuid) { - QueryRequest queryRequest = QueryRequest.builder() - .tableName(Tables.PQ_LAST_RESORT_KEYS.tableName()) - .keyConditionExpression("#uuid = :uuid") - .expressionAttributeNames(Map.of("#uuid", Keys.KEY_ACCOUNT_UUID)) - .expressionAttributeValues(Map.of(":uuid", AttributeValues.fromUUID(uuid))) - .select(Select.COUNT) - .build(); - QueryResponse response = DYNAMO_DB_EXTENSION.getDynamoDbClient().query(queryRequest); - return response.count(); - } - - @Test - void testGetPqEnabledDevices() { - final ECKeyPair identityKeyPair = Curve.generateKeyPair(); - - keys.store(ACCOUNT_UUID, DEVICE_ID, null, List.of(KeysHelper.signedKEMPreKey(1, identityKeyPair)), null); - keys.store(ACCOUNT_UUID, DEVICE_ID + 1, null, null, KeysHelper.signedKEMPreKey(2, identityKeyPair)); - keys.store(ACCOUNT_UUID, DEVICE_ID + 2, null, List.of(KeysHelper.signedKEMPreKey(3, identityKeyPair)), KeysHelper.signedKEMPreKey(4, identityKeyPair)); - keys.store(ACCOUNT_UUID, DEVICE_ID + 3, null, null, null); - assertIterableEquals( - Set.of(DEVICE_ID + 1, DEVICE_ID + 2), - Set.copyOf(keys.getPqEnabledDevices(ACCOUNT_UUID))); - } - - @Test - void testSortKeyPrefix() { - AttributeValue got = Keys.getSortKeyPrefix(123); - assertArrayEquals(new byte[]{0, 0, 0, 0, 0, 0, 0, 123}, got.b().asByteArray()); - } - - @ParameterizedTest - @MethodSource - void extractByteArray(final AttributeValue attributeValue, final byte[] expectedByteArray) { - assertArrayEquals(expectedByteArray, Keys.extractByteArray(attributeValue)); - } - - private static Stream extractByteArray() { - final byte[] key = Base64.getDecoder().decode("c+k+8zv8WaFdDjR9IOvCk6BcY5OI7rge/YUDkaDGyRc="); - - return Stream.of( - Arguments.of(AttributeValue.fromB(SdkBytes.fromByteArray(key)), key), - Arguments.of(AttributeValue.fromS(Base64.getEncoder().encodeToString(key)), key), - Arguments.of(AttributeValue.fromS(Base64.getEncoder().withoutPadding().encodeToString(key)), key) - ); - } - - @ParameterizedTest - @MethodSource - void extractByteArrayIllegalArgument(final AttributeValue attributeValue) { - assertThrows(IllegalArgumentException.class, () -> Keys.extractByteArray(attributeValue)); - } - - private static Stream extractByteArrayIllegalArgument() { - return Stream.of( - Arguments.of(AttributeValue.fromN("12")), - Arguments.of(AttributeValue.fromS("")), - Arguments.of(AttributeValue.fromS("Definitely not legitimate base64 👎")) - ); - } - - private static PreKey generateTestPreKey(final long keyId) { - final byte[] key = new byte[32]; - new SecureRandom().nextBytes(key); - - return new PreKey(keyId, key); - } - - private static SignedPreKey generateTestSignedPreKey(final long keyId) { - final byte[] key = new byte[32]; - final byte[] signature = new byte[32]; - - final SecureRandom secureRandom = new SecureRandom(); - secureRandom.nextBytes(key); - secureRandom.nextBytes(signature); - - return new SignedPreKey(keyId, key, signature); - } -} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java new file mode 100644 index 000000000..40a4757d8 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/RepeatedUseSignedPreKeyStoreTest.java @@ -0,0 +1,149 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.reactivestreams.Subscriber; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.tests.util.KeysHelper; +import org.whispersystems.textsecuregcm.util.AttributeValues; +import reactor.core.publisher.Flux; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.dynamodb.model.DeleteItemRequest; +import software.amazon.awssdk.services.dynamodb.model.DeleteItemResponse; +import software.amazon.awssdk.services.dynamodb.model.QueryRequest; +import software.amazon.awssdk.services.dynamodb.paginators.QueryPublisher; + +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +class RepeatedUseSignedPreKeyStoreTest { + + private RepeatedUseSignedPreKeyStore keys; + + private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair(); + + @RegisterExtension + static final DynamoDbExtension DYNAMO_DB_EXTENSION = + new DynamoDbExtension(DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS); + + @BeforeEach + void setUp() { + keys = new RepeatedUseSignedPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), + DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName()); + } + + @Test + void storeFind() { + assertEquals(Optional.empty(), keys.find(UUID.randomUUID(), 1).join()); + + { + final UUID identifier = UUID.randomUUID(); + final long deviceId = 1; + final SignedPreKey signedPreKey = generateSignedPreKey(); + + assertDoesNotThrow(() -> keys.store(identifier, deviceId, signedPreKey).join()); + assertEquals(Optional.of(signedPreKey), keys.find(identifier, deviceId).join()); + } + + { + final UUID identifier = UUID.randomUUID(); + final Map signedPreKeys = Map.of( + 1L, generateSignedPreKey(), + 2L, generateSignedPreKey() + ); + + assertDoesNotThrow(() -> keys.store(identifier, signedPreKeys).join()); + assertEquals(Optional.of(signedPreKeys.get(1L)), keys.find(identifier, 1).join()); + assertEquals(Optional.of(signedPreKeys.get(2L)), keys.find(identifier, 2).join()); + } + } + + @Test + void delete() { + assertDoesNotThrow(() -> keys.delete(UUID.randomUUID()).join()); + + { + final UUID identifier = UUID.randomUUID(); + final Map signedPreKeys = Map.of( + 1L, generateSignedPreKey(), + 2L, generateSignedPreKey() + ); + + keys.store(identifier, signedPreKeys).join(); + keys.delete(identifier, 1).join(); + + assertEquals(Optional.empty(), keys.find(identifier, 1).join()); + assertEquals(Optional.of(signedPreKeys.get(2L)), keys.find(identifier, 2).join()); + } + + { + final UUID identifier = UUID.randomUUID(); + final Map signedPreKeys = Map.of( + 1L, generateSignedPreKey(), + 2L, generateSignedPreKey() + ); + + keys.store(identifier, signedPreKeys).join(); + keys.delete(identifier).join(); + + assertEquals(Optional.empty(), keys.find(identifier, 1).join()); + assertEquals(Optional.empty(), keys.find(identifier, 2).join()); + } + } + + @Test + void deleteWithError() { + final DynamoDbAsyncClient mockClient = mock(DynamoDbAsyncClient.class); + final QueryPublisher queryPublisher = mock(QueryPublisher.class); + + final SdkPublisher> itemPublisher = new SdkPublisher>() { + final Flux> items = Flux.just( + Map.of(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID, AttributeValues.fromLong(1)), + Map.of(RepeatedUseSignedPreKeyStore.KEY_DEVICE_ID, AttributeValues.fromLong(2))); + + @Override + public void subscribe(final Subscriber> subscriber) { + items.subscribe(subscriber); + } + }; + + when(queryPublisher.items()).thenReturn(itemPublisher); + when(mockClient.queryPaginator(any(QueryRequest.class))).thenReturn(queryPublisher); + + final Exception deleteItemException = new IllegalArgumentException("OH NO"); + + when(mockClient.deleteItem(any(DeleteItemRequest.class))) + .thenReturn(CompletableFuture.completedFuture(DeleteItemResponse.builder().build())) + .thenReturn(CompletableFuture.failedFuture(deleteItemException)); + + final RepeatedUseSignedPreKeyStore keyStore = new RepeatedUseSignedPreKeyStore(mockClient, + DynamoDbExtensionSchema.Tables.REPEATED_USE_SIGNED_PRE_KEYS.tableName()); + + final CompletionException completionException = + assertThrows(CompletionException.class, () -> keyStore.delete(UUID.randomUUID()).join()); + + assertEquals(deleteItemException, completionException.getCause()); + } + + private static SignedPreKey generateSignedPreKey() { + return KeysHelper.signedECPreKey(1, IDENTITY_KEY_PAIR); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java new file mode 100644 index 000000000..ca6b14654 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStoreTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.signal.libsignal.protocol.ecc.Curve; +import org.whispersystems.textsecuregcm.entities.PreKey; + +class SingleUseECPreKeyStoreTest extends SingleUsePreKeyStoreTest { + + private SingleUseECPreKeyStore preKeyStore; + + @RegisterExtension + static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(DynamoDbExtensionSchema.Tables.EC_KEYS); + + @BeforeEach + void setUp() { + preKeyStore = new SingleUseECPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), + DynamoDbExtensionSchema.Tables.EC_KEYS.tableName()); + } + + @Override + protected SingleUsePreKeyStore getPreKeyStore() { + return preKeyStore; + } + + @Override + protected PreKey generatePreKey(final long keyId) { + return new PreKey(keyId, Curve.generateKeyPair().getPublicKey().serialize()); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStoreTest.java new file mode 100644 index 000000000..b0df5cd31 --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUseKEMPreKeyStoreTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.signal.libsignal.protocol.ecc.Curve; +import org.signal.libsignal.protocol.ecc.ECKeyPair; +import org.whispersystems.textsecuregcm.entities.SignedPreKey; +import org.whispersystems.textsecuregcm.tests.util.KeysHelper; + +class SingleUseKEMPreKeyStoreTest extends SingleUsePreKeyStoreTest { + + private SingleUseKEMPreKeyStore preKeyStore; + + private static final ECKeyPair IDENTITY_KEY_PAIR = Curve.generateKeyPair(); + + @RegisterExtension + static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(DynamoDbExtensionSchema.Tables.PQ_KEYS); + + @BeforeEach + void setUp() { + preKeyStore = new SingleUseKEMPreKeyStore(DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), + DynamoDbExtensionSchema.Tables.PQ_KEYS.tableName()); + } + + @Override + protected SingleUsePreKeyStore getPreKeyStore() { + return preKeyStore; + } + + @Override + protected SignedPreKey generatePreKey(final long keyId) { + return KeysHelper.signedKEMPreKey(keyId, IDENTITY_KEY_PAIR); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java new file mode 100644 index 000000000..dd8563ceb --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java @@ -0,0 +1,155 @@ +/* + * Copyright 2023 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.stream.Stream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.whispersystems.textsecuregcm.entities.PreKey; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; + +abstract class SingleUsePreKeyStoreTest { + + private static final int KEY_COUNT = 100; + + protected abstract SingleUsePreKeyStore getPreKeyStore(); + + protected abstract K generatePreKey(final long keyId); + + @Test + void storeTake() { + final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); + + final UUID accountIdentifier = UUID.randomUUID(); + final long deviceId = 1; + + assertEquals(Optional.empty(), preKeyStore.take(accountIdentifier, deviceId).join()); + + final List preKeys = new ArrayList<>(KEY_COUNT); + + for (int i = 0; i < KEY_COUNT; i++) { + preKeys.add(generatePreKey(i)); + } + + assertDoesNotThrow(() -> preKeyStore.store(accountIdentifier, deviceId, preKeys).join()); + + assertEquals(Optional.of(preKeys.get(0)), preKeyStore.take(accountIdentifier, deviceId).join()); + assertEquals(Optional.of(preKeys.get(1)), preKeyStore.take(accountIdentifier, deviceId).join()); + } + + @Test + void getCount() { + final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); + + final UUID accountIdentifier = UUID.randomUUID(); + final long deviceId = 1; + + assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); + + final List preKeys = new ArrayList<>(KEY_COUNT); + + for (int i = 0; i < KEY_COUNT; i++) { + preKeys.add(generatePreKey(i)); + } + + preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); + + assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId).join()); + } + + @Test + void deleteSingleDevice() { + final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); + + final UUID accountIdentifier = UUID.randomUUID(); + final long deviceId = 1; + + assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); + assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier, deviceId).join()); + + final List preKeys = new ArrayList<>(KEY_COUNT); + + for (int i = 0; i < KEY_COUNT; i++) { + preKeys.add(generatePreKey(i)); + } + + preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); + preKeyStore.store(accountIdentifier, deviceId + 1, preKeys).join(); + + assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier, deviceId).join()); + + assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); + assertEquals(KEY_COUNT, preKeyStore.getCount(accountIdentifier, deviceId + 1).join()); + } + + @Test + void deleteAllDevices() { + final SingleUsePreKeyStore preKeyStore = getPreKeyStore(); + + final UUID accountIdentifier = UUID.randomUUID(); + final long deviceId = 1; + + assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); + assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier).join()); + + final List preKeys = new ArrayList<>(KEY_COUNT); + + for (int i = 0; i < KEY_COUNT; i++) { + preKeys.add(generatePreKey(i)); + } + + preKeyStore.store(accountIdentifier, deviceId, preKeys).join(); + preKeyStore.store(accountIdentifier, deviceId + 1, preKeys).join(); + + assertDoesNotThrow(() -> preKeyStore.delete(accountIdentifier).join()); + + assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); + assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId + 1).join()); + } + + @ParameterizedTest + @MethodSource + void extractByteArray(final AttributeValue attributeValue, final byte[] expectedByteArray) { + assertArrayEquals(expectedByteArray, getPreKeyStore().extractByteArray(attributeValue)); + } + + private static Stream extractByteArray() { + final byte[] key = Base64.getDecoder().decode("c+k+8zv8WaFdDjR9IOvCk6BcY5OI7rge/YUDkaDGyRc="); + + return Stream.of( + Arguments.of(AttributeValue.fromB(SdkBytes.fromByteArray(key)), key), + Arguments.of(AttributeValue.fromS(Base64.getEncoder().encodeToString(key)), key), + Arguments.of(AttributeValue.fromS(Base64.getEncoder().withoutPadding().encodeToString(key)), key) + ); + } + + @ParameterizedTest + @MethodSource + void extractByteArrayIllegalArgument(final AttributeValue attributeValue) { + assertThrows(IllegalArgumentException.class, () -> getPreKeyStore().extractByteArray(attributeValue)); + } + + private static Stream extractByteArrayIllegalArgument() { + return Stream.of( + Arguments.of(AttributeValue.fromN("12")), + Arguments.of(AttributeValue.fromS("")), + Arguments.of(AttributeValue.fromS("Definitely not legitimate base64 👎")) + ); + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java index f74fa10ff..bdf04aec0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/DeviceControllerTest.java @@ -23,7 +23,6 @@ import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; import io.dropwizard.testing.junit5.ResourceExtension; import java.nio.charset.StandardCharsets; -import java.util.Base64; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -65,7 +64,7 @@ import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities; -import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.StoredVerificationCodeManager; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; @@ -82,7 +81,7 @@ static class DumbVerificationDeviceController extends DeviceController { public DumbVerificationDeviceController(StoredVerificationCodeManager pendingDevices, AccountsManager accounts, MessagesManager messages, - Keys keys, + KeysManager keys, RateLimiters rateLimiters, Map deviceConfiguration) { super(pendingDevices, accounts, messages, keys, rateLimiters, deviceConfiguration); @@ -97,7 +96,7 @@ protected VerificationCode generateVerificationCode() { private static StoredVerificationCodeManager pendingDevicesManager = mock(StoredVerificationCodeManager.class); private static AccountsManager accountsManager = mock(AccountsManager.class); private static MessagesManager messagesManager = mock(MessagesManager.class); - private static Keys keys = mock(Keys.class); + private static KeysManager keysManager = mock(KeysManager.class); private static RateLimiters rateLimiters = mock(RateLimiters.class); private static RateLimiter rateLimiter = mock(RateLimiter.class); private static Account account = mock(Account.class); @@ -117,7 +116,7 @@ protected VerificationCode generateVerificationCode() { .addResource(new DumbVerificationDeviceController(pendingDevicesManager, accountsManager, messagesManager, - keys, + keysManager, rateLimiters, deviceConfiguration)) .build(); @@ -161,7 +160,7 @@ void teardown() { pendingDevicesManager, accountsManager, messagesManager, - keys, + keysManager, rateLimiters, rateLimiter, account, @@ -314,8 +313,8 @@ void linkDeviceAtomic(final boolean fetchesMessages, verify(pendingDevicesManager).remove(AuthHelper.VALID_NUMBER); verify(messagesManager).clear(eq(AuthHelper.VALID_UUID), eq(42L)); verify(clientPresenceManager).disconnectPresence(AuthHelper.VALID_UUID, Device.MASTER_ID); - verify(keys).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get())); - verify(keys).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get())); + verify(keysManager).storePqLastResort(AuthHelper.VALID_UUID, Map.of(response.getDeviceId(), aciPqLastResortPreKey.get())); + verify(keysManager).storePqLastResort(AuthHelper.VALID_PNI, Map.of(response.getDeviceId(), pniPqLastResortPreKey.get())); } private static Stream linkDeviceAtomic() { @@ -822,7 +821,7 @@ void deviceRemovalClearsMessagesAndKeys() { verify(messagesManager, times(2)).clear(AuthHelper.VALID_UUID, deviceId); verify(accountsManager, times(1)).update(eq(AuthHelper.VALID_ACCOUNT), any()); verify(AuthHelper.VALID_ACCOUNT).removeDevice(deviceId); - verify(keys).delete(AuthHelper.VALID_UUID, deviceId); + verify(keysManager).delete(AuthHelper.VALID_UUID, deviceId); } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java index 2126aecc7..0b5f12a3f 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/tests/controllers/KeysControllerTest.java @@ -58,7 +58,7 @@ import org.whispersystems.textsecuregcm.storage.Account; import org.whispersystems.textsecuregcm.storage.AccountsManager; import org.whispersystems.textsecuregcm.storage.Device; -import org.whispersystems.textsecuregcm.storage.Keys; +import org.whispersystems.textsecuregcm.storage.KeysManager; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; @@ -107,7 +107,7 @@ class KeysControllerTest { private final SignedPreKey VALID_DEVICE_SIGNED_KEY = KeysHelper.signedECPreKey(89898, IDENTITY_KEY_PAIR); private final SignedPreKey VALID_DEVICE_PNI_SIGNED_KEY = KeysHelper.signedECPreKey(7777, PNI_IDENTITY_KEY_PAIR); - private final static Keys KEYS = mock(Keys.class ); + private final static KeysManager KEYS = mock(KeysManager.class ); private final static AccountsManager accounts = mock(AccountsManager.class ); private final static Account existsAccount = mock(Account.class );