Skip to content

Commit 861dc40

Browse files
committed
Refactor of DiskStore for shuffle file consolidation
The main goal of this refactor was to allow the interposition of a new layer which maps logical BlockIds to physical locations other than a file with the same name as the BlockId. In particular, BlockIds will need to be mappable to chunks of files, as multiple will be stored in the same file. In order to accomplish this, the following changes have been made: - Creation of DiskBlockManager, which manages the association of logical BlockIds to physical disk locations (called FileSegments). By default, Blocks are simply mapped to physical files of the same name, as before. - The DiskStore now indirects all requests for a given BlockId through the DiskBlockManager in order to resolve the actual File location. - DiskBlockObjectWriter has been merged into BlockObjectWriter. - The Netty PathResolver has been changed to map BlockIds into FileSegments, as this codepath is the only one that uses Netty, and that is likely to remain the case. Overall, I think this refactor produces a clearer division between the logical Block paradigm and their physical on-disk location. There is now an explicit (and documented) mapping from one to the other.
1 parent 747f538 commit 861dc40

File tree

10 files changed

+366
-269
lines changed

10 files changed

+366
-269
lines changed

core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import io.netty.channel.DefaultFileRegion;
2626

2727
import org.apache.spark.storage.BlockId;
28+
import org.apache.spark.storage.FileSegment;
2829

2930
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
3031

@@ -37,40 +38,34 @@ public FileServerHandler(PathResolver pResolver){
3738
@Override
3839
public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
3940
BlockId blockId = BlockId.apply(blockIdString);
40-
String path = pResolver.getAbsolutePath(blockId.name());
41-
// if getFilePath returns null, close the channel
42-
if (path == null) {
41+
FileSegment fileSegment = pResolver.getBlockLocation(blockId);
42+
// if getBlockLocation returns null, close the channel
43+
if (fileSegment == null) {
4344
//ctx.close();
4445
return;
4546
}
46-
File file = new File(path);
47+
File file = fileSegment.file();
4748
if (file.exists()) {
4849
if (!file.isFile()) {
49-
//logger.info("Not a file : " + file.getAbsolutePath());
5050
ctx.write(new FileHeader(0, blockId).buffer());
5151
ctx.flush();
5252
return;
5353
}
5454
long length = file.length();
5555
if (length > Integer.MAX_VALUE || length <= 0) {
56-
//logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
5756
ctx.write(new FileHeader(0, blockId).buffer());
5857
ctx.flush();
5958
return;
6059
}
6160
int len = new Long(length).intValue();
62-
//logger.info("Sending block "+blockId+" filelen = "+len);
63-
//logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
6461
ctx.write((new FileHeader(len, blockId)).buffer());
6562
try {
6663
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
67-
.getChannel(), 0, file.length()));
64+
.getChannel(), fileSegment.offset(), fileSegment.length()));
6865
} catch (Exception e) {
69-
//logger.warning("Exception when sending file : " + file.getAbsolutePath());
7066
e.printStackTrace();
7167
}
7268
} else {
73-
//logger.warning("File not found: " + file.getAbsolutePath());
7469
ctx.write(new FileHeader(0, blockId).buffer());
7570
}
7671
ctx.flush();

core/src/main/java/org/apache/spark/network/netty/PathResolver.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,10 @@
1717

1818
package org.apache.spark.network.netty;
1919

20+
import org.apache.spark.storage.BlockId;
21+
import org.apache.spark.storage.FileSegment;
2022

2123
public interface PathResolver {
22-
/**
23-
* Get the absolute path of the file
24-
*
25-
* @param fileId
26-
* @return the absolute path of file
27-
*/
28-
public String getAbsolutePath(String fileId);
24+
/** Get the file segment in which the given Block resides. */
25+
public FileSegment getBlockLocation(BlockId blockId);
2926
}

core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.io.File
2121

2222
import org.apache.spark.Logging
2323
import org.apache.spark.util.Utils
24-
import org.apache.spark.storage.BlockId
24+
import org.apache.spark.storage.{BlockId, FileSegment}
2525

2626

