From a71dc48b9bbb6905afbb5b54557ab0fe7633badd Mon Sep 17 00:00:00 2001 From: Katherine Yen Date: Thu, 10 Aug 2023 14:00:35 -0700 Subject: [PATCH] Prepare to read profile data stored as byte arrays --- .../textsecuregcm/storage/Profiles.java | 39 +-- .../storage/SingleUseECPreKeyStore.java | 5 +- .../storage/SingleUsePreKeyStore.java | 23 -- .../textsecuregcm/util/AttributeValues.java | 24 ++ .../textsecuregcm/storage/ProfilesTest.java | 291 +++++++++++------- .../storage/SingleUsePreKeyStoreTest.java | 30 -- .../util/AttributeValuesTest.java | 36 +++ 7 files changed, 252 insertions(+), 196 deletions(-) diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Profiles.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Profiles.java index 6ee68bfca..5c65745da 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/Profiles.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/Profiles.java @@ -8,7 +8,6 @@ import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; import com.google.common.annotations.VisibleForTesting; -import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Timer; import java.util.ArrayList; @@ -77,12 +76,7 @@ public class Profiles { private static final Timer SET_PROFILES_TIMER = Metrics.timer(name(Profiles.class, "set")); private static final Timer GET_PROFILE_TIMER = Metrics.timer(name(Profiles.class, "get")); private static final Timer DELETE_PROFILES_TIMER = Metrics.timer(name(Profiles.class, "delete")); - - private static final Counter INVALID_NAME_COUNTER = Metrics.counter(name(Profiles.class, "invalidProfileData"), "field", "name"); - private static final Counter INVALID_EMOJI_COUNTER = Metrics.counter(name(Profiles.class, "invalidProfileData"), "field", "emoji"); - private static final Counter INVALID_ABOUT_COUNTER = Metrics.counter(name(Profiles.class, "invalidProfileData"), "field", "about"); - private static final Counter INVALID_PAYMENT_ADDRESS_COUNTER = Metrics.counter(name(Profiles.class, "invalidProfileData"), "field", "paymentAddress"); - + private static final String PARSE_BYTE_ARRAY_COUNTER_NAME = name(Profiles.class, "parseByteArray"); public Profiles(final DynamoDbClient dynamoDbClient, final DynamoDbAsyncClient dynamoDbAsyncClient, @@ -232,34 +226,23 @@ public CompletableFuture> getAsync(final UUID uuid, f } private static VersionedProfile fromItem(final Map item) { - final String name = AttributeValues.getString(item, ATTR_NAME, null); - final String emoji = AttributeValues.getString(item, ATTR_EMOJI, null); - final String about = AttributeValues.getString(item, ATTR_ABOUT, null); - final String paymentAddress = AttributeValues.getString(item, ATTR_PAYMENT_ADDRESS, null); - - checkValidBase64(name, INVALID_NAME_COUNTER); - checkValidBase64(emoji, INVALID_EMOJI_COUNTER); - checkValidBase64(about, INVALID_ABOUT_COUNTER); - checkValidBase64(paymentAddress, INVALID_PAYMENT_ADDRESS_COUNTER); - return new VersionedProfile( AttributeValues.getString(item, ATTR_VERSION, null), - name, + getBase64EncodedBytes(item, ATTR_NAME, PARSE_BYTE_ARRAY_COUNTER_NAME), AttributeValues.getString(item, ATTR_AVATAR, null), - emoji, - about, - paymentAddress, + getBase64EncodedBytes(item, ATTR_EMOJI, PARSE_BYTE_ARRAY_COUNTER_NAME), + getBase64EncodedBytes(item, ATTR_ABOUT, PARSE_BYTE_ARRAY_COUNTER_NAME), + getBase64EncodedBytes(item, ATTR_PAYMENT_ADDRESS, PARSE_BYTE_ARRAY_COUNTER_NAME), AttributeValues.getByteArray(item, ATTR_COMMITMENT, null)); } - private static void checkValidBase64(final String value, final Counter counter) { - if (StringUtils.isNotBlank(value)) { - try { - Base64.getDecoder().decode(value); - } catch (final IllegalArgumentException e) { - counter.increment(); - } + private static String getBase64EncodedBytes(final Map item, final String attributeName, final String counterName) { + final AttributeValue attributeValue = item.get(attributeName); + + if (attributeValue == null) { + return null; } + return Base64.getEncoder().encodeToString(AttributeValues.extractByteArray(attributeValue, counterName)); } public void deleteAll(final UUID uuid) { diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java index 9296aabb9..025f10b81 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUseECPreKeyStore.java @@ -14,7 +14,10 @@ import java.util.Map; import java.util.UUID; +import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; + public class SingleUseECPreKeyStore extends SingleUsePreKeyStore { + private static final String PARSE_BYTE_ARRAY_COUNTER_NAME = name(SingleUseECPreKeyStore.class, "parseByteArray"); protected SingleUseECPreKeyStore(final DynamoDbAsyncClient dynamoDbAsyncClient, final String tableName) { super(dynamoDbAsyncClient, tableName); @@ -31,7 +34,7 @@ KEY_DEVICE_ID_KEY_ID, getSortKey(deviceId, preKey.keyId()), @Override protected ECPreKey getPreKeyFromItem(final Map item) { final long keyId = item.get(KEY_DEVICE_ID_KEY_ID).b().asByteBuffer().getLong(8); - final byte[] publicKey = extractByteArray(item.get(ATTR_PUBLIC_KEY)); + final byte[] publicKey = AttributeValues.extractByteArray(item.get(ATTR_PUBLIC_KEY), PARSE_BYTE_ARRAY_COUNTER_NAME); try { return new ECPreKey(keyId, new ECPublicKey(publicKey)); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java index 2b87e5b41..f9d815434 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStore.java @@ -73,9 +73,6 @@ public abstract class SingleUsePreKeyStore> { private final String takeKeyTimerName = name(getClass(), "takeKey"); private static final String KEY_PRESENT_TAG_NAME = "keyPresent"; - private final Counter parseBytesFromStringCounter = Metrics.counter(name(getClass(), "parseByteArray"), "format", "string"); - private final Counter readBytesFromByteArrayCounter = Metrics.counter(name(getClass(), "parseByteArray"), "format", "bytes"); - static final String KEY_ACCOUNT_UUID = "U"; static final String KEY_DEVICE_ID_KEY_ID = "DK"; static final String ATTR_PUBLIC_KEY = "P"; @@ -289,24 +286,4 @@ protected abstract Map getItemFromPreKey(final UUID iden final K preKey); protected abstract K getPreKeyFromItem(final Map item); - - /** - * Extracts a byte array from an {@link AttributeValue} that may be either a byte array or a base64-encoded string. - * - * @param attributeValue the {@code AttributeValue} from which to extract a byte array - * - * @return the byte array represented by the given {@code AttributeValue} - */ - @VisibleForTesting - byte[] extractByteArray(final AttributeValue attributeValue) { - if (attributeValue.b() != null) { - readBytesFromByteArrayCounter.increment(); - return attributeValue.b().asByteArray(); - } else if (StringUtils.isNotBlank(attributeValue.s())) { - parseBytesFromStringCounter.increment(); - return Base64.getDecoder().decode(attributeValue.s()); - } - - throw new IllegalArgumentException("Attribute value has neither a byte array nor a string value"); - } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/util/AttributeValues.java b/service/src/main/java/org/whispersystems/textsecuregcm/util/AttributeValues.java index 494703e00..b0d37b7c1 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/util/AttributeValues.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/util/AttributeValues.java @@ -6,9 +6,13 @@ package org.whispersystems.textsecuregcm.util; import java.nio.ByteBuffer; +import java.util.Base64; import java.util.Map; import java.util.Optional; import java.util.UUID; +import com.google.common.annotations.VisibleForTesting; +import io.micrometer.core.instrument.Metrics; +import org.apache.commons.lang3.StringUtils; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; @@ -124,4 +128,24 @@ public static byte[] getByteArray(Map item, String key, public static UUID getUUID(Map item, String key, UUID defaultValue) { return AttributeValues.get(item, key).filter(av -> av.b() != null).map(AttributeValues::toUUID).orElse(defaultValue); } + + /** + * Extracts a byte array from an {@link AttributeValue} that may be either a byte array or a base64-encoded string. + * + * @param attributeValue the {@code AttributeValue} from which to extract a byte array + * + * @return the byte array represented by the given {@code AttributeValue} + */ + @VisibleForTesting + public static byte[] extractByteArray(final AttributeValue attributeValue, final String counterName) { + if (attributeValue.b() != null) { + Metrics.counter(counterName, "format", "bytes").increment(); + return attributeValue.b().asByteArray(); + } else if (StringUtils.isNotBlank(attributeValue.s())) { + Metrics.counter(counterName, "format", "string").increment(); + return Base64.getDecoder().decode(attributeValue.s()); + } + + throw new IllegalArgumentException("Attribute value has neither a byte array nor a string value"); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ProfilesTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ProfilesTest.java index af37b1151..69f96fd54 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/ProfilesTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/ProfilesTest.java @@ -12,6 +12,9 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.signal.libsignal.protocol.ServiceId; +import org.signal.libsignal.zkgroup.InvalidInputException; +import org.signal.libsignal.zkgroup.profiles.ProfileKey; import org.whispersystems.textsecuregcm.util.AttributeValues; import org.whispersystems.textsecuregcm.storage.DynamoDbExtensionSchema.Tables; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; @@ -19,78 +22,85 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; -import java.nio.charset.StandardCharsets; +import java.util.Base64; import java.util.Map; import java.util.Optional; +import java.util.Random; import java.util.UUID; import java.util.stream.Stream; @Timeout(value = 10, threadMode = Timeout.ThreadMode.SEPARATE_THREAD) public class ProfilesTest { - + private static final UUID ACI = UUID.randomUUID(); @RegisterExtension static final DynamoDbExtension DYNAMO_DB_EXTENSION = new DynamoDbExtension(Tables.PROFILES); private Profiles profiles; + private VersionedProfile validProfile; @BeforeEach - void setUp() { + void setUp() throws InvalidInputException { profiles = new Profiles(DYNAMO_DB_EXTENSION.getDynamoDbClient(), DYNAMO_DB_EXTENSION.getDynamoDbAsyncClient(), Tables.PROFILES.tableName()); + final byte[] commitment = new ProfileKey(new byte[32]).getCommitment(new ServiceId.Aci(ACI)).serialize(); + final String version = "someVersion"; + final String name = generateRandomBase64FromByteArray(81); + final String validAboutEmoji = generateRandomBase64FromByteArray(60); + final String validAbout = generateRandomBase64FromByteArray(156); + final String avatar = "profiles/" + generateRandomBase64FromByteArray(16); + + validProfile = new VersionedProfile(version, name, avatar, validAboutEmoji, validAbout, null, commitment); } @Test void testSetGet() { - UUID uuid = UUID.randomUUID(); - VersionedProfile profile = new VersionedProfile("123", "foo", "avatarLocation", "emoji", - "the very model of a modern major general", - null, "acommitment".getBytes()); - profiles.set(uuid, profile); + profiles.set(ACI, validProfile); - Optional retrieved = profiles.get(uuid, "123"); + Optional retrieved = profiles.get(ACI, validProfile.getVersion()); assertThat(retrieved.isPresent()).isTrue(); - assertThat(retrieved.get().getName()).isEqualTo(profile.getName()); - assertThat(retrieved.get().getAvatar()).isEqualTo(profile.getAvatar()); - assertThat(retrieved.get().getCommitment()).isEqualTo(profile.getCommitment()); - assertThat(retrieved.get().getAbout()).isEqualTo(profile.getAbout()); - assertThat(retrieved.get().getAboutEmoji()).isEqualTo(profile.getAboutEmoji()); + assertThat(retrieved.get().getName()).isEqualTo(validProfile.getName()); + assertThat(retrieved.get().getAvatar()).isEqualTo(validProfile.getAvatar()); + assertThat(retrieved.get().getCommitment()).isEqualTo(validProfile.getCommitment()); + assertThat(retrieved.get().getAbout()).isEqualTo(validProfile.getAbout()); + assertThat(retrieved.get().getAboutEmoji()).isEqualTo(validProfile.getAboutEmoji()); } @Test void testSetGetAsync() { - UUID uuid = UUID.randomUUID(); - VersionedProfile profile = new VersionedProfile("123", "foo", "avatarLocation", "emoji", - "the very model of a modern major general", - null, "acommitment".getBytes()); - profiles.setAsync(uuid, profile).join(); + profiles.setAsync(ACI, validProfile).join(); - Optional retrieved = profiles.getAsync(uuid, "123").join(); + Optional retrieved = profiles.getAsync(ACI, validProfile.getVersion()).join(); assertThat(retrieved.isPresent()).isTrue(); - assertThat(retrieved.get().getName()).isEqualTo(profile.getName()); - assertThat(retrieved.get().getAvatar()).isEqualTo(profile.getAvatar()); - assertThat(retrieved.get().getCommitment()).isEqualTo(profile.getCommitment()); - assertThat(retrieved.get().getAbout()).isEqualTo(profile.getAbout()); - assertThat(retrieved.get().getAboutEmoji()).isEqualTo(profile.getAboutEmoji()); + assertThat(retrieved.get().getName()).isEqualTo(validProfile.getName()); + assertThat(retrieved.get().getAvatar()).isEqualTo(validProfile.getAvatar()); + assertThat(retrieved.get().getCommitment()).isEqualTo(validProfile.getCommitment()); + assertThat(retrieved.get().getAbout()).isEqualTo(validProfile.getAbout()); + assertThat(retrieved.get().getAboutEmoji()).isEqualTo(validProfile.getAboutEmoji()); } @Test - void testDeleteReset() { - UUID uuid = UUID.randomUUID(); - profiles.set(uuid, new VersionedProfile("123", "foo", "avatarLocation", "emoji", - "the very model of a modern major general", - null, "acommitment".getBytes())); + void testDeleteReset() throws InvalidInputException { + profiles.set(ACI, validProfile); + + profiles.deleteAll(ACI); - profiles.deleteAll(uuid); + final String version = "someVersion"; + final String name = generateRandomBase64FromByteArray(81); + final String differentAvatar = "profiles/" + generateRandomBase64FromByteArray(16); + final String differentEmoji = generateRandomBase64FromByteArray(60); + final String differentAbout = generateRandomBase64FromByteArray(156); + final String paymentAddress = generateRandomBase64FromByteArray(582); + final byte[] commitment = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize(); - VersionedProfile updatedProfile = new VersionedProfile("123", "name", "differentAvatarLocation", - "differentEmoji", "changed text", "paymentAddress", "differentcommitment".getBytes(StandardCharsets.UTF_8)); + VersionedProfile updatedProfile = new VersionedProfile(version, name, differentAvatar, + differentEmoji, differentAbout, paymentAddress, commitment); - profiles.set(uuid, updatedProfile); + profiles.set(ACI, updatedProfile); - Optional retrieved = profiles.get(uuid, "123"); + Optional retrieved = profiles.get(ACI, version); assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.get().getName()).isEqualTo(updatedProfile.getName()); @@ -101,13 +111,16 @@ void testDeleteReset() { } @Test - void testSetGetNullOptionalFields() { - UUID uuid = UUID.randomUUID(); - VersionedProfile profile = new VersionedProfile("123", "foo", null, null, null, null, - "acommitment".getBytes()); - profiles.set(uuid, profile); + void testSetGetNullOptionalFields() throws InvalidInputException { + final String version = "someVersion"; + final String name = generateRandomBase64FromByteArray(81); + final byte[] commitment = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize(); - Optional retrieved = profiles.get(uuid, "123"); + VersionedProfile profile = new VersionedProfile(version, name, null, null, null, null, + commitment); + profiles.set(ACI, profile); + + Optional retrieved = profiles.get(ACI, version); assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.get().getName()).isEqualTo(profile.getName()); @@ -118,26 +131,30 @@ void testSetGetNullOptionalFields() { } @Test - void testSetReplace() { - UUID uuid = UUID.randomUUID(); - VersionedProfile profile = new VersionedProfile("123", "foo", "avatarLocation", null, null, - "paymentAddress", "acommitment".getBytes()); - profiles.set(uuid, profile); + void testSetReplace() throws InvalidInputException { + profiles.set(ACI, validProfile); - Optional retrieved = profiles.get(uuid, "123"); + Optional retrieved = profiles.get(ACI, validProfile.getVersion()); assertThat(retrieved.isPresent()).isTrue(); - assertThat(retrieved.get().getName()).isEqualTo(profile.getName()); - assertThat(retrieved.get().getAvatar()).isEqualTo(profile.getAvatar()); - assertThat(retrieved.get().getCommitment()).isEqualTo(profile.getCommitment()); - assertThat(retrieved.get().getAbout()).isNull(); - assertThat(retrieved.get().getAboutEmoji()).isNull(); + assertThat(retrieved.get().getName()).isEqualTo(validProfile.getName()); + assertThat(retrieved.get().getAvatar()).isEqualTo(validProfile.getAvatar()); + assertThat(retrieved.get().getCommitment()).isEqualTo(validProfile.getCommitment()); + assertThat(retrieved.get().getAbout()).isEqualTo(validProfile.getAbout()); + assertThat(retrieved.get().getAboutEmoji()).isEqualTo(validProfile.getAboutEmoji()); + assertThat(retrieved.get().getPaymentAddress()).isNull(); - VersionedProfile updated = new VersionedProfile("123", "bar", "baz", "emoji", "bio", null, - "boof".getBytes()); - profiles.set(uuid, updated); + final String differentName = generateRandomBase64FromByteArray(81); + final String differentEmoji = generateRandomBase64FromByteArray(60); + final String differentAbout = generateRandomBase64FromByteArray(156); + final String differentAvatar = "profiles/" + generateRandomBase64FromByteArray(16); + final byte[] differentCommitment = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize(); - retrieved = profiles.get(uuid, "123"); + VersionedProfile updated = new VersionedProfile(validProfile.getVersion(), differentName, differentAvatar, differentEmoji, differentAbout, null, + differentCommitment); + profiles.set(ACI, updated); + + retrieved = profiles.get(ACI, updated.getVersion()); assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.get().getName()).isEqualTo(updated.getName()); @@ -146,22 +163,34 @@ void testSetReplace() { assertThat(retrieved.get().getAvatar()).isEqualTo(updated.getAvatar()); // Commitment should be unchanged after an overwrite - assertThat(retrieved.get().getCommitment()).isEqualTo(profile.getCommitment()); + assertThat(retrieved.get().getCommitment()).isEqualTo(validProfile.getCommitment()); } @Test - void testMultipleVersions() { - UUID uuid = UUID.randomUUID(); - VersionedProfile profileOne = new VersionedProfile("123", "foo", "avatarLocation", null, null, - null, "acommitmnet".getBytes()); - VersionedProfile profileTwo = new VersionedProfile("345", "bar", "baz", "emoji", - "i keep typing emoju for some reason", - null, "boof".getBytes()); + void testMultipleVersions() throws InvalidInputException { + final String versionOne = "versionOne"; + final String versionTwo = "versionTwo"; + + final String nameOne = generateRandomBase64FromByteArray(81); + final String nameTwo = generateRandomBase64FromByteArray(81); + + final String avatarOne = "profiles/" + generateRandomBase64FromByteArray(16); + final String avatarTwo = "profiles/" + generateRandomBase64FromByteArray(16); - profiles.set(uuid, profileOne); - profiles.set(uuid, profileTwo); + final String aboutEmoji = generateRandomBase64FromByteArray(60); + final String about = generateRandomBase64FromByteArray(156); - Optional retrieved = profiles.get(uuid, "123"); + final byte[] commitmentOne = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize(); + final byte[] commitmentTwo = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize(); + + VersionedProfile profileOne = new VersionedProfile(versionOne, nameOne, avatarOne, null, null, + null, commitmentOne); + VersionedProfile profileTwo = new VersionedProfile(versionTwo, nameTwo, avatarTwo, aboutEmoji, about, null, commitmentTwo); + + profiles.set(ACI, profileOne); + profiles.set(ACI, profileTwo); + + Optional retrieved = profiles.get(ACI, versionOne); assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.get().getName()).isEqualTo(profileOne.getName()); @@ -170,7 +199,7 @@ void testMultipleVersions() { assertThat(retrieved.get().getAbout()).isEqualTo(profileOne.getAbout()); assertThat(retrieved.get().getAboutEmoji()).isEqualTo(profileOne.getAboutEmoji()); - retrieved = profiles.get(uuid, "345"); + retrieved = profiles.get(ACI, versionTwo); assertThat(retrieved.isPresent()).isTrue(); assertThat(retrieved.get().getName()).isEqualTo(profileTwo.getName()); @@ -182,33 +211,45 @@ void testMultipleVersions() { @Test void testMissing() { - UUID uuid = UUID.randomUUID(); - VersionedProfile profile = new VersionedProfile("123", "foo", "avatarLocation", null, null, - null, "aDigest".getBytes()); - profiles.set(uuid, profile); + profiles.set(ACI, validProfile); + final String missingVersion = "missingVersion"; - Optional retrieved = profiles.get(uuid, "888"); + Optional retrieved = profiles.get(ACI, missingVersion); assertThat(retrieved.isPresent()).isFalse(); } @Test - void testDelete() { - UUID uuid = UUID.randomUUID(); - VersionedProfile profileOne = new VersionedProfile("123", "foo", "avatarLocation", null, null, - null, "aDigest".getBytes()); - VersionedProfile profileTwo = new VersionedProfile("345", "bar", "baz", null, null, null, "boof".getBytes()); + void testDelete() throws InvalidInputException { + final String versionOne = "versionOne"; + final String versionTwo = "versionTwo"; + + final String nameOne = generateRandomBase64FromByteArray(81); + final String nameTwo = generateRandomBase64FromByteArray(81); - profiles.set(uuid, profileOne); - profiles.set(uuid, profileTwo); + final String aboutEmoji = generateRandomBase64FromByteArray(60); + final String about = generateRandomBase64FromByteArray(156); - profiles.deleteAll(uuid); + final String avatarOne = "profiles/" + generateRandomBase64FromByteArray(16); + final String avatarTwo = "profiles/" + generateRandomBase64FromByteArray(16); - Optional retrieved = profiles.get(uuid, "123"); + final byte[] commitmentOne = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize(); + final byte[] commitmentTwo = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize(); + + VersionedProfile profileOne = new VersionedProfile(versionOne, nameOne, avatarOne, null, null, + null, commitmentOne); + VersionedProfile profileTwo = new VersionedProfile(versionTwo, nameTwo, avatarTwo, aboutEmoji, about, null, commitmentTwo); + + profiles.set(ACI, profileOne); + profiles.set(ACI, profileTwo); + + profiles.deleteAll(ACI); + + Optional retrieved = profiles.get(ACI, versionOne); assertThat(retrieved.isPresent()).isFalse(); - retrieved = profiles.get(uuid, "345"); + retrieved = profiles.get(ACI, versionTwo); assertThat(retrieved.isPresent()).isFalse(); } @@ -219,32 +260,38 @@ void buildUpdateExpression(final VersionedProfile profile, final String expected assertEquals(expectedUpdateExpression, Profiles.buildUpdateExpression(profile)); } - private static Stream buildUpdateExpression() { - final byte[] commitment = "commitment".getBytes(StandardCharsets.UTF_8); + private static Stream buildUpdateExpression() throws InvalidInputException { + final String version = "someVersion"; + final String name = generateRandomBase64FromByteArray(81); + final String avatar = "profiles/" + generateRandomBase64FromByteArray(16);; + final String emoji = generateRandomBase64FromByteArray(60); + final String about = generateRandomBase64FromByteArray(156); + final String paymentAddress = generateRandomBase64FromByteArray(582); + final byte[] commitment = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize(); return Stream.of( Arguments.of( - new VersionedProfile("version", "name", "avatar", "emoji", "about", "paymentAddress", commitment), + new VersionedProfile(version, name, avatar, emoji, about, paymentAddress, commitment), "SET #commitment = if_not_exists(#commitment, :commitment), #name = :name, #avatar = :avatar, #about = :about, #aboutEmoji = :aboutEmoji, #paymentAddress = :paymentAddress"), Arguments.of( - new VersionedProfile("version", "name", "avatar", "emoji", "about", null, commitment), + new VersionedProfile(version, name, avatar, emoji, about, null, commitment), "SET #commitment = if_not_exists(#commitment, :commitment), #name = :name, #avatar = :avatar, #about = :about, #aboutEmoji = :aboutEmoji REMOVE #paymentAddress"), Arguments.of( - new VersionedProfile("version", "name", "avatar", "emoji", null, null, commitment), + new VersionedProfile(version, name, avatar, emoji, null, null, commitment), "SET #commitment = if_not_exists(#commitment, :commitment), #name = :name, #avatar = :avatar, #aboutEmoji = :aboutEmoji REMOVE #about, #paymentAddress"), Arguments.of( - new VersionedProfile("version", "name", "avatar", null, null, null, commitment), + new VersionedProfile(version, name, avatar, null, null, null, commitment), "SET #commitment = if_not_exists(#commitment, :commitment), #name = :name, #avatar = :avatar REMOVE #about, #aboutEmoji, #paymentAddress"), Arguments.of( - new VersionedProfile("version", "name", null, null, null, null, commitment), + new VersionedProfile(version, name, null, null, null, null, commitment), "SET #commitment = if_not_exists(#commitment, :commitment), #name = :name REMOVE #avatar, #about, #aboutEmoji, #paymentAddress"), Arguments.of( - new VersionedProfile("version", null, null, null, null, null, commitment), + new VersionedProfile(version, null, null, null, null, null, commitment), "SET #commitment = if_not_exists(#commitment, :commitment) REMOVE #name, #avatar, #about, #aboutEmoji, #paymentAddress") ); } @@ -255,53 +302,69 @@ void buildUpdateExpressionAttributeValues(final VersionedProfile profile, final assertEquals(expectedAttributeValues, Profiles.buildUpdateExpressionAttributeValues(profile)); } - private static Stream buildUpdateExpressionAttributeValues() { - final byte[] commitment = "commitment".getBytes(StandardCharsets.UTF_8); + private static Stream buildUpdateExpressionAttributeValues() throws InvalidInputException { + final String version = "someVersion"; + final String name = generateRandomBase64FromByteArray(81); + final String avatar = "profiles/" + generateRandomBase64FromByteArray(16);; + final String emoji = generateRandomBase64FromByteArray(60); + final String about = generateRandomBase64FromByteArray(156); + final String paymentAddress = generateRandomBase64FromByteArray(582); + final byte[] commitment = new ProfileKey(generateRandomByteArray(32)).getCommitment(new ServiceId.Aci(ACI)).serialize(); return Stream.of( Arguments.of( - new VersionedProfile("version", "name", "avatar", "emoji", "about", "paymentAddress", commitment), + new VersionedProfile(version, name, avatar, emoji, about, paymentAddress, commitment), Map.of( ":commitment", AttributeValues.fromByteArray(commitment), - ":name", AttributeValues.fromString("name"), - ":avatar", AttributeValues.fromString("avatar"), - ":aboutEmoji", AttributeValues.fromString("emoji"), - ":about", AttributeValues.fromString("about"), - ":paymentAddress", AttributeValues.fromString("paymentAddress"))), + ":name", AttributeValues.fromString(name), + ":avatar", AttributeValues.fromString(avatar), + ":aboutEmoji", AttributeValues.fromString(emoji), + ":about", AttributeValues.fromString(about), + ":paymentAddress", AttributeValues.fromString(paymentAddress))), Arguments.of( - new VersionedProfile("version", "name", "avatar", "emoji", "about", null, commitment), + new VersionedProfile(version, name, avatar, emoji, about, null, commitment), Map.of( ":commitment", AttributeValues.fromByteArray(commitment), - ":name", AttributeValues.fromString("name"), - ":avatar", AttributeValues.fromString("avatar"), - ":aboutEmoji", AttributeValues.fromString("emoji"), - ":about", AttributeValues.fromString("about"))), + ":name", AttributeValues.fromString(name), + ":avatar", AttributeValues.fromString(avatar), + ":aboutEmoji", AttributeValues.fromString(emoji), + ":about", AttributeValues.fromString(about))), Arguments.of( - new VersionedProfile("version", "name", "avatar", "emoji", null, null, commitment), + new VersionedProfile(version, name, avatar, emoji, null, null, commitment), Map.of( ":commitment", AttributeValues.fromByteArray(commitment), - ":name", AttributeValues.fromString("name"), - ":avatar", AttributeValues.fromString("avatar"), - ":aboutEmoji", AttributeValues.fromString("emoji"))), + ":name", AttributeValues.fromString(name), + ":avatar", AttributeValues.fromString(avatar), + ":aboutEmoji", AttributeValues.fromString(emoji))), Arguments.of( - new VersionedProfile("version", "name", "avatar", null, null, null, commitment), + new VersionedProfile(version, name, avatar, null, null, null, commitment), Map.of( ":commitment", AttributeValues.fromByteArray(commitment), - ":name", AttributeValues.fromString("name"), - ":avatar", AttributeValues.fromString("avatar"))), + ":name", AttributeValues.fromString(name), + ":avatar", AttributeValues.fromString(avatar))), Arguments.of( - new VersionedProfile("version", "name", null, null, null, null, commitment), + new VersionedProfile(version, name, null, null, null, null, commitment), Map.of( ":commitment", AttributeValues.fromByteArray(commitment), - ":name", AttributeValues.fromString("name"))), + ":name", AttributeValues.fromString(name))), Arguments.of( - new VersionedProfile("version", null, null, null, null, null, commitment), + new VersionedProfile(version, null, null, null, null, null, commitment), Map.of(":commitment", AttributeValues.fromByteArray(commitment))) ); } + + private static String generateRandomBase64FromByteArray(final int byteArrayLength) { + return Base64.getEncoder().encodeToString(generateRandomByteArray(byteArrayLength)); + } + + private static byte[] generateRandomByteArray(final int length) { + byte[] byteArray = new byte[length]; + new Random().nextBytes(byteArray); + return byteArray; + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java index 894be6367..e284a6eb1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/SingleUsePreKeyStoreTest.java @@ -122,34 +122,4 @@ void deleteAllDevices() { assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId).join()); assertEquals(0, preKeyStore.getCount(accountIdentifier, deviceId + 1).join()); } - - @ParameterizedTest - @MethodSource - void extractByteArray(final AttributeValue attributeValue, final byte[] expectedByteArray) { - assertArrayEquals(expectedByteArray, getPreKeyStore().extractByteArray(attributeValue)); - } - - private static Stream extractByteArray() { - final byte[] key = Base64.getDecoder().decode("c+k+8zv8WaFdDjR9IOvCk6BcY5OI7rge/YUDkaDGyRc="); - - return Stream.of( - Arguments.of(AttributeValue.fromB(SdkBytes.fromByteArray(key)), key), - Arguments.of(AttributeValue.fromS(Base64.getEncoder().encodeToString(key)), key), - Arguments.of(AttributeValue.fromS(Base64.getEncoder().withoutPadding().encodeToString(key)), key) - ); - } - - @ParameterizedTest - @MethodSource - void extractByteArrayIllegalArgument(final AttributeValue attributeValue) { - assertThrows(IllegalArgumentException.class, () -> getPreKeyStore().extractByteArray(attributeValue)); - } - - private static Stream extractByteArrayIllegalArgument() { - return Stream.of( - Arguments.of(AttributeValue.fromN("12")), - Arguments.of(AttributeValue.fromS("")), - Arguments.of(AttributeValue.fromS("Definitely not legitimate base64 👎")) - ); - } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/util/AttributeValuesTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/util/AttributeValuesTest.java index b9b8115b2..94be5c743 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/util/AttributeValuesTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/util/AttributeValuesTest.java @@ -6,10 +6,16 @@ package org.whispersystems.textsecuregcm.util; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.services.dynamodb.model.AttributeValue; import java.nio.ByteBuffer; +import java.util.Base64; import java.util.Map; import java.util.UUID; +import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.*; @@ -63,4 +69,34 @@ void testNullUuid() { final Map item = Map.of("key", AttributeValue.builder().nul(true).build()); assertNull(AttributeValues.getUUID(item, "key", null)); } + + @ParameterizedTest + @MethodSource + void extractByteArray(final AttributeValue attributeValue, final byte[] expectedByteArray) { + assertArrayEquals(expectedByteArray, AttributeValues.extractByteArray(attributeValue, "counter")); + } + + private static Stream extractByteArray() { + final byte[] key = Base64.getDecoder().decode("c+k+8zv8WaFdDjR9IOvCk6BcY5OI7rge/YUDkaDGyRc="); + + return Stream.of( + Arguments.of(AttributeValue.fromB(SdkBytes.fromByteArray(key)), key), + Arguments.of(AttributeValue.fromS(Base64.getEncoder().encodeToString(key)), key), + Arguments.of(AttributeValue.fromS(Base64.getEncoder().withoutPadding().encodeToString(key)), key) + ); + } + + @ParameterizedTest + @MethodSource + void extractByteArrayIllegalArgument(final AttributeValue attributeValue) { + assertThrows(IllegalArgumentException.class, () -> AttributeValues.extractByteArray(attributeValue, "counter")); + } + + private static Stream extractByteArrayIllegalArgument() { + return Stream.of( + Arguments.of(AttributeValue.fromN("12")), + Arguments.of(AttributeValue.fromS("")), + Arguments.of(AttributeValue.fromS("Definitely not legitimate base64 👎")) + ); + } }