From 2634bdd4d9f5512514779830412c7a1076ff8bc1 Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Mon, 16 Dec 2024 16:20:35 +0100 Subject: [PATCH 1/5] SNOW-1858529 Implement header handling in FLOE --- .../cloud/storage/EncryptionProvider.java | 3 +- .../cloud/storage/GcmEncryptionProvider.java | 2 +- .../client/jdbc/cloud/storage/floe/Aead.java | 16 +++++ .../client/jdbc/cloud/storage/floe/Floe.java | 23 +++++++ .../jdbc/cloud/storage/floe/FloeAad.java | 15 +++++ .../jdbc/cloud/storage/floe/FloeBase.java | 18 +++++ .../cloud/storage/floe/FloeDecryptor.java | 3 + .../cloud/storage/floe/FloeDecryptorImpl.java | 40 +++++++++++ .../cloud/storage/floe/FloeEncryptor.java | 5 ++ .../cloud/storage/floe/FloeEncryptorImpl.java | 38 +++++++++++ .../jdbc/cloud/storage/floe/FloeIv.java | 21 ++++++ .../jdbc/cloud/storage/floe/FloeIvLength.java | 13 ++++ .../jdbc/cloud/storage/floe/FloeKdf.java | 47 +++++++++++++ .../jdbc/cloud/storage/floe/FloeKey.java | 15 +++++ .../cloud/storage/floe/FloeParameterSpec.java | 51 ++++++++++++++ .../jdbc/cloud/storage/floe/FloePurpose.java | 17 +++++ .../jdbc/cloud/storage/floe/FloeRandom.java | 5 ++ .../client/jdbc/cloud/storage/floe/Hash.java | 21 ++++++ .../cloud/storage/floe/SecureFloeRandom.java | 15 +++++ .../cloud/storage/floe/FixedFloeRandom.java | 17 +++++ .../storage/floe/FloeEncryptorImplTest.java | 45 +++++++++++++ .../jdbc/cloud/storage/floe/FloeTest.java | 66 +++++++++++++++++++ 22 files changed, 493 insertions(+), 3 deletions(-) create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Floe.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeAad.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeBase.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIv.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIvLength.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKdf.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKey.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeRandom.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Hash.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SecureFloeRandom.java create mode 100644 src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FixedFloeRandom.java create mode 100644 src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java create mode 100644 src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/EncryptionProvider.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/EncryptionProvider.java index faa74ce86..5f6ddd801 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/EncryptionProvider.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/EncryptionProvider.java @@ -41,8 +41,7 @@ public class EncryptionProvider { private static final String FILE_CIPHER = "AES/CBC/PKCS5Padding"; private static final String KEY_CIPHER = "AES/ECB/PKCS5Padding"; private static final int BUFFER_SIZE = 2 * 1024 * 1024; // 2 MB - private static ThreadLocal secRnd = - new ThreadLocal<>().withInitial(SecureRandom::new); + private static ThreadLocal secRnd = ThreadLocal.withInitial(SecureRandom::new); /** * Decrypt a InputStream diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java index b4a4682d8..1e2e25e64 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java @@ -37,7 +37,7 @@ class GcmEncryptionProvider { private static final String KEY_CIPHER = "AES/GCM/NoPadding"; private static final int BUFFER_SIZE = 8 * 1024 * 1024; // 2 MB private static final ThreadLocal random = - new ThreadLocal<>().withInitial(SecureRandom::new); + ThreadLocal.withInitial(SecureRandom::new); private static final Base64.Decoder base64Decoder = Base64.getDecoder(); static InputStream encrypt( diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java new file mode 100644 index 000000000..861343163 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java @@ -0,0 +1,16 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +public enum Aead { + AES_GCM_128((byte) 0), + AES_GCM_256((byte) 1); + + private byte id; + + Aead(byte id) { + this.id = id; + } + + byte getId() { + return id; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Floe.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Floe.java new file mode 100644 index 000000000..b3147a097 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Floe.java @@ -0,0 +1,23 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import javax.crypto.SecretKey; + +public class Floe { + private final FloeParameterSpec parameterSpec; + + private Floe(FloeParameterSpec parameterSpec) { + this.parameterSpec = parameterSpec; + } + + public static Floe getInstance(FloeParameterSpec parameterSpec) { + return new Floe(parameterSpec); + } + + public FloeEncryptor createEncryptor(SecretKey key, byte[] aad) { + return new FloeEncryptorImpl(parameterSpec, new FloeKey(key), new FloeAad(aad)); + } + + public FloeDecryptor createDecryptor(SecretKey key, byte[] aad, byte[] floeHeader) { + return new FloeDecryptorImpl(parameterSpec, new FloeKey(key), new FloeAad(aad), floeHeader); + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeAad.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeAad.java new file mode 100644 index 000000000..3c24eb136 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeAad.java @@ -0,0 +1,15 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.util.Optional; + +class FloeAad { + private final byte[] aad; + + FloeAad(byte[] aad) { + this.aad = Optional.ofNullable(aad).orElse(new byte[0]); + } + + byte[] getBytes() { + return aad; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeBase.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeBase.java new file mode 100644 index 000000000..7328d480c --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeBase.java @@ -0,0 +1,18 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +abstract class FloeBase { + protected static final int headerTagLength = 32; + + protected final FloeParameterSpec parameterSpec; + protected final FloeKey floeKey; + protected final FloeAad floeAad; + + protected final FloeKdf floeKdf; + + FloeBase(FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad) { + this.parameterSpec = parameterSpec; + this.floeKey = floeKey; + this.floeAad = floeAad; + this.floeKdf = new FloeKdf(parameterSpec); + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java new file mode 100644 index 000000000..87e2463fd --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java @@ -0,0 +1,3 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +public interface FloeDecryptor {} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java new file mode 100644 index 000000000..7139d9e73 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java @@ -0,0 +1,40 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; +import java.util.Arrays; + +public class FloeDecryptorImpl extends FloeBase implements FloeDecryptor { + FloeDecryptorImpl( + FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad, byte[] floeHeaderAsBytes) { + super(parameterSpec, floeKey, floeAad); + validate(floeHeaderAsBytes); + } + + public void validate(byte[] floeHeaderAsBytes) { + byte[] encodedParams = parameterSpec.paramEncode(); + if (floeHeaderAsBytes.length + != encodedParams.length + parameterSpec.getFloeIvLength().getLength() + headerTagLength) { + throw new IllegalArgumentException("invalid header length"); + } + ByteBuffer floeHeader = ByteBuffer.wrap(floeHeaderAsBytes); + + byte[] encodedParamsFromHeader = new byte[10]; + floeHeader.get(encodedParamsFromHeader, 0, encodedParamsFromHeader.length); + if (!Arrays.equals(encodedParams, encodedParamsFromHeader)) { + throw new IllegalArgumentException("invalid parameters header"); + } + + byte[] floeIvBytes = new byte[parameterSpec.getFloeIvLength().getLength()]; + floeHeader.get(floeIvBytes, 0, floeIvBytes.length); + FloeIv floeIv = new FloeIv(floeIvBytes); + + byte[] headerTagFromHeader = new byte[headerTagLength]; + floeHeader.get(headerTagFromHeader, 0, headerTagFromHeader.length); + + byte[] headerTag = + floeKdf.hkdfExpand(floeKey, floeIv, floeAad, FloePurpose.HEADER_TAG, headerTagLength); + if (!Arrays.equals(headerTag, headerTagFromHeader)) { + throw new IllegalArgumentException("invalid header tag"); + } + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java new file mode 100644 index 000000000..b629869f8 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java @@ -0,0 +1,5 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +public interface FloeEncryptor { + byte[] getHeader(); +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java new file mode 100644 index 000000000..ed993962f --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java @@ -0,0 +1,38 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; + +class FloeEncryptorImpl extends FloeBase implements FloeEncryptor { + private final FloeIv floeIv; + + private final byte[] header; + + FloeEncryptorImpl(FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad) { + super(parameterSpec, floeKey, floeAad); + this.floeIv = + FloeIv.generateRandom(parameterSpec.getFloeRandom(), parameterSpec.getFloeIvLength()); + this.header = buildHeader(); + } + + private byte[] buildHeader() { + byte[] parametersEncoded = parameterSpec.paramEncode(); + byte[] floeIvBytes = floeIv.getBytes(); + byte[] headerTag = + floeKdf.hkdfExpand(floeKey, floeIv, floeAad, FloePurpose.HEADER_TAG, headerTagLength); + + ByteBuffer result = + ByteBuffer.allocate(parametersEncoded.length + floeIvBytes.length + headerTag.length); + result.put(parametersEncoded); + result.put(floeIvBytes); + result.put(headerTag); + if (result.hasRemaining()) { + throw new IllegalArgumentException("Header is too long"); + } + return result.array(); + } + + @Override + public byte[] getHeader() { + return header; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIv.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIv.java new file mode 100644 index 000000000..1022510b5 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIv.java @@ -0,0 +1,21 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +class FloeIv { + private final byte[] bytes; + + FloeIv(byte[] bytes) { + this.bytes = bytes; + } + + static FloeIv generateRandom(FloeRandom floeRandom, FloeIvLength floeIvLength) { + return new FloeIv(floeRandom.ofLength(floeIvLength.getLength())); + } + + byte[] getBytes() { + return bytes; + } + + int lengthInBytes() { + return bytes.length; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIvLength.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIvLength.java new file mode 100644 index 000000000..a0cf8f05d --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIvLength.java @@ -0,0 +1,13 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +public class FloeIvLength { + private final int length; + + public FloeIvLength(int length) { + this.length = length; + } + + public int getLength() { + return length; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKdf.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKdf.java new file mode 100644 index 000000000..0d39e0a52 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKdf.java @@ -0,0 +1,47 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; +import javax.crypto.Mac; + +class FloeKdf { + private final FloeParameterSpec parameterSpec; + + FloeKdf(FloeParameterSpec parameterSpec) { + this.parameterSpec = parameterSpec; + } + + byte[] hkdfExpand( + FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, FloePurpose purpose, int length) { + byte[] encodedParams = parameterSpec.paramEncode(); + ByteBuffer info = + ByteBuffer.allocate( + encodedParams.length + + floeIv.getBytes().length + + purpose.getBytes().length + + floeAad.getBytes().length); + info.put(encodedParams); + info.put(floeIv.getBytes()); + info.put(purpose.getBytes()); + info.put(floeAad.getBytes()); + return jceHkdfExpand(parameterSpec.getHash(), floeKey, info.array(), length); + } + + private byte[] jceHkdfExpand(Hash hash, FloeKey prk, byte[] info, int len) { + try { + Mac mac = Mac.getInstance(hash.getJceName()); + mac.init(prk.getKey()); + mac.update(info); + mac.update((byte) 1); + byte[] bytes = mac.doFinal(); + if (bytes.length != len) { + return Arrays.copyOf(bytes, len); + } + return bytes; + } catch (NoSuchAlgorithmException | InvalidKeyException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKey.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKey.java new file mode 100644 index 000000000..6b6bf9991 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKey.java @@ -0,0 +1,15 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import javax.crypto.SecretKey; + +class FloeKey { + private final SecretKey key; + + FloeKey(SecretKey key) { + this.key = key; + } + + SecretKey getKey() { + return key; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java new file mode 100644 index 000000000..53b5db779 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java @@ -0,0 +1,51 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +public class FloeParameterSpec { + private final Aead aead; + private final Hash hash; + private final int encryptedSegmentLength; + private final FloeIvLength floeIvLength; + private final FloeRandom floeRandom; + + public FloeParameterSpec(Aead aead, Hash hash, int encryptedSegmentLength, int floeIvLength) { + this( + aead, hash, encryptedSegmentLength, new FloeIvLength(floeIvLength), new SecureFloeRandom()); + } + + FloeParameterSpec( + Aead aead, + Hash hash, + int encryptedSegmentLength, + FloeIvLength floeIvLength, + FloeRandom floeRandom) { + this.aead = aead; + this.hash = hash; + this.encryptedSegmentLength = encryptedSegmentLength; + this.floeIvLength = floeIvLength; + this.floeRandom = floeRandom; + } + + byte[] paramEncode() { + ByteBuffer result = ByteBuffer.allocate(10).order(ByteOrder.BIG_ENDIAN); + result.put(aead.getId()); + result.put(hash.getId()); + result.putInt(encryptedSegmentLength); + result.putInt(floeIvLength.getLength()); + return result.array(); + } + + public Hash getHash() { + return hash; + } + + public FloeIvLength getFloeIvLength() { + return floeIvLength; + } + + FloeRandom getFloeRandom() { + return floeRandom; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java new file mode 100644 index 000000000..ad4627035 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java @@ -0,0 +1,17 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.charset.StandardCharsets; + +public enum FloePurpose { + HEADER_TAG("HEADER_TAG:".getBytes(StandardCharsets.UTF_8)); + + private final byte[] bytes; + + FloePurpose(byte[] bytes) { + this.bytes = bytes; + } + + public byte[] getBytes() { + return bytes; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeRandom.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeRandom.java new file mode 100644 index 000000000..a0bd176f1 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeRandom.java @@ -0,0 +1,5 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +interface FloeRandom { + byte[] ofLength(int length); +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Hash.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Hash.java new file mode 100644 index 000000000..45f5b8c3a --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Hash.java @@ -0,0 +1,21 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +public enum Hash { + SHA384((byte) 0, "HmacSHA384"); + + private byte id; + private final String jceName; + + Hash(byte id, String jceName) { + this.id = id; + this.jceName = jceName; + } + + byte getId() { + return id; + } + + public String getJceName() { + return jceName; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SecureFloeRandom.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SecureFloeRandom.java new file mode 100644 index 000000000..9302c8c83 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SecureFloeRandom.java @@ -0,0 +1,15 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.security.SecureRandom; + +class SecureFloeRandom implements FloeRandom { + private static final ThreadLocal random = + ThreadLocal.withInitial(SecureRandom::new); + + @Override + public byte[] ofLength(int length) { + byte[] bytes = new byte[length]; + random.get().nextBytes(bytes); + return bytes; + } +} diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FixedFloeRandom.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FixedFloeRandom.java new file mode 100644 index 000000000..1a790d7ca --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FixedFloeRandom.java @@ -0,0 +1,17 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +public class FixedFloeRandom implements FloeRandom { + private final byte[] bytes; + + public FixedFloeRandom(byte[] bytes) { + this.bytes = bytes; + } + + @Override + public byte[] ofLength(int length) { + if (bytes.length != length) { + throw new IllegalArgumentException("allowed only " + bytes.length + " bytes"); + } + return bytes; + } +} diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java new file mode 100644 index 000000000..b50f5d280 --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java @@ -0,0 +1,45 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.nio.charset.StandardCharsets; +import javax.crypto.spec.SecretKeySpec; +import org.junit.jupiter.api.Test; + +class FloeEncryptorImplTest { + @Test + void shouldCreateCorrectHeader() { + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 12345678, + new FloeIvLength(4), + new FixedFloeRandom(new byte[] {11, 22, 33, 44})); + FloeKey floeKey = new FloeKey(new SecretKeySpec(new byte[32], "FLOE")); + FloeAad floeAad = new FloeAad("test aad".getBytes(StandardCharsets.UTF_8)); + FloeEncryptorImpl floeEncryptor = new FloeEncryptorImpl(parameterSpec, floeKey, floeAad); + byte[] header = floeEncryptor.getHeader(); + // AEAD ID + assertEquals(Aead.AES_GCM_256.getId(), header[0]); + // HASH ID + assertEquals(Hash.SHA384.getId(), header[1]); + // Segment length in BE + // 12345678(10) = BC614E(16) + assertEquals(0, header[2]); + assertEquals((byte) 188, header[3]); + assertEquals((byte) 97, header[4]); + assertEquals((byte) 78, header[5]); + // FLOE IV length in BE + // 4(10) = 4(16) = 00,00,00,04 + assertEquals(0, header[6]); + assertEquals(0, header[7]); + assertEquals(0, header[8]); + assertEquals(4, header[9]); + // FLOE IV + assertEquals(11, header[10]); + assertEquals(22, header[11]); + assertEquals(33, header[12]); + assertEquals(44, header[13]); + } +} diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java new file mode 100644 index 000000000..8395f9f8d --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java @@ -0,0 +1,66 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import static org.junit.jupiter.api.Assertions.*; + +import java.nio.charset.StandardCharsets; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import org.junit.jupiter.api.Test; + +class FloeTest { + @Test + void validateHeaderMatchesForEncryptionAndDecryption() { + byte[] aad = "test aad".getBytes(StandardCharsets.UTF_8); + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_128, Hash.SHA384, 1024, 4); + Floe floe = Floe.getInstance(parameterSpec); + SecretKey secretKey = new SecretKeySpec(new byte[16], "FLOE"); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + assertDoesNotThrow(() -> floe.createDecryptor(secretKey, aad, header)); + } + + @Test + void validateHeaderDoesNotMatchInParams() { + byte[] aad = "test aad".getBytes(StandardCharsets.UTF_8); + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_128, Hash.SHA384, 1024, 4); + Floe floe = Floe.getInstance(parameterSpec); + SecretKey secretKey = new SecretKeySpec(new byte[16], "FLOE"); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + header[0] = 12; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid parameters header"); + } + + @Test + void validateHeaderDoesNotMatchInIV() { + byte[] aad = "test aad".getBytes(StandardCharsets.UTF_8); + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_128, Hash.SHA384, 1024, 4); + Floe floe = Floe.getInstance(parameterSpec); + SecretKey secretKey = new SecretKeySpec(new byte[16], "FLOE"); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + header[11]++; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid header tag"); + } + + @Test + void validateHeaderDoesNotMatchInHeaderTag() { + byte[] aad = "test aad".getBytes(StandardCharsets.UTF_8); + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_128, Hash.SHA384, 1024, 4); + Floe floe = Floe.getInstance(parameterSpec); + SecretKey secretKey = new SecretKeySpec(new byte[16], "FLOE"); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + header[header.length - 3]++; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid header tag"); + } +} From 68b11d239678ee5bfe7deeeb5b084c831070694e Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Fri, 20 Dec 2024 13:45:08 +0100 Subject: [PATCH 2/5] Implement processing segments --- .../cloud/storage/GcmEncryptionProvider.java | 16 +- .../client/jdbc/cloud/storage/floe/Aead.java | 45 +++++- .../jdbc/cloud/storage/floe/AeadAad.java | 22 +++ .../jdbc/cloud/storage/floe/AeadIv.java | 25 +++ .../jdbc/cloud/storage/floe/AeadKey.java | 15 ++ .../jdbc/cloud/storage/floe/AeadProvider.java | 12 ++ .../storage/floe/BaseSegmentProcessor.java | 44 ++++++ .../jdbc/cloud/storage/floe/FloeBase.java | 18 --- .../cloud/storage/floe/FloeDecryptor.java | 2 +- .../cloud/storage/floe/FloeDecryptorImpl.java | 61 ++++++-- .../cloud/storage/floe/FloeEncryptor.java | 2 +- .../cloud/storage/floe/FloeEncryptorImpl.java | 47 +++++- .../cloud/storage/floe/FloeParameterSpec.java | 29 +++- .../jdbc/cloud/storage/floe/FloePurpose.java | 32 +++- .../floe/{FloeKdf.java => KeyDerivator.java} | 13 +- .../cloud/storage/floe/SegmentProcessor.java | 5 + .../jdbc/cloud/storage/floe/aead/Gcm.java | 51 ++++++ .../snowflake/client/AbstractDriverIT.java | 4 + .../cloud/storage/floe/FixedFloeRandom.java | 17 -- .../storage/floe/FloeEncryptorImplTest.java | 60 +++++++- .../jdbc/cloud/storage/floe/FloeTest.java | 145 +++++++++++++----- .../storage/floe/IncrementingFloeRandom.java | 14 ++ 22 files changed, 564 insertions(+), 115 deletions(-) create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadAad.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadIv.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadKey.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadProvider.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java delete mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeBase.java rename src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/{FloeKdf.java => KeyDerivator.java} (76%) create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java delete mode 100644 src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FixedFloeRandom.java create mode 100644 src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/IncrementingFloeRandom.java diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java index 1e2e25e64..37475083f 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java @@ -26,15 +26,15 @@ import javax.crypto.SecretKey; import javax.crypto.spec.GCMParameterSpec; import javax.crypto.spec.SecretKeySpec; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; import net.snowflake.client.jdbc.MatDesc; import net.snowflake.common.core.RemoteStoreFileEncryptionMaterial; -class GcmEncryptionProvider { +@SnowflakeJdbcInternalApi +public class GcmEncryptionProvider { private static final int TAG_LENGTH_IN_BITS = 128; private static final int IV_LENGTH_IN_BYTES = 12; private static final String AES = "AES"; - private static final String FILE_CIPHER = "AES/GCM/NoPadding"; - private static final String KEY_CIPHER = "AES/GCM/NoPadding"; private static final int BUFFER_SIZE = 8 * 1024 * 1024; // 2 MB private static final ThreadLocal random = ThreadLocal.withInitial(SecureRandom::new); @@ -85,7 +85,7 @@ private static byte[] encryptKey(byte[] kekBytes, byte[] keyBytes, byte[] keyIvD BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException { SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, keyIvData); - Cipher keyCipher = Cipher.getInstance(KEY_CIPHER); + Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME); keyCipher.init(Cipher.ENCRYPT_MODE, kek, gcmParameterSpec); if (aad != null) { keyCipher.updateAAD(aad); @@ -99,7 +99,7 @@ private static CipherInputStream encryptContent( NoSuchAlgorithmException { SecretKey fileKey = new SecretKeySpec(keyBytes, 0, keyBytes.length, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, dataIvBytes); - Cipher fileCipher = Cipher.getInstance(FILE_CIPHER); + Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME); fileCipher.init(Cipher.ENCRYPT_MODE, fileKey, gcmParameterSpec); if (aad != null) { fileCipher.updateAAD(aad); @@ -172,7 +172,7 @@ private static CipherInputStream decryptContentFromStream( NoSuchAlgorithmException { GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes); SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES); - Cipher fileCipher = Cipher.getInstance(FILE_CIPHER); + Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME); fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec); if (aad != null) { fileCipher.updateAAD(aad); @@ -187,7 +187,7 @@ private static void decryptContentFromFile( SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, cekIvBytes); byte[] buffer = new byte[BUFFER_SIZE]; - Cipher fileCipher = Cipher.getInstance(FILE_CIPHER); + Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME); fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec); if (aad != null) { fileCipher.updateAAD(aad); @@ -215,7 +215,7 @@ private static byte[] decryptKey(byte[] kekBytes, byte[] ivBytes, byte[] keyByte BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException { SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes); - Cipher keyCipher = Cipher.getInstance(KEY_CIPHER); + Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME); keyCipher.init(Cipher.DECRYPT_MODE, kek, gcmParameterSpec); if (aad != null) { keyCipher.updateAAD(aad); diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java index 861343163..04562d47b 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java @@ -1,16 +1,55 @@ package net.snowflake.client.jdbc.cloud.storage.floe; +import net.snowflake.client.jdbc.cloud.storage.floe.aead.Gcm; + public enum Aead { - AES_GCM_128((byte) 0), - AES_GCM_256((byte) 1); + // TODO confirm id + AES_GCM_256((byte) 0, "AES/GCM/NoPadding", 32, 12, 16, new Gcm(16)), + AES_GCM_128((byte) 1, "AES/GCM/NoPadding", 16, 12, 16, new Gcm(16)); private byte id; + private String jceName; + private int keyLength; + private int ivLength; + private int authTagLength; + private AeadProvider aeadProvider; - Aead(byte id) { + Aead( + byte id, + String jceName, + int keyLength, + int ivLength, + int authTagLength, + AeadProvider aeadProvider) { + this.jceName = jceName; + this.keyLength = keyLength; this.id = id; + this.ivLength = ivLength; + this.authTagLength = authTagLength; + this.aeadProvider = aeadProvider; } byte getId() { return id; } + + String getJceName() { + return jceName; + } + + int getKeyLength() { + return keyLength; + } + + int getIvLength() { + return ivLength; + } + + int getAuthTagLength() { + return authTagLength; + } + + AeadProvider getAeadProvider() { + return aeadProvider; + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadAad.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadAad.java new file mode 100644 index 000000000..6586c28bd --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadAad.java @@ -0,0 +1,22 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; + +class AeadAad { + private final byte[] bytes; + + private AeadAad(long segmentCounter, byte terminalityByte) { + ByteBuffer buf = ByteBuffer.allocate(9); + buf.putLong(segmentCounter); + buf.put(terminalityByte); + this.bytes = buf.array(); + } + + static AeadAad nonTerminal(long segmentCounter) { + return new AeadAad(segmentCounter, (byte) 0); + } + + byte[] getBytes() { + return bytes; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadIv.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadIv.java new file mode 100644 index 000000000..c2a559b47 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadIv.java @@ -0,0 +1,25 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; + +class AeadIv { + private final byte[] bytes; + + AeadIv(byte[] bytes) { + this.bytes = bytes; + } + + public static AeadIv generateRandom(FloeRandom floeRandom, int ivLength) { + return new AeadIv(floeRandom.ofLength(ivLength)); + } + + public static AeadIv from(ByteBuffer buffer, int ivLength) { + byte[] bytes = new byte[ivLength]; + buffer.get(bytes); + return new AeadIv(bytes); + } + + byte[] getBytes() { + return bytes; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadKey.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadKey.java new file mode 100644 index 000000000..bfbd01976 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadKey.java @@ -0,0 +1,15 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import javax.crypto.SecretKey; + +class AeadKey { + private final SecretKey key; + + AeadKey(SecretKey key) { + this.key = key; + } + + SecretKey getKey() { + return key; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadProvider.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadProvider.java new file mode 100644 index 000000000..106d604cd --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadProvider.java @@ -0,0 +1,12 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.security.GeneralSecurityException; +import javax.crypto.SecretKey; + +public interface AeadProvider { + byte[] encrypt(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext) + throws GeneralSecurityException; + + byte[] decrypt(SecretKey key, byte[] iv, byte[] aad, byte[] ciphertext) + throws GeneralSecurityException; +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java new file mode 100644 index 000000000..63453e18c --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java @@ -0,0 +1,44 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; + +abstract class BaseSegmentProcessor { + protected static final int NON_TERMINAL_SEGMENT_SIZE_MARKER = -1; + protected static final int headerTagLength = 32; + + protected final FloeParameterSpec parameterSpec; + protected final FloeKey floeKey; + protected final FloeAad floeAad; + + protected final KeyDerivator keyDerivator; + + private AeadKey currentAeadKey; + + BaseSegmentProcessor(FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad) { + this.parameterSpec = parameterSpec; + this.floeKey = floeKey; + this.floeAad = floeAad; + this.keyDerivator = new KeyDerivator(parameterSpec); + } + + protected AeadKey getKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) { + if (currentAeadKey == null || segmentCounter % parameterSpec.getKeyRotationModulo() == 0) { + currentAeadKey = deriveKey(floeKey, floeIv, floeAad, segmentCounter); + } + return currentAeadKey; + } + + private AeadKey deriveKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) { + byte[] keyBytes = + keyDerivator.hkdfExpand( + floeKey, + floeIv, + floeAad, + new DekTagFloePurpose(segmentCounter), + parameterSpec.getAead().getKeyLength()); + SecretKey key = + new SecretKeySpec(keyBytes, "AES"); // for now it is safe as we use only AES as AEAD + return new AeadKey(key); + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeBase.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeBase.java deleted file mode 100644 index 7328d480c..000000000 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeBase.java +++ /dev/null @@ -1,18 +0,0 @@ -package net.snowflake.client.jdbc.cloud.storage.floe; - -abstract class FloeBase { - protected static final int headerTagLength = 32; - - protected final FloeParameterSpec parameterSpec; - protected final FloeKey floeKey; - protected final FloeAad floeAad; - - protected final FloeKdf floeKdf; - - FloeBase(FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad) { - this.parameterSpec = parameterSpec; - this.floeKey = floeKey; - this.floeAad = floeAad; - this.floeKdf = new FloeKdf(parameterSpec); - } -} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java index 87e2463fd..085f23789 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java @@ -1,3 +1,3 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -public interface FloeDecryptor {} +public interface FloeDecryptor extends SegmentProcessor {} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java index 7139d9e73..9fc91d2ef 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java @@ -1,19 +1,21 @@ package net.snowflake.client.jdbc.cloud.storage.floe; import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; import java.util.Arrays; -public class FloeDecryptorImpl extends FloeBase implements FloeDecryptor { +public class FloeDecryptorImpl extends BaseSegmentProcessor implements FloeDecryptor { + private final FloeIv floeIv; + private long segmentCounter; + FloeDecryptorImpl( FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad, byte[] floeHeaderAsBytes) { super(parameterSpec, floeKey, floeAad); - validate(floeHeaderAsBytes); - } - - public void validate(byte[] floeHeaderAsBytes) { - byte[] encodedParams = parameterSpec.paramEncode(); + byte[] encodedParams = this.parameterSpec.paramEncode(); if (floeHeaderAsBytes.length - != encodedParams.length + parameterSpec.getFloeIvLength().getLength() + headerTagLength) { + != encodedParams.length + + this.parameterSpec.getFloeIvLength().getLength() + + headerTagLength) { throw new IllegalArgumentException("invalid header length"); } ByteBuffer floeHeader = ByteBuffer.wrap(floeHeaderAsBytes); @@ -24,17 +26,56 @@ public void validate(byte[] floeHeaderAsBytes) { throw new IllegalArgumentException("invalid parameters header"); } - byte[] floeIvBytes = new byte[parameterSpec.getFloeIvLength().getLength()]; + byte[] floeIvBytes = new byte[this.parameterSpec.getFloeIvLength().getLength()]; floeHeader.get(floeIvBytes, 0, floeIvBytes.length); - FloeIv floeIv = new FloeIv(floeIvBytes); + this.floeIv = new FloeIv(floeIvBytes); byte[] headerTagFromHeader = new byte[headerTagLength]; floeHeader.get(headerTagFromHeader, 0, headerTagFromHeader.length); byte[] headerTag = - floeKdf.hkdfExpand(floeKey, floeIv, floeAad, FloePurpose.HEADER_TAG, headerTagLength); + keyDerivator.hkdfExpand( + this.floeKey, floeIv, this.floeAad, HeaderTagFloePurpose.INSTANCE, headerTagLength); if (!Arrays.equals(headerTag, headerTagFromHeader)) { throw new IllegalArgumentException("invalid header tag"); } } + + @Override + public byte[] processSegment(byte[] input) { + try { + verifySegmentLength(input); + ByteBuffer inputBuf = ByteBuffer.wrap(input); + verifySegmentSizeMarker(inputBuf); + AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); + AeadIv aeadIv = AeadIv.from(inputBuf, parameterSpec.getAead().getIvLength()); + AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter++); + AeadProvider aeadProvider = parameterSpec.getAead().getAeadProvider(); + byte[] ciphertext = new byte[inputBuf.remaining()]; + inputBuf.get(ciphertext); + return aeadProvider.decrypt( + aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), ciphertext); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + private void verifySegmentLength(byte[] input) { + if (input.length != parameterSpec.getEncryptedSegmentLength()) { + throw new IllegalArgumentException( + String.format( + "segment length mismatch, expected %d, got %d", + parameterSpec.getEncryptedSegmentLength(), input.length)); + } + } + + private void verifySegmentSizeMarker(ByteBuffer inputBuf) { + int segmentSizeMarker = inputBuf.getInt(); + if (segmentSizeMarker != NON_TERMINAL_SEGMENT_SIZE_MARKER) { + throw new IllegalStateException( + String.format( + "segment length marker mismatch, expected: %d, got :%d", + NON_TERMINAL_SEGMENT_SIZE_MARKER, segmentSizeMarker)); + } + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java index b629869f8..f1ab85496 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java @@ -1,5 +1,5 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -public interface FloeEncryptor { +public interface FloeEncryptor extends SegmentProcessor { byte[] getHeader(); } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java index ed993962f..93479e789 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java @@ -1,9 +1,14 @@ package net.snowflake.client.jdbc.cloud.storage.floe; import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; + +class FloeEncryptorImpl extends BaseSegmentProcessor implements FloeEncryptor { -class FloeEncryptorImpl extends FloeBase implements FloeEncryptor { private final FloeIv floeIv; + private AeadKey currentAeadKey; + + private long segmentCounter; private final byte[] header; @@ -18,7 +23,8 @@ private byte[] buildHeader() { byte[] parametersEncoded = parameterSpec.paramEncode(); byte[] floeIvBytes = floeIv.getBytes(); byte[] headerTag = - floeKdf.hkdfExpand(floeKey, floeIv, floeAad, FloePurpose.HEADER_TAG, headerTagLength); + keyDerivator.hkdfExpand( + floeKey, floeIv, floeAad, HeaderTagFloePurpose.INSTANCE, headerTagLength); ByteBuffer result = ByteBuffer.allocate(parametersEncoded.length + floeIvBytes.length + headerTag.length); @@ -35,4 +41,41 @@ private byte[] buildHeader() { public byte[] getHeader() { return header; } + + @Override + public byte[] processSegment(byte[] input) { + verifySegmentLength(input); + // TODO assert State.Counter != 2^40-1 # Prevent overflow + try { + AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); + AeadIv aeadIv = + AeadIv.generateRandom( + parameterSpec.getFloeRandom(), parameterSpec.getAead().getIvLength()); + AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter++); + AeadProvider aeadProvider = parameterSpec.getAead().getAeadProvider(); + // it works as long as AEAD returns auth tag as a part of the ciphertext + byte[] ciphertextWithAuthTag = + aeadProvider.encrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), input); + return segmentToBytes(aeadIv, ciphertextWithAuthTag); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + private byte[] segmentToBytes(AeadIv aeadIv, byte[] ciphertextWithAuthTag) { + ByteBuffer output = ByteBuffer.allocate(parameterSpec.getEncryptedSegmentLength()); + output.putInt(NON_TERMINAL_SEGMENT_SIZE_MARKER); + output.put(aeadIv.getBytes()); + output.put(ciphertextWithAuthTag); + return output.array(); + } + + private void verifySegmentLength(byte[] input) { + if (input.length != parameterSpec.getPlainTextSegmentLength()) { + throw new IllegalArgumentException( + String.format( + "segment length mismatch, expected %d, got %d", + parameterSpec.getPlainTextSegmentLength(), input.length)); + } + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java index 53b5db779..65e937f8c 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java @@ -9,10 +9,16 @@ public class FloeParameterSpec { private final int encryptedSegmentLength; private final FloeIvLength floeIvLength; private final FloeRandom floeRandom; + private final int keyRotationModulo; public FloeParameterSpec(Aead aead, Hash hash, int encryptedSegmentLength, int floeIvLength) { this( - aead, hash, encryptedSegmentLength, new FloeIvLength(floeIvLength), new SecureFloeRandom()); + aead, + hash, + encryptedSegmentLength, + new FloeIvLength(floeIvLength), + new SecureFloeRandom(), + 1 << 20); } FloeParameterSpec( @@ -20,12 +26,14 @@ public FloeParameterSpec(Aead aead, Hash hash, int encryptedSegmentLength, int f Hash hash, int encryptedSegmentLength, FloeIvLength floeIvLength, - FloeRandom floeRandom) { + FloeRandom floeRandom, + int keyRotationModulo) { this.aead = aead; this.hash = hash; this.encryptedSegmentLength = encryptedSegmentLength; this.floeIvLength = floeIvLength; this.floeRandom = floeRandom; + this.keyRotationModulo = keyRotationModulo; } byte[] paramEncode() { @@ -37,6 +45,10 @@ byte[] paramEncode() { return result.array(); } + public Aead getAead() { + return aead; + } + public Hash getHash() { return hash; } @@ -48,4 +60,17 @@ public FloeIvLength getFloeIvLength() { FloeRandom getFloeRandom() { return floeRandom; } + + int getEncryptedSegmentLength() { + return encryptedSegmentLength; + } + + int getPlainTextSegmentLength() { + // sizeof(int) == 4, file size is a part of the segment ciphertext + return encryptedSegmentLength - aead.getIvLength() - aead.getAuthTagLength() - 4; + } + + int getKeyRotationModulo() { + return keyRotationModulo; + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java index ad4627035..41fda867a 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java @@ -1,17 +1,39 @@ package net.snowflake.client.jdbc.cloud.storage.floe; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -public enum FloePurpose { - HEADER_TAG("HEADER_TAG:".getBytes(StandardCharsets.UTF_8)); +interface FloePurpose { + byte[] generate(); +} + +class HeaderTagFloePurpose implements FloePurpose { + private static final byte[] bytes = "HEADER_TAG:".getBytes(StandardCharsets.UTF_8); + + static final HeaderTagFloePurpose INSTANCE = new HeaderTagFloePurpose(); + + private HeaderTagFloePurpose() {} + + @Override + public byte[] generate() { + return bytes; + } +} + +class DekTagFloePurpose implements FloePurpose { + private static final byte[] prefix = "DEK:".getBytes(StandardCharsets.UTF_8); private final byte[] bytes; - FloePurpose(byte[] bytes) { - this.bytes = bytes; + DekTagFloePurpose(long segmentCount) { + ByteBuffer buffer = ByteBuffer.allocate(prefix.length + 8 /*size of long*/); + buffer.put(prefix); + buffer.putLong(segmentCount); + this.bytes = buffer.array(); } - public byte[] getBytes() { + @Override + public byte[] generate() { return bytes; } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKdf.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/KeyDerivator.java similarity index 76% rename from src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKdf.java rename to src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/KeyDerivator.java index 0d39e0a52..e3d064ba1 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKdf.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/KeyDerivator.java @@ -6,30 +6,31 @@ import java.util.Arrays; import javax.crypto.Mac; -class FloeKdf { +class KeyDerivator { private final FloeParameterSpec parameterSpec; - FloeKdf(FloeParameterSpec parameterSpec) { + KeyDerivator(FloeParameterSpec parameterSpec) { this.parameterSpec = parameterSpec; } byte[] hkdfExpand( FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, FloePurpose purpose, int length) { byte[] encodedParams = parameterSpec.paramEncode(); + byte[] purposeBytes = purpose.generate(); ByteBuffer info = ByteBuffer.allocate( encodedParams.length + floeIv.getBytes().length - + purpose.getBytes().length + + purposeBytes.length + floeAad.getBytes().length); info.put(encodedParams); info.put(floeIv.getBytes()); - info.put(purpose.getBytes()); + info.put(purposeBytes); info.put(floeAad.getBytes()); - return jceHkdfExpand(parameterSpec.getHash(), floeKey, info.array(), length); + return hkdfExpandInternal(parameterSpec.getHash(), floeKey, info.array(), length); } - private byte[] jceHkdfExpand(Hash hash, FloeKey prk, byte[] info, int len) { + private byte[] hkdfExpandInternal(Hash hash, FloeKey prk, byte[] info, int len) { try { Mac mac = Mac.getInstance(hash.getJceName()); mac.init(prk.getKey()); diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java new file mode 100644 index 000000000..45e4f2872 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java @@ -0,0 +1,5 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +interface SegmentProcessor { + byte[] processSegment(byte[] input); +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java new file mode 100644 index 000000000..ec5662f11 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java @@ -0,0 +1,51 @@ +package net.snowflake.client.jdbc.cloud.storage.floe.aead; + +import java.security.GeneralSecurityException; +import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.SecretKey; +import javax.crypto.spec.GCMParameterSpec; +import net.snowflake.client.jdbc.cloud.storage.floe.AeadProvider; + +// This class is not thread safe! +public class Gcm implements AeadProvider { + private final Cipher keyCipher; + private final int tagLengthInBits; + + public Gcm(int tagLengthInBytes) { + try { + keyCipher = Cipher.getInstance("AES/GCM/NoPadding"); + this.tagLengthInBits = tagLengthInBytes * 8; + } catch (NoSuchAlgorithmException | NoSuchPaddingException e) { + throw new ExceptionInInitializerError(e); + } + } + + @Override + public byte[] encrypt(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext) + throws GeneralSecurityException { + return process(key, iv, aad, plaintext, Cipher.ENCRYPT_MODE); + } + + @Override + public byte[] decrypt(SecretKey key, byte[] iv, byte[] aad, byte[] ciphertext) + throws GeneralSecurityException { + return process(key, iv, aad, ciphertext, Cipher.DECRYPT_MODE); + } + + private byte[] process(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext, int encryptMode) + throws InvalidKeyException, InvalidAlgorithmParameterException, IllegalBlockSizeException, + BadPaddingException { + GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(tagLengthInBits, iv); + keyCipher.init(encryptMode, key, gcmParameterSpec); + if (aad != null) { + keyCipher.updateAAD(aad); + } + return keyCipher.doFinal(plaintext); + } +} diff --git a/src/test/java/net/snowflake/client/AbstractDriverIT.java b/src/test/java/net/snowflake/client/AbstractDriverIT.java index 3104ce7e9..c75b81818 100644 --- a/src/test/java/net/snowflake/client/AbstractDriverIT.java +++ b/src/test/java/net/snowflake/client/AbstractDriverIT.java @@ -324,6 +324,10 @@ public static Connection getConnection( properties.put("internal", Boolean.TRUE.toString()); // TODO: do we need this? properties.put("insecureMode", false); // use OCSP for all tests. + properties.put("useProxy", "true"); + properties.put("proxyHost", "localhost"); + properties.put("proxyPort", "8080"); + if (injectSocketTimeout > 0) { properties.put("injectSocketTimeout", String.valueOf(injectSocketTimeout)); } diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FixedFloeRandom.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FixedFloeRandom.java deleted file mode 100644 index 1a790d7ca..000000000 --- a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FixedFloeRandom.java +++ /dev/null @@ -1,17 +0,0 @@ -package net.snowflake.client.jdbc.cloud.storage.floe; - -public class FixedFloeRandom implements FloeRandom { - private final byte[] bytes; - - public FixedFloeRandom(byte[] bytes) { - this.bytes = bytes; - } - - @Override - public byte[] ofLength(int length) { - if (bytes.length != length) { - throw new IllegalArgumentException("allowed only " + bytes.length + " bytes"); - } - return bytes; - } -} diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java index b50f5d280..9ada35480 100644 --- a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java @@ -3,10 +3,16 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; import org.junit.jupiter.api.Test; class FloeEncryptorImplTest { + byte[] aad = "This is AAD".getBytes(StandardCharsets.UTF_8); + SecretKey secretKey = new SecretKeySpec(new byte[32], "FLOE"); + @Test void shouldCreateCorrectHeader() { FloeParameterSpec parameterSpec = @@ -15,7 +21,9 @@ void shouldCreateCorrectHeader() { Hash.SHA384, 12345678, new FloeIvLength(4), - new FixedFloeRandom(new byte[] {11, 22, 33, 44})); + new IncrementingFloeRandom(), + 4); + parameterSpec.getFloeRandom().ofLength(4); // just to trigger incrementation FloeKey floeKey = new FloeKey(new SecretKeySpec(new byte[32], "FLOE")); FloeAad floeAad = new FloeAad("test aad".getBytes(StandardCharsets.UTF_8)); FloeEncryptorImpl floeEncryptor = new FloeEncryptorImpl(parameterSpec, floeKey, floeAad); @@ -37,9 +45,51 @@ void shouldCreateCorrectHeader() { assertEquals(0, header[8]); assertEquals(4, header[9]); // FLOE IV - assertEquals(11, header[10]); - assertEquals(22, header[11]); - assertEquals(33, header[12]); - assertEquals(44, header[13]); + assertEquals(0, header[10]); + assertEquals(0, header[11]); + assertEquals(0, header[12]); + assertEquals(1, header[13]); + } + + @Test + void testEncryptionMatchesReference() { + List referenceCiphertextSegments = + Arrays.asList( + "ffffffff0000000100000000000000000100007f5713b9827bb806318311fcde197146a144c6b485", // pragma: allowlist secret + "ffffffff000000020000000000000000f926dfc0a0bac6263d1634ad9a72f86900872033a271a037", // pragma: allowlist secret + "ffffffff00000003000000000000000080df8fdee872febe574c2b8df0bb34b3fb25bfc5802703a2", // pragma: allowlist secret + "ffffffff000000040000000000000000f4d81083e57451dbfa538827942245019b8bc3354ecc31e0", // pragma: allowlist secret + "ffffffff000000050000000000000000d91b774b5b460bd665910114e155f1cbc55a9a262a54f65e", // pragma: allowlist secret + "ffffffff000000060000000000000000ec723f3807eb71ea42ff03f5420daf34e1a8f4fb58931db1", // pragma: allowlist secret + "ffffffff00000007000000000000000072960c06ec19ce94c27c9fc72d79164f187f37e86325d849", // pragma: allowlist secret + "ffffffff000000080000000000000000c00a40fb140d797da818ab57399cb986bddddd174b8d3d6a", // pragma: allowlist secret + "ffffffff000000090000000000000000065e959cd1ffa521896fb54949a57ad1c1f8291a531c6d60", // pragma: allowlist secret + "ffffffff0000000a0000000000000000dfde3da3f67a081fb31229ac11e43a629ed120fbf9942513" // pragma: allowlist secret + ); + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(), + 4); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + byte[] testData = new byte[8]; + for (int i = 0; i < referenceCiphertextSegments.size(); i++) { + byte[] ciphertextBytes = encryptor.processSegment(testData); + String ciphertextHex = toHex(ciphertextBytes); + assertEquals(referenceCiphertextSegments.get(i), ciphertextHex); + } + } + + private String toHex(byte[] input) { + StringBuilder result = new StringBuilder(); + for (byte b : input) { + result.append(String.format("%02x", b)); + } + return result.toString(); } } diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java index 8395f9f8d..2308f287b 100644 --- a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java @@ -1,66 +1,137 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; class FloeTest { - @Test - void validateHeaderMatchesForEncryptionAndDecryption() { - byte[] aad = "test aad".getBytes(StandardCharsets.UTF_8); - FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_128, Hash.SHA384, 1024, 4); - Floe floe = Floe.getInstance(parameterSpec); - SecretKey secretKey = new SecretKeySpec(new byte[16], "FLOE"); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] header = encryptor.getHeader(); - assertDoesNotThrow(() -> floe.createDecryptor(secretKey, aad, header)); + byte[] aad = "This is AAD".getBytes(StandardCharsets.UTF_8); + SecretKey secretKey = new SecretKeySpec(new byte[32], "FLOE"); + + @Nested + class HeaderTests { + @Test + void validateHeaderMatchesForEncryptionAndDecryption() { + FloeParameterSpec parameterSpec = + new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 4); + Floe floe = Floe.getInstance(parameterSpec); + + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + assertDoesNotThrow(() -> floe.createDecryptor(secretKey, aad, header)); + } + + @Test + void validateHeaderDoesNotMatchInParams() { + FloeParameterSpec parameterSpec = + new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 4); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + header[0] = 12; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid parameters header"); + } + + @Test + void validateHeaderDoesNotMatchInIV() { + FloeParameterSpec parameterSpec = + new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 4); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + header[11]++; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid header tag"); + } + + @Test + void validateHeaderDoesNotMatchInHeaderTag() { + FloeParameterSpec parameterSpec = + new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 4096, 4); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + header[header.length - 3]++; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid header tag"); + } } @Test - void validateHeaderDoesNotMatchInParams() { - byte[] aad = "test aad".getBytes(StandardCharsets.UTF_8); - FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_128, Hash.SHA384, 1024, 4); + void testSegmentEncryptedAndDecrypted() { + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(), + 4); Floe floe = Floe.getInstance(parameterSpec); - SecretKey secretKey = new SecretKeySpec(new byte[16], "FLOE"); FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); byte[] header = encryptor.getHeader(); - header[0] = 12; - IllegalArgumentException e = - assertThrows( - IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); - assertEquals(e.getMessage(), "invalid parameters header"); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); + byte[] testData = new byte[8]; + byte[] ciphertext = encryptor.processSegment(testData); + byte[] result = decryptor.processSegment(ciphertext); + assertArrayEquals(testData, result); } @Test - void validateHeaderDoesNotMatchInIV() { - byte[] aad = "test aad".getBytes(StandardCharsets.UTF_8); - FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_128, Hash.SHA384, 1024, 4); + void testSegmentEncryptedAndDecryptedWithRandomData() { + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(), + 4); Floe floe = Floe.getInstance(parameterSpec); - SecretKey secretKey = new SecretKeySpec(new byte[16], "FLOE"); FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); byte[] header = encryptor.getHeader(); - header[11]++; - IllegalArgumentException e = - assertThrows( - IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); - assertEquals(e.getMessage(), "invalid header tag"); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); + byte[] testData = new byte[8]; + new SecureRandom().nextBytes(testData); + byte[] ciphertext = encryptor.processSegment(testData); + byte[] result = decryptor.processSegment(ciphertext); + assertArrayEquals(testData, result); } @Test - void validateHeaderDoesNotMatchInHeaderTag() { - byte[] aad = "test aad".getBytes(StandardCharsets.UTF_8); - FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_128, Hash.SHA384, 1024, 4); + void testSegmentEncryptedAndDecryptedWithDerivedKeyRotation() { + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(), + 4); Floe floe = Floe.getInstance(parameterSpec); - SecretKey secretKey = new SecretKeySpec(new byte[16], "FLOE"); FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); byte[] header = encryptor.getHeader(); - header[header.length - 3]++; - IllegalArgumentException e = - assertThrows( - IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); - assertEquals(e.getMessage(), "invalid header tag"); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); + byte[] testData = new byte[8]; + for (int i = 0; i < 10; i++) { + byte[] ciphertext = encryptor.processSegment(testData); + byte[] result = decryptor.processSegment(ciphertext); + assertArrayEquals(testData, result); + } } } diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/IncrementingFloeRandom.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/IncrementingFloeRandom.java new file mode 100644 index 000000000..0d954a7fe --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/IncrementingFloeRandom.java @@ -0,0 +1,14 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; + +public class IncrementingFloeRandom implements FloeRandom { + private int seed; + + @Override + public byte[] ofLength(int length) { + ByteBuffer buffer = ByteBuffer.allocate(length); + buffer.putInt(seed++); + return buffer.array(); + } +} From 0d8d62425af30191ee5798341eedf481e3ca15ed Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Thu, 16 Jan 2025 13:23:31 +0100 Subject: [PATCH 3/5] Fix not thread safe usage of Gcm class, that caused that the FLOE itself couldn't have been used in parallel + review fixes --- .../cloud/storage/EncryptionProvider.java | 11 +++++---- .../cloud/storage/GcmEncryptionProvider.java | 24 +++++++++---------- .../client/jdbc/cloud/storage/floe/Aead.java | 11 +++++---- .../jdbc/cloud/storage/floe/AeadIv.java | 4 ++-- .../jdbc/cloud/storage/floe/FloeAad.java | 3 ++- .../cloud/storage/floe/FloeDecryptorImpl.java | 13 ++++++---- .../cloud/storage/floe/FloeEncryptorImpl.java | 9 ++++--- .../jdbc/cloud/storage/floe/FloeIvLength.java | 6 ++--- .../storage/floe/FloeEncryptorImplTest.java | 6 ++--- .../jdbc/cloud/storage/floe/FloeTest.java | 6 ++--- .../storage/floe/IncrementingFloeRandom.java | 4 ++++ 11 files changed, 56 insertions(+), 41 deletions(-) diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/EncryptionProvider.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/EncryptionProvider.java index 5f6ddd801..67faec889 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/EncryptionProvider.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/EncryptionProvider.java @@ -41,7 +41,8 @@ public class EncryptionProvider { private static final String FILE_CIPHER = "AES/CBC/PKCS5Padding"; private static final String KEY_CIPHER = "AES/ECB/PKCS5Padding"; private static final int BUFFER_SIZE = 2 * 1024 * 1024; // 2 MB - private static ThreadLocal secRnd = ThreadLocal.withInitial(SecureRandom::new); + private static final ThreadLocal SEC_RND = + ThreadLocal.withInitial(SecureRandom::new); /** * Decrypt a InputStream @@ -69,7 +70,7 @@ public static InputStream decryptStream( byte[] kekBytes = Base64.getDecoder().decode(encMat.getQueryStageMasterKey()); byte[] keyBytes = Base64.getDecoder().decode(keyBase64); byte[] ivBytes = Base64.getDecoder().decode(ivBase64); - SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES); + SecretKey kek = new SecretKeySpec(kekBytes, AES); Cipher keyCipher = Cipher.getInstance(KEY_CIPHER); keyCipher.init(Cipher.DECRYPT_MODE, kek); byte[] fileKeyBytes = keyCipher.doFinal(keyBytes); @@ -97,7 +98,7 @@ public static void decrypt( // Decrypt file key { final Cipher keyCipher = Cipher.getInstance(KEY_CIPHER); - SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES); + SecretKey kek = new SecretKeySpec(kekBytes, AES); keyCipher.init(Cipher.DECRYPT_MODE, kek); byte[] fileKeyBytes = keyCipher.doFinal(keyBytes); @@ -165,11 +166,11 @@ public static CipherInputStream encrypt( // Create IV ivData = new byte[blockSize]; - secRnd.get().nextBytes(ivData); + SEC_RND.get().nextBytes(ivData); final IvParameterSpec iv = new IvParameterSpec(ivData); // Create file key - secRnd.get().nextBytes(fileKeyBytes); + SEC_RND.get().nextBytes(fileKeyBytes); SecretKey fileKey = new SecretKeySpec(fileKeyBytes, 0, keySize, AES); // Init cipher diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java index 37475083f..6859b609b 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java @@ -26,18 +26,18 @@ import javax.crypto.SecretKey; import javax.crypto.spec.GCMParameterSpec; import javax.crypto.spec.SecretKeySpec; -import net.snowflake.client.core.SnowflakeJdbcInternalApi; import net.snowflake.client.jdbc.MatDesc; import net.snowflake.common.core.RemoteStoreFileEncryptionMaterial; -@SnowflakeJdbcInternalApi -public class GcmEncryptionProvider { +class GcmEncryptionProvider { private static final int TAG_LENGTH_IN_BITS = 128; private static final int IV_LENGTH_IN_BYTES = 12; private static final String AES = "AES"; + private static final String FILE_CIPHER = "AES/GCM/NoPadding"; + private static final String KEY_CIPHER = "AES/GCM/NoPadding"; private static final int BUFFER_SIZE = 8 * 1024 * 1024; // 2 MB private static final ThreadLocal random = - ThreadLocal.withInitial(SecureRandom::new); + new ThreadLocal<>().withInitial(SecureRandom::new); private static final Base64.Decoder base64Decoder = Base64.getDecoder(); static InputStream encrypt( @@ -83,9 +83,9 @@ private static void initRandomIvsAndFileKey( private static byte[] encryptKey(byte[] kekBytes, byte[] keyBytes, byte[] keyIvData, byte[] aad) throws InvalidKeyException, InvalidAlgorithmParameterException, IllegalBlockSizeException, BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException { - SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES); + SecretKey kek = new SecretKeySpec(kekBytes, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, keyIvData); - Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME); + Cipher keyCipher = Cipher.getInstance(KEY_CIPHER); keyCipher.init(Cipher.ENCRYPT_MODE, kek, gcmParameterSpec); if (aad != null) { keyCipher.updateAAD(aad); @@ -97,9 +97,9 @@ private static CipherInputStream encryptContent( InputStream src, byte[] keyBytes, byte[] dataIvBytes, byte[] aad) throws InvalidKeyException, InvalidAlgorithmParameterException, NoSuchPaddingException, NoSuchAlgorithmException { - SecretKey fileKey = new SecretKeySpec(keyBytes, 0, keyBytes.length, AES); + SecretKey fileKey = new SecretKeySpec(keyBytes, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, dataIvBytes); - Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME); + Cipher fileCipher = Cipher.getInstance(FILE_CIPHER); fileCipher.init(Cipher.ENCRYPT_MODE, fileKey, gcmParameterSpec); if (aad != null) { fileCipher.updateAAD(aad); @@ -172,7 +172,7 @@ private static CipherInputStream decryptContentFromStream( NoSuchAlgorithmException { GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes); SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES); - Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME); + Cipher fileCipher = Cipher.getInstance(FILE_CIPHER); fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec); if (aad != null) { fileCipher.updateAAD(aad); @@ -187,7 +187,7 @@ private static void decryptContentFromFile( SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, cekIvBytes); byte[] buffer = new byte[BUFFER_SIZE]; - Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME); + Cipher fileCipher = Cipher.getInstance(FILE_CIPHER); fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec); if (aad != null) { fileCipher.updateAAD(aad); @@ -213,9 +213,9 @@ private static void decryptContentFromFile( private static byte[] decryptKey(byte[] kekBytes, byte[] ivBytes, byte[] keyBytes, byte[] aad) throws InvalidKeyException, InvalidAlgorithmParameterException, IllegalBlockSizeException, BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException { - SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES); + SecretKey kek = new SecretKeySpec(kekBytes, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes); - Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME); + Cipher keyCipher = Cipher.getInstance(KEY_CIPHER); keyCipher.init(Cipher.DECRYPT_MODE, kek, gcmParameterSpec); if (aad != null) { keyCipher.updateAAD(aad); diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java index 04562d47b..0f823c133 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java @@ -1,18 +1,19 @@ package net.snowflake.client.jdbc.cloud.storage.floe; +import java.util.function.Supplier; import net.snowflake.client.jdbc.cloud.storage.floe.aead.Gcm; public enum Aead { // TODO confirm id - AES_GCM_256((byte) 0, "AES/GCM/NoPadding", 32, 12, 16, new Gcm(16)), - AES_GCM_128((byte) 1, "AES/GCM/NoPadding", 16, 12, 16, new Gcm(16)); + AES_GCM_256((byte) 0, "AES/GCM/NoPadding", 32, 12, 16, () -> new Gcm(16)), + AES_GCM_128((byte) 1, "AES/GCM/NoPadding", 16, 12, 16, () -> new Gcm(16)); private byte id; private String jceName; private int keyLength; private int ivLength; private int authTagLength; - private AeadProvider aeadProvider; + private Supplier aeadProvider; Aead( byte id, @@ -20,7 +21,7 @@ public enum Aead { int keyLength, int ivLength, int authTagLength, - AeadProvider aeadProvider) { + Supplier aeadProvider) { this.jceName = jceName; this.keyLength = keyLength; this.id = id; @@ -50,6 +51,6 @@ int getAuthTagLength() { } AeadProvider getAeadProvider() { - return aeadProvider; + return aeadProvider.get(); } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadIv.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadIv.java index c2a559b47..471fa7204 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadIv.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadIv.java @@ -9,11 +9,11 @@ class AeadIv { this.bytes = bytes; } - public static AeadIv generateRandom(FloeRandom floeRandom, int ivLength) { + static AeadIv generateRandom(FloeRandom floeRandom, int ivLength) { return new AeadIv(floeRandom.ofLength(ivLength)); } - public static AeadIv from(ByteBuffer buffer, int ivLength) { + static AeadIv from(ByteBuffer buffer, int ivLength) { byte[] bytes = new byte[ivLength]; buffer.get(bytes); return new AeadIv(bytes); diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeAad.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeAad.java index 3c24eb136..f135d9b68 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeAad.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeAad.java @@ -3,10 +3,11 @@ import java.util.Optional; class FloeAad { + private static final byte[] EMPTY_AAD = new byte[0]; private final byte[] aad; FloeAad(byte[] aad) { - this.aad = Optional.ofNullable(aad).orElse(new byte[0]); + this.aad = Optional.ofNullable(aad).orElse(EMPTY_AAD); } byte[] getBytes() { diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java index 9fc91d2ef..aa89a0fe9 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java @@ -6,7 +6,7 @@ public class FloeDecryptorImpl extends BaseSegmentProcessor implements FloeDecryptor { private final FloeIv floeIv; - private long segmentCounter; + private final AeadProvider aeadProvider; FloeDecryptorImpl( FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad, byte[] floeHeaderAsBytes) { @@ -29,6 +29,7 @@ public class FloeDecryptorImpl extends BaseSegmentProcessor implements FloeDecry byte[] floeIvBytes = new byte[this.parameterSpec.getFloeIvLength().getLength()]; floeHeader.get(floeIvBytes, 0, floeIvBytes.length); this.floeIv = new FloeIv(floeIvBytes); + this.aeadProvider = parameterSpec.getAead().getAeadProvider(); byte[] headerTagFromHeader = new byte[headerTagLength]; floeHeader.get(headerTagFromHeader, 0, headerTagFromHeader.length); @@ -41,6 +42,8 @@ public class FloeDecryptorImpl extends BaseSegmentProcessor implements FloeDecry } } + private long segmentCounter; + @Override public byte[] processSegment(byte[] input) { try { @@ -49,12 +52,14 @@ public byte[] processSegment(byte[] input) { verifySegmentSizeMarker(inputBuf); AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); AeadIv aeadIv = AeadIv.from(inputBuf, parameterSpec.getAead().getIvLength()); - AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter++); + AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter); AeadProvider aeadProvider = parameterSpec.getAead().getAeadProvider(); byte[] ciphertext = new byte[inputBuf.remaining()]; inputBuf.get(ciphertext); - return aeadProvider.decrypt( - aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), ciphertext); + byte[] decrypted = + aeadProvider.decrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), ciphertext); + segmentCounter++; + return decrypted; } catch (GeneralSecurityException e) { throw new RuntimeException(e); } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java index 93479e789..41b65909b 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java @@ -6,6 +6,7 @@ class FloeEncryptorImpl extends BaseSegmentProcessor implements FloeEncryptor { private final FloeIv floeIv; + private final AeadProvider aeadProvider; private AeadKey currentAeadKey; private long segmentCounter; @@ -16,6 +17,7 @@ class FloeEncryptorImpl extends BaseSegmentProcessor implements FloeEncryptor { super(parameterSpec, floeKey, floeAad); this.floeIv = FloeIv.generateRandom(parameterSpec.getFloeRandom(), parameterSpec.getFloeIvLength()); + this.aeadProvider = parameterSpec.getAead().getAeadProvider(); this.header = buildHeader(); } @@ -51,12 +53,13 @@ public byte[] processSegment(byte[] input) { AeadIv aeadIv = AeadIv.generateRandom( parameterSpec.getFloeRandom(), parameterSpec.getAead().getIvLength()); - AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter++); - AeadProvider aeadProvider = parameterSpec.getAead().getAeadProvider(); + AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter); // it works as long as AEAD returns auth tag as a part of the ciphertext byte[] ciphertextWithAuthTag = aeadProvider.encrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), input); - return segmentToBytes(aeadIv, ciphertextWithAuthTag); + byte[] encoded = segmentToBytes(aeadIv, ciphertextWithAuthTag); + segmentCounter++; + return encoded; } catch (GeneralSecurityException e) { throw new RuntimeException(e); } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIvLength.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIvLength.java index a0cf8f05d..466005265 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIvLength.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIvLength.java @@ -1,13 +1,13 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -public class FloeIvLength { +class FloeIvLength { private final int length; - public FloeIvLength(int length) { + FloeIvLength(int length) { this.length = length; } - public int getLength() { + int getLength() { return length; } } diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java index 9ada35480..592eb1920 100644 --- a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java @@ -21,7 +21,7 @@ void shouldCreateCorrectHeader() { Hash.SHA384, 12345678, new FloeIvLength(4), - new IncrementingFloeRandom(), + new IncrementingFloeRandom(17), 4); parameterSpec.getFloeRandom().ofLength(4); // just to trigger incrementation FloeKey floeKey = new FloeKey(new SecretKeySpec(new byte[32], "FLOE")); @@ -48,7 +48,7 @@ void shouldCreateCorrectHeader() { assertEquals(0, header[10]); assertEquals(0, header[11]); assertEquals(0, header[12]); - assertEquals(1, header[13]); + assertEquals(18, header[13]); } @Test @@ -72,7 +72,7 @@ void testEncryptionMatchesReference() { Hash.SHA384, 40, new FloeIvLength(32), - new IncrementingFloeRandom(), + new IncrementingFloeRandom(0), 4); Floe floe = Floe.getInstance(parameterSpec); FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java index 2308f287b..d521866b6 100644 --- a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java @@ -80,7 +80,7 @@ void testSegmentEncryptedAndDecrypted() { Hash.SHA384, 40, new FloeIvLength(32), - new IncrementingFloeRandom(), + new IncrementingFloeRandom(678765), 4); Floe floe = Floe.getInstance(parameterSpec); FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); @@ -100,7 +100,7 @@ void testSegmentEncryptedAndDecryptedWithRandomData() { Hash.SHA384, 40, new FloeIvLength(32), - new IncrementingFloeRandom(), + new IncrementingFloeRandom(37665), 4); Floe floe = Floe.getInstance(parameterSpec); FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); @@ -121,7 +121,7 @@ void testSegmentEncryptedAndDecryptedWithDerivedKeyRotation() { Hash.SHA384, 40, new FloeIvLength(32), - new IncrementingFloeRandom(), + new IncrementingFloeRandom(6546), 4); Floe floe = Floe.getInstance(parameterSpec); FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/IncrementingFloeRandom.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/IncrementingFloeRandom.java index 0d954a7fe..fb5a152d1 100644 --- a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/IncrementingFloeRandom.java +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/IncrementingFloeRandom.java @@ -5,6 +5,10 @@ public class IncrementingFloeRandom implements FloeRandom { private int seed; + public IncrementingFloeRandom(int seed) { + this.seed = seed; + } + @Override public byte[] ofLength(int length) { ByteBuffer buffer = ByteBuffer.allocate(length); From 45d194d726781c3e351f0b3e408588c9741e52ed Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Tue, 21 Jan 2025 10:07:13 +0100 Subject: [PATCH 4/5] Add last segment processing + test tests for edge cases --- .../client/jdbc/cloud/storage/floe/Aead.java | 7 +- .../jdbc/cloud/storage/floe/AeadAad.java | 4 + .../storage/floe/BaseSegmentProcessor.java | 1 + .../cloud/storage/floe/FloeDecryptorImpl.java | 43 ++++- .../cloud/storage/floe/FloeEncryptorImpl.java | 56 +++++- .../cloud/storage/floe/FloeParameterSpec.java | 18 +- .../cloud/storage/floe/SegmentProcessor.java | 1 + .../jdbc/cloud/storage/floe/aead/Gcm.java | 12 +- .../storage/floe/FloeDecryptorImplTest.java | 120 ++++++++++++ .../storage/floe/FloeEncryptorImplTest.java | 75 +++++++- .../jdbc/cloud/storage/floe/FloeTest.java | 179 ++++++++++++------ 11 files changed, 423 insertions(+), 93 deletions(-) create mode 100644 src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImplTest.java diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java index 0f823c133..cc75fb1d1 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java @@ -1,12 +1,11 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -import java.util.function.Supplier; import net.snowflake.client.jdbc.cloud.storage.floe.aead.Gcm; +import java.util.function.Supplier; + public enum Aead { - // TODO confirm id - AES_GCM_256((byte) 0, "AES/GCM/NoPadding", 32, 12, 16, () -> new Gcm(16)), - AES_GCM_128((byte) 1, "AES/GCM/NoPadding", 16, 12, 16, () -> new Gcm(16)); + AES_GCM_256((byte) 0, "AES/GCM/NoPadding", 32, 12, 16, () -> new Gcm(16)); private byte id; private String jceName; diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadAad.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadAad.java index 6586c28bd..36bf52b4c 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadAad.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadAad.java @@ -16,6 +16,10 @@ static AeadAad nonTerminal(long segmentCounter) { return new AeadAad(segmentCounter, (byte) 0); } + static AeadAad terminal(long segmentCounter) { + return new AeadAad(segmentCounter, (byte) 1); + } + byte[] getBytes() { return bytes; } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java index 63453e18c..65fda1351 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java @@ -24,6 +24,7 @@ abstract class BaseSegmentProcessor { protected AeadKey getKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) { if (currentAeadKey == null || segmentCounter % parameterSpec.getKeyRotationModulo() == 0) { + // TODO should we mask segments here? currentAeadKey = deriveKey(floeKey, floeIv, floeAad, segmentCounter); } return currentAeadKey; diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java index aa89a0fe9..dd8d07116 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java @@ -4,7 +4,7 @@ import java.security.GeneralSecurityException; import java.util.Arrays; -public class FloeDecryptorImpl extends BaseSegmentProcessor implements FloeDecryptor { +class FloeDecryptorImpl extends BaseSegmentProcessor implements FloeDecryptor { private final FloeIv floeIv; private final AeadProvider aeadProvider; @@ -53,7 +53,6 @@ public byte[] processSegment(byte[] input) { AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); AeadIv aeadIv = AeadIv.from(inputBuf, parameterSpec.getAead().getIvLength()); AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter); - AeadProvider aeadProvider = parameterSpec.getAead().getAeadProvider(); byte[] ciphertext = new byte[inputBuf.remaining()]; inputBuf.get(ciphertext); byte[] decrypted = @@ -77,10 +76,46 @@ private void verifySegmentLength(byte[] input) { private void verifySegmentSizeMarker(ByteBuffer inputBuf) { int segmentSizeMarker = inputBuf.getInt(); if (segmentSizeMarker != NON_TERMINAL_SEGMENT_SIZE_MARKER) { - throw new IllegalStateException( + throw new IllegalArgumentException( String.format( - "segment length marker mismatch, expected: %d, got :%d", + "segment length marker mismatch, expected: %d, got: %d", NON_TERMINAL_SEGMENT_SIZE_MARKER, segmentSizeMarker)); } } + + @Override + public byte[] processLastSegment(byte[] input) { + verifyLastSegmentLength(input); + ByteBuffer inputBuf = ByteBuffer.wrap(input); + verifyLastSegmentSizeMarker(inputBuf); + try { + AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); + AeadIv aeadIv = AeadIv.from(inputBuf, parameterSpec.getAead().getIvLength()); + AeadAad aeadAad = AeadAad.terminal(segmentCounter); + byte[] ciphertext = new byte[inputBuf.remaining()]; + inputBuf.get(ciphertext); + byte[] decrypted = + aeadProvider.decrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), ciphertext); + return decrypted; + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + private void verifyLastSegmentLength(byte[] input) { + // TODO <= ? + if (input.length < 4 + parameterSpec.getAead().getIvLength() + parameterSpec.getAead().getAuthTagLength()) { + throw new IllegalArgumentException("last segment is too short"); + } + if (input.length > parameterSpec.getEncryptedSegmentLength()) { + throw new IllegalArgumentException("last segment is too long"); + } + } + + private void verifyLastSegmentSizeMarker(ByteBuffer inputBuf) { + int segmentLengthFromSegment = inputBuf.getInt(); + if (segmentLengthFromSegment != inputBuf.capacity()) { + throw new IllegalArgumentException(String.format("last segment length marker mismatch, expected: %d, got: %d", inputBuf.capacity(), segmentLengthFromSegment)); + } + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java index 41b65909b..cf35ec994 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java @@ -4,7 +4,6 @@ import java.security.GeneralSecurityException; class FloeEncryptorImpl extends BaseSegmentProcessor implements FloeEncryptor { - private final FloeIv floeIv; private final AeadProvider aeadProvider; private AeadKey currentAeadKey; @@ -48,6 +47,7 @@ public byte[] getHeader() { public byte[] processSegment(byte[] input) { verifySegmentLength(input); // TODO assert State.Counter != 2^40-1 # Prevent overflow + verifyMaxSegmentNumberNotReached(); try { AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); AeadIv aeadIv = @@ -65,6 +65,21 @@ public byte[] processSegment(byte[] input) { } } + private void verifySegmentLength(byte[] input) { + if (input.length != parameterSpec.getPlainTextSegmentLength()) { + throw new IllegalArgumentException( + String.format( + "segment length mismatch, expected %d, got %d", + parameterSpec.getPlainTextSegmentLength(), input.length)); + } + } + + private void verifyMaxSegmentNumberNotReached() { + if (segmentCounter >= parameterSpec.getMaxSegmentNumber() - 1) { + throw new IllegalStateException("maximum segment number reached"); + } + } + private byte[] segmentToBytes(AeadIv aeadIv, byte[] ciphertextWithAuthTag) { ByteBuffer output = ByteBuffer.allocate(parameterSpec.getEncryptedSegmentLength()); output.putInt(NON_TERMINAL_SEGMENT_SIZE_MARKER); @@ -73,12 +88,39 @@ private byte[] segmentToBytes(AeadIv aeadIv, byte[] ciphertextWithAuthTag) { return output.array(); } - private void verifySegmentLength(byte[] input) { - if (input.length != parameterSpec.getPlainTextSegmentLength()) { - throw new IllegalArgumentException( - String.format( - "segment length mismatch, expected %d, got %d", - parameterSpec.getPlainTextSegmentLength(), input.length)); + @Override + public byte[] processLastSegment(byte[] input) { + verifyLastSegmentNotEmpty(input); + try { + AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); + AeadIv aeadIv = + AeadIv.generateRandom( + parameterSpec.getFloeRandom(), parameterSpec.getAead().getIvLength()); + AeadAad aeadAad = AeadAad.terminal(segmentCounter); + byte[] ciphertextWithAuthTag = + aeadProvider.encrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), input); + return lastSegmentToBytes(aeadIv, ciphertextWithAuthTag); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + private byte[] lastSegmentToBytes(AeadIv aeadIv, byte[] ciphertextWithAuthTag) { + int lastSegmentLength = 4 + aeadIv.getBytes().length + ciphertextWithAuthTag.length; + ByteBuffer output = ByteBuffer.allocate(lastSegmentLength); + output.putInt(lastSegmentLength); + output.put(aeadIv.getBytes()); + output.put(ciphertextWithAuthTag); + return output.array(); + } + + private void verifyLastSegmentNotEmpty(byte[] input) { + // TODO +// if (input.length == 0) { +// throw new IllegalArgumentException("last segment is empty"); +// } + if (input.length > parameterSpec.getPlainTextSegmentLength()) { + throw new IllegalArgumentException(String.format("last segment is too long, got %d, max is %d", input.length, parameterSpec.getPlainTextSegmentLength())); } } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java index 65e937f8c..bb79314d4 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java @@ -10,6 +10,7 @@ public class FloeParameterSpec { private final FloeIvLength floeIvLength; private final FloeRandom floeRandom; private final int keyRotationModulo; + private final long maxSegmentNumber; public FloeParameterSpec(Aead aead, Hash hash, int encryptedSegmentLength, int floeIvLength) { this( @@ -18,7 +19,8 @@ public FloeParameterSpec(Aead aead, Hash hash, int encryptedSegmentLength, int f encryptedSegmentLength, new FloeIvLength(floeIvLength), new SecureFloeRandom(), - 1 << 20); + 1 << 20, + 1L << 40); } FloeParameterSpec( @@ -27,13 +29,21 @@ public FloeParameterSpec(Aead aead, Hash hash, int encryptedSegmentLength, int f int encryptedSegmentLength, FloeIvLength floeIvLength, FloeRandom floeRandom, - int keyRotationModulo) { + int keyRotationModulo, + long maxSegmentNumber) { this.aead = aead; this.hash = hash; this.encryptedSegmentLength = encryptedSegmentLength; this.floeIvLength = floeIvLength; this.floeRandom = floeRandom; this.keyRotationModulo = keyRotationModulo; + this.maxSegmentNumber = maxSegmentNumber; + if (encryptedSegmentLength <= 0) { + throw new IllegalArgumentException("encryptedSegmentLength must be > 0"); + } + if (floeIvLength.getLength() <= 0) { + throw new IllegalArgumentException("floeIvLength must be > 0"); + } } byte[] paramEncode() { @@ -73,4 +83,8 @@ int getPlainTextSegmentLength() { int getKeyRotationModulo() { return keyRotationModulo; } + + long getMaxSegmentNumber() { + return maxSegmentNumber; + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java index 45e4f2872..8c1b90cd5 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java @@ -2,4 +2,5 @@ interface SegmentProcessor { byte[] processSegment(byte[] input); + byte[] processLastSegment(byte[] input); } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java index ec5662f11..b9b44813e 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java @@ -1,18 +1,20 @@ package net.snowflake.client.jdbc.cloud.storage.floe.aead; -import java.security.GeneralSecurityException; -import java.security.InvalidAlgorithmParameterException; -import java.security.InvalidKeyException; -import java.security.NoSuchAlgorithmException; +import net.snowflake.client.jdbc.cloud.storage.floe.AeadProvider; + import javax.crypto.BadPaddingException; import javax.crypto.Cipher; import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; import javax.crypto.SecretKey; import javax.crypto.spec.GCMParameterSpec; -import net.snowflake.client.jdbc.cloud.storage.floe.AeadProvider; +import java.security.GeneralSecurityException; +import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; // This class is not thread safe! +// But as long as it is used only for FLOE, it is fine, as FLOE instance keeps its own instance of GCM. public class Gcm implements AeadProvider { private final Cipher keyCipher; private final int tagLengthInBits; diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImplTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImplTest.java new file mode 100644 index 000000000..6488b9faa --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImplTest.java @@ -0,0 +1,120 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import org.junit.jupiter.api.Test; + +import javax.crypto.AEADBadTagException; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.nio.charset.StandardCharsets; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class FloeDecryptorImplTest { + private final SecretKey secretKey = new SecretKeySpec(new byte[32], "AES"); + private final byte[] aad = "Test AAD".getBytes(StandardCharsets.UTF_8); + + @Test + void shouldDecryptCiphertext() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] firstSegment = encryptor.processSegment(new byte[8]); + byte[] lastSegment = encryptor.processLastSegment(new byte[4]); + + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); + assertArrayEquals(new byte[8], decryptor.processSegment(firstSegment)); + assertArrayEquals(new byte[4], decryptor.processLastSegment(lastSegment)); + } + + @Test + void shouldDecryptLastSegmentZeroLength() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] lastSegment = encryptor.processLastSegment(new byte[0]); + + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); + assertArrayEquals(new byte[0], decryptor.processLastSegment(lastSegment)); + } + + @Test + void shouldDecryptLastSegmentFullLength() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] lastSegment = encryptor.processLastSegment(new byte[8]); + + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); + assertArrayEquals(new byte[8], decryptor.processLastSegment(lastSegment)); + } + + @Test + void shouldThrowExceptionIfSegmentLengthIsMismatched() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> decryptor.processSegment(new byte[12])); + assertEquals("segment length mismatch, expected 40, got 12", e.getMessage()); + e = assertThrows(IllegalArgumentException.class, () -> decryptor.processSegment(new byte[1024])); + assertEquals("segment length mismatch, expected 40, got 1024", e.getMessage()); + } + + @Test + void shouldThrowExceptionIfLastSegmentLengthIsMismatched() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> decryptor.processLastSegment(new byte[12])); + assertEquals("last segment is too short", e.getMessage()); + e = assertThrows(IllegalArgumentException.class, () -> decryptor.processLastSegment(new byte[1024])); + assertEquals("last segment is too long", e.getMessage()); + } + + @Test + void shouldThrowExceptionIfSegmentLengthInSegmentIsNotMinusOne() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> decryptor.processSegment(new byte[40])); + assertEquals("segment length marker mismatch, expected: -1, got: 0", e.getMessage()); + } + + @Test + void shouldThrowExceptionIfLastSegmentLengthInSegmentIsNotMinusOne() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> decryptor.processLastSegment(new byte[40])); + assertEquals("last segment length marker mismatch, expected: 40, got: 0", e.getMessage()); + } + + @Test + void shouldThrowExceptionIfSegmentIsTampered() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] ciphertext = encryptor.processSegment(new byte[8]); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); + ciphertext[39]++; + RuntimeException e = assertThrows(RuntimeException.class, () -> decryptor.processSegment(ciphertext)); + assertEquals(e.getCause().getClass(), AEADBadTagException.class); + } + + @Test + void shouldThrowExceptionIfSegmentAreOutOfOrder() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] ciphertext1 = encryptor.processSegment(new byte[8]); + byte[] ciphertext2 = encryptor.processSegment(new byte[8]); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); + RuntimeException e = assertThrows(RuntimeException.class, () -> decryptor.processSegment(ciphertext2)); + assertEquals(e.getCause().getClass(), AEADBadTagException.class); + } +} \ No newline at end of file diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java index 592eb1920..2d74b40bb 100644 --- a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java @@ -1,13 +1,17 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -import static org.junit.jupiter.api.Assertions.assertEquals; +import org.junit.jupiter.api.Test; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.List; -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; -import org.junit.jupiter.api.Test; + +import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; class FloeEncryptorImplTest { byte[] aad = "This is AAD".getBytes(StandardCharsets.UTF_8); @@ -22,7 +26,7 @@ void shouldCreateCorrectHeader() { 12345678, new FloeIvLength(4), new IncrementingFloeRandom(17), - 4); + 4, 1L << 40); parameterSpec.getFloeRandom().ofLength(4); // just to trigger incrementation FloeKey floeKey = new FloeKey(new SecretKeySpec(new byte[32], "FLOE")); FloeAad floeAad = new FloeAad("test aad".getBytes(StandardCharsets.UTF_8)); @@ -73,19 +77,72 @@ void testEncryptionMatchesReference() { 40, new FloeIvLength(32), new IncrementingFloeRandom(0), - 4); + 4, 1L << 40); Floe floe = Floe.getInstance(parameterSpec); FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); byte[] header = encryptor.getHeader(); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); byte[] testData = new byte[8]; - for (int i = 0; i < referenceCiphertextSegments.size(); i++) { + for (String referenceCiphertextSegment : referenceCiphertextSegments) { byte[] ciphertextBytes = encryptor.processSegment(testData); String ciphertextHex = toHex(ciphertextBytes); - assertEquals(referenceCiphertextSegments.get(i), ciphertextHex); + assertEquals(referenceCiphertextSegment, ciphertextHex); + byte[] plaintextBytes = decryptor.processSegment(ciphertextBytes); + assertArrayEquals(testData, plaintextBytes); } } - private String toHex(byte[] input) { + @Test + void shouldThrowExceptionOnMaxSegmentReached() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, new FloeIvLength(32), new SecureFloeRandom(), 20, 3); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] plaintext = new byte[8]; + encryptor.processSegment(plaintext); + encryptor.processSegment(plaintext); + assertThrows(IllegalStateException.class, () -> encryptor.processSegment(plaintext)); + assertDoesNotThrow(() -> encryptor.processLastSegment(plaintext)); + } + + @Test + void shouldThrowExceptionIfPlaintextIsTooShort() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> encryptor.processSegment(new byte[0])); + assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 0"); + e = assertThrows(IllegalArgumentException.class, () -> encryptor.processSegment(new byte[7])); + assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 7"); + } + + @Test + void shouldThrowEncryptionIfPlaintextIsTooLong() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> encryptor.processSegment(new byte[9])); + assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 9"); + e = assertThrows(IllegalArgumentException.class, () -> encryptor.processSegment(new byte[1024])); + assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 1024"); + + e = assertThrows(IllegalArgumentException.class, () -> encryptor.processLastSegment(new byte[9])); + assertEquals(e.getMessage(), "last segment is too long, got 9, max is 8"); + } + + @Test + void shouldAcceptSegmentWithCorrectSize() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + + assertDoesNotThrow(() -> encryptor.processSegment(new byte[8])); + assertDoesNotThrow(() -> encryptor.processLastSegment(new byte[8])); + assertDoesNotThrow(() -> encryptor.processLastSegment(new byte[0])); + } + + String toHex(byte[] input) { StringBuilder result = new StringBuilder(); for (byte b : input) { result.append(String.format("%02x", b)); diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java index d521866b6..f148ac4e3 100644 --- a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java @@ -1,17 +1,19 @@ package net.snowflake.client.jdbc.cloud.storage.floe; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; + +import static com.amazonaws.util.BinaryUtils.toHex; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; -import java.nio.charset.StandardCharsets; -import java.security.SecureRandom; -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; -import org.junit.jupiter.api.Nested; -import org.junit.jupiter.api.Test; - class FloeTest { byte[] aad = "This is AAD".getBytes(StandardCharsets.UTF_8); SecretKey secretKey = new SecretKeySpec(new byte[32], "FLOE"); @@ -72,66 +74,119 @@ void validateHeaderDoesNotMatchInHeaderTag() { } } - @Test - void testSegmentEncryptedAndDecrypted() { - FloeParameterSpec parameterSpec = - new FloeParameterSpec( - Aead.AES_GCM_256, - Hash.SHA384, - 40, - new FloeIvLength(32), - new IncrementingFloeRandom(678765), - 4); - Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] header = encryptor.getHeader(); - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); - byte[] testData = new byte[8]; - byte[] ciphertext = encryptor.processSegment(testData); - byte[] result = decryptor.processSegment(ciphertext); - assertArrayEquals(testData, result); - } + @Nested + class SegmentTests { - @Test - void testSegmentEncryptedAndDecryptedWithRandomData() { - FloeParameterSpec parameterSpec = - new FloeParameterSpec( - Aead.AES_GCM_256, - Hash.SHA384, - 40, - new FloeIvLength(32), - new IncrementingFloeRandom(37665), - 4); - Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] header = encryptor.getHeader(); - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); - byte[] testData = new byte[8]; - new SecureRandom().nextBytes(testData); - byte[] ciphertext = encryptor.processSegment(testData); - byte[] result = decryptor.processSegment(ciphertext); - assertArrayEquals(testData, result); - } + @Test + void testSegmentEncryptedAndDecrypted() { + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(678765), + 4, 1L << 40); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); + byte[] testData = new byte[8]; + byte[] ciphertext = encryptor.processSegment(testData); + byte[] result = decryptor.processSegment(ciphertext); + assertArrayEquals(testData, result); + } - @Test - void testSegmentEncryptedAndDecryptedWithDerivedKeyRotation() { - FloeParameterSpec parameterSpec = - new FloeParameterSpec( - Aead.AES_GCM_256, - Hash.SHA384, - 40, - new FloeIvLength(32), - new IncrementingFloeRandom(6546), - 4); - Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] header = encryptor.getHeader(); - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); - byte[] testData = new byte[8]; - for (int i = 0; i < 10; i++) { + @Test + void testSegmentEncryptedAndDecryptedWithRandomData() { + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(37665), + 4, 1L << 40); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); + byte[] testData = new byte[8]; + new SecureRandom().nextBytes(testData); byte[] ciphertext = encryptor.processSegment(testData); byte[] result = decryptor.processSegment(ciphertext); assertArrayEquals(testData, result); } + + @Test + void testSegmentEncryptedAndDecryptedWithDerivedKeyRotation() { + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(6546), + 4, 1L << 40); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] header = encryptor.getHeader(); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); + byte[] testData = new byte[8]; + for (int i = 0; i < 10; i++) { + byte[] ciphertext = encryptor.processSegment(testData); + byte[] result = decryptor.processSegment(ciphertext); + assertArrayEquals(testData, result); + } + } + } + + @Nested + class LastSegmentTests { + @Test + void testLastSegmentEncryptedAndDecrypted() { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 32); + Floe floe = Floe.getInstance(parameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + byte[] plaintext = new byte[3]; + byte[] encrypted = encryptor.processLastSegment(plaintext); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); + byte[] decrypted = decryptor.processLastSegment(encrypted); + assertArrayEquals(plaintext, decrypted); + } + + @Test + void testDecryptLastSegmentWithReferenceDataWithEmptyLastSegment() { + FloeParameterSpec floeParameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, new FloeIvLength(32), new IncrementingFloeRandom(0), 16, 1L << 40); + Floe floe = Floe.getInstance(floeParameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(new SecretKeySpec(new byte[16], "FLOE"), new byte[0]); + byte[] plaintext = new byte[8]; + byte[] encryptedFirstSegment = encryptor.processSegment(plaintext); + byte[] encryptedLastSegment = encryptor.processLastSegment(new byte[0]); + + assertEquals(toHex(encryptedFirstSegment), "ffffffff0000000100000000000000002dd631464f6a583369b74f546adfa4db9a838732d6338ef4"); // pragma: allowlist secret + assertEquals(toHex(encryptedLastSegment), "000000200000000200000000000000004a4082e6b94a8b1b2053f40879402df1"); // pragma: allowlist secret + + FloeDecryptor decryptor = floe.createDecryptor(secretKey, new byte[0], encryptor.getHeader()); + assertArrayEquals(plaintext, decryptor.processSegment(encryptedFirstSegment)); + assertArrayEquals(new byte[0], decryptor.processLastSegment(encryptedLastSegment)); + } + + @Test + void testDecryptLastSegmentWithReferenceDataWithNonEmptyLastSegment() { + FloeParameterSpec floeParameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, new FloeIvLength(32), new IncrementingFloeRandom(0), 16, 1L << 40); + Floe floe = Floe.getInstance(floeParameterSpec); + FloeEncryptor encryptor = floe.createEncryptor(new SecretKeySpec(new byte[16], "FLOE"), new byte[0]); + byte[] plaintext = new byte[8]; + byte[] encryptedFirstSegment = encryptor.processSegment(plaintext); + byte[] encryptedLastSegment = encryptor.processLastSegment(plaintext); + + assertEquals(toHex(encryptedFirstSegment), "ffffffff0000000100000000000000002dd631464f6a583369b74f546adfa4db9a838732d6338ef4"); // pragma: allowlist secret + assertEquals(toHex(encryptedLastSegment), "000000280000000200000000000000003b14259ad693c7df7a2d6b9d9912dc70a81205d41ac43a41"); // pragma: allowlist secret + + FloeDecryptor decryptor = floe.createDecryptor(secretKey, new byte[0], encryptor.getHeader()); + assertArrayEquals(plaintext, decryptor.processSegment(encryptedFirstSegment)); + assertArrayEquals(plaintext, decryptor.processLastSegment(encryptedLastSegment)); + } } } From 866a14b0476a71f9003050a2645f417cbbf45b32 Mon Sep 17 00:00:00 2001 From: Piotr Fus Date: Wed, 22 Jan 2025 15:04:08 +0100 Subject: [PATCH 5/5] Add close checking --- .../client/jdbc/cloud/storage/floe/Aead.java | 23 +- .../jdbc/cloud/storage/floe/AeadProvider.java | 12 - .../storage/floe/BaseSegmentProcessor.java | 37 ++- .../cloud/storage/floe/FloeDecryptor.java | 6 +- .../cloud/storage/floe/FloeDecryptorImpl.java | 59 ++++- .../cloud/storage/floe/FloeEncryptor.java | 8 +- .../cloud/storage/floe/FloeEncryptorImpl.java | 82 +++--- .../cloud/storage/floe/SegmentProcessor.java | 6 - .../cloud/storage/floe/aead/AeadProvider.java | 20 ++ .../jdbc/cloud/storage/floe/aead/Gcm.java | 13 +- .../storage/floe/FloeDecryptorImplTest.java | 157 ++++++----- .../storage/floe/FloeEncryptorImplTest.java | 192 ++++++++------ .../jdbc/cloud/storage/floe/FloeTest.java | 244 +++++++++++------- 13 files changed, 531 insertions(+), 328 deletions(-) delete mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadProvider.java delete mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java create mode 100644 src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/AeadProvider.java diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java index cc75fb1d1..1591780e4 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java @@ -1,14 +1,15 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -import net.snowflake.client.jdbc.cloud.storage.floe.aead.Gcm; - import java.util.function.Supplier; +import net.snowflake.client.jdbc.cloud.storage.floe.aead.AeadProvider; +import net.snowflake.client.jdbc.cloud.storage.floe.aead.Gcm; public enum Aead { - AES_GCM_256((byte) 0, "AES/GCM/NoPadding", 32, 12, 16, () -> new Gcm(16)); + AES_GCM_256((byte) 0, "AES", "AES/GCM/NoPadding", 32, 12, 16, () -> new Gcm(16)); private byte id; - private String jceName; + private String jceKeyTypeName; + private String jceFullName; private int keyLength; private int ivLength; private int authTagLength; @@ -16,12 +17,14 @@ public enum Aead { Aead( byte id, - String jceName, + String jceKeyTypeName, + String jceFullName, int keyLength, int ivLength, int authTagLength, Supplier aeadProvider) { - this.jceName = jceName; + this.jceKeyTypeName = jceKeyTypeName; + this.jceFullName = jceFullName; this.keyLength = keyLength; this.id = id; this.ivLength = ivLength; @@ -33,8 +36,12 @@ byte getId() { return id; } - String getJceName() { - return jceName; + public String getJceKeyTypeName() { + return jceKeyTypeName; + } + + String getJceFullName() { + return jceFullName; } int getKeyLength() { diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadProvider.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadProvider.java deleted file mode 100644 index 106d604cd..000000000 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadProvider.java +++ /dev/null @@ -1,12 +0,0 @@ -package net.snowflake.client.jdbc.cloud.storage.floe; - -import java.security.GeneralSecurityException; -import javax.crypto.SecretKey; - -public interface AeadProvider { - byte[] encrypt(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext) - throws GeneralSecurityException; - - byte[] decrypt(SecretKey key, byte[] iv, byte[] aad, byte[] ciphertext) - throws GeneralSecurityException; -} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java index 65fda1351..aa00aa8cd 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java @@ -1,9 +1,11 @@ package net.snowflake.client.jdbc.cloud.storage.floe; +import java.io.Closeable; +import java.io.IOException; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; -abstract class BaseSegmentProcessor { +abstract class BaseSegmentProcessor implements Closeable { protected static final int NON_TERMINAL_SEGMENT_SIZE_MARKER = -1; protected static final int headerTagLength = 32; @@ -15,6 +17,9 @@ abstract class BaseSegmentProcessor { private AeadKey currentAeadKey; + private boolean isClosed; + private boolean completedExceptionally; + BaseSegmentProcessor(FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad) { this.parameterSpec = parameterSpec; this.floeKey = floeKey; @@ -24,7 +29,7 @@ abstract class BaseSegmentProcessor { protected AeadKey getKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) { if (currentAeadKey == null || segmentCounter % parameterSpec.getKeyRotationModulo() == 0) { - // TODO should we mask segments here? + // we don't need masking, because we derive a new key only when key rotation happens currentAeadKey = deriveKey(floeKey, floeIv, floeAad, segmentCounter); } return currentAeadKey; @@ -38,8 +43,32 @@ private AeadKey deriveKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long floeAad, new DekTagFloePurpose(segmentCounter), parameterSpec.getAead().getKeyLength()); - SecretKey key = - new SecretKeySpec(keyBytes, "AES"); // for now it is safe as we use only AES as AEAD + SecretKey key = new SecretKeySpec(keyBytes, parameterSpec.getAead().getJceKeyTypeName()); return new AeadKey(key); } + + protected void closeInternal() { + isClosed = true; + } + + protected void markAsCompletedExceptionally() { + completedExceptionally = true; + } + + protected void assertNotClosed() { + if (isClosed) { + throw new IllegalStateException("stream has already been closed"); + } + } + + @Override + public void close() throws IOException { + if (!isClosed && !completedExceptionally) { + throw new IllegalStateException("last segment was not processed"); + } + } + + protected boolean isClosed() { + return isClosed; + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java index 085f23789..986a738fe 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java @@ -1,3 +1,7 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -public interface FloeDecryptor extends SegmentProcessor {} +public interface FloeDecryptor extends AutoCloseable { + byte[] processSegment(byte[] ciphertext); + + boolean isClosed(); +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java index dd8d07116..ad440f8de 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java @@ -3,6 +3,7 @@ import java.nio.ByteBuffer; import java.security.GeneralSecurityException; import java.util.Arrays; +import net.snowflake.client.jdbc.cloud.storage.floe.aead.AeadProvider; class FloeDecryptorImpl extends BaseSegmentProcessor implements FloeDecryptor { private final FloeIv floeIv; @@ -46,9 +47,32 @@ class FloeDecryptorImpl extends BaseSegmentProcessor implements FloeDecryptor { @Override public byte[] processSegment(byte[] input) { + assertNotClosed(); + ByteBuffer inputBuffer = ByteBuffer.wrap(input); try { - verifySegmentLength(input); - ByteBuffer inputBuf = ByteBuffer.wrap(input); + if (isLastSegment(inputBuffer)) { + return processLastSegment(inputBuffer); + } else { + return processNonLastSegment(inputBuffer); + } + } catch (Exception e) { + markAsCompletedExceptionally(); + throw e; + } + } + + private boolean isLastSegment(ByteBuffer inputBuffer) { + int segmentSizeMarker = inputBuffer.getInt(); + try { + return segmentSizeMarker != NON_TERMINAL_SEGMENT_SIZE_MARKER; + } finally { + inputBuffer.rewind(); + } + } + + private byte[] processNonLastSegment(ByteBuffer inputBuf) { + try { + verifyNonLastSegmentLength(inputBuf); verifySegmentSizeMarker(inputBuf); AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); AeadIv aeadIv = AeadIv.from(inputBuf, parameterSpec.getAead().getIvLength()); @@ -64,12 +88,12 @@ public byte[] processSegment(byte[] input) { } } - private void verifySegmentLength(byte[] input) { - if (input.length != parameterSpec.getEncryptedSegmentLength()) { + private void verifyNonLastSegmentLength(ByteBuffer inputBuf) { + if (inputBuf.capacity() != parameterSpec.getEncryptedSegmentLength()) { throw new IllegalArgumentException( String.format( "segment length mismatch, expected %d, got %d", - parameterSpec.getEncryptedSegmentLength(), input.length)); + parameterSpec.getEncryptedSegmentLength(), inputBuf.capacity())); } } @@ -83,10 +107,8 @@ private void verifySegmentSizeMarker(ByteBuffer inputBuf) { } } - @Override - public byte[] processLastSegment(byte[] input) { - verifyLastSegmentLength(input); - ByteBuffer inputBuf = ByteBuffer.wrap(input); + private byte[] processLastSegment(ByteBuffer inputBuf) { + verifyLastSegmentLength(inputBuf); verifyLastSegmentSizeMarker(inputBuf); try { AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); @@ -96,18 +118,19 @@ public byte[] processLastSegment(byte[] input) { inputBuf.get(ciphertext); byte[] decrypted = aeadProvider.decrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), ciphertext); + closeInternal(); return decrypted; } catch (GeneralSecurityException e) { throw new RuntimeException(e); } } - private void verifyLastSegmentLength(byte[] input) { - // TODO <= ? - if (input.length < 4 + parameterSpec.getAead().getIvLength() + parameterSpec.getAead().getAuthTagLength()) { + private void verifyLastSegmentLength(ByteBuffer inputBuf) { + if (inputBuf.capacity() + < 4 + parameterSpec.getAead().getIvLength() + parameterSpec.getAead().getAuthTagLength()) { throw new IllegalArgumentException("last segment is too short"); } - if (input.length > parameterSpec.getEncryptedSegmentLength()) { + if (inputBuf.capacity() > parameterSpec.getEncryptedSegmentLength()) { throw new IllegalArgumentException("last segment is too long"); } } @@ -115,7 +138,15 @@ private void verifyLastSegmentLength(byte[] input) { private void verifyLastSegmentSizeMarker(ByteBuffer inputBuf) { int segmentLengthFromSegment = inputBuf.getInt(); if (segmentLengthFromSegment != inputBuf.capacity()) { - throw new IllegalArgumentException(String.format("last segment length marker mismatch, expected: %d, got: %d", inputBuf.capacity(), segmentLengthFromSegment)); + throw new IllegalArgumentException( + String.format( + "last segment length marker mismatch, expected: %d, got: %d", + inputBuf.capacity(), segmentLengthFromSegment)); } } + + @Override + public boolean isClosed() { + return super.isClosed(); + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java index f1ab85496..911ba2e81 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java @@ -1,5 +1,11 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -public interface FloeEncryptor extends SegmentProcessor { +public interface FloeEncryptor extends AutoCloseable { + byte[] processSegment(byte[] plaintext); + + byte[] processLastSegment(byte[] plaintext); + byte[] getHeader(); + + boolean isClosed(); } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java index cf35ec994..a7677fc9a 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java @@ -2,6 +2,7 @@ import java.nio.ByteBuffer; import java.security.GeneralSecurityException; +import net.snowflake.client.jdbc.cloud.storage.floe.aead.AeadProvider; class FloeEncryptorImpl extends BaseSegmentProcessor implements FloeEncryptor { private final FloeIv floeIv; @@ -45,23 +46,28 @@ public byte[] getHeader() { @Override public byte[] processSegment(byte[] input) { - verifySegmentLength(input); - // TODO assert State.Counter != 2^40-1 # Prevent overflow - verifyMaxSegmentNumberNotReached(); + assertNotClosed(); try { - AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); - AeadIv aeadIv = - AeadIv.generateRandom( - parameterSpec.getFloeRandom(), parameterSpec.getAead().getIvLength()); - AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter); - // it works as long as AEAD returns auth tag as a part of the ciphertext - byte[] ciphertextWithAuthTag = - aeadProvider.encrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), input); - byte[] encoded = segmentToBytes(aeadIv, ciphertextWithAuthTag); - segmentCounter++; - return encoded; - } catch (GeneralSecurityException e) { - throw new RuntimeException(e); + verifySegmentLength(input); + verifyMaxSegmentNumberNotReached(); + try { + AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); + AeadIv aeadIv = + AeadIv.generateRandom( + parameterSpec.getFloeRandom(), parameterSpec.getAead().getIvLength()); + AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter); + // it works as long as AEAD returns auth tag as a part of the ciphertext + byte[] ciphertextWithAuthTag = + aeadProvider.encrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), input); + byte[] encoded = segmentToBytes(aeadIv, ciphertextWithAuthTag); + segmentCounter++; + return encoded; + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } catch (Exception e) { + markAsCompletedExceptionally(); + throw e; } } @@ -90,18 +96,26 @@ private byte[] segmentToBytes(AeadIv aeadIv, byte[] ciphertextWithAuthTag) { @Override public byte[] processLastSegment(byte[] input) { - verifyLastSegmentNotEmpty(input); + assertNotClosed(); try { - AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); - AeadIv aeadIv = - AeadIv.generateRandom( - parameterSpec.getFloeRandom(), parameterSpec.getAead().getIvLength()); - AeadAad aeadAad = AeadAad.terminal(segmentCounter); - byte[] ciphertextWithAuthTag = - aeadProvider.encrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), input); - return lastSegmentToBytes(aeadIv, ciphertextWithAuthTag); - } catch (GeneralSecurityException e) { - throw new RuntimeException(e); + verifyLastSegmentNotEmpty(input); + try { + AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); + AeadIv aeadIv = + AeadIv.generateRandom( + parameterSpec.getFloeRandom(), parameterSpec.getAead().getIvLength()); + AeadAad aeadAad = AeadAad.terminal(segmentCounter); + byte[] ciphertextWithAuthTag = + aeadProvider.encrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), input); + byte[] lastSegmentBytes = lastSegmentToBytes(aeadIv, ciphertextWithAuthTag); + closeInternal(); + return lastSegmentBytes; + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } catch (Exception e) { + markAsCompletedExceptionally(); + throw e; } } @@ -115,12 +129,16 @@ private byte[] lastSegmentToBytes(AeadIv aeadIv, byte[] ciphertextWithAuthTag) { } private void verifyLastSegmentNotEmpty(byte[] input) { - // TODO -// if (input.length == 0) { -// throw new IllegalArgumentException("last segment is empty"); -// } if (input.length > parameterSpec.getPlainTextSegmentLength()) { - throw new IllegalArgumentException(String.format("last segment is too long, got %d, max is %d", input.length, parameterSpec.getPlainTextSegmentLength())); + throw new IllegalArgumentException( + String.format( + "last segment is too long, got %d, max is %d", + input.length, parameterSpec.getPlainTextSegmentLength())); } } + + @Override + public boolean isClosed() { + return super.isClosed(); + } } diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java deleted file mode 100644 index 8c1b90cd5..000000000 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SegmentProcessor.java +++ /dev/null @@ -1,6 +0,0 @@ -package net.snowflake.client.jdbc.cloud.storage.floe; - -interface SegmentProcessor { - byte[] processSegment(byte[] input); - byte[] processLastSegment(byte[] input); -} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/AeadProvider.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/AeadProvider.java new file mode 100644 index 000000000..63fee639c --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/AeadProvider.java @@ -0,0 +1,20 @@ +package net.snowflake.client.jdbc.cloud.storage.floe.aead; + +import java.security.GeneralSecurityException; +import javax.crypto.SecretKey; + +// Consideration for implementations: +// 1. Implementations does not have to be thread safe, they are used in FLOE in a thread safe manner +// (FLOE encryptor and decryptor creates their own instances). +// 2. Authentication tag is a part of ciphertext: +// a) For encrypt function - auth tag is returned with ciphertext. +// b) For decrypt function - auth tag is passed with ciphertext. +// As long as it isn't strictly required to be at the end of the ciphertext, it is needed to be in +// the correct position for the underlying algorithm. +public interface AeadProvider { + byte[] encrypt(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext) + throws GeneralSecurityException; + + byte[] decrypt(SecretKey key, byte[] iv, byte[] aad, byte[] ciphertext) + throws GeneralSecurityException; +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java index b9b44813e..616ff2aea 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java @@ -1,20 +1,19 @@ package net.snowflake.client.jdbc.cloud.storage.floe.aead; -import net.snowflake.client.jdbc.cloud.storage.floe.AeadProvider; - +import java.security.GeneralSecurityException; +import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; import javax.crypto.BadPaddingException; import javax.crypto.Cipher; import javax.crypto.IllegalBlockSizeException; import javax.crypto.NoSuchPaddingException; import javax.crypto.SecretKey; import javax.crypto.spec.GCMParameterSpec; -import java.security.GeneralSecurityException; -import java.security.InvalidAlgorithmParameterException; -import java.security.InvalidKeyException; -import java.security.NoSuchAlgorithmException; // This class is not thread safe! -// But as long as it is used only for FLOE, it is fine, as FLOE instance keeps its own instance of GCM. +// But as long as it is used only for FLOE, it is fine, as FLOE instance keeps its own instance of +// GCM. public class Gcm implements AeadProvider { private final Cipher keyCipher; private final int tagLengthInBits; diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImplTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImplTest.java index 6488b9faa..c127ee069 100644 --- a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImplTest.java +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImplTest.java @@ -1,120 +1,137 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import javax.crypto.AEADBadTagException; import javax.crypto.SecretKey; import javax.crypto.spec.SecretKeySpec; -import java.nio.charset.StandardCharsets; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; +import org.junit.jupiter.api.Test; class FloeDecryptorImplTest { private final SecretKey secretKey = new SecretKeySpec(new byte[32], "AES"); private final byte[] aad = "Test AAD".getBytes(StandardCharsets.UTF_8); @Test - void shouldDecryptCiphertext() { + void shouldDecryptCiphertext() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] firstSegment = encryptor.processSegment(new byte[8]); - byte[] lastSegment = encryptor.processLastSegment(new byte[4]); - - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); - assertArrayEquals(new byte[8], decryptor.processSegment(firstSegment)); - assertArrayEquals(new byte[4], decryptor.processLastSegment(lastSegment)); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] firstSegment = encryptor.processSegment(new byte[8]); + byte[] lastSegment = encryptor.processLastSegment(new byte[4]); + + assertArrayEquals(new byte[8], decryptor.processSegment(firstSegment)); + assertArrayEquals(new byte[4], decryptor.processSegment(lastSegment)); + } } @Test - void shouldDecryptLastSegmentZeroLength() { + void shouldDecryptLastSegmentZeroLength() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] lastSegment = encryptor.processLastSegment(new byte[0]); - - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); - assertArrayEquals(new byte[0], decryptor.processLastSegment(lastSegment)); - } - - @Test - void shouldDecryptLastSegmentFullLength() { - FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); - Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] lastSegment = encryptor.processLastSegment(new byte[8]); - - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); - assertArrayEquals(new byte[8], decryptor.processLastSegment(lastSegment)); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] lastSegment = encryptor.processLastSegment(new byte[0]); + assertArrayEquals(new byte[0], decryptor.processSegment(lastSegment)); + } } @Test - void shouldThrowExceptionIfSegmentLengthIsMismatched() { + void shouldDecryptLastSegmentFullLength() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> decryptor.processSegment(new byte[12])); - assertEquals("segment length mismatch, expected 40, got 12", e.getMessage()); - e = assertThrows(IllegalArgumentException.class, () -> decryptor.processSegment(new byte[1024])); - assertEquals("segment length mismatch, expected 40, got 1024", e.getMessage()); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] lastSegment = encryptor.processLastSegment(new byte[8]); + assertArrayEquals(new byte[8], decryptor.processSegment(lastSegment)); + } } @Test - void shouldThrowExceptionIfLastSegmentLengthIsMismatched() { + void shouldThrowExceptionIfSegmentLengthIsMismatched() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> decryptor.processLastSegment(new byte[12])); - assertEquals("last segment is too short", e.getMessage()); - e = assertThrows(IllegalArgumentException.class, () -> decryptor.processLastSegment(new byte[1024])); - assertEquals("last segment is too long", e.getMessage()); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] ciphertext = encryptor.processSegment(new byte[8]); + byte[] prunedCiphertext = new byte[12]; + ByteBuffer.wrap(ciphertext).get(prunedCiphertext); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> decryptor.processSegment(prunedCiphertext)); + assertEquals("segment length mismatch, expected 40, got 12", e.getMessage()); + byte[] extendedCiphertext = new byte[1024]; + ByteBuffer.wrap(extendedCiphertext).put(ciphertext); + e = + assertThrows( + IllegalArgumentException.class, () -> decryptor.processSegment(extendedCiphertext)); + assertEquals("segment length mismatch, expected 40, got 1024", e.getMessage()); + encryptor.processLastSegment(new byte[4]); + } } @Test - void shouldThrowExceptionIfSegmentLengthInSegmentIsNotMinusOne() { + void shouldThrowExceptionIfLastSegmentLengthIsMismatched() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> decryptor.processSegment(new byte[40])); - assertEquals("segment length marker mismatch, expected: -1, got: 0", e.getMessage()); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + encryptor.processLastSegment(new byte[4]); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> decryptor.processSegment(new byte[12])); + assertEquals("last segment is too short", e.getMessage()); + e = + assertThrows( + IllegalArgumentException.class, () -> decryptor.processSegment(new byte[1024])); + assertEquals("last segment is too long", e.getMessage()); + } } @Test - void shouldThrowExceptionIfLastSegmentLengthInSegmentIsNotMinusOne() { + void shouldThrowExceptionIfLastSegmentLengthMarkerIsNotMinusOne() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> decryptor.processLastSegment(new byte[40])); - assertEquals("last segment length marker mismatch, expected: 40, got: 0", e.getMessage()); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + encryptor.processLastSegment(new byte[4]); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> decryptor.processSegment(new byte[40])); + assertEquals("last segment length marker mismatch, expected: 40, got: 0", e.getMessage()); + } } @Test - void shouldThrowExceptionIfSegmentIsTampered() { + void shouldThrowExceptionIfSegmentIsTampered() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] ciphertext = encryptor.processSegment(new byte[8]); - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); - ciphertext[39]++; - RuntimeException e = assertThrows(RuntimeException.class, () -> decryptor.processSegment(ciphertext)); - assertEquals(e.getCause().getClass(), AEADBadTagException.class); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] ciphertext = encryptor.processLastSegment(new byte[8]); + ciphertext[39]++; + RuntimeException e = + assertThrows(RuntimeException.class, () -> decryptor.processSegment(ciphertext)); + assertEquals(e.getCause().getClass(), AEADBadTagException.class); + } } @Test - void shouldThrowExceptionIfSegmentAreOutOfOrder() { + void shouldThrowExceptionIfSegmentAreOutOfOrder() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] ciphertext1 = encryptor.processSegment(new byte[8]); - byte[] ciphertext2 = encryptor.processSegment(new byte[8]); - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); - RuntimeException e = assertThrows(RuntimeException.class, () -> decryptor.processSegment(ciphertext2)); - assertEquals(e.getCause().getClass(), AEADBadTagException.class); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] ciphertext1 = encryptor.processSegment(new byte[8]); + byte[] ciphertext2 = encryptor.processSegment(new byte[8]); + encryptor.processLastSegment(new byte[4]); + RuntimeException e = + assertThrows(RuntimeException.class, () -> decryptor.processSegment(ciphertext2)); + assertEquals(e.getCause().getClass(), AEADBadTagException.class); + } } -} \ No newline at end of file +} diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java index 2d74b40bb..4681176e7 100644 --- a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java @@ -1,24 +1,27 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -import org.junit.jupiter.api.Test; - -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.List; - import static org.junit.Assert.assertThrows; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; class FloeEncryptorImplTest { byte[] aad = "This is AAD".getBytes(StandardCharsets.UTF_8); SecretKey secretKey = new SecretKeySpec(new byte[32], "FLOE"); @Test - void shouldCreateCorrectHeader() { + void shouldCreateCorrectHeader() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec( Aead.AES_GCM_256, @@ -26,37 +29,45 @@ void shouldCreateCorrectHeader() { 12345678, new FloeIvLength(4), new IncrementingFloeRandom(17), - 4, 1L << 40); + 4, + 1L << 40); parameterSpec.getFloeRandom().ofLength(4); // just to trigger incrementation FloeKey floeKey = new FloeKey(new SecretKeySpec(new byte[32], "FLOE")); FloeAad floeAad = new FloeAad("test aad".getBytes(StandardCharsets.UTF_8)); - FloeEncryptorImpl floeEncryptor = new FloeEncryptorImpl(parameterSpec, floeKey, floeAad); - byte[] header = floeEncryptor.getHeader(); - // AEAD ID - assertEquals(Aead.AES_GCM_256.getId(), header[0]); - // HASH ID - assertEquals(Hash.SHA384.getId(), header[1]); - // Segment length in BE - // 12345678(10) = BC614E(16) - assertEquals(0, header[2]); - assertEquals((byte) 188, header[3]); - assertEquals((byte) 97, header[4]); - assertEquals((byte) 78, header[5]); - // FLOE IV length in BE - // 4(10) = 4(16) = 00,00,00,04 - assertEquals(0, header[6]); - assertEquals(0, header[7]); - assertEquals(0, header[8]); - assertEquals(4, header[9]); - // FLOE IV - assertEquals(0, header[10]); - assertEquals(0, header[11]); - assertEquals(0, header[12]); - assertEquals(18, header[13]); + try (FloeEncryptor encryptor = new FloeEncryptorImpl(parameterSpec, floeKey, floeAad)) { + byte[] header = encryptor.getHeader(); + // AEAD ID + assertEquals(Aead.AES_GCM_256.getId(), header[0]); + // HASH ID + assertEquals(Hash.SHA384.getId(), header[1]); + // Segment length in BE + // 12345678(10) = BC614E(16) + assertEquals(0, header[2]); + assertEquals((byte) 188, header[3]); + assertEquals((byte) 97, header[4]); + assertEquals((byte) 78, header[5]); + // FLOE IV length in BE + // 4(10) = 4(16) = 00,00,00,04 + assertEquals(0, header[6]); + assertEquals(0, header[7]); + assertEquals(0, header[8]); + assertEquals(4, header[9]); + // FLOE IV + assertEquals(0, header[10]); + assertEquals(0, header[11]); + assertEquals(0, header[12]); + assertEquals(18, header[13]); + + close(encryptor); + } + } + + private static byte[] close(FloeEncryptor encryptor) { + return encryptor.processLastSegment(new byte[0]); } @Test - void testEncryptionMatchesReference() { + void testEncryptionMatchesReference() throws Exception { List referenceCiphertextSegments = Arrays.asList( "ffffffff0000000100000000000000000100007f5713b9827bb806318311fcde197146a144c6b485", // pragma: allowlist secret @@ -77,69 +88,104 @@ void testEncryptionMatchesReference() { 40, new FloeIvLength(32), new IncrementingFloeRandom(0), - 4, 1L << 40); + 4, + 1L << 40); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] header = encryptor.getHeader(); - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); - byte[] testData = new byte[8]; - for (String referenceCiphertextSegment : referenceCiphertextSegments) { - byte[] ciphertextBytes = encryptor.processSegment(testData); - String ciphertextHex = toHex(ciphertextBytes); - assertEquals(referenceCiphertextSegment, ciphertextHex); - byte[] plaintextBytes = decryptor.processSegment(ciphertextBytes); - assertArrayEquals(testData, plaintextBytes); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] testData = new byte[8]; + for (String referenceCiphertextSegment : referenceCiphertextSegments) { + byte[] ciphertextBytes = encryptor.processSegment(testData); + String ciphertextHex = toHex(ciphertextBytes); + assertEquals(referenceCiphertextSegment, ciphertextHex); + byte[] plaintextBytes = decryptor.processSegment(ciphertextBytes); + assertArrayEquals(testData, plaintextBytes); + } + + byte[] lastSegment = encryptor.processLastSegment(new byte[0]); + decryptor.processSegment(lastSegment); } } @Test - void shouldThrowExceptionOnMaxSegmentReached() { - FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, new FloeIvLength(32), new SecureFloeRandom(), 20, 3); + void shouldThrowExceptionOnMaxSegmentReached() throws Exception { + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, Hash.SHA384, 40, new FloeIvLength(32), new SecureFloeRandom(), 20, 3); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] plaintext = new byte[8]; - encryptor.processSegment(plaintext); - encryptor.processSegment(plaintext); - assertThrows(IllegalStateException.class, () -> encryptor.processSegment(plaintext)); - assertDoesNotThrow(() -> encryptor.processLastSegment(plaintext)); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + byte[] plaintext = new byte[8]; + encryptor.processSegment(plaintext); + encryptor.processSegment(plaintext); + assertThrows(IllegalStateException.class, () -> encryptor.processSegment(plaintext)); + assertDoesNotThrow(() -> encryptor.processLastSegment(plaintext)); + } } @Test - void shouldThrowExceptionIfPlaintextIsTooShort() { + void shouldThrowExceptionIfPlaintextIsTooShort() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + IllegalArgumentException e = + assertThrows(IllegalArgumentException.class, () -> encryptor.processSegment(new byte[0])); + assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 0"); + e = assertThrows(IllegalArgumentException.class, () -> encryptor.processSegment(new byte[7])); + assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 7"); - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> encryptor.processSegment(new byte[0])); - assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 0"); - e = assertThrows(IllegalArgumentException.class, () -> encryptor.processSegment(new byte[7])); - assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 7"); + close(encryptor); + } } @Test - void shouldThrowEncryptionIfPlaintextIsTooLong() { + void shouldThrowEncryptionIfPlaintextIsTooLong() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + IllegalArgumentException e = + assertThrows(IllegalArgumentException.class, () -> encryptor.processSegment(new byte[9])); + assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 9"); + e = + assertThrows( + IllegalArgumentException.class, () -> encryptor.processSegment(new byte[1024])); + assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 1024"); - IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> encryptor.processSegment(new byte[9])); - assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 9"); - e = assertThrows(IllegalArgumentException.class, () -> encryptor.processSegment(new byte[1024])); - assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 1024"); + e = + assertThrows( + IllegalArgumentException.class, () -> encryptor.processLastSegment(new byte[9])); + assertEquals(e.getMessage(), "last segment is too long, got 9, max is 8"); - e = assertThrows(IllegalArgumentException.class, () -> encryptor.processLastSegment(new byte[9])); - assertEquals(e.getMessage(), "last segment is too long, got 9, max is 8"); + close(encryptor); + } } - @Test - void shouldAcceptSegmentWithCorrectSize() { + @ParameterizedTest + @ValueSource(ints = {0, 8}) + void shouldAcceptSegmentWithCorrectSize(int lastSegmentSize) throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + assertDoesNotThrow(() -> encryptor.processSegment(new byte[8])); + assertDoesNotThrow(() -> encryptor.processLastSegment(new byte[lastSegmentSize])); + } + } - assertDoesNotThrow(() -> encryptor.processSegment(new byte[8])); - assertDoesNotThrow(() -> encryptor.processLastSegment(new byte[8])); - assertDoesNotThrow(() -> encryptor.processLastSegment(new byte[0])); + @Test + void shouldNotAcceptNewSegmentsAfterLastOneIsProcessed() throws Exception { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + assertFalse(encryptor.isClosed()); + encryptor.processLastSegment(new byte[4]); + assertTrue(encryptor.isClosed()); + IllegalStateException e = + assertThrows(IllegalStateException.class, () -> encryptor.processSegment(new byte[4])); + assertEquals("stream has already been closed", e.getMessage()); + e = + assertThrows( + IllegalStateException.class, () -> encryptor.processLastSegment(new byte[4])); + assertEquals("stream has already been closed", e.getMessage()); + } } String toHex(byte[] input) { diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java index f148ac4e3..a6914acf5 100644 --- a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java @@ -1,19 +1,17 @@ package net.snowflake.client.jdbc.cloud.storage.floe; -import org.junit.jupiter.api.Nested; -import org.junit.jupiter.api.Test; - -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; -import java.nio.charset.StandardCharsets; -import java.security.SecureRandom; - import static com.amazonaws.util.BinaryUtils.toHex; import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + class FloeTest { byte[] aad = "This is AAD".getBytes(StandardCharsets.UTF_8); SecretKey secretKey = new SecretKeySpec(new byte[32], "FLOE"); @@ -21,56 +19,63 @@ class FloeTest { @Nested class HeaderTests { @Test - void validateHeaderMatchesForEncryptionAndDecryption() { + void validateHeaderMatchesForEncryptionAndDecryption() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 4); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] header = encryptor.getHeader(); - assertDoesNotThrow(() -> floe.createDecryptor(secretKey, aad, header)); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + decryptor.processSegment(encryptor.processLastSegment(new byte[0])); + } } @Test - void validateHeaderDoesNotMatchInParams() { + void validateHeaderDoesNotMatchInParams() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 4); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] header = encryptor.getHeader(); - header[0] = 12; - IllegalArgumentException e = - assertThrows( - IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); - assertEquals(e.getMessage(), "invalid parameters header"); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + byte[] header = encryptor.getHeader(); + header[0] = 12; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid parameters header"); + encryptor.processLastSegment(new byte[0]); + } } @Test - void validateHeaderDoesNotMatchInIV() { + void validateHeaderDoesNotMatchInIV() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 4); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] header = encryptor.getHeader(); - header[11]++; - IllegalArgumentException e = - assertThrows( - IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); - assertEquals(e.getMessage(), "invalid header tag"); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + byte[] header = encryptor.getHeader(); + header[11]++; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid header tag"); + encryptor.processLastSegment(new byte[0]); + } } @Test - void validateHeaderDoesNotMatchInHeaderTag() { + void validateHeaderDoesNotMatchInHeaderTag() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 4096, 4); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] header = encryptor.getHeader(); - header[header.length - 3]++; - IllegalArgumentException e = - assertThrows( - IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); - assertEquals(e.getMessage(), "invalid header tag"); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + byte[] header = encryptor.getHeader(); + header[header.length - 3]++; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid header tag"); + encryptor.processLastSegment(new byte[0]); + } } } @@ -78,7 +83,7 @@ void validateHeaderDoesNotMatchInHeaderTag() { class SegmentTests { @Test - void testSegmentEncryptedAndDecrypted() { + void testSegmentEncryptedAndDecrypted() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec( Aead.AES_GCM_256, @@ -86,19 +91,20 @@ void testSegmentEncryptedAndDecrypted() { 40, new FloeIvLength(32), new IncrementingFloeRandom(678765), - 4, 1L << 40); + 4, + 1L << 40); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] header = encryptor.getHeader(); - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); - byte[] testData = new byte[8]; - byte[] ciphertext = encryptor.processSegment(testData); - byte[] result = decryptor.processSegment(ciphertext); - assertArrayEquals(testData, result); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] testData = new byte[8]; + byte[] ciphertext = encryptor.processLastSegment(testData); + byte[] result = decryptor.processSegment(ciphertext); + assertArrayEquals(testData, result); + } } @Test - void testSegmentEncryptedAndDecryptedWithRandomData() { + void testSegmentEncryptedAndDecryptedWithRandomData() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec( Aead.AES_GCM_256, @@ -106,20 +112,22 @@ void testSegmentEncryptedAndDecryptedWithRandomData() { 40, new FloeIvLength(32), new IncrementingFloeRandom(37665), - 4, 1L << 40); + 4, + 1L << 40); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] header = encryptor.getHeader(); - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); - byte[] testData = new byte[8]; - new SecureRandom().nextBytes(testData); - byte[] ciphertext = encryptor.processSegment(testData); - byte[] result = decryptor.processSegment(ciphertext); - assertArrayEquals(testData, result); + byte[] ciphertext; + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] testData = new byte[8]; + new SecureRandom().nextBytes(testData); + ciphertext = encryptor.processLastSegment(testData); + byte[] result = decryptor.processSegment(ciphertext); + assertArrayEquals(testData, result); + } } @Test - void testSegmentEncryptedAndDecryptedWithDerivedKeyRotation() { + void testSegmentEncryptedAndDecryptedWithDerivedKeyRotation() throws Exception { FloeParameterSpec parameterSpec = new FloeParameterSpec( Aead.AES_GCM_256, @@ -127,16 +135,19 @@ void testSegmentEncryptedAndDecryptedWithDerivedKeyRotation() { 40, new FloeIvLength(32), new IncrementingFloeRandom(6546), - 4, 1L << 40); + 4, + 1L << 40); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] header = encryptor.getHeader(); - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, header); - byte[] testData = new byte[8]; - for (int i = 0; i < 10; i++) { - byte[] ciphertext = encryptor.processSegment(testData); - byte[] result = decryptor.processSegment(ciphertext); - assertArrayEquals(testData, result); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] testData = new byte[8]; + for (int i = 0; i < 10; i++) { + byte[] ciphertext = encryptor.processSegment(testData); + byte[] result = decryptor.processSegment(ciphertext); + assertArrayEquals(testData, result); + } + byte[] ciphertext = encryptor.processLastSegment(testData); + decryptor.processSegment(ciphertext); } } } @@ -144,49 +155,82 @@ void testSegmentEncryptedAndDecryptedWithDerivedKeyRotation() { @Nested class LastSegmentTests { @Test - void testLastSegmentEncryptedAndDecrypted() { - FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 32); + void testLastSegmentEncryptedAndDecrypted() throws Exception { + FloeParameterSpec parameterSpec = + new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 32); Floe floe = Floe.getInstance(parameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); - byte[] plaintext = new byte[3]; - byte[] encrypted = encryptor.processLastSegment(plaintext); - FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader()); - byte[] decrypted = decryptor.processLastSegment(encrypted); - assertArrayEquals(plaintext, decrypted); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] plaintext = new byte[3]; + byte[] encrypted = encryptor.processLastSegment(plaintext); + byte[] decrypted = decryptor.processSegment(encrypted); + assertArrayEquals(plaintext, decrypted); + } } @Test - void testDecryptLastSegmentWithReferenceDataWithEmptyLastSegment() { - FloeParameterSpec floeParameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, new FloeIvLength(32), new IncrementingFloeRandom(0), 16, 1L << 40); + void testDecryptLastSegmentWithReferenceDataWithEmptyLastSegment() throws Exception { + FloeParameterSpec floeParameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(0), + 16, + 1L << 40); Floe floe = Floe.getInstance(floeParameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(new SecretKeySpec(new byte[16], "FLOE"), new byte[0]); - byte[] plaintext = new byte[8]; - byte[] encryptedFirstSegment = encryptor.processSegment(plaintext); - byte[] encryptedLastSegment = encryptor.processLastSegment(new byte[0]); - - assertEquals(toHex(encryptedFirstSegment), "ffffffff0000000100000000000000002dd631464f6a583369b74f546adfa4db9a838732d6338ef4"); // pragma: allowlist secret - assertEquals(toHex(encryptedLastSegment), "000000200000000200000000000000004a4082e6b94a8b1b2053f40879402df1"); // pragma: allowlist secret - - FloeDecryptor decryptor = floe.createDecryptor(secretKey, new byte[0], encryptor.getHeader()); - assertArrayEquals(plaintext, decryptor.processSegment(encryptedFirstSegment)); - assertArrayEquals(new byte[0], decryptor.processLastSegment(encryptedLastSegment)); + try (FloeEncryptor encryptor = + floe.createEncryptor(new SecretKeySpec(new byte[16], "FLOE"), new byte[0]); + FloeDecryptor decryptor = + floe.createDecryptor(secretKey, new byte[0], encryptor.getHeader())) { + byte[] plaintext = new byte[8]; + byte[] encryptedFirstSegment = encryptor.processSegment(plaintext); + byte[] encryptedLastSegment = encryptor.processLastSegment(new byte[0]); + + assertEquals( + toHex(encryptedFirstSegment), + "ffffffff0000000100000000000000002dd631464f6a583369b74f546adfa4db9a838732d6338ef4"); // pragma: allowlist secret + assertEquals( + toHex(encryptedLastSegment), + "000000200000000200000000000000004a4082e6b94a8b1b2053f40879402df1"); // pragma: + // allowlist secret + + assertArrayEquals(plaintext, decryptor.processSegment(encryptedFirstSegment)); + assertArrayEquals(new byte[0], decryptor.processSegment(encryptedLastSegment)); + } } @Test - void testDecryptLastSegmentWithReferenceDataWithNonEmptyLastSegment() { - FloeParameterSpec floeParameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, new FloeIvLength(32), new IncrementingFloeRandom(0), 16, 1L << 40); + void testDecryptLastSegmentWithReferenceDataWithNonEmptyLastSegment() throws Exception { + FloeParameterSpec floeParameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(0), + 16, + 1L << 40); Floe floe = Floe.getInstance(floeParameterSpec); - FloeEncryptor encryptor = floe.createEncryptor(new SecretKeySpec(new byte[16], "FLOE"), new byte[0]); - byte[] plaintext = new byte[8]; - byte[] encryptedFirstSegment = encryptor.processSegment(plaintext); - byte[] encryptedLastSegment = encryptor.processLastSegment(plaintext); - - assertEquals(toHex(encryptedFirstSegment), "ffffffff0000000100000000000000002dd631464f6a583369b74f546adfa4db9a838732d6338ef4"); // pragma: allowlist secret - assertEquals(toHex(encryptedLastSegment), "000000280000000200000000000000003b14259ad693c7df7a2d6b9d9912dc70a81205d41ac43a41"); // pragma: allowlist secret - - FloeDecryptor decryptor = floe.createDecryptor(secretKey, new byte[0], encryptor.getHeader()); - assertArrayEquals(plaintext, decryptor.processSegment(encryptedFirstSegment)); - assertArrayEquals(plaintext, decryptor.processLastSegment(encryptedLastSegment)); + try (FloeEncryptor encryptor = + floe.createEncryptor(new SecretKeySpec(new byte[16], "FLOE"), new byte[0]); + FloeDecryptor decryptor = + floe.createDecryptor(secretKey, new byte[0], encryptor.getHeader())) { + byte[] plaintext = new byte[8]; + byte[] encryptedFirstSegment = encryptor.processSegment(plaintext); + byte[] encryptedLastSegment = encryptor.processLastSegment(plaintext); + + assertEquals( + toHex(encryptedFirstSegment), + "ffffffff0000000100000000000000002dd631464f6a583369b74f546adfa4db9a838732d6338ef4"); // pragma: allowlist secret + assertEquals( + toHex(encryptedLastSegment), + "000000280000000200000000000000003b14259ad693c7df7a2d6b9d9912dc70a81205d41ac43a41"); // pragma: allowlist secret + + assertArrayEquals(plaintext, decryptor.processSegment(encryptedFirstSegment)); + assertArrayEquals(plaintext, decryptor.processSegment(encryptedLastSegment)); + } } } }