2727
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
@@ -54,8 +54,7 @@ private[spark] object ShuffleSender {
5454
val localDirs = args.drop(2).map(new File(_))
5555

5656
val pResovler = new PathResolver {
57-
override def getAbsolutePath(blockIdString: String): String = {
58-
val blockId = BlockId(blockIdString)
57+
override def getBlockLocation(blockId: BlockId): FileSegment = {
5958
if (!blockId.isShuffle) {
6059
throw new Exception("Block " + blockId + " is not a shuffle block")
6160
}
@@ -65,7 +64,7 @@ private[spark] object ShuffleSender {
6564
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
6665
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
6766
val file = new File(subDir, blockId.name)
68-
return file.getAbsolutePath
67+
return new FileSegment(file, 0, file.length())
6968
}
7069
}
7170
val sender = new ShuffleSender(port, pResovler)

core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ private[spark] class ShuffleMapTask(
167167
val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
168168
writer.commit()
169169
writer.close()
170-
val size = writer.size()
170+
val size = writer.fileSegment().length
171171
totalBytes += size
172172
MapOutputTracker.compressSize(size)
173173
}

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import akka.dispatch.{Await, Future}
2828
import akka.util.Duration
2929
import akka.util.duration._
3030

31-
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
31+
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
3232

3333
import org.apache.spark.{Logging, SparkEnv, SparkException}
3434
import org.apache.spark.io.CompressionCodec
@@ -102,18 +102,19 @@ private[spark] class BlockManager(
102102
}
103103

104104
val shuffleBlockManager = new ShuffleBlockManager(this)
105+
val diskBlockManager = new DiskBlockManager(
106+
System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
105107

106108
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
107109

108110
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
109-
private[storage] val diskStore: DiskStore =
110-
new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
111+
private[storage] val diskStore = new DiskStore(this, diskBlockManager)
111112

112113
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
113114
private val nettyPort: Int = {
114115
val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
115116
val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
116-
if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
117+
if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
117118
}
118119

119120
val connectionManager = new ConnectionManager(0)
@@ -567,16 +568,19 @@ private[spark] class BlockManager(
567568

568569
/**
569570
* A short circuited method to get a block writer that can write data directly to disk.
571+
* The Block will be appended to the File specified by filename.
570572
* This is currently used for writing shuffle files out. Callers should handle error
571573
* cases.
572574
*/
573-
def getDiskBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
575+
def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int)
574576
: BlockObjectWriter = {
575-
val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
577+
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
578+
val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true)
579+
val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
576580
writer.registerCloseEventHandler(() => {
577581
val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
578582
blockInfo.put(blockId, myInfo)
579-
myInfo.markReady(writer.size())
583+
myInfo.markReady(writer.fileSegment().length)
580584
})
581585
writer
582586
}
@@ -988,13 +992,24 @@ private[spark] class BlockManager(
988992
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
989993
}
990994

995+
/** Serializes into a stream. */
996+
def dataSerializeStream(
997+
blockId: BlockId,
998+
outputStream: OutputStream,
999+
values: Iterator[Any],
1000+
serializer: Serializer = defaultSerializer) {
1001+
val byteStream = new FastBufferedOutputStream(outputStream)
1002+
val ser = serializer.newInstance()
1003+
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
1004+
}
1005+
1006+
/** Serializes into a byte buffer. */
9911007
def dataSerialize(
9921008
blockId: BlockId,
9931009
values: Iterator[Any],
9941010
serializer: Serializer = defaultSerializer): ByteBuffer = {
9951011
val byteStream = new FastByteArrayOutputStream(4096)
996-
val ser = serializer.newInstance()
997-
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
1012+
dataSerializeStream(blockId, byteStream, values, serializer)
9981013
byteStream.trim()
9991014
ByteBuffer.wrap(byteStream.array)
10001015
}

core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717

1818
package org.apache.spark.storage
1919

20+
import java.io.{FileOutputStream, File, OutputStream}
21+
import java.nio.channels.FileChannel
22+
23+
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
24+
25+
import org.apache.spark.Logging
26+
import org.apache.spark.serializer.{SerializationStream, Serializer}
2027

2128
/**
2229
* An interface for writing JVM objects to some underlying storage. This interface allows
@@ -59,7 +66,86 @@ abstract class BlockObjectWriter(val blockId: BlockId) {
5966
def write(value: Any)
6067

6168
/**
62-
* Size of the valid writes, in bytes.
69+
* Returns the file segment of committed data that this Writer has written.
6370
*/
64-
def size(): Long
71+
def fileSegment(): FileSegment
72+
}
73+
74+
/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
75+
class DiskBlockObjectWriter(
76+
blockId: BlockId,
77+
file: File,
78+
serializer: Serializer,
79+
bufferSize: Int,
80+
compressStream: OutputStream => OutputStream)
81+
extends BlockObjectWriter(blockId)
82+
with Logging
83+
{
84+
85+
/** The file channel, used for repositioning / truncating the file. */
86+
private var channel: FileChannel = null
87+
private var bs: OutputStream = null
88+
private var objOut: SerializationStream = null
89+
private var initialPosition = 0L
90+
private var lastValidPosition = 0L
91+
private var initialized = false
92+
93+
override def open(): BlockObjectWriter = {
94+
val fos = new FileOutputStream(file, true)
95+
channel = fos.getChannel()
96+
initialPosition = channel.position
97+
lastValidPosition = initialPosition
98+
bs = compressStream(new FastBufferedOutputStream(fos, bufferSize))
99+
objOut = serializer.newInstance().serializeStream(bs)
100+
initialized = true
101+
this
102+
}
103+
104+
override def close() {
105+
if (initialized) {
106+
objOut.close()
107+
channel = null
108+
bs = null
109+
objOut = null
110+
}
111+
super.close()
112+
}
113+
114+
override def isOpen: Boolean = objOut != null
115+
116+
override def commit(): Long = {
117+
if (initialized) {
118+
// NOTE: Flush the serializer first and then the compressed/buffered output stream
119+
objOut.flush()
120+
bs.flush()
121+
val prevPos = lastValidPosition
122+
lastValidPosition = channel.position()
123+
lastValidPosition - prevPos
124+
} else {
125+
// lastValidPosition is zero if stream is uninitialized
126+
lastValidPosition
127+
}
128+
}
129+
130+
override def revertPartialWrites() {
131+
if (initialized) {
132+
// Discard current writes. We do this by flushing the outstanding writes and
133+
// truncate the file to the last valid position.
134+
objOut.flush()
135+
bs.flush()
136+
channel.truncate(lastValidPosition)
137+
}
138+
}
139+
140+
override def write(value: Any) {
141+
if (!initialized) {
142+
open()
143+
}
144+
objOut.writeObject(value)
145+
}
146+
147+
override def fileSegment(): FileSegment = {
148+
val bytesWritten = lastValidPosition - initialPosition
149+
new FileSegment(file, initialPosition, bytesWritten)
150+
}
65151
}

0 commit comments

Comments
 (0)