diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java index 29edbb933e60..c5701e8360bb 100644 --- a/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/DatagramUnicastTest.java @@ -36,14 +36,18 @@ public class DatagramUnicastTest extends AbstractDatagramTest { private static final byte[] BYTES = {0, 1, 2, 3}; + private enum WrapType { + NONE, DUP, SLICE, + } + @Test public void testSimpleSendDirectByteBuf() throws Throwable { run(); } public void testSimpleSendDirectByteBuf(Bootstrap sb, Bootstrap cb) throws Throwable { - testSimpleSend0(sb, cb, Unpooled.directBuffer().writeBytes(BYTES), true, BYTES, 1); - testSimpleSend0(sb, cb, Unpooled.directBuffer().writeBytes(BYTES), true, BYTES, 4); + testSimpleSend(sb, cb, Unpooled.directBuffer().writeBytes(BYTES), true, BYTES, 1); + testSimpleSend(sb, cb, Unpooled.directBuffer().writeBytes(BYTES), true, BYTES, 4); } @Test @@ -52,8 +56,8 @@ public void testSimpleSendHeapByteBuf() throws Throwable { } public void testSimpleSendHeapByteBuf(Bootstrap sb, Bootstrap cb) throws Throwable { - testSimpleSend0(sb, cb, Unpooled.buffer().writeBytes(BYTES), true, BYTES, 1); - testSimpleSend0(sb, cb, Unpooled.buffer().writeBytes(BYTES), true, BYTES, 4); + testSimpleSend(sb, cb, Unpooled.buffer().writeBytes(BYTES), true, BYTES, 1); + testSimpleSend(sb, cb, Unpooled.buffer().writeBytes(BYTES), true, BYTES, 4); } @Test @@ -65,12 +69,12 @@ public void testSimpleSendCompositeDirectByteBuf(Bootstrap sb, Bootstrap cb) thr CompositeByteBuf buf = Unpooled.compositeBuffer(); buf.addComponent(true, Unpooled.directBuffer().writeBytes(BYTES, 0, 2)); buf.addComponent(true, Unpooled.directBuffer().writeBytes(BYTES, 2, 2)); - testSimpleSend0(sb, cb, buf, true, BYTES, 1); + testSimpleSend(sb, cb, buf, true, BYTES, 1); CompositeByteBuf buf2 = Unpooled.compositeBuffer(); buf2.addComponent(true, Unpooled.directBuffer().writeBytes(BYTES, 0, 2)); buf2.addComponent(true, Unpooled.directBuffer().writeBytes(BYTES, 2, 2)); - testSimpleSend0(sb, cb, buf2, true, BYTES, 4); + testSimpleSend(sb, cb, buf2, true, BYTES, 4); } @Test @@ -82,12 +86,12 @@ public void testSimpleSendCompositeHeapByteBuf(Bootstrap sb, Bootstrap cb) throw CompositeByteBuf buf = Unpooled.compositeBuffer(); buf.addComponent(true, Unpooled.buffer().writeBytes(BYTES, 0, 2)); buf.addComponent(true, Unpooled.buffer().writeBytes(BYTES, 2, 2)); - testSimpleSend0(sb, cb, buf, true, BYTES, 1); + testSimpleSend(sb, cb, buf, true, BYTES, 1); CompositeByteBuf buf2 = Unpooled.compositeBuffer(); buf2.addComponent(true, Unpooled.buffer().writeBytes(BYTES, 0, 2)); buf2.addComponent(true, Unpooled.buffer().writeBytes(BYTES, 2, 2)); - testSimpleSend0(sb, cb, buf2, true, BYTES, 4); + testSimpleSend(sb, cb, buf2, true, BYTES, 4); } @Test @@ -99,12 +103,12 @@ public void testSimpleSendCompositeMixedByteBuf(Bootstrap sb, Bootstrap cb) thro CompositeByteBuf buf = Unpooled.compositeBuffer(); buf.addComponent(true, Unpooled.directBuffer().writeBytes(BYTES, 0, 2)); buf.addComponent(true, Unpooled.buffer().writeBytes(BYTES, 2, 2)); - testSimpleSend0(sb, cb, buf, true, BYTES, 1); + testSimpleSend(sb, cb, buf, true, BYTES, 1); CompositeByteBuf buf2 = Unpooled.compositeBuffer(); buf2.addComponent(true, Unpooled.directBuffer().writeBytes(BYTES, 0, 2)); buf2.addComponent(true, Unpooled.buffer().writeBytes(BYTES, 2, 2)); - testSimpleSend0(sb, cb, buf2, true, BYTES, 4); + testSimpleSend(sb, cb, buf2, true, BYTES, 4); } @Test @@ -113,13 +117,21 @@ public void testSimpleSendWithoutBind() throws Throwable { } public void testSimpleSendWithoutBind(Bootstrap sb, Bootstrap cb) throws Throwable { - testSimpleSend0(sb, cb, Unpooled.directBuffer().writeBytes(BYTES), false, BYTES, 1); - testSimpleSend0(sb, cb, Unpooled.directBuffer().writeBytes(BYTES), false, BYTES, 4); + testSimpleSend(sb, cb, Unpooled.directBuffer().writeBytes(BYTES), false, BYTES, 1); + testSimpleSend(sb, cb, Unpooled.directBuffer().writeBytes(BYTES), false, BYTES, 4); + } + + private void testSimpleSend(Bootstrap sb, Bootstrap cb, ByteBuf buf, boolean bindClient, + final byte[] bytes, int count) throws Throwable { + for (WrapType type: WrapType.values()) { + testSimpleSend0(sb, cb, buf.retain(), bindClient, bytes, count, type); + } + assertTrue(buf.release()); } @SuppressWarnings("deprecation") private void testSimpleSend0(Bootstrap sb, Bootstrap cb, ByteBuf buf, boolean bindClient, - final byte[] bytes, int count) + final byte[] bytes, int count, WrapType wrapType) throws Throwable { final CountDownLatch latch = new CountDownLatch(count); @@ -177,7 +189,15 @@ public void channelRead0(ChannelHandlerContext ctx, Object msgs) throws Exceptio } for (int i = 0; i < count; i++) { - cc.write(new DatagramPacket(buf.retain().duplicate(), addr)); + if (wrapType == WrapType.DUP) { + cc.write(new DatagramPacket(buf.retain().duplicate(), addr)); + } else if (wrapType == WrapType.SLICE) { + cc.write(new DatagramPacket(buf.retain().slice(), addr)); + } else if (wrapType == WrapType.NONE) { + cc.write(new DatagramPacket(buf.retain(), addr)); + } else { + throw new Exception("unknown wrap type: " + wrapType); + } } // release as we used buf.retain() before buf.release(); diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java index 6c3efe869e2a..2ca159846a6f 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/AbstractEpollStreamChannel.java @@ -17,7 +17,6 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.CompositeByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelFuture; @@ -35,6 +34,7 @@ import io.netty.channel.unix.FileDescriptor; import io.netty.channel.unix.IovArray; import io.netty.channel.unix.SocketWritableByteChannel; +import io.netty.channel.unix.UnixChannelUtil; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; import io.netty.util.internal.ThrowableUtil; @@ -53,7 +53,6 @@ import java.util.concurrent.TimeUnit; import static io.netty.channel.unix.FileDescriptor.pipe; -import static io.netty.channel.unix.Limits.IOV_MAX; import static io.netty.util.internal.ObjectUtil.checkNotNull; public abstract class AbstractEpollStreamChannel extends AbstractEpollChannel implements DuplexChannel { @@ -539,24 +538,7 @@ private boolean doWriteMultiple(ChannelOutboundBuffer in, int writeSpinCount) th protected Object filterOutboundMessage(Object msg) { if (msg instanceof ByteBuf) { ByteBuf buf = (ByteBuf) msg; - if (!buf.hasMemoryAddress() && (PlatformDependent.hasUnsafe() || !buf.isDirect())) { - if (buf instanceof CompositeByteBuf) { - // Special handling of CompositeByteBuf to reduce memory copies if some of the Components - // in the CompositeByteBuf are backed by a memoryAddress. - CompositeByteBuf comp = (CompositeByteBuf) buf; - if (!comp.isDirect() || comp.nioBufferCount() > IOV_MAX) { - // more then 1024 buffers for gathering writes so just do a memory copy. - buf = newDirectBuffer(buf); - assert buf.hasMemoryAddress(); - } - } else { - // We can only handle buffers with memory address so we need to copy if a non direct is - // passed to write. - buf = newDirectBuffer(buf); - assert buf.hasMemoryAddress(); - } - } - return buf; + return UnixChannelUtil.isBufferCopyNeededForWrite(buf)? newDirectBuffer(buf): buf; } if (msg instanceof FileRegion || msg instanceof SpliceOutTask) { diff --git a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java index eb5d367b60f7..d8fb55286abb 100644 --- a/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java +++ b/transport-native-epoll/src/main/java/io/netty/channel/epoll/EpollDatagramChannel.java @@ -17,7 +17,6 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.CompositeByteBuf; import io.netty.channel.AddressedEnvelope; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelMetadata; @@ -30,6 +29,7 @@ import io.netty.channel.socket.DatagramPacket; import io.netty.channel.unix.DatagramSocketAddress; import io.netty.channel.unix.IovArray; +import io.netty.channel.unix.UnixChannelUtil; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; @@ -45,7 +45,6 @@ import java.util.List; import static io.netty.channel.epoll.LinuxSocket.newSocketDgram; -import static io.netty.channel.unix.Limits.IOV_MAX; /** * {@link DatagramChannel} implementation that uses linux EPOLL Edge-Triggered Mode for @@ -373,7 +372,7 @@ private boolean doWriteMessage(Object msg) throws Exception { long memoryAddress = data.memoryAddress(); writtenBytes = socket.sendToAddress(memoryAddress, data.readerIndex(), data.writerIndex(), remoteAddress.getAddress(), remoteAddress.getPort()); - } else if (data instanceof CompositeByteBuf) { + } else if (data.nioBufferCount() > 1) { IovArray array = ((EpollEventLoop) eventLoop()).cleanArray(); array.add(data); int cnt = array.count(); @@ -395,43 +394,13 @@ protected Object filterOutboundMessage(Object msg) { if (msg instanceof DatagramPacket) { DatagramPacket packet = (DatagramPacket) msg; ByteBuf content = packet.content(); - if (content.hasMemoryAddress()) { - return msg; - } - - if (content.isDirect() && content instanceof CompositeByteBuf) { - // Special handling of CompositeByteBuf to reduce memory copies if some of the Components - // in the CompositeByteBuf are backed by a memoryAddress. - CompositeByteBuf comp = (CompositeByteBuf) content; - if (comp.isDirect() && comp.nioBufferCount() <= IOV_MAX) { - return msg; - } - } - // We can only handle direct buffers so we need to copy if a non direct is - // passed to write. - return new DatagramPacket(newDirectBuffer(packet, content), packet.recipient()); + return UnixChannelUtil.isBufferCopyNeededForWrite(content) ? + new DatagramPacket(newDirectBuffer(packet, content), packet.recipient()) : msg; } if (msg instanceof ByteBuf) { ByteBuf buf = (ByteBuf) msg; - if (!buf.hasMemoryAddress() && (PlatformDependent.hasUnsafe() || !buf.isDirect())) { - if (buf instanceof CompositeByteBuf) { - // Special handling of CompositeByteBuf to reduce memory copies if some of the Components - // in the CompositeByteBuf are backed by a memoryAddress. - CompositeByteBuf comp = (CompositeByteBuf) buf; - if (!comp.isDirect() || comp.nioBufferCount() > IOV_MAX) { - // more then 1024 buffers for gathering writes so just do a memory copy. - buf = newDirectBuffer(buf); - assert buf.hasMemoryAddress(); - } - } else { - // We can only handle buffers with memory address so we need to copy if a non direct is - // passed to write. - buf = newDirectBuffer(buf); - assert buf.hasMemoryAddress(); - } - } - return buf; + return UnixChannelUtil.isBufferCopyNeededForWrite(buf)? newDirectBuffer(buf) : buf; } if (msg instanceof AddressedEnvelope) { @@ -441,21 +410,9 @@ protected Object filterOutboundMessage(Object msg) { (e.recipient() == null || e.recipient() instanceof InetSocketAddress)) { ByteBuf content = (ByteBuf) e.content(); - if (content.hasMemoryAddress()) { - return e; - } - if (content instanceof CompositeByteBuf) { - // Special handling of CompositeByteBuf to reduce memory copies if some of the Components - // in the CompositeByteBuf are backed by a memoryAddress. - CompositeByteBuf comp = (CompositeByteBuf) content; - if (comp.isDirect() && comp.nioBufferCount() <= IOV_MAX) { - return e; - } - } - // We can only handle direct buffers so we need to copy if a non direct is - // passed to write. - return new DefaultAddressedEnvelope( - newDirectBuffer(e, content), (InetSocketAddress) e.recipient()); + return UnixChannelUtil.isBufferCopyNeededForWrite(content)? + new DefaultAddressedEnvelope( + newDirectBuffer(e, content), (InetSocketAddress) e.recipient()) : e; } } diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java index d55599797924..db8bb55fc5ac 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/AbstractKQueueStreamChannel.java @@ -17,7 +17,6 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.CompositeByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelFuture; @@ -33,6 +32,7 @@ import io.netty.channel.socket.DuplexChannel; import io.netty.channel.unix.IovArray; import io.netty.channel.unix.SocketWritableByteChannel; +import io.netty.channel.unix.UnixChannelUtil; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; import io.netty.util.internal.ThrowableUtil; @@ -48,8 +48,6 @@ import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; -import static io.netty.channel.unix.Limits.IOV_MAX; - @UnstableApi public abstract class AbstractKQueueStreamChannel extends AbstractKQueueChannel implements DuplexChannel { private static final ChannelMetadata METADATA = new ChannelMetadata(false, 16); @@ -369,24 +367,7 @@ private boolean doWriteMultiple(ChannelOutboundBuffer in, int writeSpinCount) th protected Object filterOutboundMessage(Object msg) { if (msg instanceof ByteBuf) { ByteBuf buf = (ByteBuf) msg; - if (!buf.hasMemoryAddress() && (PlatformDependent.hasUnsafe() || !buf.isDirect())) { - if (buf instanceof CompositeByteBuf) { - // Special handling of CompositeByteBuf to reduce memory copies if some of the Components - // in the CompositeByteBuf are backed by a memoryAddress. - CompositeByteBuf comp = (CompositeByteBuf) buf; - if (!comp.isDirect() || comp.nioBufferCount() > IOV_MAX) { - // more then 1024 buffers for gathering writes so just do a memory copy. - buf = newDirectBuffer(buf); - assert buf.hasMemoryAddress(); - } - } else { - // We can only handle buffers with memory address so we need to copy if a non direct is - // passed to write. - buf = newDirectBuffer(buf); - assert buf.hasMemoryAddress(); - } - } - return buf; + return UnixChannelUtil.isBufferCopyNeededForWrite(buf)? newDirectBuffer(buf) : buf; } if (msg instanceof FileRegion) { diff --git a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java index 9ae3845f7409..1c8f6137da0b 100644 --- a/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java +++ b/transport-native-kqueue/src/main/java/io/netty/channel/kqueue/KQueueDatagramChannel.java @@ -17,7 +17,6 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.CompositeByteBuf; import io.netty.channel.AddressedEnvelope; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelMetadata; @@ -30,6 +29,7 @@ import io.netty.channel.socket.DatagramPacket; import io.netty.channel.unix.DatagramSocketAddress; import io.netty.channel.unix.IovArray; +import io.netty.channel.unix.UnixChannelUtil; import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.StringUtil; import io.netty.util.internal.UnstableApi; @@ -46,7 +46,6 @@ import java.util.List; import static io.netty.channel.kqueue.BsdSocket.newSocketDgram; -import static io.netty.channel.unix.Limits.IOV_MAX; @UnstableApi public final class KQueueDatagramChannel extends AbstractKQueueChannel implements DatagramChannel { @@ -345,7 +344,7 @@ private boolean doWriteMessage(Object msg) throws Exception { long memoryAddress = data.memoryAddress(); writtenBytes = socket.sendToAddress(memoryAddress, data.readerIndex(), data.writerIndex(), remoteAddress.getAddress(), remoteAddress.getPort()); - } else if (data instanceof CompositeByteBuf) { + } else if (data.nioBufferCount() > 1) { IovArray array = ((KQueueEventLoop) eventLoop()).cleanArray(); array.add(data); int cnt = array.count(); @@ -367,43 +366,13 @@ protected Object filterOutboundMessage(Object msg) { if (msg instanceof DatagramPacket) { DatagramPacket packet = (DatagramPacket) msg; ByteBuf content = packet.content(); - if (content.hasMemoryAddress()) { - return msg; - } - - if (content.isDirect() && content instanceof CompositeByteBuf) { - // Special handling of CompositeByteBuf to reduce memory copies if some of the Components - // in the CompositeByteBuf are backed by a memoryAddress. - CompositeByteBuf comp = (CompositeByteBuf) content; - if (comp.isDirect() && comp.nioBufferCount() <= IOV_MAX) { - return msg; - } - } - // We can only handle direct buffers so we need to copy if a non direct is - // passed to write. - return new DatagramPacket(newDirectBuffer(packet, content), packet.recipient()); + return UnixChannelUtil.isBufferCopyNeededForWrite(content)? + new DatagramPacket(newDirectBuffer(packet, content), packet.recipient()) : msg; } if (msg instanceof ByteBuf) { ByteBuf buf = (ByteBuf) msg; - if (!buf.hasMemoryAddress() && (PlatformDependent.hasUnsafe() || !buf.isDirect())) { - if (buf instanceof CompositeByteBuf) { - // Special handling of CompositeByteBuf to reduce memory copies if some of the Components - // in the CompositeByteBuf are backed by a memoryAddress. - CompositeByteBuf comp = (CompositeByteBuf) buf; - if (!comp.isDirect() || comp.nioBufferCount() > IOV_MAX) { - // more then 1024 buffers for gathering writes so just do a memory copy. - buf = newDirectBuffer(buf); - assert buf.hasMemoryAddress(); - } - } else { - // We can only handle buffers with memory address so we need to copy if a non direct is - // passed to write. - buf = newDirectBuffer(buf); - assert buf.hasMemoryAddress(); - } - } - return buf; + return UnixChannelUtil.isBufferCopyNeededForWrite(buf)? newDirectBuffer(buf) : buf; } if (msg instanceof AddressedEnvelope) { @@ -413,21 +382,9 @@ protected Object filterOutboundMessage(Object msg) { (e.recipient() == null || e.recipient() instanceof InetSocketAddress)) { ByteBuf content = (ByteBuf) e.content(); - if (content.hasMemoryAddress()) { - return e; - } - if (content instanceof CompositeByteBuf) { - // Special handling of CompositeByteBuf to reduce memory copies if some of the Components - // in the CompositeByteBuf are backed by a memoryAddress. - CompositeByteBuf comp = (CompositeByteBuf) content; - if (comp.isDirect() && comp.nioBufferCount() <= IOV_MAX) { - return e; - } - } - // We can only handle direct buffers so we need to copy if a non direct is - // passed to write. - return new DefaultAddressedEnvelope( - newDirectBuffer(e, content), (InetSocketAddress) e.recipient()); + return UnixChannelUtil.isBufferCopyNeededForWrite(content)? + new DefaultAddressedEnvelope( + newDirectBuffer(e, content), (InetSocketAddress) e.recipient()) : e; } } diff --git a/transport-native-unix-common-tests/src/main/java/io/netty/channel/unix/tests/UnixChannelUtilTest.java b/transport-native-unix-common-tests/src/main/java/io/netty/channel/unix/tests/UnixChannelUtilTest.java new file mode 100644 index 000000000000..38eeae63d207 --- /dev/null +++ b/transport-native-unix-common-tests/src/main/java/io/netty/channel/unix/tests/UnixChannelUtilTest.java @@ -0,0 +1,84 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel.unix.tests; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.buffer.UnpooledByteBufAllocator; +import org.junit.Test; + +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; + +import static io.netty.channel.unix.Limits.IOV_MAX; +import static io.netty.channel.unix.UnixChannelUtil.isBufferCopyNeededForWrite; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class UnixChannelUtilTest { + @Test + public void testPooledAllocatorIsBufferCopyNeededForWrite() { + testIsBufferCopyNeededForWrite(PooledByteBufAllocator.DEFAULT); + } + + @Test + public void testUnPooledAllocatorIsBufferCopyNeededForWrite() { + testIsBufferCopyNeededForWrite(UnpooledByteBufAllocator.DEFAULT); + } + + private static void testIsBufferCopyNeededForWrite(ByteBufAllocator alloc) { + ByteBuf byteBuf = alloc.directBuffer(); + assertFalse(isBufferCopyNeededForWrite(byteBuf)); + assertTrue(byteBuf.release()); + + byteBuf = alloc.heapBuffer(); + assertTrue(isBufferCopyNeededForWrite(byteBuf)); + assertTrue(byteBuf.release()); + + assertCompositeByteBufIsBufferCopyNeededForWrite(alloc, 2, 0, false); + assertCompositeByteBufIsBufferCopyNeededForWrite(alloc, IOV_MAX + 1, 0, true); + assertCompositeByteBufIsBufferCopyNeededForWrite(alloc, 0, 2, true); + assertCompositeByteBufIsBufferCopyNeededForWrite(alloc, 1, 1, true); + } + + private static void assertCompositeByteBufIsBufferCopyNeededForWrite(ByteBufAllocator alloc, int numDirect, + int numHeap, boolean expected) { + CompositeByteBuf comp = alloc.compositeBuffer(numDirect + numHeap); + List byteBufs = new LinkedList(); + + while (numDirect > 0) { + byteBufs.add(alloc.directBuffer(1)); + numDirect--; + } + while (numHeap > 0) { + byteBufs.add(alloc.heapBuffer(1)); + numHeap--; + } + + Collections.shuffle(byteBufs); + for (ByteBuf byteBuf : byteBufs) { + comp.addComponent(byteBuf); + } + + assertEquals(byteBufs.toString(), expected, isBufferCopyNeededForWrite(comp)); + assertTrue(comp.release()); + } +} diff --git a/transport-native-unix-common/src/main/c/netty_unix_socket.c b/transport-native-unix-common/src/main/c/netty_unix_socket.c index 6de618eb9a12..4ada12740d70 100644 --- a/transport-native-unix-common/src/main/c/netty_unix_socket.c +++ b/transport-native-unix-common/src/main/c/netty_unix_socket.c @@ -496,7 +496,7 @@ static jint netty_unix_socket_sendToAddresses(JNIEnv* env, jclass clazz, jint fd return -1; } - struct msghdr m; + struct msghdr m = { 0 }; m.msg_name = (void*) &addr; m.msg_namelen = addrSize; m.msg_iov = (struct iovec*) (intptr_t) memoryAddress; diff --git a/transport-native-unix-common/src/main/java/io/netty/channel/unix/IovArray.java b/transport-native-unix-common/src/main/java/io/netty/channel/unix/IovArray.java index e47465856488..04c8d8645085 100644 --- a/transport-native-unix-common/src/main/java/io/netty/channel/unix/IovArray.java +++ b/transport-native-unix-common/src/main/java/io/netty/channel/unix/IovArray.java @@ -16,9 +16,9 @@ package io.netty.channel.unix; import io.netty.buffer.ByteBuf; -import io.netty.buffer.CompositeByteBuf; import io.netty.channel.ChannelOutboundBuffer.MessageProcessor; import io.netty.util.internal.PlatformDependent; + import java.nio.ByteBuffer; import static io.netty.channel.unix.Limits.IOV_MAX; @@ -76,22 +76,41 @@ public void clear() { * {@code false} otherwise. */ public boolean add(ByteBuf buf) { - if (count == IOV_MAX) { + int nioBufferCount = buf.nioBufferCount(); + if (count + nioBufferCount > IOV_MAX) { // No more room! return false; } - final int len = buf.readableBytes(); - if (len == 0) { - // No need to add an empty buffer. - // We return true here because we want ChannelOutboundBuffer.forEachFlushedMessage() to continue - // fetching the next buffers. + if (nioBufferCount == 1) { + final int len = buf.readableBytes(); + if (len == 0) { + // No need to add an empty buffer. + // We return true here because we want ChannelOutboundBuffer.forEachFlushedMessage() to continue + // fetching the next buffers. + return true; + } + + final long addr = buf.memoryAddress(); + final int offset = buf.readerIndex(); + return add(addr, offset, len); + } else { + ByteBuffer[] buffers = buf.nioBuffers(); + for (ByteBuffer nioBuffer : buffers) { + int len = nioBuffer.remaining(); + if (len == 0) { + // No need to add an empty buffer so just continue + continue; + } + int offset = nioBuffer.position(); + long addr = PlatformDependent.directBufferAddress(nioBuffer); + + if (!add(addr, offset, len)) { + return false; + } + } return true; } - - final long addr = buf.memoryAddress(); - final int offset = buf.readerIndex(); - return add(addr, offset, len); } private boolean add(long addr, int offset, int len) { @@ -126,32 +145,6 @@ private boolean add(long addr, int offset, int len) { return true; } - /** - * Try to add the given {@link CompositeByteBuf}. Returns {@code true} on success, - * {@code false} otherwise. - */ - public boolean add(CompositeByteBuf buf) { - ByteBuffer[] buffers = buf.nioBuffers(); - if (count + buffers.length >= IOV_MAX) { - // No more room! - return false; - } - for (ByteBuffer nioBuffer : buffers) { - int offset = nioBuffer.position(); - int len = nioBuffer.limit() - nioBuffer.position(); - if (len == 0) { - // No need to add an empty buffer so just continue - continue; - } - long addr = PlatformDependent.directBufferAddress(nioBuffer); - - if (!add(addr, offset, len)) { - return false; - } - } - return true; - } - /** * Process the written iov entries. This will return the length of the iov entry on the given index if it is * smaller then the given {@code written} value. Otherwise it returns {@code -1}. @@ -213,11 +206,7 @@ public void release() { @Override public boolean processMessage(Object msg) throws Exception { if (msg instanceof ByteBuf) { - if (msg instanceof CompositeByteBuf) { - return add((CompositeByteBuf) msg); - } else { - return add((ByteBuf) msg); - } + return add((ByteBuf) msg); } return false; } diff --git a/transport-native-unix-common/src/main/java/io/netty/channel/unix/UnixChannelUtil.java b/transport-native-unix-common/src/main/java/io/netty/channel/unix/UnixChannelUtil.java new file mode 100644 index 000000000000..8672c20f5c2c --- /dev/null +++ b/transport-native-unix-common/src/main/java/io/netty/channel/unix/UnixChannelUtil.java @@ -0,0 +1,34 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.buffer.ByteBuf; + +import static io.netty.channel.unix.Limits.IOV_MAX; + +public final class UnixChannelUtil { + + private UnixChannelUtil() { + } + + /** + * Checks if the specified buffer has memory address or is composed of n(n <= IOV_MAX) NIO direct buffers. + * (We check this because otherwise we need to make it a new direct buffer.) + */ + public static boolean isBufferCopyNeededForWrite(ByteBuf byteBuf) { + return !(byteBuf.hasMemoryAddress() || byteBuf.isDirect() && byteBuf.nioBufferCount() <= IOV_MAX); + } +}