From 4f97b2d7c157bb28ba7c3ed28d9b5aea560f1345 Mon Sep 17 00:00:00 2001 From: baoloongmao Date: Mon, 22 Jul 2024 08:58:01 +0800 Subject: [PATCH] Fix allocate size by add 9 bytes --- .../org/apache/uniffle/common/ShuffleBlockInfo.java | 13 ++++++++++--- .../uniffle/common/ShufflePartitionedBlock.java | 13 ++++++++++--- .../apache/uniffle/common/netty/MessageEncoder.java | 5 ++++- .../uniffle/common/netty/protocol/Message.java | 4 ++-- .../impl/grpc/ShuffleServerGrpcNettyClient.java | 10 +++++++++- .../server/netty/ShuffleServerNettyHandler.java | 4 +++- 6 files changed, 38 insertions(+), 11 deletions(-) diff --git a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java index 36dec5e257..0ec71a304d 100644 --- a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java +++ b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java @@ -95,10 +95,17 @@ public int getLength() { return length; } - // calculate the data size for this block in memory including metadata which are - // partitionId, blockId, crc, taskAttemptId, length, uncompressLength + /** + * Calculate the data size for this block in memory including metadata which are partitionId, + * blockId, crc, taskAttemptId, uncompressLength and data length. This should be consistent with + * {@link ShufflePartitionedBlock#getSize()}. + * + * @return the encoded size of this object in memory + */ public int getSize() { - return length + 3 * 8 + 2 * 4; + // FIXME(maobaolong): The size is calculated based on the Protobuf message ShuffleBlock. + // If Netty's custom serialization is used, the calculation logic here needs to be modified. + return length + 3 * Long.BYTES + 2 * Integer.BYTES; } public long getCrc() { diff --git a/common/src/main/java/org/apache/uniffle/common/ShufflePartitionedBlock.java b/common/src/main/java/org/apache/uniffle/common/ShufflePartitionedBlock.java index 1ce68b6b6b..47b1487b74 100644 --- a/common/src/main/java/org/apache/uniffle/common/ShufflePartitionedBlock.java +++ b/common/src/main/java/org/apache/uniffle/common/ShufflePartitionedBlock.java @@ -51,10 +51,17 @@ public ShufflePartitionedBlock( this.data = data; } - // calculate the data size for this block in memory including metadata which are - // blockId, crc, taskAttemptId, length, uncompressLength + /** + * Calculate the data size for this block in memory including metadata which are partitionId, + * blockId, crc, taskAttemptId, uncompressLength and data length. This should be consistent with + * {@link ShuffleBlockInfo#getSize()}. + * + * @return the encoded size of this object in memory + */ public long getSize() { - return length + 3 * 8 + 2 * 4; + // FIXME(maobaolong): The size is calculated based on the Protobuf message ShuffleBlock. + // If Netty's custom serialization is used, the calculation logic here needs to be modified. + return length + 3 * Long.BYTES + 2 * Integer.BYTES; } @Override diff --git a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java index cd40024826..cc8e228893 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java @@ -41,6 +41,9 @@ public final class MessageEncoder extends MessageToMessageEncoder { private static final Logger logger = LoggerFactory.getLogger(MessageEncoder.class); public static final MessageEncoder INSTANCE = new MessageEncoder(); + public static final int MESSAGE_HEADER_SIZE = + // Inner message encodedLength + TYPE_ENCODED_LENGTH + bodyLength + Integer.BYTES + Message.TYPE_ENCODED_LENGTH + Integer.BYTES; private MessageEncoder() {} @@ -79,7 +82,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) thro Message.Type msgType = in.type(); // message size, message type size, body size, message encoded length - int headerLength = Integer.BYTES + msgType.encodedLength() + Integer.BYTES + in.encodedLength(); + int headerLength = MESSAGE_HEADER_SIZE + in.encodedLength(); ByteBuf header = ctx.alloc().heapBuffer(headerLength); header.writeInt(in.encodedLength()); msgType.encode(header); diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java index 8e925fac80..ff0bec3b10 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java @@ -23,7 +23,7 @@ import org.apache.uniffle.common.netty.buffer.ManagedBuffer; public abstract class Message implements Encodable { - + public static final byte TYPE_ENCODED_LENGTH = 1; private ManagedBuffer body; protected Message() { @@ -79,7 +79,7 @@ public byte id() { @Override public int encodedLength() { - return 1; + return TYPE_ENCODED_LENGTH; } @Override diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java index 5d3303aacb..5ceb915669 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java @@ -42,6 +42,7 @@ import org.apache.uniffle.common.exception.NotRetryException; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.exception.RssFetchFailedException; +import org.apache.uniffle.common.netty.MessageEncoder; import org.apache.uniffle.common.netty.client.TransportClient; import org.apache.uniffle.common.netty.client.TransportClientFactory; import org.apache.uniffle.common.netty.client.TransportConf; @@ -170,7 +171,14 @@ public RssSendShuffleDataResponse sendShuffleData(RssSendShuffleDataRequest requ 0L, stb.getValue(), System.currentTimeMillis()); - int allocateSize = size + sendShuffleDataRequest.encodedLength(); + // allocateSize = MESSAGE_HEADER_SIZE + requestMessage encodedLength + data size + // {@link org.apache.uniffle.common.netty.MessageEncoder#encode} + // We calculated the size again, even though sendShuffleDataRequest.encodedLength() + // already included the data size, because there is a brief moment when decoding + // sendShuffleDataRequest at the shuffle server, where there are two copies of data + // in direct memory. + int allocateSize = + MessageEncoder.MESSAGE_HEADER_SIZE + sendShuffleDataRequest.encodedLength() + size; int finalBlockNum = blockNum; try { RetryUtils.retryWithCondition( diff --git a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java index fdb723dfb6..9cc79d6aef 100644 --- a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java +++ b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java @@ -45,6 +45,7 @@ import org.apache.uniffle.common.exception.ExceedHugePartitionHardLimitException; import org.apache.uniffle.common.exception.FileNotFoundException; import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.netty.MessageEncoder; import org.apache.uniffle.common.netty.buffer.ManagedBuffer; import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer; import org.apache.uniffle.common.netty.client.TransportClient; @@ -135,8 +136,9 @@ public void handleSendShuffleDataRequest(TransportClient client, SendShuffleData PreAllocatedBufferInfo info = shuffleTaskManager.getAndRemovePreAllocatedBuffer(requireBufferId); int requireSize = info == null ? 0 : info.getRequireSize(); + int encodedLength = req.encodedLength() + MessageEncoder.MESSAGE_HEADER_SIZE; int requireBlocksSize = - requireSize - req.encodedLength() < 0 ? 0 : requireSize - req.encodedLength(); + requireSize - encodedLength < 0 ? 0 : requireSize - encodedLength; boolean isPreAllocated = info != null;