Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public void copyFromRssServer() throws IOException {
// fetch a block
if (!hasPendingData) {
final long startFetch = System.currentTimeMillis();
compressedBlock = shuffleReadClient.readShuffleBlockData();
compressedBlock = (CompressedShuffleBlock) (shuffleReadClient.readShuffleBlockData());
if (compressedBlock != null) {
compressedData = compressedBlock.getByteBuffer();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@

public class RssSparkConfig {

public static final ConfigOption<Boolean> RSS_READ_OVERLAPPING_DECOMPRESSION_ENABLED =
ConfigOptions.key("rss.client.read.overlappingDecompressionEnable")
.booleanType()
.defaultValue(false)
.withDescription("Whether to overlapping decompress shuffle blocks.");

public static final ConfigOption<Integer> RSS_READ_OVERLAPPING_DECOMPRESSION_THREADS =
ConfigOptions.key("rss.client.read.overlappingDecompressionThreads")
.intType()
.defaultValue(1)
.withDescription("Number of threads to use for overlapping decompress shuffle blocks.");

public static final ConfigOption<Boolean> RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED =
ConfigOptions.key("rss.client.write.overlappingCompressionEnable")
.booleanType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

import org.apache.uniffle.client.api.ShuffleReadClient;
import org.apache.uniffle.client.response.CompressedShuffleBlock;
import org.apache.uniffle.client.response.ShuffleBlock;
import org.apache.uniffle.common.ShuffleReadTimes;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
Expand All @@ -63,14 +64,14 @@ public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C
private ByteBuffer uncompressedData;
private Optional<Codec> codec;

// only for tests
@VisibleForTesting
public RssShuffleDataIterator(
Serializer serializer,
ShuffleReadClient shuffleReadClient,
ShuffleReadMetrics shuffleReadMetrics,
RssConf rssConf) {
this.serializerInstance = serializer.newInstance();
this.shuffleReadClient = shuffleReadClient;
this.shuffleReadMetrics = shuffleReadMetrics;
this(serializer, shuffleReadClient, shuffleReadMetrics, rssConf, Optional.empty());
boolean compress =
rssConf.getBoolean(
RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY.substring(
Expand All @@ -79,6 +80,18 @@ public RssShuffleDataIterator(
this.codec = compress ? Codec.newInstance(rssConf) : Optional.empty();
}

public RssShuffleDataIterator(
Serializer serializer,
ShuffleReadClient shuffleReadClient,
ShuffleReadMetrics shuffleReadMetrics,
RssConf rssConf,
Optional<Codec> codec) {
this.serializerInstance = serializer.newInstance();
this.shuffleReadClient = shuffleReadClient;
this.shuffleReadMetrics = shuffleReadMetrics;
this.codec = codec;
}

public Iterator<Tuple2<Object, Object>> createKVIterator(ByteBuffer data) {
clearDeserializationStream();
// Unpooled.wrapperBuffer will return a ByteBuf, but this ByteBuf won't release direct/heap
Expand Down Expand Up @@ -114,17 +127,29 @@ public boolean hasNext() {
// read next segment
long startFetch = System.currentTimeMillis();
// depends on spark.shuffle.compress, shuffled block may not be compressed
CompressedShuffleBlock rawBlock = shuffleReadClient.readShuffleBlockData();
// If ShuffleServer delete
ShuffleBlock shuffleBlock = shuffleReadClient.readShuffleBlockData();
ByteBuffer rawData = shuffleBlock != null ? shuffleBlock.getByteBuffer() : null;

ByteBuffer rawData = rawBlock != null ? rawBlock.getByteBuffer() : null;
long fetchDuration = System.currentTimeMillis() - startFetch;
shuffleReadMetrics.incFetchWaitTime(fetchDuration);
if (rawData != null) {
uncompress(rawBlock, rawData);
// collect metrics from raw data
long rawDataLength = rawData.limit() - rawData.position();
totalRawBytesLength += rawDataLength;
shuffleReadMetrics.incRemoteBytesRead(rawDataLength);

// get initial data
ByteBuffer decompressed = null;
if (shuffleBlock instanceof CompressedShuffleBlock) {
uncompress(shuffleBlock, rawData);
decompressed = uncompressedData;
} else {
decompressed = shuffleBlock.getByteBuffer();
}

// create new iterator for shuffle data
long startSerialization = System.currentTimeMillis();
recordsIterator = createKVIterator(uncompressedData);
recordsIterator = createKVIterator(decompressed);
long serializationDuration = System.currentTimeMillis() - startSerialization;
readTime += fetchDuration;
serializeTime += serializationDuration;
Expand Down Expand Up @@ -156,11 +181,7 @@ private boolean isSameMemoryType(ByteBuffer left, ByteBuffer right) {
return left.isDirect() == right.isDirect();
}

private int uncompress(CompressedShuffleBlock rawBlock, ByteBuffer rawData) {
long rawDataLength = rawData.limit() - rawData.position();
totalRawBytesLength += rawDataLength;
shuffleReadMetrics.incRemoteBytesRead(rawDataLength);

private int uncompress(ShuffleBlock rawBlock, ByteBuffer rawData) {
int uncompressedLen = rawBlock.getUncompressLength();
if (uncompressedLen < 0) {
LOG.error("Uncompressed length is negative: {}", uncompressedLen);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.FunctionUtils;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.ShuffleReader;
import org.apache.spark.util.CompletionIterator;
import org.apache.spark.util.CompletionIterator$;
Expand All @@ -62,11 +63,14 @@
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleReadTimes;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.storage.handler.impl.ShuffleServerReadCostTracker;

import static org.apache.spark.shuffle.RssSparkConfig.RSS_READ_OVERLAPPING_DECOMPRESSION_ENABLED;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_READ_OVERLAPPING_DECOMPRESSION_THREADS;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_READ_REORDER_MULTI_SERVERS_ENABLED;
import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;

Expand Down Expand Up @@ -289,29 +293,42 @@ class MultiPartitionIterator<K, C> extends AbstractIterator<Product2<K, C>> {
rssConf.getLong(
RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX,
RssClientConfig.RSS_CLIENT_RETRY_INTERVAL_MAX_DEFAULT_VALUE);
boolean compress =
rssConf.getBoolean(
RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY.substring(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()),
RssSparkConfig.SPARK_SHUFFLE_COMPRESS_DEFAULT);
Optional<Codec> codec = compress ? Codec.newInstance(rssConf) : Optional.empty();
ShuffleClientFactory.ReadClientBuilder builder =
ShuffleClientFactory.newReadBuilder()
.readCostTracker(shuffleServerReadCostTracker)
.appId(appId)
.shuffleId(shuffleId)
.partitionId(partition)
.basePath(basePath)
.partitionNumPerRange(1)
.partitionNum(partitionNum)
.blockIdBitmap(partitionToExpectBlocks.get(partition))
.taskIdBitmap(taskIdBitmap)
.shuffleServerInfoList(shuffleServerInfoList)
.hadoopConf(hadoopConf)
.shuffleDataDistributionType(dataDistributionType)
.expectedTaskIdsBitmapFilterEnable(expectedTaskIdsBitmapFilterEnable)
.retryMax(retryMax)
.retryIntervalMax(retryIntervalMax)
.rssConf(rssConf);
if (codec.isPresent() && rssConf.get(RSS_READ_OVERLAPPING_DECOMPRESSION_ENABLED)) {
builder
.overlappingDecompressionEnabled(true)
.codec(codec.get())
.overlappingDecompressionThreadNum(
rssConf.get(RSS_READ_OVERLAPPING_DECOMPRESSION_THREADS));
}
ShuffleReadClient shuffleReadClient =
ShuffleClientFactory.getInstance()
.createShuffleReadClient(
ShuffleClientFactory.newReadBuilder()
.readCostTracker(shuffleServerReadCostTracker)
.appId(appId)
.shuffleId(shuffleId)
.partitionId(partition)
.basePath(basePath)
.partitionNumPerRange(1)
.partitionNum(partitionNum)
.blockIdBitmap(partitionToExpectBlocks.get(partition))
.taskIdBitmap(taskIdBitmap)
.shuffleServerInfoList(shuffleServerInfoList)
.hadoopConf(hadoopConf)
.shuffleDataDistributionType(dataDistributionType)
.expectedTaskIdsBitmapFilterEnable(expectedTaskIdsBitmapFilterEnable)
.retryMax(retryMax)
.retryIntervalMax(retryIntervalMax)
.rssConf(rssConf));
ShuffleClientFactory.getInstance().createShuffleReadClient(builder);
RssShuffleDataIterator<K, C> iterator =
new RssShuffleDataIterator<>(
shuffleDependency.serializer(), shuffleReadClient, readMetrics, rssConf);
shuffleDependency.serializer(), shuffleReadClient, readMetrics, rssConf, codec);
CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>> completionIterator =
CompletionIterator$.MODULE$.apply(
iterator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public void copyFromRssServer() throws IOException {
if (!hasPendingData) {
final long startFetch = System.currentTimeMillis();
blockStartFetch = System.currentTimeMillis();
compressedBlock = shuffleReadClient.readShuffleBlockData();
compressedBlock = (CompressedShuffleBlock) (shuffleReadClient.readShuffleBlockData());
if (compressedBlock != null) {
compressedData = compressedBlock.getByteBuffer();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ public void copyFromRssServer() throws IOException {
// fetch a block
if (!hasPendingData) {
final long startFetch = System.currentTimeMillis();
compressedBlock = shuffleReadClient.readShuffleBlockData();
compressedBlock = (CompressedShuffleBlock) (shuffleReadClient.readShuffleBlockData());
if (compressedBlock != null) {
compressedData = compressedBlock.getByteBuffer();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.uniffle.client.api;

import org.apache.uniffle.client.response.CompressedShuffleBlock;
import org.apache.uniffle.client.response.ShuffleBlock;
import org.apache.uniffle.common.ShuffleReadTimes;

public interface ShuffleReadClient {

CompressedShuffleBlock readShuffleBlockData();
ShuffleBlock readShuffleBlockData();

void checkProcessedBlockIds();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.IdHelper;
import org.apache.uniffle.storage.handler.impl.ShuffleServerReadCostTracker;
Expand Down Expand Up @@ -227,6 +228,27 @@ public static class ReadClientBuilder {
private long retryIntervalMax;
private ShuffleServerReadCostTracker readCostTracker;

private boolean overlappingDecompressionEnabled;
private int overlappingDecompressionThreadNum;
private Codec codec;

public ReadClientBuilder overlappingDecompressionEnabled(
boolean overlappingDecompressionEnabled) {
this.overlappingDecompressionEnabled = overlappingDecompressionEnabled;
return this;
}

public ReadClientBuilder overlappingDecompressionThreadNum(
int overlappingDecompressionThreadNum) {
this.overlappingDecompressionThreadNum = overlappingDecompressionThreadNum;
return this;
}

public ReadClientBuilder codec(Codec codec) {
this.codec = codec;
return this;
}

public ReadClientBuilder readCostTracker(ShuffleServerReadCostTracker tracker) {
this.readCostTracker = tracker;
return this;
Expand Down Expand Up @@ -429,6 +451,18 @@ public long getRetryIntervalMax() {
return retryIntervalMax;
}

public boolean isOverlappingDecompressionEnabled() {
return overlappingDecompressionEnabled;
}

public int getOverlappingDecompressionThreadNum() {
return overlappingDecompressionThreadNum;
}

public Codec getCodec() {
return codec;
}

public ShuffleReadClientImpl build() {
return new ShuffleReadClientImpl(this);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.uniffle.client.impl;

import java.nio.ByteBuffer;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.client.response.DecompressedShuffleBlock;
import org.apache.uniffle.common.BufferSegment;
import org.apache.uniffle.common.ShuffleDataResult;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.util.JavaUtils;

public class DecompressionWorker {
private static final Logger LOG = LoggerFactory.getLogger(DecompressionWorker.class);

private final ExecutorService executorService;
private final ConcurrentHashMap<Integer, ConcurrentHashMap<Integer, DecompressedShuffleBlock>>
tasks;
private Codec codec;
private final ThreadLocal<ByteBuffer> bufferLocal =
ThreadLocal.withInitial(() -> ByteBuffer.allocate(0));

public DecompressionWorker(Codec codec, int threads) {
if (codec == null) {
throw new IllegalArgumentException("Codec cannot be null");
}
if (threads <= 0) {
throw new IllegalArgumentException("Threads must be greater than 0");
}
this.tasks = JavaUtils.newConcurrentMap();
this.executorService = Executors.newFixedThreadPool(threads);
this.codec = codec;
}

public void add(int batchIndex, ShuffleDataResult shuffleDataResult) {
List<BufferSegment> bufferSegments = shuffleDataResult.getBufferSegments();
ByteBuffer sharedByteBuffer = shuffleDataResult.getDataBuffer();
int index = 0;
LOG.debug(
"Adding {} segments with batch index:{} to decompression worker",
bufferSegments.size(),
batchIndex);
for (BufferSegment bufferSegment : bufferSegments) {
CompletableFuture<ByteBuffer> f =
CompletableFuture.supplyAsync(
() -> {
int offset = bufferSegment.getOffset();
int length = bufferSegment.getLength();
ByteBuffer buffer = sharedByteBuffer.duplicate();
buffer.position(offset);
buffer.limit(offset + length);

int uncompressedLen = bufferSegment.getUncompressLength();
ByteBuffer dst =
buffer.isDirect()
? ByteBuffer.allocateDirect(uncompressedLen)
: ByteBuffer.allocate(uncompressedLen);
codec.decompress(buffer, uncompressedLen, dst, 0);
return dst;
},
executorService);
ConcurrentHashMap<Integer, DecompressedShuffleBlock> blocks =
tasks.computeIfAbsent(batchIndex, k -> new ConcurrentHashMap<>());
blocks.put(index++, new DecompressedShuffleBlock(f));
}
}

public DecompressedShuffleBlock get(int batchIndex, int segmentIndex) {
ConcurrentHashMap<Integer, DecompressedShuffleBlock> blocks = tasks.get(batchIndex);
if (blocks == null) {
return null;
}
DecompressedShuffleBlock block = blocks.remove(segmentIndex);
return block;
}

public void close() {
executorService.shutdown();
}
}
Loading
Loading