1919
2020import java .io .File ;
2121import java .io .FileInputStream ;
22- import java .io .FileOutputStream ;
2322import java .io .IOException ;
23+ import java .io .OutputStream ;
24+ import java .nio .channels .FileChannel ;
25+ import java .nio .channels .WritableByteChannel ;
2426import javax .annotation .Nullable ;
2527
2628import scala .None$ ;
3436import org .slf4j .Logger ;
3537import org .slf4j .LoggerFactory ;
3638
39+ import org .apache .spark .api .shuffle .ShuffleMapOutputWriter ;
40+ import org .apache .spark .api .shuffle .ShufflePartitionWriter ;
41+ import org .apache .spark .api .shuffle .ShuffleWriteSupport ;
3742import org .apache .spark .internal .config .package$ ;
3843import org .apache .spark .Partitioner ;
3944import org .apache .spark .ShuffleDependency ;
@@ -82,6 +87,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
8287 private final int shuffleId ;
8388 private final int mapId ;
8489 private final Serializer serializer ;
90+ private final ShuffleWriteSupport shuffleWriteSupport ;
8591 private final IndexShuffleBlockResolver shuffleBlockResolver ;
8692
8793 /** Array of file writers, one for each partition */
@@ -103,7 +109,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
103109 BypassMergeSortShuffleHandle <K , V > handle ,
104110 int mapId ,
105111 SparkConf conf ,
106- ShuffleWriteMetricsReporter writeMetrics ) {
112+ ShuffleWriteMetricsReporter writeMetrics ,
113+ ShuffleWriteSupport shuffleWriteSupport ) {
107114 // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
108115 this .fileBufferSize = (int ) (long ) conf .get (package$ .MODULE$ .SHUFFLE_FILE_BUFFER_SIZE ()) * 1024 ;
109116 this .transferToEnabled = conf .getBoolean ("spark.file.transferTo" , true );
@@ -116,57 +123,61 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
116123 this .writeMetrics = writeMetrics ;
117124 this .serializer = dep .serializer ();
118125 this .shuffleBlockResolver = shuffleBlockResolver ;
126+ this .shuffleWriteSupport = shuffleWriteSupport ;
119127 }
120128
121129 @ Override
122130 public void write (Iterator <Product2 <K , V >> records ) throws IOException {
123131 assert (partitionWriters == null );
124- if (!records .hasNext ()) {
125- partitionLengths = new long [numPartitions ];
126- shuffleBlockResolver .writeIndexFileAndCommit (shuffleId , mapId , partitionLengths , null );
127- mapStatus = MapStatus$ .MODULE$ .apply (blockManager .shuffleServerId (), partitionLengths );
128- return ;
129- }
130- final SerializerInstance serInstance = serializer .newInstance ();
131- final long openStartTime = System .nanoTime ();
132- partitionWriters = new DiskBlockObjectWriter [numPartitions ];
133- partitionWriterSegments = new FileSegment [numPartitions ];
134- for (int i = 0 ; i < numPartitions ; i ++) {
135- final Tuple2 <TempShuffleBlockId , File > tempShuffleBlockIdPlusFile =
136- blockManager .diskBlockManager ().createTempShuffleBlock ();
137- final File file = tempShuffleBlockIdPlusFile ._2 ();
138- final BlockId blockId = tempShuffleBlockIdPlusFile ._1 ();
139- partitionWriters [i ] =
140- blockManager .getDiskWriter (blockId , file , serInstance , fileBufferSize , writeMetrics );
141- }
142- // Creating the file to write to and creating a disk writer both involve interacting with
143- // the disk, and can take a long time in aggregate when we open many files, so should be
144- // included in the shuffle write time.
145- writeMetrics .incWriteTime (System .nanoTime () - openStartTime );
146-
147- while (records .hasNext ()) {
148- final Product2 <K , V > record = records .next ();
149- final K key = record ._1 ();
150- partitionWriters [partitioner .getPartition (key )].write (key , record ._2 ());
151- }
132+ ShuffleMapOutputWriter mapOutputWriter = shuffleWriteSupport
133+ .createMapOutputWriter (shuffleId , mapId , numPartitions );
134+ try {
135+ if (!records .hasNext ()) {
136+ partitionLengths = new long [numPartitions ];
137+ mapOutputWriter .commitAllPartitions ();
138+ mapStatus = MapStatus$ .MODULE$ .apply (blockManager .shuffleServerId (), partitionLengths );
139+ return ;
140+ }
141+ final SerializerInstance serInstance = serializer .newInstance ();
142+ final long openStartTime = System .nanoTime ();
143+ partitionWriters = new DiskBlockObjectWriter [numPartitions ];
144+ partitionWriterSegments = new FileSegment [numPartitions ];
145+ for (int i = 0 ; i < numPartitions ; i ++) {
146+ final Tuple2 <TempShuffleBlockId , File > tempShuffleBlockIdPlusFile =
147+ blockManager .diskBlockManager ().createTempShuffleBlock ();
148+ final File file = tempShuffleBlockIdPlusFile ._2 ();
149+ final BlockId blockId = tempShuffleBlockIdPlusFile ._1 ();
150+ partitionWriters [i ] =
151+ blockManager .getDiskWriter (blockId , file , serInstance , fileBufferSize , writeMetrics );
152+ }
153+ // Creating the file to write to and creating a disk writer both involve interacting with
154+ // the disk, and can take a long time in aggregate when we open many files, so should be
155+ // included in the shuffle write time.
156+ writeMetrics .incWriteTime (System .nanoTime () - openStartTime );
152157
153- for (int i = 0 ; i < numPartitions ; i ++) {
154- try (DiskBlockObjectWriter writer = partitionWriters [i ]) {
155- partitionWriterSegments [i ] = writer .commitAndGet ();
158+ while (records .hasNext ()) {
159+ final Product2 <K , V > record = records .next ();
160+ final K key = record ._1 ();
161+ partitionWriters [partitioner .getPartition (key )].write (key , record ._2 ());
156162 }
157- }
158163
159- File output = shuffleBlockResolver .getDataFile (shuffleId , mapId );
160- File tmp = Utils .tempFileWith (output );
161- try {
162- partitionLengths = writePartitionedFile (tmp );
163- shuffleBlockResolver .writeIndexFileAndCommit (shuffleId , mapId , partitionLengths , tmp );
164- } finally {
165- if (tmp .exists () && !tmp .delete ()) {
166- logger .error ("Error while deleting temp file {}" , tmp .getAbsolutePath ());
164+ for (int i = 0 ; i < numPartitions ; i ++) {
165+ try (DiskBlockObjectWriter writer = partitionWriters [i ]) {
166+ partitionWriterSegments [i ] = writer .commitAndGet ();
167+ }
167168 }
169+
170+ partitionLengths = writePartitionedData (mapOutputWriter );
171+ mapOutputWriter .commitAllPartitions ();
172+ mapStatus = MapStatus$ .MODULE$ .apply (blockManager .shuffleServerId (), partitionLengths );
173+ } catch (Exception e ) {
174+ try {
175+ mapOutputWriter .abort (e );
176+ } catch (Exception e2 ) {
177+ logger .error ("Failed to abort the writer after failing to write map output." , e2 );
178+ }
179+ throw e ;
168180 }
169- mapStatus = MapStatus$ .MODULE$ .apply (blockManager .shuffleServerId (), partitionLengths );
170181 }
171182
172183 @ VisibleForTesting
@@ -179,37 +190,54 @@ long[] getPartitionLengths() {
179190 *
180191 * @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
181192 */
182- private long [] writePartitionedFile ( File outputFile ) throws IOException {
193+ private long [] writePartitionedData ( ShuffleMapOutputWriter mapOutputWriter ) throws IOException {
183194 // Track location of the partition starts in the output file
184195 final long [] lengths = new long [numPartitions ];
185196 if (partitionWriters == null ) {
186197 // We were passed an empty iterator
187198 return lengths ;
188199 }
189-
190- final FileOutputStream out = new FileOutputStream (outputFile , true );
191200 final long writeStartTime = System .nanoTime ();
192- boolean threwException = true ;
193201 try {
194202 for (int i = 0 ; i < numPartitions ; i ++) {
195203 final File file = partitionWriterSegments [i ].file ();
196- if ( file . exists ()) {
197- final FileInputStream in = new FileInputStream ( file ) ;
198- boolean copyThrewException = true ;
199- try {
200- lengths [ i ] = Utils . copyStream ( in , out , false , transferToEnabled );
204+ boolean copyThrewException = true ;
205+ ShufflePartitionWriter writer = null ;
206+ try {
207+ writer = mapOutputWriter . getNextPartitionWriter ();
208+ if (! file . exists ()) {
201209 copyThrewException = false ;
202- } finally {
203- Closeables .close (in , copyThrewException );
204- }
205- if (!file .delete ()) {
206- logger .error ("Unable to delete file for partition {}" , i );
210+ } else {
211+ if (transferToEnabled ) {
212+ WritableByteChannel outputChannel = writer .toChannel ();
213+ FileInputStream in = new FileInputStream (file );
214+ try (FileChannel inputChannel = in .getChannel ()) {
215+ Utils .copyFileStreamNIO (inputChannel , outputChannel , 0 , inputChannel .size ());
216+ copyThrewException = false ;
217+ } finally {
218+ Closeables .close (in , copyThrewException );
219+ }
220+ } else {
221+ OutputStream tempOutputStream = writer .toStream ();
222+ FileInputStream in = new FileInputStream (file );
223+ try {
224+ Utils .copyStream (in , tempOutputStream , false , false );
225+ copyThrewException = false ;
226+ } finally {
227+ Closeables .close (in , copyThrewException );
228+ }
229+ }
230+ if (!file .delete ()) {
231+ logger .error ("Unable to delete file for partition {}" , i );
232+ }
207233 }
234+ } finally {
235+ Closeables .close (writer , copyThrewException );
208236 }
237+
238+ lengths [i ] = writer .getNumBytesWritten ();
209239 }
210- threwException = false ;
211240 } finally {
212- Closeables .close (out , threwException );
213241 writeMetrics .incWriteTime (System .nanoTime () - writeStartTime );
214242 }
215243 partitionWriters = null ;
0 commit comments