From b53f42196306adbc0d5ebcfdbab7246a3113708a Mon Sep 17 00:00:00 2001 From: yl09099 Date: Fri, 29 Nov 2024 18:04:13 +0800 Subject: [PATCH] Triggering Stage retry requires reassigning the shuffle server in the retry Stage --- .../manager/RssShuffleManagerBase.java | 23 ++++++++----------- .../manager/ShuffleManagerGrpcService.java | 18 +++++++++------ .../spark/shuffle/RssShuffleManager.java | 4 ++-- .../shuffle/writer/RssShuffleWriter.java | 2 +- .../RssPartitionToShuffleServerRequest.java | 8 +++---- proto/src/main/proto/Rss.proto | 2 +- 6 files changed, 29 insertions(+), 28 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index 195128a915..c1fc5b68eb 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -571,23 +571,20 @@ protected static RemoteStorageInfo getDefaultRemoteStorageInfo(SparkConf sparkCo } public ShuffleHandleInfo getShuffleHandleInfo( - int stageAttemptId, int stageAttemptNumber, RssShuffleHandle rssHandle) { + int stageAttemptId, + int stageAttemptNumber, + RssShuffleHandle rssHandle, + boolean isWritePhase) { int shuffleId = rssHandle.getShuffleId(); if (shuffleManagerRpcServiceEnabled && rssStageRetryForWriteFailureEnabled) { // In Stage Retry mode, Get the ShuffleServer list from the Driver based on the shuffleId. return getRemoteShuffleHandleInfoWithStageRetry( - stageAttemptId, - stageAttemptNumber, - shuffleId, - rssHandle.getDependency().partitioner().numPartitions()); + stageAttemptId, stageAttemptNumber, shuffleId, isWritePhase); } else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) { // In partition block Retry mode, Get the ShuffleServer list from the Driver based on the // shuffleId. return getRemoteShuffleHandleInfoWithBlockRetry( - stageAttemptId, - stageAttemptNumber, - shuffleId, - rssHandle.getDependency().partitioner().numPartitions()); + stageAttemptId, stageAttemptNumber, shuffleId, isWritePhase); } else { return new SimpleShuffleHandleInfo( shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage()); @@ -601,10 +598,10 @@ public ShuffleHandleInfo getShuffleHandleInfo( * @return ShuffleHandleInfo */ protected synchronized StageAttemptShuffleHandleInfo getRemoteShuffleHandleInfoWithStageRetry( - int stageAttemptId, int stageAttemptNumber, int shuffleId, int numPartitions) { + int stageAttemptId, int stageAttemptNumber, int shuffleId, boolean isWritePhase) { RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest = new RssPartitionToShuffleServerRequest( - stageAttemptId, stageAttemptNumber, shuffleId, numPartitions); + stageAttemptId, stageAttemptNumber, shuffleId, isWritePhase); RssReassignOnStageRetryResponse rpcPartitionToShufflerServer = getOrCreateShuffleManagerClientSupplier() .get() @@ -622,10 +619,10 @@ protected synchronized StageAttemptShuffleHandleInfo getRemoteShuffleHandleInfoW * @return ShuffleHandleInfo */ protected synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfoWithBlockRetry( - int stageAttemptId, int stageAttemptNumber, int shuffleId, int numPartitions) { + int stageAttemptId, int stageAttemptNumber, int shuffleId, boolean isWritePhase) { RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest = new RssPartitionToShuffleServerRequest( - stageAttemptId, stageAttemptNumber, shuffleId, numPartitions); + stageAttemptId, stageAttemptNumber, shuffleId, isWritePhase); RssReassignOnBlockSendFailureResponse rpcPartitionToShufflerServer = getOrCreateShuffleManagerClientSupplier() .get() diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java index 768de1a2c2..667b6a9056 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java @@ -224,14 +224,18 @@ public void getPartitionToShufflerServerWithStageRetry( int stageAttemptId = request.getStageAttemptId(); int stageAttemptNumber = request.getStageAttemptNumber(); int shuffleId = request.getShuffleId(); + boolean isWritePhase = request.getIsWritePhase(); StageAttemptShuffleHandleInfo shuffleHandle; - ShuffleServerWriterFailureRecord shuffleServerWriterFailureRecord = - shuffleWriteStatus.get(shuffleId); - if (shuffleServerWriterFailureRecord != null) { - synchronized (shuffleServerWriterFailureRecord) { - if (shuffleServerWriterFailureRecord.isNeedReassignForLastStageNumber(stageAttemptNumber)) { - shuffleManager.reassignOnStageResubmit(shuffleId, stageAttemptId, stageAttemptNumber); - shuffleServerWriterFailureRecord.setShuffleServerAssignmented(true); + if (isWritePhase) { + ShuffleServerWriterFailureRecord shuffleServerWriterFailureRecord = + shuffleWriteStatus.get(shuffleId); + if (shuffleServerWriterFailureRecord != null) { + synchronized (shuffleServerWriterFailureRecord) { + if (shuffleServerWriterFailureRecord.isNeedReassignForLastStageNumber( + stageAttemptNumber)) { + shuffleManager.reassignOnStageResubmit(shuffleId, stageAttemptId, stageAttemptNumber); + shuffleServerWriterFailureRecord.setShuffleServerAssignmented(true); + } } } } diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index a52db23eee..d628272438 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -494,12 +494,12 @@ public ShuffleReader getReader( // In Stage Retry mode, Get the ShuffleServer list from the Driver based on the shuffleId. shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry( - context.stageId(), context.stageAttemptNumber(), shuffleId, partitionNum); + context.stageId(), context.stageAttemptNumber(), shuffleId, false); } else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) { // In Block Retry mode, Get the ShuffleServer list from the Driver based on the shuffleId shuffleHandleInfo = getRemoteShuffleHandleInfoWithBlockRetry( - context.stageId(), context.stageAttemptNumber(), shuffleId, partitionNum); + context.stageId(), context.stageAttemptNumber(), shuffleId, false); } else { shuffleHandleInfo = new SimpleShuffleHandleInfo( diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index e2632a7f35..aa4ff9f890 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -212,7 +212,7 @@ public RssShuffleWriter( rssHandle, taskFailureCallback, shuffleManager.getShuffleHandleInfo( - context.stageId(), context.stageAttemptNumber(), rssHandle), + context.stageId(), context.stageAttemptNumber(), rssHandle, true), context); BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf); final WriteBufferManager bufferManager = diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssPartitionToShuffleServerRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssPartitionToShuffleServerRequest.java index 0a020c3ef9..9e42d2c6a7 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssPartitionToShuffleServerRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssPartitionToShuffleServerRequest.java @@ -23,14 +23,14 @@ public class RssPartitionToShuffleServerRequest { private int stageAttemptId; private int stageAttemptNumber; private int shuffleId; - private int numPartitions; + private boolean isWritePhase; public RssPartitionToShuffleServerRequest( - int stageAttemptId, int stageAttemptNumber, int shuffleId, int numPartitions) { + int stageAttemptId, int stageAttemptNumber, int shuffleId, boolean isWritePhase) { this.stageAttemptId = stageAttemptId; this.stageAttemptNumber = stageAttemptNumber; this.shuffleId = shuffleId; - this.numPartitions = numPartitions; + this.isWritePhase = isWritePhase; } public int getShuffleId() { @@ -47,7 +47,7 @@ public RssProtos.PartitionToShuffleServerRequest toProto() { builder.setStageAttemptId(stageAttemptId); builder.setStageAttemptNumber(stageAttemptNumber); builder.setShuffleId(shuffleId); - builder.setNumPartitions(numPartitions); + builder.setIsWritePhase(isWritePhase); return builder.build(); } } diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index 0655aad006..4749a06584 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -608,7 +608,7 @@ message PartitionToShuffleServerRequest { int32 stageAttemptId = 1; int32 stageAttemptNumber = 2; int32 shuffleId = 3; - int32 numPartitions = 4; + bool isWritePhase = 4; } message ReassignOnStageRetryResponse {