From b17165d03bbc0e3eaec0eff11f6759e9792daae8 Mon Sep 17 00:00:00 2001 From: Justin Lin Date: Wed, 18 Dec 2019 10:44:29 -0800 Subject: [PATCH] Fix CryptoService to work with CompositeByteBuf (#1323) PutOperation would read multiple ByteBuf as a Composite from NettyServer layer. If the blob should be encrypted, then the crypto service has to work with CompositeByteBuf. This PR adds such support. --- .../router/CryptoService.java | 12 +- .../GCMCryptoService.java | 57 ++++- .../CryptoServiceTest.java | 227 ++++++++++++++++++ .../GCMCryptoServiceTest.java | 114 +++++++++ .../java/com.github.ambry.utils/Utils.java | 39 +++ 5 files changed, 432 insertions(+), 17 deletions(-) create mode 100644 ambry-router/src/test/java/com.github.ambry.router/CryptoServiceTest.java diff --git a/ambry-api/src/main/java/com.github.ambry/router/CryptoService.java b/ambry-api/src/main/java/com.github.ambry/router/CryptoService.java index 6167882769..acd001794a 100644 --- a/ambry-api/src/main/java/com.github.ambry/router/CryptoService.java +++ b/ambry-api/src/main/java/com.github.ambry/router/CryptoService.java @@ -13,7 +13,9 @@ */ package com.github.ambry.router; +import com.github.ambry.utils.Utils; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import java.nio.ByteBuffer; import java.security.GeneralSecurityException; @@ -47,8 +49,9 @@ public interface CryptoService { * @throws {@link GeneralSecurityException} on any exception with encryption */ default ByteBuf encrypt(ByteBuf toEncrypt, T key) throws GeneralSecurityException { - ByteBuffer encrypted = encrypt(toEncrypt.nioBuffer(), key); - return Unpooled.wrappedBuffer(encrypted); + return Utils.applyByteBufferFunctionToByteBuf(toEncrypt, (buffer) -> { + return encrypt(buffer, key); + }); } /** @@ -69,8 +72,9 @@ default ByteBuf encrypt(ByteBuf toEncrypt, T key) throws GeneralSecurityExceptio * @throws {@link GeneralSecurityException} on any exception with decryption */ default ByteBuf decrypt(ByteBuf toDecrypt, T key) throws GeneralSecurityException { - ByteBuffer decrypted = decrypt(toDecrypt.nioBuffer(), key); - return Unpooled.wrappedBuffer(decrypted); + return Utils.applyByteBufferFunctionToByteBuf(toDecrypt, (buffer) -> { + return decrypt(buffer, key); + }); } /** diff --git a/ambry-router/src/main/java/com.github.ambry.router/GCMCryptoService.java b/ambry-router/src/main/java/com.github.ambry.router/GCMCryptoService.java index 9b0eec4e34..afd23fb707 100644 --- a/ambry-router/src/main/java/com.github.ambry.router/GCMCryptoService.java +++ b/ambry-router/src/main/java/com.github.ambry.router/GCMCryptoService.java @@ -126,13 +126,28 @@ public ByteBuf encrypt(ByteBuf toEncrypt, SecretKeySpec key, byte[] iv) throws G ByteBuf encryptedContent = ByteBufAllocator.DEFAULT.heapBuffer(IVRecord_Format_V1.getIVRecordSize(iv) + outputSize); IVRecord_Format_V1.serializeIVRecord(encryptedContent, iv); - ByteBuffer toEncryptBuffer = toEncrypt.nioBuffer(); - ByteBuffer encryptedContentBuffer = encryptedContent.nioBuffer(encryptedContent.writerIndex(), - encryptedContent.capacity() - encryptedContent.writerIndex()); - int n = encrypter.doFinal(toEncryptBuffer, encryptedContentBuffer); - toEncrypt.readerIndex(toEncrypt.readerIndex() + toEncrypt.readableBytes()); - encryptedContent.writerIndex(encryptedContent.writerIndex() + n); - return encryptedContent; + + boolean toRelease = false; + if (toEncrypt.nioBufferCount() != 1) { + toRelease = true; + ByteBuf temp = ByteBufAllocator.DEFAULT.heapBuffer(toEncrypt.readableBytes()); + temp.writeBytes(toEncrypt); + toEncrypt = temp; + } + try { + ByteBuffer toEncryptBuffer = toEncrypt.nioBuffer(); + ByteBuffer encryptedContentBuffer = encryptedContent.nioBuffer(encryptedContent.writerIndex(), + encryptedContent.capacity() - encryptedContent.writerIndex()); + int n = encrypter.doFinal(toEncryptBuffer, encryptedContentBuffer); + encryptedContent.writerIndex(encryptedContent.writerIndex() + n); + return encryptedContent; + } finally { + if (toRelease) { + toEncrypt.release(); + } else { + toEncrypt.skipBytes(toEncrypt.readableBytes()); + } + } } catch (Exception e) { throw new GeneralSecurityException("Exception thrown while encrypting data", e); } @@ -161,12 +176,28 @@ public ByteBuf decrypt(ByteBuf toDecrypt, SecretKeySpec key) throws GeneralSecur decrypter.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv)); int outputSize = decrypter.getOutputSize(toDecrypt.readableBytes()); ByteBuf decryptedContent = ByteBufAllocator.DEFAULT.heapBuffer(outputSize); - ByteBuffer toDecryptBuffer = toDecrypt.nioBuffer(); - ByteBuffer decryptedContentBuffer = decryptedContent.nioBuffer(0, outputSize); - int n = decrypter.doFinal(toDecryptBuffer, decryptedContentBuffer); - toDecrypt.readerIndex(toDecrypt.readerIndex() + toDecrypt.readableBytes()); - decryptedContent.writerIndex(decryptedContent.writerIndex() + n); - return decryptedContent; + + boolean toRelease = false; + if (toDecrypt.nioBufferCount() != 1) { + toRelease = true; + ByteBuf temp = ByteBufAllocator.DEFAULT.heapBuffer(toDecrypt.readableBytes()); + temp.writeBytes(toDecrypt); + toDecrypt = temp; + } + + try { + ByteBuffer toDecryptBuffer = toDecrypt.nioBuffer(); + ByteBuffer decryptedContentBuffer = decryptedContent.nioBuffer(0, outputSize); + int n = decrypter.doFinal(toDecryptBuffer, decryptedContentBuffer); + decryptedContent.writerIndex(decryptedContent.writerIndex() + n); + return decryptedContent; + } finally { + if (toRelease) { + toDecrypt.release(); + } else { + toDecrypt.skipBytes(toDecrypt.readableBytes()); + } + } } catch (Exception e) { throw new GeneralSecurityException("Exception thrown while decrypting data", e); } diff --git a/ambry-router/src/test/java/com.github.ambry.router/CryptoServiceTest.java b/ambry-router/src/test/java/com.github.ambry.router/CryptoServiceTest.java new file mode 100644 index 0000000000..ba9c777fb1 --- /dev/null +++ b/ambry-router/src/test/java/com.github.ambry.router/CryptoServiceTest.java @@ -0,0 +1,227 @@ +/* + * Copyright 2017 LinkedIn Corp. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +package com.github.ambry.router; + +import com.codahale.metrics.MetricRegistry; +import com.github.ambry.commons.NettyByteBufLeakHelper; +import com.github.ambry.config.VerifiableProperties; +import com.github.ambry.utils.TestUtils; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; +import javax.crypto.spec.SecretKeySpec; +import org.bouncycastle.util.encoders.Hex; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static com.github.ambry.router.CryptoTestUtils.*; + + +/** + * Tests for default method in interface {@link CryptoService}. + */ +@RunWith(Parameterized.class) +public class CryptoServiceTest { + private final NettyByteBufLeakHelper nettyByteBufLeakHelper = new NettyByteBufLeakHelper(); + private static final MetricRegistry REGISTRY = new MetricRegistry(); + private static final int DEFAULT_KEY_SIZE_IN_CHARS = 64; + private static final int MAX_DATA_SIZE = 10000; + private static final int MIN_DATA_SIZE = 100; + private final boolean isCompositeByteBuf; + + @Parameterized.Parameters + public static List data() { + return Arrays.asList(new Object[][]{{false}, {true}}); + } + + /** + * Constructor to create a CryptoServiceTest. + * @param isCompositeByteBuf + */ + public CryptoServiceTest(boolean isCompositeByteBuf) { + this.isCompositeByteBuf = isCompositeByteBuf; + } + + @Before + public void before() { + nettyByteBufLeakHelper.beforeTest(); + } + + @After + public void after() { + nettyByteBufLeakHelper.afterTest(); + } + + /** + * Create a {@link CompositeByteBuf} from the given byte array. + * @param data the byte array. + * @return A {@link CompositeByteBuf}. + */ + private CompositeByteBuf fromByteArrayToCompositeByteBuf(byte[] data) { + int size = data.length; + ByteBuf toEncrypt = Unpooled.wrappedBuffer(data); + CompositeByteBuf composite = new CompositeByteBuf(toEncrypt.alloc(), toEncrypt.isDirect(), size); + int start = 0; + int end = 0; + for (int j = 0; j < 3; j++) { + start = end; + end = TestUtils.RANDOM.nextInt(size / 2 - 1) + end; + if (j == 2) { + end = size; + } + ByteBuf c = Unpooled.buffer(end - start); + c.writeBytes(data, start, end - start); + composite.addComponent(true, c); + } + return composite; + } + + /** + * Create a {@link ByteBuf} based on whether it should be a composite ByteBuf or not. If it should, then + * create a {@link CompositeByteBuf} with three components. + * @return A {@link ByteBuf}. + */ + private ByteBuf createByteBuf() { + int size = TestUtils.RANDOM.nextInt(MAX_DATA_SIZE - MIN_DATA_SIZE) + MIN_DATA_SIZE; + byte[] randomData = new byte[size]; + TestUtils.RANDOM.nextBytes(randomData); + if (isCompositeByteBuf) { + return fromByteArrayToCompositeByteBuf(randomData); + } else { + return ByteBufAllocator.DEFAULT.heapBuffer(size); + } + } + + /** + * Convert the given {@link ByteBuf} to a {@link CompositeByteBuf} if the {@code isCompositeByteBuf} is true. + * @param buf The given {@link ByteBuf}. + * @return The result {@link ByteBuf}. + */ + private ByteBuf maybeConvertToComposite(ByteBuf buf) { + if (!isCompositeByteBuf) { + return buf.retainedDuplicate(); + } else { + byte[] data = new byte[buf.readableBytes()]; + buf.getBytes(buf.readerIndex(), data); + return fromByteArrayToCompositeByteBuf(data); + } + } + + /** + * Create a {@link ByteBuffer} from given {@link ByteBuf} so that they have the same content. + * @param byteBuf The given {@link ByteBuf}.... + * @return The {@link ByteBuffer}. + */ + private ByteBuffer fromByteBufToByteBuffer(ByteBuf byteBuf) { + int size = byteBuf.readableBytes(); + ByteBuffer buffer = ByteBuffer.allocate(size); + byteBuf.getBytes(0, buffer); + buffer.flip(); + return buffer; + } + + /** + * Test the default methods for those implementations that don't implement the default methods. + * @throws Exception + */ + @Test + public void testDefaultMethodForEncryptDecrypt() throws Exception { + CryptoService cryptoService = new MockCryptoService(); + String key = ((MockCryptoService) cryptoService).getKey(); + SecretKeySpec secretKeySpec = new SecretKeySpec(Hex.decode(key), "AES"); + for (int i = 0; i < 5; i++) { + ByteBuf toEncryptByteBuf = createByteBuf(); + ByteBuffer toEncrypt = fromByteBufToByteBuffer(toEncryptByteBuf); + ByteBuf encryptedBytesByteBuf = cryptoService.encrypt(toEncryptByteBuf, secretKeySpec); + ByteBuffer encryptedBytes = cryptoService.encrypt(toEncrypt, secretKeySpec); + + Assert.assertTrue(encryptedBytesByteBuf.hasArray()); + Assert.assertEquals(encryptedBytes.remaining(), encryptedBytesByteBuf.readableBytes()); + Assert.assertEquals(toEncryptByteBuf.readableBytes(), 0); + Assert.assertEquals(toEncrypt.remaining(), 0); + byte[] arrayFromByteBuf = new byte[encryptedBytesByteBuf.readableBytes()]; + encryptedBytesByteBuf.getBytes(encryptedBytesByteBuf.readerIndex(), arrayFromByteBuf); + Assert.assertArrayEquals(encryptedBytes.array(), arrayFromByteBuf); + + ByteBuf toDecryptByteBuf = maybeConvertToComposite(encryptedBytesByteBuf); + ByteBuffer toDecrypt = encryptedBytes; + ByteBuf decryptedBytesByteBuf = cryptoService.decrypt(toDecryptByteBuf, secretKeySpec); + ByteBuffer decryptedBytes = cryptoService.decrypt(encryptedBytes, secretKeySpec); + + Assert.assertTrue(decryptedBytesByteBuf.hasArray()); + Assert.assertEquals(decryptedBytes.remaining(), decryptedBytesByteBuf.readableBytes()); + Assert.assertEquals(toDecryptByteBuf.readableBytes(), 0); + Assert.assertEquals(toDecrypt.remaining(), 0); + arrayFromByteBuf = new byte[decryptedBytesByteBuf.readableBytes()]; + decryptedBytesByteBuf.getBytes(decryptedBytesByteBuf.readerIndex(), arrayFromByteBuf); + Assert.assertArrayEquals(decryptedBytes.array(), arrayFromByteBuf); + + toEncryptByteBuf.release(); + encryptedBytesByteBuf.release(); + toDecryptByteBuf.release(); + decryptedBytesByteBuf.release(); + } + } + + /** + * A mock {@link CryptoService} that doesn't implements default methods. + */ + static class MockCryptoService implements CryptoService { + private GCMCryptoService cryptoService; + private final String key; + private final byte[] fixedIv; + + public MockCryptoService() { + key = TestUtils.getRandomKey(DEFAULT_KEY_SIZE_IN_CHARS); + Properties props = getKMSProperties(key, DEFAULT_KEY_SIZE_IN_CHARS); + VerifiableProperties verifiableProperties = new VerifiableProperties((props)); + cryptoService = (GCMCryptoService) new GCMCryptoServiceFactory(verifiableProperties, REGISTRY).getCryptoService(); + fixedIv = new byte[12]; + } + + @Override + public ByteBuffer encrypt(ByteBuffer toEncrypt, SecretKeySpec key) throws GeneralSecurityException { + return cryptoService.encrypt(toEncrypt, key, fixedIv); + } + + @Override + public ByteBuffer decrypt(ByteBuffer toDecrypt, SecretKeySpec key) throws GeneralSecurityException { + return cryptoService.decrypt(toDecrypt, key); + } + + @Override + public ByteBuffer encryptKey(SecretKeySpec toEncrypt, SecretKeySpec key) throws GeneralSecurityException { + return cryptoService.encryptKey(toEncrypt, key); + } + + @Override + public SecretKeySpec decryptKey(ByteBuffer toDecrypt, SecretKeySpec key) throws GeneralSecurityException { + return cryptoService.decryptKey(toDecrypt, key); + } + + public String getKey() { + return key; + } + } +} diff --git a/ambry-router/src/test/java/com.github.ambry.router/GCMCryptoServiceTest.java b/ambry-router/src/test/java/com.github.ambry.router/GCMCryptoServiceTest.java index 3e75322375..bbe7e792bb 100644 --- a/ambry-router/src/test/java/com.github.ambry.router/GCMCryptoServiceTest.java +++ b/ambry-router/src/test/java/com.github.ambry.router/GCMCryptoServiceTest.java @@ -14,17 +14,22 @@ package com.github.ambry.router; import com.codahale.metrics.MetricRegistry; +import com.github.ambry.commons.NettyByteBufLeakHelper; import com.github.ambry.config.VerifiableProperties; import com.github.ambry.utils.TestUtils; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; import java.nio.ByteBuffer; import java.security.GeneralSecurityException; import java.util.Arrays; import java.util.Properties; import javax.crypto.spec.SecretKeySpec; import org.bouncycastle.util.encoders.Hex; +import org.junit.After; import org.junit.Assert; +import org.junit.Before; import org.junit.Test; import static com.github.ambry.router.CryptoTestUtils.*; @@ -36,8 +41,20 @@ public class GCMCryptoServiceTest { private static final int MAX_DATA_SIZE = 10000; + private static final int MIN_DATA_SIZE = 100; private static final int DEFAULT_KEY_SIZE_IN_CHARS = 64; private static final MetricRegistry REGISTRY = new MetricRegistry(); + private final NettyByteBufLeakHelper nettyByteBufLeakHelper = new NettyByteBufLeakHelper(); + + @Before + public void before() { + nettyByteBufLeakHelper.beforeTest(); + } + + @After + public void after() { + nettyByteBufLeakHelper.afterTest(); + } /** * Tests basic encryption and decryption for different sizes of keys and random data in bytes @@ -100,6 +117,9 @@ public void testEncryptDecryptNettyByteBuf() throws Exception { Assert.assertTrue(encryptedBytesByteBufDirect.hasArray()); Assert.assertEquals(encryptedBytes.remaining(), encryptedBytesByteBufHeap.readableBytes()); Assert.assertEquals(encryptedBytes.remaining(), encryptedBytesByteBufDirect.readableBytes()); + Assert.assertEquals(toEncrypt.remaining(), 0); + Assert.assertEquals(toEncryptByteBufDirect.readableBytes(), 0); + Assert.assertEquals(toEncryptByteBufHeap.readableBytes(), 0); byte[] arrayFromByteBuf = new byte[encryptedBytesByteBufHeap.readableBytes()]; encryptedBytesByteBufHeap.getBytes(encryptedBytesByteBufHeap.readerIndex(), arrayFromByteBuf); @@ -119,12 +139,106 @@ public void testEncryptDecryptNettyByteBuf() throws Exception { Assert.assertTrue(decryptedBytesByteBufDirect.hasArray()); Assert.assertEquals(decryptedBytes.remaining(), decryptedBytesByteBufHeap.readableBytes()); Assert.assertEquals(decryptedBytes.remaining(), decryptedBytesByteBufDirect.readableBytes()); + Assert.assertEquals(encryptedBytes.remaining(), 0); + Assert.assertEquals(toDecryptByteBufDirect.readableBytes(), 0); + Assert.assertEquals(toDecryptByteBufHeap.readableBytes(), 0); arrayFromByteBuf = new byte[decryptedBytesByteBufHeap.readableBytes()]; decryptedBytesByteBufHeap.getBytes(decryptedBytesByteBufHeap.readerIndex(), arrayFromByteBuf); Assert.assertArrayEquals(decryptedBytes.array(), arrayFromByteBuf); decryptedBytesByteBufDirect.getBytes(decryptedBytesByteBufDirect.readerIndex(), arrayFromByteBuf); Assert.assertArrayEquals(decryptedBytes.array(), arrayFromByteBuf); + + toEncryptByteBufHeap.release(); + toEncryptByteBufDirect.release(); + encryptedBytesByteBufHeap.release(); + encryptedBytesByteBufDirect.release(); + toDecryptByteBufDirect.release(); + decryptedBytesByteBufHeap.release(); + decryptedBytesByteBufDirect.release(); + } + } + + @Test + public void testEncryptDecryptNettyCompositeByteBuf() throws Exception { + // testEncryptDecryptNettyByte already tests the correctness of the encrypt decrypt methods with non-composite Netty + // ByteBuf, in this test case, we can make the assumption that these two functions always provide correct answers. + + String key = TestUtils.getRandomKey(DEFAULT_KEY_SIZE_IN_CHARS); + Properties props = getKMSProperties(key, DEFAULT_KEY_SIZE_IN_CHARS); + VerifiableProperties verifiableProperties = new VerifiableProperties((props)); + SecretKeySpec secretKeySpec = new SecretKeySpec(Hex.decode(key), "AES"); + GCMCryptoService cryptoService = + (GCMCryptoService) (new GCMCryptoServiceFactory(verifiableProperties, REGISTRY).getCryptoService()); + byte[] fixedIv = new byte[12]; + for (int i = 0; i < 5; i++) { + int size = TestUtils.RANDOM.nextInt(MAX_DATA_SIZE - MIN_DATA_SIZE) + MIN_DATA_SIZE; + byte[] randomData = new byte[size]; + TestUtils.RANDOM.nextBytes(randomData); + + ByteBuf toEncrypt = Unpooled.wrappedBuffer(randomData); + CompositeByteBuf toEncryptComposite = new CompositeByteBuf(toEncrypt.alloc(), toEncrypt.isDirect(), size); + int start = 0; + int end = 0; + for (int j = 0; j < 3; j++) { + start = end; + end = TestUtils.RANDOM.nextInt(size / 2 - 1) + end; + if (j == 2) { + end = size; + } + ByteBuf c = Unpooled.buffer(end - start); + c.writeBytes(randomData, start, end - start); + toEncryptComposite.addComponent(true, c); + } + + ByteBuf encryptedBytes = cryptoService.encrypt(toEncrypt, secretKeySpec, fixedIv); + ByteBuf encryptedBytesComposite = cryptoService.encrypt(toEncryptComposite, secretKeySpec, fixedIv); + Assert.assertEquals(encryptedBytes.readableBytes(), encryptedBytesComposite.readableBytes()); + Assert.assertEquals(toEncrypt.readableBytes(), 0); + Assert.assertEquals(toEncryptComposite.readableBytes(), 0); + + byte[] array = new byte[encryptedBytes.readableBytes()]; + encryptedBytes.getBytes(encryptedBytes.readerIndex(), array); + byte[] arrayComposite = new byte[encryptedBytesComposite.readableBytes()]; + encryptedBytesComposite.getBytes(encryptedBytesComposite.readerIndex(), arrayComposite); + Assert.assertArrayEquals(array, arrayComposite); + + ByteBuf toDecrypt = encryptedBytes; + CompositeByteBuf toDecryptComposite = new CompositeByteBuf(toDecrypt.alloc(), toEncrypt.isDirect(), size); + size = encryptedBytes.readableBytes(); + start = 0; + end = 0; + for (int j = 0; j < 3; j++) { + start = end; + end = TestUtils.RANDOM.nextInt(size / 2 - 1) + end; + if (j == 2) { + end = size; + } + ByteBuf c = Unpooled.buffer(end - start); + c.writeBytes(encryptedBytes, start, end - start); + toDecryptComposite.addComponent(true, c); + } + + ByteBuf decryptedBytes = cryptoService.decrypt(toDecrypt, secretKeySpec); + ByteBuf decryptedBytesComposite = cryptoService.decrypt(toDecryptComposite, secretKeySpec); + + Assert.assertEquals(decryptedBytes.readableBytes(), decryptedBytesComposite.readableBytes()); + Assert.assertEquals(toDecrypt.readableBytes(), 0); + Assert.assertEquals(toDecryptComposite.readableBytes(), 0); + + array = new byte[decryptedBytes.readableBytes()]; + arrayComposite = new byte[decryptedBytesComposite.readableBytes()]; + decryptedBytes.getBytes(decryptedBytes.readerIndex(), array); + decryptedBytesComposite.getBytes(decryptedBytesComposite.readerIndex(), arrayComposite); + Assert.assertArrayEquals(array, arrayComposite); + + toEncrypt.release(); + toEncryptComposite.release(); + encryptedBytes.release(); + encryptedBytesComposite.release(); + toDecryptComposite.release(); + decryptedBytes.release(); + decryptedBytesComposite.release(); } } diff --git a/ambry-utils/src/main/java/com.github.ambry.utils/Utils.java b/ambry-utils/src/main/java/com.github.ambry.utils/Utils.java index e9b9b17c2c..088a338b94 100644 --- a/ambry-utils/src/main/java/com.github.ambry.utils/Utils.java +++ b/ambry-utils/src/main/java/com.github.ambry.utils/Utils.java @@ -14,7 +14,9 @@ package com.github.ambry.utils; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.Unpooled; import java.io.BufferedReader; import java.io.DataInputStream; import java.io.File; @@ -290,6 +292,43 @@ public static ByteBuffer readByteBufferFromCrcInputStream(CrcInputStream crcStre return output; } + /** + * Represent an operation that accepts a {@link ByteBuffer} and returns another {@link ByteBuffer}. Side effect should + * be expected to the input {@link ByteBuffer} and the returned {@link ByteBuffer} should be ready for read. + * @param The exception to throw in this operation. + */ + @FunctionalInterface + public static interface ByteBufferFunction { + ByteBuffer apply(ByteBuffer buffer) throws T; + } + + /** + * Apply a {@link ByteBufferFunction} to a {@link ByteBuf} and return a {@link ByteBuf}. All the bytes in the input + * {@link ByteBuf} will be consumed. + * @param buf The input {@link ByteBuf}. + * @param fn The {@link ByteBufferFunction}. + * @param The exception to throw in the {@code fn}. + * @return A {@link ByteBuf}. + * @throws T Exception thrown from {@code fn}. + */ + public static ByteBuf applyByteBufferFunctionToByteBuf(ByteBuf buf, ByteBufferFunction fn) + throws T { + if (buf.nioBufferCount() == 1) { + ByteBuffer buffer = buf.nioBuffer(); + buf.skipBytes(buffer.remaining()); + return Unpooled.wrappedBuffer(fn.apply(buffer)); + } else { + ByteBuf temp = ByteBufAllocator.DEFAULT.heapBuffer(buf.readableBytes()); + try { + temp.writeBytes(buf); + ByteBuffer buffer = temp.nioBuffer(); + return Unpooled.wrappedBuffer(fn.apply(buffer)); + } finally { + temp.release(); + } + } + } + /** * Create a new thread *