Skip to content

Commit 24880a3

Browse files
committed
Triggering Stage retry requires reassigning the shuffle server in the retry Stage
1 parent 7d1d97c commit 24880a3

File tree

4 files changed

+35
-4
lines changed

4 files changed

+35
-4
lines changed

client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java

+19-4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import java.util.function.Supplier;
3535
import java.util.stream.Collectors;
3636

37+
import scala.Tuple2;
38+
3739
import com.google.common.annotations.VisibleForTesting;
3840
import com.google.common.collect.Maps;
3941
import com.google.common.collect.Sets;
@@ -684,8 +686,6 @@ public boolean reassignOnStageResubmit(
684686
int requiredShuffleServerNumber =
685687
RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
686688
int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(sparkConf);
687-
// Deregister the shuffleId corresponding to the Shuffle Server.
688-
shuffleWriteClient.unregisterShuffle(appId, shuffleId);
689689
Map<Integer, List<ShuffleServerInfo>> partitionToServers =
690690
requestShuffleAssignment(
691691
shuffleId,
@@ -1042,7 +1042,7 @@ protected void registerShuffleServers(
10421042
}
10431043
LOG.info("Start to register shuffleId {}", shuffleId);
10441044
long start = System.currentTimeMillis();
1045-
Map<String, String> sparkConfMap = RssSparkConfig.sparkConfToMap(getSparkConf());
1045+
Map<String, String> sparkConfMap = sparkConfToMap(getSparkConf());
10461046
serverToPartitionRanges.entrySet().stream()
10471047
.forEach(
10481048
entry -> {
@@ -1073,7 +1073,7 @@ protected void registerShuffleServers(
10731073
}
10741074
LOG.info("Start to register shuffleId[{}]", shuffleId);
10751075
long start = System.currentTimeMillis();
1076-
Map<String, String> sparkConfMap = RssSparkConfig.sparkConfToMap(getSparkConf());
1076+
Map<String, String> sparkConfMap = sparkConfToMap(getSparkConf());
10771077
Set<Map.Entry<ShuffleServerInfo, List<PartitionRange>>> entries =
10781078
serverToPartitionRanges.entrySet();
10791079
entries.stream()
@@ -1119,4 +1119,19 @@ public boolean isRssStageRetryForFetchFailureEnabled() {
11191119
public SparkConf getSparkConf() {
11201120
return sparkConf;
11211121
}
1122+
1123+
public Map<String, String> sparkConfToMap(SparkConf sparkConf) {
1124+
Map<String, String> map = new HashMap<>();
1125+
1126+
for (Tuple2<String, String> tuple : sparkConf.getAll()) {
1127+
String key = tuple._1;
1128+
map.put(key, tuple._2);
1129+
}
1130+
1131+
return map;
1132+
}
1133+
1134+
public ShuffleWriteClient getShuffleWriteClient() {
1135+
return shuffleWriteClient;
1136+
}
11221137
}

client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerInterface.java

+8
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
2525
import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
2626

27+
import org.apache.uniffle.client.api.ShuffleWriteClient;
2728
import org.apache.uniffle.common.ReceivingFailureServer;
2829
import org.apache.uniffle.shuffle.BlockIdManager;
2930

@@ -88,4 +89,11 @@ MutableShuffleHandleInfo reassignOnBlockSendFailure(
8889
int shuffleId,
8990
Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers,
9091
boolean partitionSplit);
92+
93+
/**
94+
* Driver Obtains ShuffleWriteClient.
95+
*
96+
* @return
97+
*/
98+
ShuffleWriteClient getShuffleWriteClient();
9199
}

client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.java

+2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ public void reportShuffleWriteFailure(
118118
// Clear the metadata of the completed task, otherwise some of the stage's data will
119119
// be lost.
120120
shuffleManager.unregisterAllMapOutput(shuffleId);
121+
// Deregister the shuffleId corresponding to the Shuffle Server.
122+
shuffleManager.getShuffleWriteClient().unregisterShuffle(appId, shuffleId);
121123
shuffleServerWriterFailureRecord.setClearedMapTrackerBlock(true);
122124
LOG.info(
123125
"Clear shuffle result in shuffleId:{}, stageId:{}, stageAttemptNumber:{}.",

client-spark/common/src/test/java/org/apache/uniffle/shuffle/manager/DummyRssShuffleManager.java

+6
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
2626
import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
2727

28+
import org.apache.uniffle.client.api.ShuffleWriteClient;
2829
import org.apache.uniffle.common.ReceivingFailureServer;
2930
import org.apache.uniffle.shuffle.BlockIdManager;
3031

@@ -84,4 +85,9 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure(
8485
boolean partitionSplit) {
8586
return null;
8687
}
88+
89+
@Override
90+
public ShuffleWriteClient getShuffleWriteClient() {
91+
return null;
92+
}
8793
}

0 commit comments

Comments
 (0)