diff --git a/.scalafmt.conf b/.scalafmt.conf
new file mode 100644
index 00000000..356e6edb
--- /dev/null
+++ b/.scalafmt.conf
@@ -0,0 +1,11 @@
+version = 2.4.2
+align = none
+align.openParenDefnSite = false
+align.openParenCallSite = false
+align.tokens = []
+optIn = {
+ configStyleArguments = false
+}
+danglingParentheses = false
+docstrings = JavaDoc
+maxColumn = 98
\ No newline at end of file
diff --git a/core/pom.xml b/core/pom.xml
index b400270b..9ffb0098 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -13,7 +13,7 @@
jar
com.intel.spark-pmof.java
- java
+ spark-pmof-java
1.0
@@ -40,6 +40,11 @@
hpnl
0.5
+
+ com.intel.rpmp
+ rpmp
+ 0.1
+
org.xerial
sqlite-jdbc
diff --git a/core/src/main/java/org/apache/spark/storage/pmof/RemotePersistentMemoryPool.java b/core/src/main/java/org/apache/spark/storage/pmof/RemotePersistentMemoryPool.java
new file mode 100644
index 00000000..9c177d56
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/storage/pmof/RemotePersistentMemoryPool.java
@@ -0,0 +1,72 @@
+package org.apache.spark.storage.pmof;
+
+import com.intel.rpmp.PmPoolClient;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+public class RemotePersistentMemoryPool {
+ private static String remote_host;
+ private static String remote_port_str;
+
+ private RemotePersistentMemoryPool(String remote_address, String remote_port) throws IOException {
+ pmPoolClient = new PmPoolClient(remote_address, remote_port);
+ }
+
+ public static RemotePersistentMemoryPool getInstance(String remote_address, String remote_port) throws IOException {
+ synchronized (RemotePersistentMemoryPool.class) {
+ if (instance == null) {
+ if (instance == null) {
+ remote_host = remote_address;
+ remote_port_str = remote_port;
+ instance = new RemotePersistentMemoryPool(remote_address, remote_port);
+ }
+ }
+ }
+ return instance;
+ }
+
+ public static int close() {
+ synchronized (RemotePersistentMemoryPool.class) {
+ if (instance != null)
+ return instance.dispose();
+ else
+ return 0;
+ }
+ }
+
+ public static String getHost() {
+ return remote_host;
+ }
+
+ public static int getPort() {
+ return Integer.parseInt(remote_port_str);
+ }
+
+ public int read(long address, long size, ByteBuffer byteBuffer) {
+ return pmPoolClient.read(address, size, byteBuffer);
+ }
+
+ public long put(String key, ByteBuffer data, long size) {
+ return pmPoolClient.put(key, data, size);
+ }
+
+ public long get(String key, long size, ByteBuffer data) {
+ return pmPoolClient.get(key, size, data);
+ }
+
+ public long[] getMeta(String key) {
+ return pmPoolClient.getMeta(key);
+ }
+
+ public int del(String key) throws IOException {
+ return pmPoolClient.del(key);
+ }
+
+ public int dispose() {
+ pmPoolClient.dispose();
+ return 0;
+ }
+
+ private static PmPoolClient pmPoolClient;
+ private static RemotePersistentMemoryPool instance;
+}
\ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/network/pmof/Client.scala b/core/src/main/scala/org/apache/spark/network/pmof/Client.scala
index 022c729a..d56e073a 100644
--- a/core/src/main/scala/org/apache/spark/network/pmof/Client.scala
+++ b/core/src/main/scala/org/apache/spark/network/pmof/Client.scala
@@ -4,9 +4,10 @@ import java.nio.ByteBuffer
import java.util.concurrent.ConcurrentHashMap
import com.intel.hpnl.core.{Connection, EqService}
+import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.pmof.PmofShuffleManager
-class Client(clientFactory: ClientFactory, val shuffleManager: PmofShuffleManager, con: Connection) {
+class Client(clientFactory: ClientFactory, val shuffleManager: PmofShuffleManager, con: Connection) extends Logging {
final val outstandingReceiveFetches: ConcurrentHashMap[Long, ReceivedCallback] =
new ConcurrentHashMap[Long, ReceivedCallback]()
final val outstandingReadFetches: ConcurrentHashMap[Int, ReadCallback] =
diff --git a/core/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala b/core/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala
index db2c9161..1a332f2e 100644
--- a/core/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala
+++ b/core/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala
@@ -5,10 +5,11 @@ import java.nio.ByteBuffer
import java.util.concurrent.ConcurrentHashMap
import com.intel.hpnl.core._
+import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.pmof.PmofShuffleManager
import org.apache.spark.util.configuration.pmof.PmofConf
-class ClientFactory(pmofConf: PmofConf) {
+class ClientFactory(pmofConf: PmofConf) extends Logging {
final val eqService = new EqService(pmofConf.clientWorkerNums, pmofConf.clientBufferNums, false).init()
private[this] final val cqService = new CqService(eqService).init()
private[this] final val clientMap = new ConcurrentHashMap[InetSocketAddress, Client]()
@@ -28,6 +29,7 @@ class ClientFactory(pmofConf: PmofConf) {
var client = clientMap.get(socketAddress)
if (client == null) {
ClientFactory.this.synchronized {
+ logInfo(s"createClient target is ${address}:${port}")
client = clientMap.get(socketAddress)
if (client == null) {
val con = eqService.connect(address, port.toString, 0)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
new file mode 100644
index 00000000..31f70c11
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -0,0 +1,232 @@
+package org.apache.spark.scheduler
+
+import java.nio.ByteBuffer
+import java.io.{Externalizable, ObjectInput, ObjectOutput}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.config
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.Utils
+
+/**
+ * A [[MapStatus]] implementation that tracks the size of each block. Size for each block is
+ * represented using a single byte.
+ *
+ * @param loc location where the task is being executed.
+ * @param compressedSizes size of the blocks, indexed by reduce partition id.
+ */
+private[spark] trait MapStatus {
+
+ /** Location where this task was run. */
+ def location: BlockManagerId
+
+ /**
+ * Estimated size for the reduce block, in bytes.
+ *
+ * If a block is non-empty, then this method MUST return a non-zero size. This invariant is
+ * necessary for correctness, since block fetchers are allowed to skip zero-size blocks.
+ */
+ def getSizeForBlock(reduceId: Int): Long
+
+}
+
+private[spark] object MapStatus {
+
+ def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
+ if (uncompressedSizes.length > 2000) {
+ HighlyCompressedMapStatus(loc, uncompressedSizes)
+ } else {
+ new CompressedMapStatus(loc, uncompressedSizes)
+ }
+ }
+
+ private[this] val LOG_BASE = 1.1
+
+ /**
+ * Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
+ * We do this by encoding the log base 1.1 of the size as an integer, which can support
+ * sizes up to 35 GB with at most 10% error.
+ */
+ def compressSize(size: Long): Byte = {
+ if (size == 0) {
+ 0
+ } else if (size <= 1L) {
+ 1
+ } else {
+ math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
+ }
+ }
+
+ /**
+ * Decompress an 8-bit encoded block size, using the reverse operation of compressSize.
+ */
+ def decompressSize(compressedSize: Byte): Long = {
+ if (compressedSize == 0) {
+ 0
+ } else {
+ math.pow(LOG_BASE, compressedSize & 0xFF).toLong
+ }
+ }
+}
+
+/**
+ * A [[MapStatus]] implementation that tracks the size of each block. Size for each block is
+ * represented using a single byte.
+ *
+ * @param loc location where the task is being executed.
+ * @param compressedSizes size of the blocks, indexed by reduce partition id.
+ */
+private[spark] class CompressedMapStatus(
+ private[this] var loc: BlockManagerId,
+ private[this] var compressedSizes: Array[Byte])
+ extends MapStatus
+ with Externalizable {
+
+ protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only
+
+ def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
+ this(loc, uncompressedSizes.map(MapStatus.compressSize))
+ }
+
+ override def location: BlockManagerId = loc
+
+ override def getSizeForBlock(reduceId: Int): Long = {
+ MapStatus.decompressSize(compressedSizes(reduceId))
+ }
+
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
+ loc.writeExternal(out)
+ out.writeInt(compressedSizes.length)
+ out.write(compressedSizes)
+ }
+
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
+ loc = BlockManagerId(in)
+ val len = in.readInt()
+ compressedSizes = new Array[Byte](len)
+ in.readFully(compressedSizes)
+ }
+}
+
+/**
+ * A [[MapStatus]] implementation that stores the accurate size of huge blocks, which are larger
+ * than spark.shuffle.accurateBlockThreshold. It stores the average size of other non-empty blocks,
+ * plus a bitmap for tracking which blocks are empty.
+ *
+ * @param loc location where the task is being executed
+ * @param numNonEmptyBlocks the number of non-empty blocks
+ * @param emptyBlocks a bitmap tracking which blocks are empty
+ * @param avgSize average size of the non-empty and non-huge blocks
+ * @param hugeBlockSizes sizes of huge blocks by their reduceId.
+ */
+private[spark] class HighlyCompressedMapStatus private (
+ private[this] var loc: BlockManagerId,
+ private[this] var numNonEmptyBlocks: Int,
+ private[this] var emptyBlocks: RoaringBitmap,
+ private[this] var avgSize: Long,
+ private var hugeBlockSizes: Map[Int, Byte])
+ extends MapStatus
+ with Externalizable {
+
+ // loc could be null when the default constructor is called during deserialization
+ require(
+ loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
+ "Average size can only be zero for map stages that produced no output")
+
+ protected def this() = this(null, -1, null, -1, null) // For deserialization only
+
+ override def location: BlockManagerId = loc
+
+ override def getSizeForBlock(reduceId: Int): Long = {
+ assert(hugeBlockSizes != null)
+ if (emptyBlocks.contains(reduceId)) {
+ 0
+ } else {
+ hugeBlockSizes.get(reduceId) match {
+ case Some(size) => MapStatus.decompressSize(size)
+ case None => avgSize
+ }
+ }
+ }
+
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
+ loc.writeExternal(out)
+ emptyBlocks.writeExternal(out)
+ out.writeLong(avgSize)
+ out.writeInt(hugeBlockSizes.size)
+ hugeBlockSizes.foreach { kv =>
+ out.writeInt(kv._1)
+ out.writeByte(kv._2)
+ }
+ }
+
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
+ loc = BlockManagerId(in)
+ emptyBlocks = new RoaringBitmap()
+ emptyBlocks.readExternal(in)
+ avgSize = in.readLong()
+ val count = in.readInt()
+ val hugeBlockSizesArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]()
+ (0 until count).foreach { _ =>
+ val block = in.readInt()
+ val size = in.readByte()
+ hugeBlockSizesArray += Tuple2(block, size)
+ }
+ hugeBlockSizes = hugeBlockSizesArray.toMap
+ }
+}
+
+private[spark] object HighlyCompressedMapStatus {
+ def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
+ // We must keep track of which blocks are empty so that we don't report a zero-sized
+ // block as being non-empty (or vice-versa) when using the average block size.
+ var i = 0
+ var numNonEmptyBlocks: Int = 0
+ var numSmallBlocks: Int = 0
+ var totalSmallBlockSize: Long = 0
+ // From a compression standpoint, it shouldn't matter whether we track empty or non-empty
+ // blocks. From a performance standpoint, we benefit from tracking empty blocks because
+ // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions.
+ val emptyBlocks = new RoaringBitmap()
+ val totalNumBlocks = uncompressedSizes.length
+ val threshold = Option(SparkEnv.get)
+ .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD))
+ .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get)
+ val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]()
+ while (i < totalNumBlocks) {
+ val size = uncompressedSizes(i)
+ if (size > 0) {
+ numNonEmptyBlocks += 1
+ // Huge blocks are not included in the calculation for average size, thus size for smaller
+ // blocks is more accurate.
+ if (size < threshold) {
+ totalSmallBlockSize += size
+ numSmallBlocks += 1
+ } else {
+ hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i)))
+ }
+ } else {
+ emptyBlocks.add(i)
+ }
+ i += 1
+ }
+ val avgSize = if (numSmallBlocks > 0) {
+ totalSmallBlockSize / numSmallBlocks
+ } else {
+ 0
+ }
+ emptyBlocks.trim()
+ emptyBlocks.runOptimize()
+ new HighlyCompressedMapStatus(
+ loc,
+ numNonEmptyBlocks,
+ emptyBlocks,
+ avgSize,
+ hugeBlockSizesArray.toMap)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/pmof/UnCompressedMapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/pmof/UnCompressedMapStatus.scala
new file mode 100644
index 00000000..4591a673
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/pmof/UnCompressedMapStatus.scala
@@ -0,0 +1,66 @@
+package org.apache.spark.scheduler.pmof
+
+import java.nio.ByteBuffer
+import java.io.{Externalizable, ObjectInput, ObjectOutput}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.config
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.Utils
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.internal.Logging
+
+private[spark] class UnCompressedMapStatus(
+ private[this] var loc: BlockManagerId,
+ private[this] var data: Array[Byte])
+ extends MapStatus
+ with Externalizable
+ with Logging {
+
+ protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only
+ val step = 8
+
+ def this(loc: BlockManagerId, sizes: Array[Long]) {
+ this(loc, sizes.flatMap(UnCompressedMapStatus.longToBytes))
+ }
+
+ override def location: BlockManagerId = loc
+
+ override def getSizeForBlock(reduceId: Int): Long = {
+ val start = reduceId * step
+ UnCompressedMapStatus.bytesToLong(data.slice(start, start + step))
+ }
+
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
+ loc.writeExternal(out)
+ out.writeInt(data.length)
+ out.write(data)
+ }
+
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
+ loc = BlockManagerId(in)
+ val len = in.readInt()
+ data = new Array[Byte](len)
+ in.readFully(data)
+ }
+}
+
+object UnCompressedMapStatus {
+ def longToBytes(x: Long): Array[Byte] = {
+ val buffer = ByteBuffer.allocate(8)
+ buffer.putLong(x)
+ buffer.array()
+ }
+
+ def bytesToLong(bytes: Array[Byte]): Long = {
+ val buffer = ByteBuffer.allocate(8);
+ buffer.put(bytes)
+ buffer.flip()
+ buffer.getLong()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala
index edc0d993..74753d65 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala
@@ -9,6 +9,7 @@ import java.util.zip.{Deflater, DeflaterOutputStream, Inflater, InflaterInputStr
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{ByteBufferInputStream, ByteBufferOutputStream, Input, Output}
+import org.apache.spark.internal.Logging
import org.apache.spark.SparkEnv
import org.apache.spark.network.pmof._
import org.apache.spark.shuffle.IndexShuffleBlockResolver
@@ -27,7 +28,7 @@ import scala.util.control.{Breaks, ControlThrowable}
* and can be used by executor to send metadata to driver
* @param pmofConf
*/
-class MetadataResolver(pmofConf: PmofConf) {
+class MetadataResolver(pmofConf: PmofConf) extends Logging {
private[this] val blockManager = SparkEnv.get.blockManager
private[this] val blockMap: ConcurrentHashMap[String, ShuffleBuffer] = new ConcurrentHashMap[String, ShuffleBuffer]()
private[this] val blockInfoMap = mutable.HashMap.empty[String, ArrayBuffer[ShuffleBlockInfo]]
@@ -44,7 +45,7 @@ class MetadataResolver(pmofConf: PmofConf) {
* @param rkey
*/
def pushPmemBlockInfo(shuffleId: Int, mapId: Int, dataAddressMap: mutable.HashMap[Int, Array[(Long, Int)]], rkey: Long): Unit = {
- val buffer: Array[Byte] = new Array[Byte](pmofConf.reduce_serializer_buffer_size.toInt)
+ val buffer: Array[Byte] = new Array[Byte](pmofConf.map_serializer_buffer_size.toInt)
var output = new Output(buffer)
val bufferArray = new ArrayBuffer[ByteBuffer]()
@@ -60,7 +61,7 @@ class MetadataResolver(pmofConf: PmofConf) {
blockBuffer.flip()
bufferArray += blockBuffer
output.close()
- val new_buffer = new Array[Byte](pmofConf.reduce_serializer_buffer_size.toInt)
+ val new_buffer = new Array[Byte](pmofConf.map_serializer_buffer_size.toInt)
output = new Output(new_buffer)
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala
index 04e7b03e..e50e5f61 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala
@@ -21,6 +21,7 @@ import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.network.pmof.PmofTransferService
import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.scheduler.pmof.UnCompressedMapStatus
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter}
import org.apache.spark.storage._
import org.apache.spark.util.collection.pmof.PmemExternalSorter
@@ -32,16 +33,18 @@ import org.apache.spark.storage.BlockManager
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-private[spark] class PmemShuffleWriter[K, V, C](shuffleBlockResolver: PmemShuffleBlockResolver,
- metadataResolver: MetadataResolver,
- blockManager: BlockManager,
- serializerManager: SerializerManager,
- handle: BaseShuffleHandle[K, V, C],
- mapId: Int,
- context: TaskContext,
- conf: SparkConf,
- pmofConf: PmofConf)
- extends ShuffleWriter[K, V] with Logging {
+private[spark] class PmemShuffleWriter[K, V, C](
+ shuffleBlockResolver: PmemShuffleBlockResolver,
+ metadataResolver: MetadataResolver,
+ blockManager: BlockManager,
+ serializerManager: SerializerManager,
+ handle: BaseShuffleHandle[K, V, C],
+ mapId: Int,
+ context: TaskContext,
+ conf: SparkConf,
+ pmofConf: PmofConf)
+ extends ShuffleWriter[K, V]
+ with Logging {
private[this] val dep = handle.dependency
private[this] var mapStatus: MapStatus = _
private[this] val stageId = dep.shuffleId
@@ -60,27 +63,35 @@ private[spark] class PmemShuffleWriter[K, V, C](shuffleBlockResolver: PmemShuffl
private var stopping = false
/**
- * Call PMDK to write data to persistent memory
- * Original Spark writer will do write and mergesort in this function,
- * while by using pmdk, we can do that once since pmdk supports transaction.
- */
+ * Call PMDK to write data to persistent memory
+ * Original Spark writer will do write and mergesort in this function,
+ * while by using pmdk, we can do that once since pmdk supports transaction.
+ */
override def write(records: Iterator[Product2[K, V]]): Unit = {
- val PmemBlockOutputStreamArray = (0 until numPartitions).toArray.map(partitionId =>
- new PmemBlockOutputStream(
- context.taskMetrics(),
- ShuffleBlockId(stageId, mapId, partitionId),
- serializerManager,
- dep.serializer,
- conf,
- pmofConf,
- numMaps,
- numPartitions))
+ logInfo(" write start")
+ val PmemBlockOutputStreamArray = (0 until numPartitions).toArray.map(
+ partitionId =>
+ new PmemBlockOutputStream(
+ context.taskMetrics(),
+ ShuffleBlockId(stageId, mapId, partitionId),
+ serializerManager,
+ dep.serializer,
+ conf,
+ pmofConf,
+ numMaps,
+ numPartitions))
if (dep.mapSideCombine) { // do aggregation
if (dep.aggregator.isDefined) {
- sorter = new PmemExternalSorter[K, V, C](context, handle, pmofConf, dep.aggregator, Some(dep.partitioner),
- dep.keyOrdering, dep.serializer)
- sorter.setPartitionByteBufferArray(PmemBlockOutputStreamArray)
+ sorter = new PmemExternalSorter[K, V, C](
+ context,
+ handle,
+ pmofConf,
+ dep.aggregator,
+ Some(dep.partitioner),
+ dep.keyOrdering,
+ dep.serializer)
+ sorter.setPartitionByteBufferArray(PmemBlockOutputStreamArray)
sorter.insertAll(records)
sorter.forceSpillToPmem()
} else {
@@ -107,31 +118,50 @@ private[spark] class PmemShuffleWriter[K, V, C](shuffleBlockResolver: PmemShuffl
spilledPartition += 1
}
val pmemBlockInfoMap = mutable.HashMap.empty[Int, Array[(Long, Int)]]
- var output_str : String = ""
+ var output_str: String = ""
+ var rKey: Int = 0
for (i <- spillPartitionArray) {
- if (pmofConf.enableRdma) {
- pmemBlockInfoMap(i) = PmemBlockOutputStreamArray(i).getPartitionMeta().map { info => (info._1, info._2) }
+ if (pmofConf.enableRdma && !pmofConf.enableRemotePmem) {
+ pmemBlockInfoMap(i) = PmemBlockOutputStreamArray(i)
+ .getPartitionMeta()
+ .map(info => {
+ if (rKey == 0) {
+ rKey = info._3
+ }
+ //logInfo(s"${ShuffleBlockId(stageId, mapId, i)} [${rKey}]${info._1}:${info._2}")
+ (info._1, info._2)
+ })
}
- partitionLengths(i) = PmemBlockOutputStreamArray(i).size
- output_str += "\tPartition " + i + ": " + partitionLengths(i) + ", records: " + PmemBlockOutputStreamArray(i).records + "\n"
+ partitionLengths(i) = PmemBlockOutputStreamArray(i).getSize
+ output_str += "\tPartition " + i + ": " + partitionLengths(i) + ", records: " + PmemBlockOutputStreamArray(
+ i).records + "\n"
}
+ //logWarning(output_str)
for (i <- 0 until numPartitions) {
PmemBlockOutputStreamArray(i).close()
}
val shuffleServerId = blockManager.shuffleServerId
- if (pmofConf.enableRdma) {
- val rkey = PmemBlockOutputStreamArray(0).getRkey()
- metadataResolver.pushPmemBlockInfo(stageId, mapId, pmemBlockInfoMap, rkey)
+ if (pmofConf.enableRemotePmem) {
+ mapStatus = new UnCompressedMapStatus(shuffleServerId, partitionLengths)
+ //mapStatus = MapStatus(shuffleServerId, partitionLengths)
+ } else if (!pmofConf.enableRdma) {
+ mapStatus = MapStatus(shuffleServerId, partitionLengths)
+ } else {
+ metadataResolver.pushPmemBlockInfo(stageId, mapId, pmemBlockInfoMap, rKey)
val blockManagerId: BlockManagerId =
- BlockManagerId(shuffleServerId.executorId, PmofTransferService.shuffleNodesMap(shuffleServerId.host),
- PmofTransferService.getTransferServiceInstance(pmofConf, blockManager).port, shuffleServerId.topologyInfo)
+ BlockManagerId(
+ shuffleServerId.executorId,
+ PmofTransferService.shuffleNodesMap(shuffleServerId.host),
+ PmofTransferService.getTransferServiceInstance(pmofConf, blockManager).port,
+ shuffleServerId.topologyInfo)
mapStatus = MapStatus(blockManagerId, partitionLengths)
- } else {
- mapStatus = MapStatus(shuffleServerId, partitionLengths)
}
+ logWarning(
+ s"shuffle_${stageId}_${mapId}_0 size is ${partitionLengths(0)}, decompressed length is ${mapStatus
+ .getSizeForBlock(0)}")
}
/** Close this writer, passing along whether the map completed */
diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala
index 4e61c4e5..2bb963ce 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala
@@ -16,22 +16,32 @@ private[spark] class PmofShuffleManager(conf: SparkConf) extends ShuffleManager
}
private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]()
- private[this] val pmofConf = new PmofConf(conf)
+ private[this] val pmofConf = PmofConf.getConf(conf)
var metadataResolver: MetadataResolver = _
- override def registerShuffle[K, V, C](shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
val env: SparkEnv = SparkEnv.get
metadataResolver = MetadataResolver.getMetadataResolver(pmofConf)
if (pmofConf.enableRdma) {
- PmofTransferService.getTransferServiceInstance(pmofConf: PmofConf, env.blockManager, this, isDriver = true)
+ PmofTransferService.getTransferServiceInstance(
+ pmofConf: PmofConf,
+ env.blockManager,
+ this,
+ isDriver = true)
}
new BaseShuffleHandle(shuffleId, numMaps, dependency)
}
- override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] = {
+ override def getWriter[K, V](
+ handle: ShuffleHandle,
+ mapId: Int,
+ context: TaskContext): ShuffleWriter[K, V] = {
assert(handle.isInstanceOf[BaseShuffleHandle[_, _, _]])
val env: SparkEnv = SparkEnv.get
@@ -47,28 +57,69 @@ private[spark] class PmofShuffleManager(conf: SparkConf) extends ShuffleManager
}
if (pmofConf.enablePmem) {
- new PmemShuffleWriter(shuffleBlockResolver.asInstanceOf[PmemShuffleBlockResolver], metadataResolver, blockManager, serializerManager,
- handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context, env.conf, pmofConf)
+ new PmemShuffleWriter(
+ shuffleBlockResolver.asInstanceOf[PmemShuffleBlockResolver],
+ metadataResolver,
+ blockManager,
+ serializerManager,
+ handle.asInstanceOf[BaseShuffleHandle[K, V, _]],
+ mapId,
+ context,
+ env.conf,
+ pmofConf)
} else {
- new BaseShuffleWriter(shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], metadataResolver, blockManager, serializerManager,
- handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context, pmofConf)
+ new BaseShuffleWriter(
+ shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+ metadataResolver,
+ blockManager,
+ serializerManager,
+ handle.asInstanceOf[BaseShuffleHandle[K, V, _]],
+ mapId,
+ context,
+ pmofConf)
}
}
- override def getReader[K, C](handle: _root_.org.apache.spark.shuffle.ShuffleHandle, startPartition: Int, endPartition: Int, context: _root_.org.apache.spark.TaskContext): _root_.org.apache.spark.shuffle.ShuffleReader[K, C] = {
- if (pmofConf.enableRdma) {
- new RdmaShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
- startPartition, endPartition, context, pmofConf)
+ override def getReader[K, C](
+ handle: _root_.org.apache.spark.shuffle.ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: _root_.org.apache.spark.TaskContext)
+ : _root_.org.apache.spark.shuffle.ShuffleReader[K, C] = {
+ val env: SparkEnv = SparkEnv.get
+ if (pmofConf.enableRemotePmem) {
+ new RpmpShuffleReader(
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
+ startPartition,
+ endPartition,
+ context,
+ pmofConf)
+
+ } else if (pmofConf.enableRdma) {
+ metadataResolver = MetadataResolver.getMetadataResolver(pmofConf)
+ PmofTransferService.getTransferServiceInstance(pmofConf, env.blockManager, this)
+ new RdmaShuffleReader(
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
+ startPartition,
+ endPartition,
+ context,
+ pmofConf)
} else {
new BaseShuffleReader(
- handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context, pmofConf)
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
+ startPartition,
+ endPartition,
+ context,
+ pmofConf)
}
}
override def unregisterShuffle(shuffleId: Int): Boolean = {
Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps =>
(0 until numMaps).foreach { mapId =>
- shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver].removeDataByMap(shuffleId, mapId)
+ shuffleBlockResolver
+ .asInstanceOf[IndexShuffleBlockResolver]
+ .removeDataByMap(shuffleId, mapId)
}
}
true
diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala
index 92af5417..3a044f13 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala
@@ -29,6 +29,7 @@ private[spark] class RdmaShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C],
private[this] val dep = handle.dependency
private[this] val serializerInstance: SerializerInstance = dep.serializer.newInstance()
private[this] val enable_pmem: Boolean = SparkEnv.get.conf.getBoolean("spark.shuffle.pmof.enable_pmem", defaultValue = true)
+ private[this] val enable_rpmp: Boolean = SparkEnv.get.conf.getBoolean("spark.shuffle.pmof.enable_remote_pmem", defaultValue = true)
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
@@ -86,7 +87,7 @@ private[spark] class RdmaShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C],
// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
- if (enable_pmem) {
+ if (enable_pmem && !enable_rpmp) {
val sorter = new PmemExternalSorter[K, C, C](context, handle, pmofConf, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/RpmpShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/RpmpShuffleReader.scala
new file mode 100644
index 00000000..97ffdaea
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/RpmpShuffleReader.scala
@@ -0,0 +1,109 @@
+package org.apache.spark.shuffle.pmof
+
+import org.apache.spark._
+import org.apache.spark.internal.{Logging, config}
+import org.apache.spark.network.pmof.PmofTransferService
+import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
+import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
+import org.apache.spark.storage.BlockManager
+import org.apache.spark.storage.pmof._
+import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.collection.ExternalSorter
+import org.apache.spark.util.collection.pmof.PmemExternalSorter
+import org.apache.spark.util.configuration.pmof.PmofConf
+
+/**
+ * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by
+ * requesting them from other nodes' block stores.
+ */
+private[spark] class RpmpShuffleReader[K, C](
+ handle: BaseShuffleHandle[K, _, C],
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ pmofConf: PmofConf,
+ serializerManager: SerializerManager = SparkEnv.get.serializerManager,
+ blockManager: BlockManager = SparkEnv.get.blockManager,
+ mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
+ extends ShuffleReader[K, C]
+ with Logging {
+
+ private[this] val dep = handle.dependency
+ private[this] val serializerInstance: SerializerInstance = dep.serializer.newInstance()
+ private[this] val enable_pmem: Boolean =
+ SparkEnv.get.conf.getBoolean("spark.shuffle.pmof.enable_pmem", defaultValue = true)
+ private[this] val enable_rpmp: Boolean =
+ SparkEnv.get.conf.getBoolean("spark.shuffle.pmof.enable_remote_pmem", defaultValue = true)
+
+ /** Read the combined key-values for this reduce task */
+ override def read(): Iterator[Product2[K, C]] = {
+ val wrappedStreams: RpmpShuffleBlockFetcherIterator = new RpmpShuffleBlockFetcherIterator(
+ context,
+ blockManager,
+ mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
+ serializerManager.wrapStream,
+ // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+ SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
+ SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
+ SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
+ SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
+ SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", defaultValue = true),
+ pmofConf)
+
+ // Create a key/value iterator for each stream
+ val recordIter = wrappedStreams.flatMap {
+ case (_, wrappedStream) =>
+ // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
+ // NextIterator. The NextIterator makes sure that close() is called on the
+ // underlying InputStream when all records have been read.
+ serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+ }
+
+ // Update the context task metrics for each record read.
+ val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
+ val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](recordIter.map {
+ record =>
+ readMetrics.incRecordsRead(1)
+ record
+ }, context.taskMetrics().mergeShuffleReadMetrics())
+
+ // An interruptible iterator must be used here in order to support task cancellation
+ val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
+
+ val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
+ if (dep.mapSideCombine) {
+ // We are reading values that are already combined
+ val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
+ dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
+ } else {
+ // We don't know the value type, but also don't care -- the dependency *should*
+ // have made sure its compatible w/ this aggregator, which will convert the value
+ // type to the combined type C
+ val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
+ dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
+ }
+ } else {
+ require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
+ interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
+ }
+
+ // Sort the output if there is a sort ordering defined.
+ dep.keyOrdering match {
+ case Some(keyOrd: Ordering[K]) =>
+ val sorter =
+ new ExternalSorter[K, C, C](
+ context,
+ ordering = Some(keyOrd),
+ serializer = dep.serializer)
+ sorter.insertAll(aggregatedIter)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
+ CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](
+ sorter.iterator,
+ sorter.stop())
+ case None =>
+ aggregatedIter
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/NettyByteBufferPool.scala b/core/src/main/scala/org/apache/spark/storage/pmof/NettyByteBufferPool.scala
index dfdd0d5c..5419c3ac 100644
--- a/core/src/main/scala/org/apache/spark/storage/pmof/NettyByteBufferPool.scala
+++ b/core/src/main/scala/org/apache/spark/storage/pmof/NettyByteBufferPool.scala
@@ -5,23 +5,19 @@ import io.netty.buffer.{ByteBuf, PooledByteBufAllocator, UnpooledByteBufAllocato
import scala.collection.mutable.Stack
import java.lang.RuntimeException
import org.apache.spark.internal.Logging
+import scala.collection.mutable.Map
object NettyByteBufferPool extends Logging {
private val allocatedBufRenCnt: AtomicLong = new AtomicLong(0)
- private val allocatedBytes: AtomicLong = new AtomicLong(0)
- private val peakAllocatedBytes: AtomicLong = new AtomicLong(0)
- private val unpooledAllocatedBytes: AtomicLong = new AtomicLong(0)
- private var fixedBufferSize: Long = 0
- private val allocatedBufferPool: Stack[ByteBuf] = Stack[ByteBuf]()
+ private val allocatedBytes: AtomicLong = new AtomicLong(0)
+ private val peakAllocatedBytes: AtomicLong = new AtomicLong(0)
+ private val unpooledAllocatedBytes: AtomicLong = new AtomicLong(0)
+ private val bufferMap: Map[ByteBuf, Long] = Map()
+ private val allocatedBufferPool: Stack[ByteBuf] = Stack[ByteBuf]()
private var reachRead = false
private val allocator = UnpooledByteBufAllocator.DEFAULT
def allocateNewBuffer(bufSize: Int): ByteBuf = synchronized {
- if (fixedBufferSize == 0) {
- fixedBufferSize = bufSize
- } else if (bufSize > fixedBufferSize) {
- throw new RuntimeException(s"allocateNewBuffer, expected size is ${fixedBufferSize}, actual size is ${bufSize}")
- }
allocatedBufRenCnt.getAndIncrement()
allocatedBytes.getAndAdd(bufSize)
if (allocatedBytes.get > peakAllocatedBytes.get) {
@@ -33,9 +29,11 @@ object NettyByteBufferPool extends Logging {
} else {
allocator.directBuffer(bufSize, bufSize)
}*/
- allocator.directBuffer(bufSize, bufSize)
+ val byteBuf = allocator.directBuffer(bufSize, bufSize)
+ bufferMap += (byteBuf -> bufSize)
+ byteBuf
} catch {
- case e : Throwable =>
+ case e: Throwable =>
logError(s"allocateNewBuffer size is ${bufSize}")
throw e
}
@@ -43,7 +41,8 @@ object NettyByteBufferPool extends Logging {
def releaseBuffer(buf: ByteBuf): Unit = synchronized {
allocatedBufRenCnt.getAndDecrement()
- allocatedBytes.getAndAdd(0 - fixedBufferSize)
+ val bufSize = bufferMap(buf)
+ allocatedBytes.getAndAdd(bufSize)
buf.clear()
//allocatedBufferPool.push(buf)
buf.release(buf.refCnt())
diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/NioManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/pmof/NioManagedBuffer.scala
new file mode 100644
index 00000000..8118232e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/pmof/NioManagedBuffer.scala
@@ -0,0 +1,75 @@
+package org.apache.spark.storage.pmof
+
+import java.io.IOException
+import java.io.InputStream
+import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicInteger
+
+import io.netty.buffer.ByteBuf
+import io.netty.buffer.ByteBufInputStream
+import io.netty.buffer.Unpooled
+import org.apache.commons.lang3.builder.ToStringBuilder
+import org.apache.commons.lang3.builder.ToStringStyle
+import org.apache.spark.network.buffer.ManagedBuffer
+
+/**
+ * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}.
+ */
+class NioManagedBuffer(bufSize: Int) extends ManagedBuffer {
+ private val buf: ByteBuf = NettyByteBufferPool.allocateNewBuffer(bufSize)
+ private val byteBuffer: ByteBuffer = buf.nioBuffer(0, bufSize)
+ private val refCount = new AtomicInteger(1)
+ private var in: InputStream = _
+ private var nettyObj: ByteBuf = _
+
+ def getByteBuf: ByteBuf = buf
+
+ def resize(size: Int): Unit = {
+ byteBuffer.limit(size)
+ }
+
+ override def size: Long = {
+ byteBuffer.remaining()
+ }
+
+ override def nioByteBuffer: ByteBuffer = {
+ byteBuffer
+ }
+
+ override def createInputStream: InputStream = {
+ nettyObj = Unpooled.wrappedBuffer(byteBuffer)
+ in = new ByteBufInputStream(nettyObj)
+ in
+ }
+
+ override def retain: ManagedBuffer = {
+ refCount.incrementAndGet()
+ return this
+ }
+
+ override def release: ManagedBuffer = {
+ if (refCount.decrementAndGet() == 0) {
+ if (in != null) {
+ in.close()
+ }
+ if (nettyObj != null) {
+ nettyObj.release()
+ }
+ NettyByteBufferPool.releaseBuffer(buf)
+ }
+ return this
+ }
+
+ override def convertToNetty: Object = {
+ if (nettyObj == null) {
+ nettyObj = Unpooled.wrappedBuffer(byteBuffer)
+ }
+ nettyObj
+ }
+
+ override def toString: String = {
+ return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
+ .append("buf", buf)
+ .toString();
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockInputStream.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockInputStream.scala
index bde6ad44..3fb19546 100644
--- a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockInputStream.scala
+++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockInputStream.scala
@@ -2,14 +2,22 @@ package org.apache.spark.storage.pmof
import com.esotericsoftware.kryo.KryoException
import org.apache.spark.SparkEnv
-import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerInstance, SerializerManager}
+import org.apache.spark.serializer.{
+ DeserializationStream,
+ Serializer,
+ SerializerInstance,
+ SerializerManager
+}
import org.apache.spark.storage.BlockId
-class PmemBlockInputStream[K, C](pmemBlockOutputStream: PmemBlockOutputStream, serializer: Serializer) {
+class PmemBlockInputStream[K, C](
+ pmemBlockOutputStream: PmemBlockOutputStream,
+ serializer: Serializer) {
val blockId: BlockId = pmemBlockOutputStream.getBlockId()
val serializerManager: SerializerManager = SparkEnv.get.serializerManager
val serInstance: SerializerInstance = serializer.newInstance()
- val persistentMemoryWriter: PersistentMemoryHandler = PersistentMemoryHandler.getPersistentMemoryHandler
+ val persistentMemoryWriter: PersistentMemoryHandler =
+ PersistentMemoryHandler.getPersistentMemoryHandler
var pmemInputStream: PmemInputStream = new PmemInputStream(persistentMemoryWriter, blockId.name)
val wrappedStream = serializerManager.wrapStream(blockId, pmemInputStream)
var inObjStream: DeserializationStream = serInstance.deserializeStream(wrappedStream)
@@ -30,7 +38,7 @@ class PmemBlockInputStream[K, C](pmemBlockOutputStream: PmemBlockOutputStream, s
close()
return null
}
- try{
+ try {
val k = inObjStream.readObject().asInstanceOf[K]
val c = inObjStream.readObject().asInstanceOf[C]
indexInBatch += 1
@@ -39,8 +47,7 @@ class PmemBlockInputStream[K, C](pmemBlockOutputStream: PmemBlockOutputStream, s
}
(k, c)
} catch {
- case ex: KryoException => {
- }
+ case ex: KryoException => {}
sys.exit(0)
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockOutputStream.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockOutputStream.scala
index ee303b0b..65748e69 100644
--- a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockOutputStream.scala
@@ -1,28 +1,30 @@
package org.apache.spark.storage.pmof
-import org.apache.spark.storage._
-import org.apache.spark.serializer._
+import java.io.File
+
+import org.apache.spark.SparkConf
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.Logging
-import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.serializer._
+import org.apache.spark.storage._
import org.apache.spark.util.Utils
-import java.io.{File, OutputStream}
-
import org.apache.spark.util.configuration.pmof.PmofConf
import scala.collection.mutable.ArrayBuffer
-class PmemBlockId (stageId: Int, tmpId: Int) extends ShuffleBlockId(stageId, 0, tmpId) {
+class PmemBlockId(stageId: Int, tmpId: Int) extends ShuffleBlockId(stageId, 0, tmpId) {
override def name: String = "reduce_spill_" + stageId + "_" + tmpId
+
override def isShuffle: Boolean = false
}
object PmemBlockId {
private var tempId: Int = 0
+
def getTempBlockId(stageId: Int): PmemBlockId = synchronized {
val cur_tempId = tempId
tempId += 1
- new PmemBlockId (stageId, cur_tempId)
+ new PmemBlockId(stageId, cur_tempId)
}
}
@@ -34,8 +36,16 @@ private[spark] class PmemBlockOutputStream(
conf: SparkConf,
pmofConf: PmofConf,
numMaps: Int = 0,
- numPartitions: Int = 1
-) extends DiskBlockObjectWriter(new File(Utils.getConfiguredLocalDirs(conf).toList(0) + "/null"), null, null, 0, true, null, null) with Logging {
+ numPartitions: Int = 1)
+ extends DiskBlockObjectWriter(
+ new File(Utils.getConfiguredLocalDirs(conf).toList(0) + "/null"),
+ null,
+ null,
+ 0,
+ true,
+ null,
+ null)
+ with Logging {
var size: Int = 0
var records: Int = 0
@@ -44,15 +54,31 @@ private[spark] class PmemBlockOutputStream(
var spilled: Boolean = false
var partitionMeta: Array[(Long, Int, Int)] = _
val root_dir = Utils.getConfiguredLocalDirs(conf).toList(0)
-
- val persistentMemoryWriter: PersistentMemoryHandler = PersistentMemoryHandler.getPersistentMemoryHandler(pmofConf,
- root_dir, pmofConf.path_list, blockId.name, pmofConf.maxPoolSize)
+ var persistentMemoryWriter: PersistentMemoryHandler = _
+ var remotePersistentMemoryPool: RemotePersistentMemoryPool = _
+
+ if (!pmofConf.enableRemotePmem) {
+ persistentMemoryWriter = PersistentMemoryHandler.getPersistentMemoryHandler(
+ pmofConf,
+ root_dir,
+ pmofConf.path_list,
+ blockId.name,
+ pmofConf.maxPoolSize)
+ } else {
+ remotePersistentMemoryPool =
+ RemotePersistentMemoryPool.getInstance(pmofConf.rpmpHost, pmofConf.rpmpPort)
+ }
//disable metadata updating by default
//persistentMemoryWriter.updateShuffleMeta(blockId.name)
val pmemOutputStream: PmemOutputStream = new PmemOutputStream(
- persistentMemoryWriter, numPartitions, blockId.name, numMaps, (pmofConf.spill_throttle.toInt + 1024))
+ persistentMemoryWriter,
+ remotePersistentMemoryPool,
+ numPartitions,
+ blockId.name,
+ numMaps,
+ (pmofConf.spill_throttle.toInt + 1024))
val serInstance = serializer.newInstance()
val bs = serializerManager.wrapStream(blockId, pmemOutputStream)
var objStream: SerializationStream = serInstance.serializeStream(bs)
@@ -62,7 +88,7 @@ private[spark] class PmemBlockOutputStream(
objStream.writeValue(value)
records += 1
recordsPerBlock += 1
- if (blockId.isShuffle == true) {
+ if (blockId.isShuffle == true) {
taskMetrics.shuffleWriteMetrics.incRecordsWritten(1)
}
maybeSpill()
@@ -92,7 +118,7 @@ private[spark] class PmemBlockOutputStream(
if (bufSize > 0) {
recordsArray += recordsPerBlock
recordsPerBlock = 0
- size += bufSize
+ size = bufSize
if (blockId.isShuffle == true) {
val writeMetrics = taskMetrics.shuffleWriteMetrics
@@ -111,10 +137,31 @@ private[spark] class PmemBlockOutputStream(
spilled
}
+ def getPartitionBlockInfo(res_array: Array[Long]): Array[(Long, Int, Int)] = {
+ var i = -3
+ var blockInfo = Array.ofDim[(Long, Int)]((res_array.length) / 3)
+ blockInfo.map { x =>
+ i += 3;
+ (res_array(i), res_array(i + 1).toInt, res_array(i + 2).toInt)
+ }
+ }
+
def getPartitionMeta(): Array[(Long, Int, Int)] = {
if (partitionMeta == null) {
var i = -1
- partitionMeta = persistentMemoryWriter.getPartitionBlockInfo(blockId.name).map{ x=> i+=1; (x._1, x._2, recordsArray(i))}
+ partitionMeta = if (!pmofConf.enableRemotePmem) {
+ persistentMemoryWriter
+ .getPartitionBlockInfo(blockId.name)
+ .map(x => {
+ i += 1
+ (x._1, x._2, getRkey().toInt)
+ })
+ } else {
+ getPartitionBlockInfo(remotePersistentMemoryPool.getMeta(blockId.name)).map(x => {
+ i += 1
+ (x._1, x._2, x._3)
+ })
+ }
}
partitionMeta
}
@@ -128,7 +175,7 @@ private[spark] class PmemBlockOutputStream(
}
def getTotalRecords(): Long = {
- records
+ records
}
def getSize(): Long = {
diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala
index 644a8333..4a896b45 100644
--- a/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala
@@ -5,14 +5,17 @@ import java.nio.ByteBuffer
import io.netty.buffer.{ByteBuf, PooledByteBufAllocator}
import org.apache.spark.internal.Logging
+import java.io.IOException
class PmemOutputStream(
- persistentMemoryWriter: PersistentMemoryHandler,
- numPartitions: Int,
- blockId: String,
- numMaps: Int,
- bufferSize: Int
- ) extends OutputStream with Logging {
+ persistentMemoryWriter: PersistentMemoryHandler,
+ remotePersistentMemoryPool: RemotePersistentMemoryPool,
+ numPartitions: Int,
+ blockId: String,
+ numMaps: Int,
+ bufferSize: Int)
+ extends OutputStream
+ with Logging {
var set_clean = true
var is_closed = false
@@ -34,7 +37,21 @@ class PmemOutputStream(
override def flush(): Unit = {
if (bufferRemainingSize > 0) {
- persistentMemoryWriter.setPartition(numPartitions, blockId, byteBuffer, bufferRemainingSize, set_clean)
+ if (persistentMemoryWriter != null) {
+ persistentMemoryWriter.setPartition(
+ numPartitions,
+ blockId,
+ byteBuffer,
+ bufferRemainingSize,
+ set_clean)
+ } else {
+ logDebug(
+ s"[put Remote Block] target is ${RemotePersistentMemoryPool.getHost}:${RemotePersistentMemoryPool.getPort}, "
+ + s"${blockId} ${bufferRemainingSize}")
+ if (remotePersistentMemoryPool.put(blockId, byteBuffer, bufferRemainingSize) == -1) {
+ throw new IOException("RPMem put failed with time out.")
+ }
+ }
bufferFlushedSize += bufferRemainingSize
bufferRemainingSize = 0
}
@@ -48,7 +65,7 @@ class PmemOutputStream(
}
def remainingSize(): Int = {
- bufferRemainingSize
+ bufferRemainingSize
}
def reset(): Unit = {
@@ -61,7 +78,7 @@ class PmemOutputStream(
if (!is_closed) {
flush()
reset()
- NettyByteBufferPool.releaseBuffer(buf)
+ NettyByteBufferPool.releaseBuffer(buf)
is_closed = true
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/RdmaShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/pmof/RdmaShuffleBlockFetcherIterator.scala
index d3c3e064..334d0ab4 100644
--- a/core/src/main/scala/org/apache/spark/storage/pmof/RdmaShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/pmof/RdmaShuffleBlockFetcherIterator.scala
@@ -28,6 +28,7 @@ import org.apache.spark.network.pmof._
import org.apache.spark.network.shuffle.{ShuffleClient, TempFileManager}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage._
+import org.apache.spark.util.configuration.pmof.PmofConf
import org.apache.spark.{SparkException, TaskContext}
import scala.collection.mutable
@@ -36,88 +37,100 @@ import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
/**
- * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
- * manager. For remote blocks, it fetches them using the provided BlockTransferService.
- *
- * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks
- * in a pipelined fashion as they are received.
- *
- * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid
- * using too much memory.
- *
- * @param context [[TaskContext]], used for metrics update
- * @param shuffleClient [[ShuffleClient]] for fetching remote blocks
- * @param blockManager [[BlockManager]] for reading local blocks
- * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
- * For each block we also require the size (in bytes as a long field) in
- * order to throttle the memory usage.
- * @param streamWrapper A function to wrap the returned input stream.
- * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
- * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
- * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point
- * for a given remote host:port.
- * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
- * @param detectCorrupt whether to detect any corruption in fetched blocks.
- */
-private[spark]
-final class RdmaShuffleBlockFetcherIterator(context: TaskContext,
- shuffleClient: ShuffleClient,
- blockManager: BlockManager,
- blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
- streamWrapper: (BlockId, InputStream) => InputStream,
- maxBytesInFlight: Long,
- maxReqsInFlight: Int,
- maxBlocksInFlightPerAddress: Int,
- maxReqSizeShuffleToMem: Long,
- detectCorrupt: Boolean)
- extends Iterator[(BlockId, InputStream)] with TempFileManager with Logging {
+ * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
+ * manager. For remote blocks, it fetches them using the provided BlockTransferService.
+ *
+ * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks
+ * in a pipelined fashion as they are received.
+ *
+ * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid
+ * using too much memory.
+ *
+ * @param context [[TaskContext]], used for metrics update
+ * @param shuffleClient [[ShuffleClient]] for fetching remote blocks
+ * @param blockManager [[BlockManager]] for reading local blocks
+ * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
+ * For each block we also require the size (in bytes as a long field) in
+ * order to throttle the memory usage.
+ * @param streamWrapper A function to wrap the returned input stream.
+ * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
+ * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
+ * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point
+ * for a given remote host:port.
+ * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
+ * @param detectCorrupt whether to detect any corruption in fetched blocks.
+ */
+private[spark] final class RdmaShuffleBlockFetcherIterator(
+ context: TaskContext,
+ shuffleClient: ShuffleClient,
+ blockManager: BlockManager,
+ blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
+ streamWrapper: (BlockId, InputStream) => InputStream,
+ maxBytesInFlight: Long,
+ maxReqsInFlight: Int,
+ maxBlocksInFlightPerAddress: Int,
+ maxReqSizeShuffleToMem: Long,
+ detectCorrupt: Boolean)
+ extends Iterator[(BlockId, InputStream)]
+ with TempFileManager
+ with Logging {
import RdmaShuffleBlockFetcherIterator._
/** Local blocks to fetch, excluding zero-sized blocks. */
private[this] val localBlocks = new ArrayBuffer[BlockId]()
+
/**
- * A queue to hold our results. This turns the asynchronous model provided by
- * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator).
- */
+ * A queue to hold our results. This turns the asynchronous model provided by
+ * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator).
+ */
private[this] val results = new LinkedBlockingQueue[FetchResult]
- private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
+ private[this] val shuffleMetrics =
+ context.taskMetrics().createTempShuffleReadMetrics()
+
/**
- * A set to store the files used for shuffling remote huge blocks. Files in this set will be
- * deleted when cleanup. This is a layer of defensiveness against disk file leaks.
- */
+ * A set to store the files used for shuffling remote huge blocks. Files in this set will be
+ * deleted when cleanup. This is a layer of defensiveness against disk file leaks.
+ */
@GuardedBy("this")
private[this] val shuffleFilesSet = mutable.HashSet[File]()
- private[this] val remoteRdmaRequestQueue = new LinkedBlockingQueue[RdmaRequest]()
+ private[this] val remoteRdmaRequestQueue =
+ new LinkedBlockingQueue[RdmaRequest]()
+
/**
- * Total number of blocks to fetch. This can be smaller than the total number of blocks
- * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]].
- *
- * This should equal localBlocks.size + remoteBlocks.size.
- */
+ * Total number of blocks to fetch. This can be smaller than the total number of blocks
+ * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]].
+ *
+ * This should equal localBlocks.size + remoteBlocks.size.
+ */
private[this] var numBlocksToFetch = 0
+
/**
- * The number of blocks processed by the caller. The iterator is exhausted when
- * [[numBlocksProcessed]] == [[numBlocksToFetch]].
- */
+ * The number of blocks processed by the caller. The iterator is exhausted when
+ * [[numBlocksProcessed]] == [[numBlocksToFetch]].
+ */
private[this] var numBlocksProcessed = 0
private[this] val numRemoteBlockToFetch = new AtomicInteger(0)
private[this] val numRemoteBlockProcessing = new AtomicInteger(0)
private[this] val numRemoteBlockProcessed = new AtomicInteger(0)
+
/**
- * Current [[FetchResult]] being processed. We track this so we can release the current buffer
- * in case of a runtime exception when processing the current buffer.
- */
+ * Current [[FetchResult]] being processed. We track this so we can release the current buffer
+ * in case of a runtime exception when processing the current buffer.
+ */
@volatile private[this] var currentResult: SuccessFetchResult = _
+
/** Current bytes in flight from our requests */
private[this] val bytesInFlight = new AtomicLong(0)
+
/** Current number of requests in flight */
private[this] val reqsInFlight = new AtomicInteger(0)
+
/**
- * Whether the iterator is still active. If isZombie is true, the callback interface will no
- * longer place fetched blocks into [[results]].
- */
+ * Whether the iterator is still active. If isZombie is true, the callback interface will no
+ * longer place fetched blocks into [[results]].
+ */
@GuardedBy("this")
private[this] var isZombie = false
@@ -125,13 +138,17 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext,
def initialize(): Unit = {
context.addTaskCompletionListener(_ => cleanup())
-
- val remoteBlocksByAddress = blocksByAddress.filter(_._1.executorId != blockManager.blockManagerId.executorId)
- for ((address, blockInfos) <- blocksByAddress) {
- if (address.executorId == blockManager.blockManagerId.executorId) {
- localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
- numBlocksToFetch += localBlocks.size
+ val (localBlocksByAddress, remoteBlocksByAddress) =
+ if (PmofConf.getConf.enableRemotePmem) {
+ (Nil, blocksByAddress)
+ } else {
+ (
+ blocksByAddress.filter(_._1.executorId == blockManager.blockManagerId.executorId),
+ blocksByAddress.filter(_._1.executorId != blockManager.blockManagerId.executorId))
}
+ for ((address, blockInfos) <- localBlocksByAddress) {
+ localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
+ numBlocksToFetch += localBlocks.size
}
startFetch(remoteBlocksByAddress)
@@ -156,7 +173,12 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext,
for (i <- 0 until num) {
val current = blockInfoArray(i)
if (current.getShuffleBlockId != last.getShuffleBlockId) {
- remoteRdmaRequestQueue.put(new RdmaRequest(blockManagerId, last.getShuffleBlockId, blockInfoArray.slice(startIndex, i), reqSize))
+ remoteRdmaRequestQueue.put(
+ new RdmaRequest(
+ blockManagerId,
+ last.getShuffleBlockId,
+ blockInfoArray.slice(startIndex, i),
+ reqSize))
startIndex = i
reqSize = 0
Future {
@@ -166,15 +188,18 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext,
last = current
reqSize += current.getLength
}
- remoteRdmaRequestQueue.put(new RdmaRequest(blockManagerId, last.getShuffleBlockId, blockInfoArray.slice(startIndex, num), reqSize))
+ remoteRdmaRequestQueue.put(
+ new RdmaRequest(
+ blockManagerId,
+ last.getShuffleBlockId,
+ blockInfoArray.slice(startIndex, num),
+ reqSize))
Future {
fetchRemoteBlocks()
}
}
- override def onFailure(e: Throwable): Unit = {
-
- }
+ override def onFailure(e: Throwable): Unit = {}
}
numBlocksToFetch += blockIds.length
@@ -185,10 +210,10 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext,
}
/**
- * Fetch the local blocks while we are fetching remote blocks. This is ok because
- * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we
- * track in-memory are the ManagedBuffer references themselves.
- */
+ * Fetch the local blocks while we are fetching remote blocks. This is ok because
+ * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we
+ * track in-memory are the ManagedBuffer references themselves.
+ */
private[this] def fetchLocalBlocks() {
val iter = localBlocks.iterator
while (iter.hasNext) {
@@ -198,7 +223,13 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext,
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
- results.put(SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, isNetworkReqDone = false))
+ results.put(
+ SuccessFetchResult(
+ blockId,
+ blockManager.blockManagerId,
+ 0,
+ buf,
+ isNetworkReqDone = false))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
@@ -210,8 +241,8 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext,
}
/**
- * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
- */
+ * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
+ */
private[this] def cleanup() {
synchronized {
isZombie = true
@@ -265,13 +296,13 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext,
}
/**
- * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers
- * underlying each InputStream will be freed by the cleanup() method registered with the
- * TaskCompletionListener. However, callers should close() these InputStreams
- * as soon as they are no longer needed, in order to release memory as early as possible.
- *
- * Throws a FetchFailedException if the next block could not be fetched.
- */
+ * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers
+ * underlying each InputStream will be freed by the cleanup() method registered with the
+ * TaskCompletionListener. However, callers should close() these InputStreams
+ * as soon as they are no longer needed, in order to release memory as early as possible.
+ *
+ * Throws a FetchFailedException if the next block could not be fetched.
+ */
override def next(): (BlockId, InputStream) = {
if (!hasNext) {
throw new NoSuchElementException
@@ -307,18 +338,20 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext,
reqsInFlight.decrementAndGet
}
- logDebug("numRemoteBlockToFetch " + numRemoteBlockToFetch + " numRemoteBlockProcessing " + numRemoteBlockProcessing + " numRemoteBlockProcessed " + numRemoteBlockProcessed)
-
- val in = try {
- buf.createInputStream()
- } catch {
- // The exception could only be throwed by local shuffle block
- case e: IOException =>
- assert(buf.isInstanceOf[FileSegmentManagedBuffer])
- logError("Failed to create input stream from local block", e)
- buf.release()
- throwFetchFailedException(blockId, address, e)
- }
+ logDebug(
+ "numRemoteBlockToFetch " + numRemoteBlockToFetch + " numRemoteBlockProcessing " + numRemoteBlockProcessing + " numRemoteBlockProcessed " + numRemoteBlockProcessed)
+
+ val in =
+ try {
+ buf.createInputStream()
+ } catch {
+ // The exception could only be throwed by local shuffle block
+ case e: IOException =>
+ assert(buf.isInstanceOf[FileSegmentManagedBuffer])
+ logError("Failed to create input stream from local block", e)
+ buf.release()
+ throwFetchFailedException(blockId, address, e)
+ }
input = streamWrapper(blockId, in)
// Only copy the stream if it's wrapped by compression or encryption, also the size of
@@ -342,10 +375,43 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext,
if (rdmaRequest == null) {
return
}
- if (!isRemoteBlockFetchable(rdmaRequest)) {
- remoteRdmaRequestQueue.put(rdmaRequest)
+ if (PmofConf.getConf.enableRemotePmem) {
+ val shuffleBlockInfos = rdmaRequest.shuffleBlockInfos
+ val blockManagerId = rdmaRequest.blockManagerId
+ val remotePersistentMemoryPool =
+ RemotePersistentMemoryPool.getInstance(blockManagerId.host, blockManagerId.port.toString)
+
+ rdmaRequest.shuffleBlockInfos.foreach(shuffleBlockInfo => {
+ logDebug(
+ s"[fetch Remote Blocks] target is ${blockManagerId.host}:${blockManagerId.port}," +
+ s" ${rdmaRequest.shuffleBlockIdName}" +
+ s" [${shuffleBlockInfo.getRkey}]${shuffleBlockInfo.getAddress}-${shuffleBlockInfo.getLength}")
+ val inputByteBuffer = new NioManagedBuffer(shuffleBlockInfo.getLength)
+ if (remotePersistentMemoryPool.read(
+ shuffleBlockInfo.getAddress,
+ shuffleBlockInfo.getLength,
+ inputByteBuffer.nioByteBuffer) == 0) {
+ results.put(
+ SuccessFetchResult(
+ BlockId(rdmaRequest.shuffleBlockIdName),
+ rdmaRequest.blockManagerId,
+ rdmaRequest.reqSize,
+ inputByteBuffer,
+ isNetworkReqDone = true))
+ } else {
+ results.put(
+ FailureFetchResult(
+ BlockId(rdmaRequest.shuffleBlockIdName),
+ rdmaRequest.blockManagerId,
+ new IllegalStateException("RemotePersistentMemoryPool read failed.")))
+ }
+ })
} else {
- sendRequest(rdmaRequest)
+ if (!isRemoteBlockFetchable(rdmaRequest)) {
+ remoteRdmaRequestQueue.put(rdmaRequest)
+ } else {
+ sendRequest(rdmaRequest)
+ }
}
}
@@ -366,7 +432,13 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext,
RdmaShuffleBlockFetcherIterator.this.synchronized {
blockNums -= 1
if (blockNums == 0) {
- results.put(SuccessFetchResult(BlockId(shuffleBlockIdName), blockManagerId, rdmaRequest.reqSize, shuffleBuffer, isNetworkReqDone = true))
+ results.put(
+ SuccessFetchResult(
+ BlockId(shuffleBlockIdName),
+ blockManagerId,
+ rdmaRequest.reqSize,
+ shuffleBuffer,
+ isNetworkReqDone = true))
f(shuffleBuffer.getRdmaBufferId)
}
}
@@ -378,17 +450,31 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext,
}
}
- val client = pmofTransferService.getClient(blockManagerId.host, blockManagerId.port)
- val shuffleBuffer = new ShuffleBuffer(rdmaRequest.reqSize, client.getEqService, true)
- val rdmaBuffer = client.getEqService.regRmaBufferByAddress(shuffleBuffer.nioByteBuffer(),
- shuffleBuffer.getAddress, shuffleBuffer.getLength.toInt)
+ val client =
+ pmofTransferService.getClient(blockManagerId.host, blockManagerId.port)
+ val shuffleBuffer =
+ new ShuffleBuffer(rdmaRequest.reqSize, client.getEqService, true)
+ val rdmaBuffer = client.getEqService.regRmaBufferByAddress(
+ shuffleBuffer.nioByteBuffer(),
+ shuffleBuffer.getAddress,
+ shuffleBuffer.getLength.toInt)
shuffleBuffer.setRdmaBufferId(rdmaBuffer.getBufferId)
var offset = 0
for (i <- 0 until blockNums) {
- pmofTransferService.fetchBlock(blockManagerId.host, blockManagerId.port,
- shuffleBlockInfos(i).getAddress, shuffleBlockInfos(i).getLength,
- shuffleBlockInfos(i).getRkey, offset, shuffleBuffer, client, blockFetchingReadCallback)
+ logInfo(
+ s"[fetch Remote Blocks] target is ${blockManagerId.host}:${blockManagerId.port}, ${shuffleBlockIdName} [${shuffleBlockInfos(
+ i).getRkey}]${shuffleBlockInfos(i).getAddress}-${shuffleBlockInfos(i).getLength}")
+ pmofTransferService.fetchBlock(
+ blockManagerId.host,
+ blockManagerId.port,
+ shuffleBlockInfos(i).getAddress,
+ shuffleBlockInfos(i).getLength,
+ shuffleBlockInfos(i).getRkey,
+ offset,
+ shuffleBuffer,
+ client,
+ blockFetchingReadCallback)
offset += shuffleBlockInfos(i).getLength
}
}
@@ -399,26 +485,34 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext,
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
- private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = {
+ private def throwFetchFailedException(
+ blockId: BlockId,
+ address: BlockManagerId,
+ e: Throwable) = {
blockId match {
case ShuffleBlockId(shufId, mapId, reduceId) =>
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
case _ =>
throw new SparkException(
- "Failed to get block " + blockId + ", which is not a shuffle block", e)
+ "Failed to get block " + blockId + ", which is not a shuffle block",
+ e)
}
}
}
-private class RdmaRequest(val blockManagerId: BlockManagerId, val shuffleBlockIdName: String, val shuffleBlockInfos: ArrayBuffer[ShuffleBlockInfo], val reqSize: Int) {}
+private class RdmaRequest(
+ val blockManagerId: BlockManagerId,
+ val shuffleBlockIdName: String,
+ val shuffleBlockInfos: ArrayBuffer[ShuffleBlockInfo],
+ val reqSize: Int) {}
/**
- * Helper class that ensures a ManagedBuffer is released upon InputStream.close()
- */
+ * Helper class that ensures a ManagedBuffer is released upon InputStream.close()
+ */
private class RDMABufferReleasingInputStream(
- private val delegate: InputStream,
- private val iterator: RdmaShuffleBlockFetcherIterator)
- extends InputStream {
+ private val delegate: InputStream,
+ private val iterator: RdmaShuffleBlockFetcherIterator)
+ extends InputStream {
private[this] var closed = false
override def read(): Int = delegate.read()
@@ -441,64 +535,65 @@ private class RDMABufferReleasingInputStream(
override def read(b: Array[Byte]): Int = delegate.read(b)
- override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len)
+ override def read(b: Array[Byte], off: Int, len: Int): Int =
+ delegate.read(b, off, len)
override def reset(): Unit = delegate.reset()
}
-private[storage]
-object RdmaShuffleBlockFetcherIterator {
+private[storage] object RdmaShuffleBlockFetcherIterator {
/**
- * Result of a fetch from a remote block.
- */
+ * Result of a fetch from a remote block.
+ */
private[storage] sealed trait FetchResult {
val blockId: BlockId
val address: BlockManagerId
}
/**
- * A request to fetch blocks from a remote BlockManager.
- *
- * @param address remote BlockManager to fetch from.
- * @param blocks Sequence of tuple, where the first element is the block id,
- * and the second element is the estimated size, used to calculate bytesInFlight.
- */
+ * A request to fetch blocks from a remote BlockManager.
+ *
+ * @param address remote BlockManager to fetch from.
+ * @param blocks Sequence of tuple, where the first element is the block id,
+ * and the second element is the estimated size, used to calculate bytesInFlight.
+ */
case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) {
val size: Long = blocks.map(_._2).sum
}
/**
- * Result of a fetch from a remote block successfully.
- *
- * @param blockId block id
- * @param address BlockManager that the block was fetched from.
- * @param size estimated size of the block, used to calculate bytesInFlight.
- * Note that this is NOT the exact bytes.
- * @param buf `ManagedBuffer` for the content.
- * @param isNetworkReqDone Is this the last network request for this host in this fetch request.
- */
+ * Result of a fetch from a remote block successfully.
+ *
+ * @param blockId block id
+ * @param address BlockManager that the block was fetched from.
+ * @param size estimated size of the block, used to calculate bytesInFlight.
+ * Note that this is NOT the exact bytes.
+ * @param buf `ManagedBuffer` for the content.
+ * @param isNetworkReqDone Is this the last network request for this host in this fetch request.
+ */
private[storage] case class SuccessFetchResult(
- blockId: BlockId,
- address: BlockManagerId,
- size: Long,
- buf: ManagedBuffer,
- isNetworkReqDone: Boolean) extends FetchResult {
+ blockId: BlockId,
+ address: BlockManagerId,
+ size: Long,
+ buf: ManagedBuffer,
+ isNetworkReqDone: Boolean)
+ extends FetchResult {
require(buf != null)
require(size >= 0)
}
/**
- * Result of a fetch from a remote block unsuccessfully.
- *
- * @param blockId block id
- * @param address BlockManager that the block was attempted to be fetched from
- * @param e the failure exception
- */
+ * Result of a fetch from a remote block unsuccessfully.
+ *
+ * @param blockId block id
+ * @param address BlockManager that the block was attempted to be fetched from
+ * @param e the failure exception
+ */
private[storage] case class FailureFetchResult(
- blockId: BlockId,
- address: BlockManagerId,
- e: Throwable)
- extends FetchResult
+ blockId: BlockId,
+ address: BlockManagerId,
+ e: Throwable)
+ extends FetchResult
}
diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/RpmpShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/pmof/RpmpShuffleBlockFetcherIterator.scala
new file mode 100644
index 00000000..00ea3a0b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/pmof/RpmpShuffleBlockFetcherIterator.scala
@@ -0,0 +1,377 @@
+package org.apache.spark.storage.pmof
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.{File, IOException, InputStream}
+import java.util.concurrent.LinkedBlockingQueue
+import java.util.concurrent.atomic.{AtomicInteger, AtomicLong}
+
+import javax.annotation.concurrent.GuardedBy
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.pmof._
+import org.apache.spark.network.shuffle.{ShuffleClient, TempFileManager}
+import org.apache.spark.shuffle.FetchFailedException
+import org.apache.spark.storage._
+import org.apache.spark.util.configuration.pmof.PmofConf
+import org.apache.spark.{SparkException, TaskContext}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.concurrent.Future
+
+import org.apache.spark.util.Utils
+import org.apache.spark.util.io.ChunkedByteBufferOutputStream
+import java.nio.ByteBuffer
+import io.netty.buffer.{ByteBuf, ByteBufInputStream, ByteBufOutputStream}
+import io.netty.buffer.Unpooled
+
+/**
+ * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
+ * manager. For remote blocks, it fetches them using the provided BlockTransferService.
+ *
+ * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks
+ * in a pipelined fashion as they are received.
+ *
+ * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid
+ * using too much memory.
+ *
+ * @param context [[TaskContext]], used for metrics update
+ * @param blockManager [[BlockManager]] for reading local blocks
+ * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
+ * For each block we also require the size (in bytes as a long field) in
+ * order to throttle the memory usage.
+ * @param streamWrapper A function to wrap the returned input stream.
+ * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
+ * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
+ * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point
+ * for a given remote host:port.
+ * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
+ * @param detectCorrupt whether to detect any corruption in fetched blocks.
+ */
+private[spark] final class RpmpShuffleBlockFetcherIterator(
+ context: TaskContext,
+ blockManager: BlockManager,
+ blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
+ streamWrapper: (BlockId, InputStream) => InputStream,
+ maxBytesInFlight: Long,
+ maxReqsInFlight: Int,
+ maxBlocksInFlightPerAddress: Int,
+ maxReqSizeShuffleToMem: Long,
+ detectCorrupt: Boolean,
+ pmofConf: PmofConf)
+ extends Iterator[(BlockId, InputStream)]
+ with TempFileManager
+ with Logging {
+
+ import RpmpShuffleBlockFetcherIterator._
+
+ /** Local blocks to fetch, excluding zero-sized blocks. */
+ private[this] val localBlocks = new ArrayBuffer[BlockId]()
+
+ /**
+ * A queue to hold our results. This turns the asynchronous model provided by
+ * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator).
+ */
+ private[this] val results = new LinkedBlockingQueue[FetchResult]
+ private[this] val shuffleMetrics =
+ context.taskMetrics().createTempShuffleReadMetrics()
+
+ /**
+ * A set to store the files used for shuffling remote huge blocks. Files in this set will be
+ * deleted when cleanup. This is a layer of defensiveness against disk file leaks.
+ */
+ @GuardedBy("this")
+ private[this] val shuffleFilesSet = mutable.HashSet[File]()
+
+ /**
+ * Total number of blocks to fetch. This can be smaller than the total number of blocks
+ * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]].
+ *
+ * This should equal localBlocks.size + remoteBlocks.size.
+ */
+ private[this] var numBlocksToFetch = 0
+
+ /**
+ * The number of blocks processed by the caller. The iterator is exhausted when
+ * [[numBlocksProcessed]] == [[numBlocksToFetch]].
+ */
+ private[this] var numBlocksProcessed = 0
+
+ private[this] val numRemoteBlockToFetch = new AtomicInteger(0)
+ private[this] val numRemoteBlockProcessing = new AtomicInteger(0)
+ private[this] val numRemoteBlockProcessed = new AtomicInteger(0)
+
+ /**
+ * Current [[FetchResult]] being processed. We track this so we can release the current buffer
+ * in case of a runtime exception when processing the current buffer.
+ */
+ @volatile private[this] var currentResult: SuccessFetchResult = _
+
+ /** Current bytes in flight from our requests */
+ private[this] val bytesInFlight = new AtomicLong(0)
+
+ /** Current number of requests in flight */
+ private[this] val reqsInFlight = new AtomicInteger(0)
+
+ /**
+ * Whether the iterator is still active. If isZombie is true, the callback interface will no
+ * longer place fetched blocks into [[results]].
+ */
+ @GuardedBy("this")
+ private[this] var isZombie = false
+
+ private[this] var blocksByAddressSize = 0
+ private[this] var blocksByAddressCurrentId = 0
+ private[this] var address: BlockManagerId = _
+ private[this] var blockInfos: Seq[(BlockId, Long)] = _
+ private[this] var iterator: Iterator[(BlockId, Long)] = _
+
+ initialize()
+
+ def initialize(): Unit = {
+ context.addTaskCompletionListener(_ => cleanup())
+ blocksByAddressSize = blocksByAddress.size
+ if (blocksByAddressCurrentId < blocksByAddressSize) {
+ val res = blocksByAddress(blocksByAddressCurrentId)
+ address = res._1
+ blockInfos = res._2
+ iterator = blockInfos.iterator
+ blocksByAddressCurrentId += 1
+ }
+ }
+
+ val remotePersistentMemoryPool =
+ RemotePersistentMemoryPool.getInstance(pmofConf.rpmpHost, pmofConf.rpmpPort)
+
+ /**
+ * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers
+ * underlying each InputStream will be freed by the cleanup() method registered with the
+ * TaskCompletionListener. However, callers should close() these InputStreams
+ * as soon as they are no longer needed, in order to release memory as early as possible.
+ *
+ * Throws a FetchFailedException if the next block could not be fetched.
+ */
+ private[this] var next_called: Boolean = true
+ private[this] var has_next: Boolean = false
+ override def next(): (BlockId, InputStream) = {
+ next_called = true
+ var input: InputStream = null
+ val (blockId, size) = iterator.next()
+ var buf = new NioManagedBuffer(size.toInt)
+ val startFetchWait = System.currentTimeMillis()
+ val readed_len = remotePersistentMemoryPool.get(blockId.name, size, buf.nioByteBuffer)
+ if (readed_len != -1) {
+ val stopFetchWait = System.currentTimeMillis()
+ shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
+ shuffleMetrics.incRemoteBytesRead(buf.size)
+ shuffleMetrics.incRemoteBlocksFetched(1)
+
+ val in = buf.createInputStream()
+ input = streamWrapper(blockId, in)
+ if (detectCorrupt && !input.eq(in)) {
+ val originalInput = input
+ val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
+ try {
+ Utils.copyStream(input, out)
+ out.close()
+ input = out.toChunkedByteBuffer.toInputStream(dispose = true)
+ logDebug(s" buf ${blockId}-${size} decompress succeeded.")
+ } catch {
+ case e: IOException =>
+ val tmp_byte_buffer = buf.nioByteBuffer
+ logWarning(
+ s" buf ${blockId}-${size} decompress corrupted, input is ${input}, raw_data is ${tmp_byte_buffer}, remaining is ${tmp_byte_buffer.remaining}")
+ val tmp: Array[Byte] = Array()
+ buf.nioByteBuffer.rewind
+ buf.nioByteBuffer.get(tmp, 0, size.toInt)
+ logWarning(s"content is ${convertBytesToHex(tmp)}.")
+ throw e
+
+ } finally {
+ originalInput.close()
+ in.close()
+ buf.release()
+ }
+ }
+ } else {
+ throw new IOException(s"remotePersistentMemoryPool.get(${blockId}, ${size}) failed.");
+ }
+ (blockId, new RpmpBufferReleasingInputStream(input, null))
+ }
+
+ override def hasNext: Boolean = {
+ if (!next_called) {
+ return has_next
+ }
+ next_called = false
+ if (iterator.hasNext) {
+ has_next = true
+ } else {
+ if (blocksByAddressCurrentId >= blocksByAddressSize) {
+ has_next = false
+ return has_next
+ }
+ val res = blocksByAddress(blocksByAddressCurrentId)
+ address = res._1
+ blockInfos = res._2
+ iterator = blockInfos.iterator
+ blocksByAddressCurrentId += 1
+ has_next = true
+ }
+ return has_next
+ }
+
+ /**
+ * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
+ */
+ private[this] def cleanup() {
+ synchronized {
+ isZombie = true
+ }
+ }
+
+ override def registerTempFileToClean(file: File): Boolean = synchronized {
+ if (isZombie) {
+ false
+ } else {
+ shuffleFilesSet += file
+ true
+ }
+ }
+
+ private def throwFetchFailedException(
+ blockId: BlockId,
+ address: BlockManagerId,
+ e: Throwable) = {
+ blockId match {
+ case ShuffleBlockId(shufId, mapId, reduceId) =>
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
+ case _ =>
+ throw new SparkException(
+ "Failed to get block " + blockId + ", which is not a shuffle block",
+ e)
+ }
+ }
+ override def createTempFile(): File = {
+ blockManager.diskBlockManager.createTempLocalBlock()._2
+ }
+ def convertBytesToHex(bytes: Seq[Byte]): String = {
+ val sb = new StringBuilder
+ for (b <- bytes) {
+ sb.append(String.format("%02x", Byte.box(b)))
+ }
+ sb.toString
+ }
+
+}
+
+/**
+ * Helper class that ensures a ManagedBuffer is released upon InputStream.close()
+ */
+private class RpmpBufferReleasingInputStream(
+ private val delegate: InputStream,
+ private val parent: NioManagedBuffer)
+ extends InputStream {
+ private[this] var closed = false
+
+ override def read(): Int = delegate.read()
+
+ override def close(): Unit = {
+ if (!closed) {
+ delegate.close()
+ if (parent != null) {
+ parent.release()
+ }
+ closed = true
+ }
+ }
+
+ override def available(): Int = delegate.available()
+
+ override def mark(readlimit: Int): Unit = delegate.mark(readlimit)
+
+ override def skip(n: Long): Long = delegate.skip(n)
+
+ override def markSupported(): Boolean = delegate.markSupported()
+
+ override def read(b: Array[Byte]): Int = delegate.read(b)
+
+ override def read(b: Array[Byte], off: Int, len: Int): Int =
+ delegate.read(b, off, len)
+
+ override def reset(): Unit = delegate.reset()
+}
+
+private[storage] object RpmpShuffleBlockFetcherIterator {
+
+ /**
+ * Result of a fetch from a remote block.
+ */
+ private[storage] sealed trait FetchResult {
+ val blockId: BlockId
+ val address: BlockManagerId
+ }
+
+ /**
+ * A request to fetch blocks from a remote BlockManager.
+ *
+ * @param address remote BlockManager to fetch from.
+ * @param blocks Sequence of tuple, where the first element is the block id,
+ * and the second element is the estimated size, used to calculate bytesInFlight.
+ */
+ case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) {
+ val size: Long = blocks.map(_._2).sum
+ }
+
+ /**
+ * Result of a fetch from a remote block successfully.
+ *
+ * @param blockId block id
+ * @param address BlockManager that the block was fetched from.
+ * @param size estimated size of the block, used to calculate bytesInFlight.
+ * Note that this is NOT the exact bytes.
+ * @param buf `ManagedBuffer` for the content.
+ * @param isNetworkReqDone Is this the last network request for this host in this fetch request.
+ */
+ private[storage] case class SuccessFetchResult(
+ blockId: BlockId,
+ address: BlockManagerId,
+ size: Long,
+ buf: ManagedBuffer,
+ isNetworkReqDone: Boolean)
+ extends FetchResult {
+ require(buf != null)
+ require(size >= 0)
+ }
+
+ /**
+ * Result of a fetch from a remote block unsuccessfully.
+ *
+ * @param blockId block id
+ * @param address BlockManager that the block was attempted to be fetched from
+ * @param e the failure exception
+ */
+ private[storage] case class FailureFetchResult(
+ blockId: BlockId,
+ address: BlockManagerId,
+ e: Throwable)
+ extends FetchResult
+
+}
diff --git a/core/src/main/scala/org/apache/spark/util/configuration/pmof/PmofConf.scala b/core/src/main/scala/org/apache/spark/util/configuration/pmof/PmofConf.scala
index f3551704..3e560c72 100644
--- a/core/src/main/scala/org/apache/spark/util/configuration/pmof/PmofConf.scala
+++ b/core/src/main/scala/org/apache/spark/util/configuration/pmof/PmofConf.scala
@@ -27,4 +27,22 @@ class PmofConf(conf: SparkConf) {
val shuffleBlockSize: Int = conf.getInt("spark.shuffle.pmof.shuffle_block_size", defaultValue = 2048)
val pmemCapacity: Long = conf.getLong("spark.shuffle.pmof.pmem_capacity", defaultValue = 264239054848L)
val pmemCoreMap = conf.get("spark.shuffle.pmof.dev_core_set", defaultValue = "/dev/dax0.0:0-17,36-53").split(";").map(_.trim).map(_.split(":")).map(arr => arr(0) -> arr(1)).toMap
+ val enableRemotePmem: Boolean = conf.getBoolean("spark.shuffle.pmof.enable_remote_pmem", defaultValue = false);
+ val rpmpHost: String = conf.get("spark.rpmp.rhost", defaultValue = "172.168.0.40")
+ val rpmpPort: String = conf.get("spark.rpmp.rport", defaultValue = "61010")
}
+
+object PmofConf {
+ var ins: PmofConf = null
+ def getConf(conf: SparkConf): PmofConf = if (ins == null) {
+ ins = new PmofConf(conf)
+ ins
+ } else {
+ ins
+ }
+ def getConf: PmofConf = if (ins == null) {
+ throw new IllegalStateException("PmofConf is not initialized yet")
+ } else {
+ ins
+ }
+}
\ No newline at end of file
diff --git a/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterSuite.scala
index 4c0af8d8..91e2f9fa 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterSuite.scala
@@ -49,7 +49,7 @@ class PmemShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with
conf.set("spark.shuffle.pmof.pmem_list", "/dev/dax0.0")
shuffleBlockResolver = new PmemShuffleBlockResolver(conf)
serializer = new JavaSerializer(conf)
- pmofConf = new PmofConf(conf)
+ pmofConf = PmofConf.getConf(conf)
taskMetrics = new TaskMetrics()
serializerManager = new SerializerManager(serializer, conf)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterWithSortSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterWithSortSuite.scala
index e01e06cd..410292f6 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterWithSortSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterWithSortSuite.scala
@@ -53,7 +53,7 @@ class PmemShuffleWriterWithSortSuite extends SparkFunSuite with SharedSparkConte
conf.set("spark.shuffle.pmof.pmem_list", "/dev/dax0.0")
shuffleBlockResolver = new PmemShuffleBlockResolver(conf)
serializer = new JavaSerializer(conf)
- pmofConf = new PmofConf(conf)
+ pmofConf = PmofConf.getConf(conf)
taskMetrics = new TaskMetrics()
serializerManager = new SerializerManager(serializer, conf)
diff --git a/rpmp/benchmark/CMakeLists.txt b/rpmp/benchmark/CMakeLists.txt
index 468d6b15..a670b1e9 100644
--- a/rpmp/benchmark/CMakeLists.txt
+++ b/rpmp/benchmark/CMakeLists.txt
@@ -1,17 +1,26 @@
-add_executable(local_allocate local_allocate.cc)
-target_link_libraries(local_allocate pmpool)
+#add_executable(local_allocate local_allocate.cc)
+#target_link_libraries(local_allocate pmpool)
+#
+#add_executable(remote_allocate remote_allocate.cc)
+#target_link_libraries(remote_allocate pmpool_client_jni)
+#
+#add_executable(remote_write remote_write.cc)
+#target_link_libraries(remote_write pmpool_client_jni)
+#
+#add_executable(remote_allocate_write remote_allocate_write.cc)
+#target_link_libraries(remote_allocate_write pmpool_client_jni)
+#
+#add_executable(circularbuffer circularbuffer.cc)
+#target_link_libraries(circularbuffer pmpool_client_jni)
+#
+#add_executable(remote_read remote_read.cc)
+#target_link_libraries(remote_read pmpool_client_jni)
+#
+#add_executable(remote_put remote_put.cc)
+#target_link_libraries(remote_put pmpool_client_jni)
-add_executable(remote_allocate remote_allocate.cc)
-target_link_libraries(remote_allocate pmpool)
+add_executable(put_and_read put_and_read.cc)
+target_link_libraries(put_and_read pmpool_client_jni)
-add_executable(remote_write remote_write.cc)
-target_link_libraries(remote_write pmpool)
-
-add_executable(remote_allocate_write remote_allocate_write.cc)
-target_link_libraries(remote_allocate_write pmpool)
-
-add_executable(circularbuffer circularbuffer.cc)
-target_link_libraries(circularbuffer pmpool)
-
-add_executable(remote_read remote_read.cc)
-target_link_libraries(remote_read pmpool)
+add_executable(put_and_get put_and_get.cc)
+target_link_libraries(put_and_get pmpool_client_jni)
\ No newline at end of file
diff --git a/rpmp/benchmark/Config.h b/rpmp/benchmark/Config.h
new file mode 100644
index 00000000..dc5659ef
--- /dev/null
+++ b/rpmp/benchmark/Config.h
@@ -0,0 +1,98 @@
+/*
+ * Filename: /mnt/spark-pmof/tool/rpmp/pmpool/Config.h
+ * Path: /mnt/spark-pmof/tool/rpmp/pmpool
+ * Created Date: Thursday, November 7th 2019, 3:48:52 pm
+ * Author: root
+ *
+ * Copyright (c) 2019 Intel
+ */
+
+#ifndef PMPOOL_CONFIG_H_
+#define PMPOOL_CONFIG_H_
+
+#include
+#include
+#include
+#include
+
+#include
+
+using boost::program_options::error;
+using boost::program_options::options_description;
+using boost::program_options::value;
+using boost::program_options::variables_map;
+using std::string;
+using std::vector;
+
+/**
+ * @brief This class represents the current RPMP configuration.
+ *
+ */
+class Config {
+ public:
+ int init(int argc, char **argv) {
+ try {
+ options_description desc{"Options"};
+ desc.add_options()("help,h", "Help screen")(
+ "address,a", value()->default_value("172.168.0.40"),
+ "set the rdma server address")(
+ "port,p", value()->default_value("12346"),
+ "set the rdma server port")(
+ "log,l", value()->default_value("/tmp/rpmp.log"),
+ "set rpmp log file path")("map_id,m", value()->default_value(0),
+ "map id")(
+ "req_num,r", value()->default_value(2048), "number of requests")(
+ "threads,t", value()->default_value(8), "number of threads");
+
+ variables_map vm;
+ store(parse_command_line(argc, argv, desc), vm);
+ notify(vm);
+
+ if (vm.count("help")) {
+ std::cout << desc << '\n';
+ return -1;
+ }
+ set_ip(vm["address"].as());
+ set_port(vm["port"].as());
+ set_log_path(vm["log"].as());
+ set_map_id(vm["map_id"].as());
+ set_num_reqs(vm["req_num"].as());
+ set_num_threads(vm["threads"].as());
+ } catch (const error &ex) {
+ std::cerr << ex.what() << '\n';
+ }
+ return 0;
+ }
+
+ int get_map_id() { return map_id_; }
+ void set_map_id(int map_id) { map_id_ = map_id; }
+
+ int get_num_reqs() { return num_reqs_; }
+ void set_num_reqs(int num_reqs) { num_reqs_ = num_reqs; }
+
+ int get_num_threads() { return num_threads_; }
+ void set_num_threads(int num_threads) { num_threads_ = num_threads; }
+
+ string get_ip() { return ip_; }
+ void set_ip(string ip) { ip_ = ip; }
+
+ string get_port() { return port_; }
+ void set_port(string port) { port_ = port; }
+
+ string get_log_path() { return log_path_; }
+ void set_log_path(string log_path) { log_path_ = log_path; }
+
+ string get_log_level() { return log_level_; }
+ void set_log_level(string log_level) { log_level_ = log_level; }
+
+ private:
+ string ip_;
+ string port_;
+ string log_path_;
+ string log_level_;
+ int map_id_ = 0;
+ int num_threads_ = 8;
+ int num_reqs_ = 2048;
+};
+
+#endif // PMPOOL_CONFIG_H_
diff --git a/rpmp/benchmark/local_allocate.cc b/rpmp/benchmark/local_allocate.cc
index 8faadd68..50afbc0c 100644
--- a/rpmp/benchmark/local_allocate.cc
+++ b/rpmp/benchmark/local_allocate.cc
@@ -29,7 +29,7 @@ std::mutex mtx;
uint64_t count = 0;
char str[1048576];
-void func(AllocatorProxy *proxy, int index) {
+void func(std::shared_ptr proxy, int index) {
while (true) {
std::unique_lock lk(mtx);
uint64_t count_ = count++;
@@ -47,7 +47,7 @@ int main() {
std::shared_ptr config = std::make_shared();
config->init(0, nullptr);
std::shared_ptr log = std::make_shared(config.get());
- auto allocatorProxy = new AllocatorProxy(config.get(), log.get(), nullptr);
+ auto allocatorProxy = std::make_shared(config, log, nullptr);
allocatorProxy->init();
std::vector threads;
memset(str, '0', 1048576);
diff --git a/rpmp/benchmark/put_and_get.cc b/rpmp/benchmark/put_and_get.cc
new file mode 100644
index 00000000..473a2664
--- /dev/null
+++ b/rpmp/benchmark/put_and_get.cc
@@ -0,0 +1,141 @@
+/*
+ * Filename: /mnt/spark-pmof/tool/rpmp/benchmark/allocate_perf.cc
+ * Path: /mnt/spark-pmof/tool/rpmp/benchmark
+ * Created Date: Friday, December 20th 2019, 8:29:23 am
+ * Author: root
+ *
+ * Copyright (c) 2019 Intel
+ */
+
+#include
+#include
+#include // NOLINT
+#include "Config.h"
+#include "pmpool/Base.h"
+#include "pmpool/client/PmPoolClient.h"
+
+uint64_t timestamp_now() {
+ return std::chrono::high_resolution_clock::now().time_since_epoch() /
+ std::chrono::milliseconds(1);
+}
+
+char str[1048576];
+int numReqs = 2048;
+
+bool comp(char* str, char* str_read, uint64_t size) {
+ auto res = memcmp(str, str_read, size);
+ if (res != 0) {
+ fprintf(stderr,
+ "** strcmp is %d, read res is not aligned with wrote. **\nreaded "
+ "content is \n",
+ res);
+ for (int i = 0; i < 100; i++) {
+ fprintf(stderr, "%X ", *(str_read + i));
+ }
+ fprintf(stderr, " ...\nwrote content is \n");
+ for (int i = 0; i < 100; i++) {
+ fprintf(stderr, "%X ", *(str + i));
+ }
+ fprintf(stderr, " ...\n");
+ }
+ return res == 0;
+}
+
+void get(int map_id, int start, int end, std::shared_ptr client) {
+ int count = start;
+ while (count < end) {
+ std::string key =
+ "block_" + std::to_string(map_id) + "_" + std::to_string(count++);
+ char str_read[1048576];
+ client->begin_tx();
+ client->get(key, str_read, 1048576);
+ client->end_tx();
+ if (comp(str, str_read, 1048576) == false) {
+ throw;
+ }
+ }
+}
+
+void put(int map_id, int start, int end, std::shared_ptr client) {
+ int count = start;
+ while (count < end) {
+ std::string key =
+ "block_" + std::to_string(map_id) + "_" + std::to_string(count++);
+ client->begin_tx();
+ client->put(key, str, 1048576);
+ client->end_tx();
+ }
+}
+
+int main(int argc, char** argv) {
+ /// initialize Config class
+ std::shared_ptr config = std::make_shared();
+ CHK_ERR("config init", config->init(argc, argv));
+
+ char temp[] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K',
+ 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
+ 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f'};
+ for (int i = 0; i < 1048576 / 32; i++) {
+ memcpy(str + i * 32, temp, 32);
+ }
+
+ int threads = config->get_num_threads();
+ int map_id = config->get_map_id();
+ numReqs = config->get_num_reqs();
+ std::string host = config->get_ip();
+ std::string port = config->get_port();
+
+ std::cout << "=================== Put and get ======================="
+ << std::endl;
+ std::cout << "RPMP server is " << host << ":" << port << std::endl;
+ std::cout << "Total Num Requests is " << numReqs << std::endl;
+ std::cout << "Total Num Threads is " << threads << std::endl;
+ std::cout << "Block key pattern is "
+ << "block_" << map_id << "_*" << std::endl;
+
+ auto client = std::make_shared(host, port);
+ client->init();
+ std::cout << "start put." << std::endl;
+ int start = 0;
+ int step = numReqs / threads;
+ std::vector> threads_1;
+ uint64_t begin = timestamp_now();
+ for (int i = 0; i < threads; i++) {
+ auto t =
+ std::make_shared(put, map_id, start, start + step, client);
+ threads_1.push_back(t);
+ start += step;
+ }
+ for (auto thread : threads_1) {
+ thread->join();
+ }
+ uint64_t end = timestamp_now();
+ std::cout << "[block_" << map_id << "_*]"
+ << "pmemkv put test: 1048576 "
+ << " bytes test, consumes " << (end - begin) / 1000.0
+ << "s, throughput is " << numReqs / ((end - begin) / 1000.0)
+ << "MB/s" << std::endl;
+
+ std::cout << "start get." << std::endl;
+ std::vector> threads_2;
+ begin = timestamp_now();
+ start = 0;
+ for (int i = 0; i < threads; i++) {
+ auto t =
+ std::make_shared(get, map_id, start, start + step, client);
+ threads_2.push_back(t);
+ start += step;
+ }
+ for (auto thread : threads_2) {
+ thread->join();
+ }
+ end = timestamp_now();
+ std::cout << "[block_" << map_id << "_*]"
+ << "pmemkv get test: 1048576 "
+ << " bytes test, consumes " << (end - begin) / 1000.0
+ << "s, throughput is " << numReqs / ((end - begin) / 1000.0)
+ << "MB/s" << std::endl;
+
+ client.reset();
+ return 0;
+}
diff --git a/rpmp/benchmark/put_and_read.cc b/rpmp/benchmark/put_and_read.cc
new file mode 100644
index 00000000..05dff4bb
--- /dev/null
+++ b/rpmp/benchmark/put_and_read.cc
@@ -0,0 +1,139 @@
+/*
+ * Filename: /mnt/spark-pmof/tool/rpmp/benchmark/allocate_perf.cc
+ * Path: /mnt/spark-pmof/tool/rpmp/benchmark
+ * Created Date: Friday, December 20th 2019, 8:29:23 am
+ * Author: root
+ *
+ * Copyright (c) 2019 Intel
+ */
+
+#include
+#include
+#include // NOLINT
+#include "Config.h"
+#include "pmpool/Base.h"
+#include "pmpool/client/PmPoolClient.h"
+
+uint64_t timestamp_now() {
+ return std::chrono::high_resolution_clock::now().time_since_epoch() /
+ std::chrono::milliseconds(1);
+}
+
+char str[1048576];
+int numReqs = 2048;
+
+bool comp(char* str, char* str_read, uint64_t size) {
+ auto res = memcmp(str, str_read, size);
+ if (res != 0) {
+ fprintf(stderr,
+ "strcmp is %d, read res is not aligned with wrote. readed "
+ "content is \n",
+ res);
+ for (int i = 0; i < size; i++) {
+ fprintf(stderr, "%X ", *(str_read + i));
+ }
+ fprintf(stderr, "\n wrote content is \n");
+ for (int i = 0; i < size; i++) {
+ fprintf(stderr, "%X ", *(str + i));
+ }
+ fprintf(stderr, "\n");
+ }
+ return res == 0;
+}
+
+void get(std::vector addresses,
+ std::shared_ptr client) {
+ for (auto bm : addresses) {
+ char str_read[1048576];
+ client->read(bm.address, str_read, bm.size);
+ comp(str, str_read, bm.size);
+ }
+}
+
+void put(int map_id, int start, int end, std::shared_ptr client,
+ std::vector* addresses) {
+ int count = start;
+ while (count < end) {
+ std::string key =
+ "block_" + std::to_string(map_id) + "_" + std::to_string(count++);
+ client->begin_tx();
+ client->put(key, str, 1048576);
+ auto res = client->getMeta(key);
+ for (auto bm : res) {
+ (*addresses).push_back(bm);
+ }
+ client->end_tx();
+ }
+}
+
+int main(int argc, char** argv) {
+ /// initialize Config class
+ std::shared_ptr config = std::make_shared();
+ CHK_ERR("config init", config->init(argc, argv));
+
+ char temp[] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K',
+ 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
+ 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f'};
+ for (int i = 0; i < 1048576 / 32; i++) {
+ memcpy(str + i * 32, temp, 32);
+ }
+
+ int threads = config->get_num_threads();
+ int map_id = config->get_map_id();
+ numReqs = config->get_num_reqs();
+ std::string host = config->get_ip();
+ std::string port = config->get_port();
+
+ std::cout << "=================== Put and get ======================="
+ << std::endl;
+ std::cout << "RPMP server is " << host << ":" << port << std::endl;
+ std::cout << "Total Num Requests is " << numReqs << std::endl;
+ std::cout << "Total Num Threads is " << threads << std::endl;
+ std::cout << "Block key pattern is "
+ << "block_" << map_id << "_*" << std::endl;
+
+ auto client = std::make_shared(host, port);
+ client->init();
+ std::cout << "start put." << std::endl;
+ int start = 0;
+ int step = numReqs / threads;
+ std::vector> addresses_list;
+ addresses_list.resize(threads);
+ std::vector> threads_1;
+ uint64_t begin = timestamp_now();
+ for (int i = 0; i < threads; i++) {
+ auto t = std::make_shared(put, map_id, start, start + step,
+ client, &addresses_list[i]);
+ threads_1.push_back(t);
+ start += step;
+ }
+ for (auto thread : threads_1) {
+ thread->join();
+ }
+ uint64_t end = timestamp_now();
+ std::cout << "[block_" << map_id << "_*]"
+ << "pmemkv put test: 1048576 "
+ << " bytes test, consumes " << (end - begin) / 1000.0
+ << "s, throughput is " << numReqs / ((end - begin) / 1000.0)
+ << "MB/s" << std::endl;
+
+ std::cout << "start get." << std::endl;
+ std::vector> threads_2;
+ begin = timestamp_now();
+ for (int i = 0; i < threads; i++) {
+ auto t = std::make_shared(get, addresses_list[i], client);
+ threads_2.push_back(t);
+ }
+ for (auto thread : threads_2) {
+ thread->join();
+ }
+ end = timestamp_now();
+ std::cout << "[block_" << map_id << "_*]"
+ << "pmemkv get test: 1048576 "
+ << " bytes test, consumes " << (end - begin) / 1000.0
+ << "s, throughput is " << numReqs / ((end - begin) / 1000.0)
+ << "MB/s" << std::endl;
+
+ client.reset();
+ return 0;
+}
diff --git a/rpmp/benchmark/remote_allocate_write.cc b/rpmp/benchmark/remote_allocate_write.cc
index 2a1f5da7..a57d10a2 100644
--- a/rpmp/benchmark/remote_allocate_write.cc
+++ b/rpmp/benchmark/remote_allocate_write.cc
@@ -8,8 +8,10 @@
*/
#include
-#include // NOLINT
#include
+#include // NOLINT
+#include "pmpool/Base.h"
+#include "pmpool/Config.h"
#include "pmpool/client/PmPoolClient.h"
uint64_t timestamp_now() {
@@ -20,12 +22,12 @@ uint64_t timestamp_now() {
std::atomic count = {0};
std::mutex mtx;
char str[1048576];
-std::vector clients;
+std::vector> clients;
std::map> addresses;
void func1(int i) {
while (true) {
- uint64_t count_ = count++;
+ auto count_ = count++;
if (count_ < 20480) {
clients[i]->begin_tx();
if (addresses.count(i) != 0) {
@@ -43,8 +45,11 @@ void func1(int i) {
}
}
-int main() {
- std::vector threads;
+int main(int argc, char **argv) {
+ /// initialize Config class
+ std::shared_ptr config = std::make_shared();
+ CHK_ERR("config init", config->init(argc, argv));
+ std::vector> threads;
memset(str, '0', 1048576);
int num = 0;
@@ -52,7 +57,8 @@ int main() {
num = 0;
count = 0;
for (int i = 0; i < 4; i++) {
- PmPoolClient *client = new PmPoolClient("172.168.0.40", "12346");
+ auto client =
+ std::make_shared(config->get_ip(), config->get_port());
client->begin_tx();
client->init();
client->end_tx();
@@ -61,12 +67,11 @@ int main() {
}
uint64_t start = timestamp_now();
for (int i = 0; i < num; i++) {
- auto t = new std::thread(func1, i);
+ auto t = std::make_shared(func1, i);
threads.push_back(t);
}
for (int i = 0; i < num; i++) {
threads[i]->join();
- delete threads[i];
}
uint64_t end = timestamp_now();
std::cout << "pmemkv put test: 1048576 "
@@ -84,7 +89,6 @@ int main() {
std::cout << "freed." << std::endl;
for (int i = 0; i < num; i++) {
clients[i]->wait();
- delete clients[i];
}
return 0;
}
diff --git a/rpmp/benchmark/remote_put.cc b/rpmp/benchmark/remote_put.cc
new file mode 100644
index 00000000..bb35696a
--- /dev/null
+++ b/rpmp/benchmark/remote_put.cc
@@ -0,0 +1,108 @@
+/*
+ * Filename: /mnt/spark-pmof/tool/rpmp/benchmark/allocate_perf.cc
+ * Path: /mnt/spark-pmof/tool/rpmp/benchmark
+ * Created Date: Friday, December 20th 2019, 8:29:23 am
+ * Author: root
+ *
+ * Copyright (c) 2019 Intel
+ */
+
+#include
+#include
+#include // NOLINT
+#include "pmpool/Base.h"
+#include "pmpool/Config.h"
+#include "pmpool/client/PmPoolClient.h"
+
+uint64_t timestamp_now() {
+ return std::chrono::high_resolution_clock::now().time_since_epoch() /
+ std::chrono::milliseconds(1);
+}
+
+int count = 0;
+std::mutex mtx;
+std::vector keys;
+char str[1048576];
+int numReqs = 2048;
+
+void func1(std::shared_ptr client) {
+ while (true) {
+ std::unique_lock lk(mtx);
+ uint64_t count_ = count++;
+ lk.unlock();
+ if (count_ < numReqs) {
+ char str_read[1048576];
+ client->begin_tx();
+ client->put(keys[count_], str, 1048576);
+ auto res = client->getMeta(keys[count_]);
+ // printf("put and get Meta of key %s\n", keys[count_].c_str());
+ for (auto bm : res) {
+ client->read(bm.address, str_read, bm.size);
+ // printf("read of key %s, info is [%d]%ld-%d\n", keys[count_].c_str(),
+ // bm.r_key, bm.address, bm.size);
+ auto res = memcmp(str, str_read, 1048576);
+ if (res != 0) {
+ fprintf(stderr,
+ "strcmp is %d, read res is not aligned with wrote. readed "
+ "content is \n",
+ res);
+ for (int i = 0; i < 1048576; i++) {
+ fprintf(stderr, "%X ", *(str_read + i));
+ }
+ fprintf(stderr, "\n wrote content is \n");
+ for (int i = 0; i < 1048576; i++) {
+ fprintf(stderr, "%X ", *(str + i));
+ }
+ fprintf(stderr, "\n");
+ }
+ }
+ client->end_tx();
+ } else {
+ break;
+ }
+ }
+}
+
+int main(int argc, char** argv) {
+ /// initialize Config class
+ std::shared_ptr config = std::make_shared();
+ CHK_ERR("config init", config->init(argc, argv));
+ auto client =
+ std::make_shared(config->get_ip(), config->get_port());
+ char temp[] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K',
+ 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
+ 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f'};
+ for (int i = 0; i < 1048576 / 32; i++) {
+ memcpy(str + i * 32, temp, 32);
+ }
+ for (int i = 0; i < 20480; i++) {
+ keys.emplace_back("block_" + std::to_string(i));
+ }
+ client->init();
+
+ int threads = 4;
+ std::cout << "start put." << std::endl;
+ std::vector> threads_1;
+ uint64_t start = timestamp_now();
+ for (int i = 0; i < threads; i++) {
+ auto t = std::make_shared(func1, client);
+ threads_1.push_back(t);
+ }
+ for (auto thread : threads_1) {
+ thread->join();
+ }
+ uint64_t end = timestamp_now();
+ std::cout << "pmemkv put test: 1048576 "
+ << " bytes test, consumes " << (end - start) / 1000.0
+ << "s, throughput is " << numReqs / ((end - start) / 1000.0)
+ << "MB/s" << std::endl;
+ client.reset();
+ /*for (int i = 0; i < 20480; i++) {
+ client->begin_tx();
+ client->del(keys[i]);
+ client->end_tx();
+ }
+ std::cout << "freed." << std::endl;
+ client->wait();*/
+ return 0;
+}
diff --git a/rpmp/include/spdlog b/rpmp/include/spdlog
index cca004ef..5ca5cbd4 160000
--- a/rpmp/include/spdlog
+++ b/rpmp/include/spdlog
@@ -1 +1 @@
-Subproject commit cca004efe4e66136a5f9f37e007d28a23bb729e0
+Subproject commit 5ca5cbd44739576ee90d3d9c21ba8cbf7e11ff5c
diff --git a/rpmp/main.cc b/rpmp/main.cc
index c8fe9d16..2f97fbd8 100644
--- a/rpmp/main.cc
+++ b/rpmp/main.cc
@@ -28,7 +28,7 @@ int ServerMain(int argc, char **argv) {
std::shared_ptr log = std::make_shared(config.get());
/// initialize DataServer class
std::shared_ptr dataServer =
- std::make_shared(config.get(), log.get());
+ std::make_shared(config, log);
log->get_file_log()->info("start to initialize data server.");
CHK_ERR("data server init", dataServer->init());
log->get_file_log()->info("data server initailized.");
diff --git a/rpmp/pmpool/AllocatorProxy.h b/rpmp/pmpool/AllocatorProxy.h
index cddc5830..4a67d402 100644
--- a/rpmp/pmpool/AllocatorProxy.h
+++ b/rpmp/pmpool/AllocatorProxy.h
@@ -11,22 +11,23 @@
#define PMPOOL_ALLOCATORPROXY_H_
#include
+#include
#include
#include
-#include
#include
+#include
#include "Allocator.h"
+#include "Base.h"
#include "Config.h"
#include "DataServer.h"
#include "Log.h"
#include "PmemAllocator.h"
-#include "Base.h"
using std::atomic;
using std::make_shared;
-using std::unordered_map;
using std::string;
+using std::unordered_map;
using std::vector;
/**
@@ -37,24 +38,21 @@ using std::vector;
class AllocatorProxy {
public:
AllocatorProxy() = delete;
- AllocatorProxy(Config *config, Log *log, NetworkServer *networkServer)
+ AllocatorProxy(std::shared_ptr config, std::shared_ptr log,
+ std::shared_ptr networkServer)
: config_(config), log_(log) {
vector paths = config_->get_pool_paths();
vector sizes = config_->get_pool_sizes();
assert(paths.size() == sizes.size());
for (int i = 0; i < paths.size(); i++) {
- DiskInfo *diskInfo = new DiskInfo(paths[i], sizes[i]);
+ auto diskInfo = std::make_shared(paths[i], sizes[i]);
diskInfos_.push_back(diskInfo);
allocators_.push_back(
- new PmemObjAllocator(log_, diskInfo, networkServer, i));
+ std::make_shared(log_, diskInfo, networkServer, i));
}
}
~AllocatorProxy() {
- for (int i = 0; i < config_->get_pool_paths().size(); i++) {
- delete allocators_[i];
- delete diskInfos_[i];
- }
allocators_.clear();
diskInfos_.clear();
}
@@ -76,6 +74,7 @@ class AllocatorProxy {
addr = allocators_[index % diskInfos_.size()]->allocate_and_write(
size, content);
}
+ return addr;
}
int write(uint64_t address, const char *content, uint64_t size) {
@@ -112,14 +111,16 @@ class AllocatorProxy {
return allocators_[wid]->get_rma_chunk();
}
- void cache_chunk(uint64_t key, uint64_t address, uint64_t size) {
- block_meta bm = {address, size};
+ void cache_chunk(uint64_t key, uint64_t address, uint64_t size, int r_key) {
+ block_meta bm = {address, size, r_key};
cache_chunk(key, bm);
}
void cache_chunk(uint64_t key, block_meta bm) {
if (kv_meta_map.count(key)) {
+ printf("key exists\n");
kv_meta_map[key].push_back(bm);
+ // kv_meta_map.erase(key);
} else {
vector bml;
bml.push_back(bm);
@@ -135,16 +136,16 @@ class AllocatorProxy {
}
void del_chunk(uint64_t key) {
- if (kv_meta_map.count(key)) {
+ if (kv_meta_map.count(key)) {
kv_meta_map.erase(key);
}
}
private:
- Config *config_;
- Log *log_;
- vector allocators_;
- vector diskInfos_;
+ std::shared_ptr config_;
+ std::shared_ptr log_;
+ vector> allocators_;
+ vector> diskInfos_;
atomic buffer_id_{0};
unordered_map> kv_meta_map;
};
diff --git a/rpmp/pmpool/Base.h b/rpmp/pmpool/Base.h
index 685c0d60..4cfa38ed 100644
--- a/rpmp/pmpool/Base.h
+++ b/rpmp/pmpool/Base.h
@@ -44,11 +44,19 @@ struct RequestReplyMsg {
};
struct block_meta {
- block_meta() : block_meta(0, 0) {}
+ block_meta() : block_meta(0, 0, 0) {}
block_meta(uint64_t _address, uint64_t _size)
: address(_address), size(_size) {}
+ block_meta(uint64_t _address, uint64_t _size, int _r_key)
+ : address(_address), size(_size), r_key(_r_key) {}
+ void set_rKey(int _r_key) { r_key = _r_key; }
+ std::string ToString() {
+ return std::to_string(r_key) + "-" + std::to_string(address) + ":" +
+ std::to_string(size);
+ }
uint64_t address;
uint64_t size;
+ int r_key;
};
#endif // PMPOOL_BASE_H_
diff --git a/rpmp/pmpool/CMakeLists.txt b/rpmp/pmpool/CMakeLists.txt
index 64c82108..f5dd3656 100644
--- a/rpmp/pmpool/CMakeLists.txt
+++ b/rpmp/pmpool/CMakeLists.txt
@@ -1,4 +1,8 @@
-add_library(pmpool SHARED DataServer.cc Protocol.cc Event.cc NetworkServer.cc hash/xxhash.cc client/PmPoolClient.cc client/NetworkClient.cc client/native/com_intel_rpmp_PmPoolClient.cc)
+add_library(pmpool_client_jni SHARED Event.cc client/PmPoolClient.cc client/NetworkClient.cc client/native/com_intel_rpmp_PmPoolClient.cc)
+target_link_libraries(pmpool_client_jni LINK_PUBLIC ${Boost_LIBRARIES} hpnl)
+set_target_properties(pmpool_client_jni PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
+
+add_library(pmpool SHARED DataServer.cc Protocol.cc Event.cc NetworkServer.cc hash/xxhash.cc client/PmPoolClient.cc client/NetworkClient.cc)
target_link_libraries(pmpool LINK_PUBLIC ${Boost_LIBRARIES} hpnl pmemobj)
set_target_properties(pmpool PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
diff --git a/rpmp/pmpool/Config.h b/rpmp/pmpool/Config.h
index 2a1b9526..d1edacbb 100644
--- a/rpmp/pmpool/Config.h
+++ b/rpmp/pmpool/Config.h
@@ -17,12 +17,14 @@
#include
-using boost::program_options::error;
+using namespace boost::program_options;
+using namespace std;
+/*using boost::program_options::error;
using boost::program_options::options_description;
using boost::program_options::value;
using boost::program_options::variables_map;
using std::string;
-using std::vector;
+using std::vector;*/
/**
* @brief This class represents the current RPMP configuration.
@@ -44,15 +46,21 @@ class Config {
"set network buffer number")("network_worker,nw",
value()->default_value(1),
"set network wroker number")(
- "paths,ps", value>(), "set memory pool path")(
- "sizes,ss", value>(), "set memory pool size")(
+ "paths,ps", value>()->multitoken(),
+ "set memory pool path")("sizes,ss",
+ value>()->multitoken(),
+ "set memory pool size")(
+ "task_set, t", value>()->multitoken(),
+ "set affinity for each device")(
"log,l", value()->default_value("/tmp/rpmp.log"),
"set rpmp log file path")("log_level,ll",
value()->default_value("warn"),
"set log level");
+ command_line_parser parser{argc, argv};
+ parsed_options parsed_options = parser.options(desc).run();
variables_map vm;
- store(parse_command_line(argc, argv, desc), vm);
+ store(parsed_options, vm);
notify(vm);
if (vm.count("help")) {
@@ -64,18 +72,31 @@ class Config {
set_network_buffer_size(vm["network_buffer_size"].as());
set_network_buffer_num(vm["network_buffer_num"].as());
set_network_worker_num(vm["network_worker"].as());
- pool_paths_.push_back("/dev/dax0.0");
- pool_paths_.push_back("/dev/dax0.1");
- pool_paths_.push_back("/dev/dax1.0");
- pool_paths_.push_back("/dev/dax1.1");
- sizes_.push_back(126833655808L);
- sizes_.push_back(126833655808L);
- sizes_.push_back(126833655808L);
- sizes_.push_back(126833655808L);
- affinities_.push_back(2);
- affinities_.push_back(41);
- affinities_.push_back(22);
- affinities_.push_back(60);
+ // pool_paths_.push_back("/dev/dax0.0");
+ if (vm.count("sizes")) {
+ set_pool_sizes(vm["sizes"].as>());
+ }
+ if (vm.count("paths")) {
+ set_pool_paths(vm["paths"].as>());
+ } else {
+ std::cerr << "No input device!!" << std::endl;
+ throw;
+ }
+ if (pool_paths_.size() != sizes_.size()) {
+ if (sizes_.size() < pool_paths_.size() && !sizes_.empty()) {
+ auto first = sizes_[0];
+ sizes_.resize(pool_paths_.size(), first);
+ } else if (sizes_.size() > pool_paths_.size()) {
+ sizes_.resize(pool_paths_.size());
+ } else {
+ throw 1;
+ }
+ }
+ if (vm.count("task_set")) {
+ set_affinities_(vm["task_set"].as>());
+ } else {
+ affinities_.resize(pool_paths_.size(), -1);
+ }
set_log_path(vm["log"].as());
set_log_level(vm["log_level"].as());
} catch (const error &ex) {
@@ -115,7 +136,18 @@ class Config {
int get_pool_size() { return sizes_.size(); }
- std::vector get_affinities_() { return affinities_; }
+ void set_affinities_(vector affinities) {
+ if (affinities.size() < pool_paths_.size()) {
+ affinities_.resize(pool_paths_.size(), -1);
+ } else {
+ for (int i = 0; i < pool_paths_.size(); i++) {
+ affinities_.push_back(affinities[i]);
+ std::cout << pool_paths_[i] << " task_set to " << affinities[i]
+ << std::endl;
+ }
+ }
+ }
+ std::vector get_affinities_() { return affinities_; }
string get_log_path() { return log_path_; }
void set_log_path(string log_path) { log_path_ = log_path; }
@@ -131,7 +163,7 @@ class Config {
int network_worker_num_;
vector pool_paths_;
vector sizes_;
- vector affinities_;
+ vector affinities_;
string log_path_;
string log_level_;
};
diff --git a/rpmp/pmpool/DataServer.cc b/rpmp/pmpool/DataServer.cc
index 47859c50..638a1acb 100644
--- a/rpmp/pmpool/DataServer.cc
+++ b/rpmp/pmpool/DataServer.cc
@@ -9,14 +9,15 @@
#include "pmpool/DataServer.h"
-#include "AllocatorProxy.h"
-#include "Config.h"
-#include "Digest.h"
-#include "NetworkServer.h"
-#include "Protocol.h"
-#include "Log.h"
+#include "pmpool/AllocatorProxy.h"
+#include "pmpool/Config.h"
+#include "pmpool/Digest.h"
+#include "pmpool/Log.h"
+#include "pmpool/NetworkServer.h"
+#include "pmpool/Protocol.h"
-DataServer::DataServer(Config *config, Log *log) : config_(config), log_(log) {}
+DataServer::DataServer(std::shared_ptr config, std::shared_ptr log)
+ : config_(config), log_(log) {}
int DataServer::init() {
networkServer_ = std::make_shared(config_, log_);
@@ -24,12 +25,12 @@ int DataServer::init() {
log_->get_file_log()->info("network server initialized.");
allocatorProxy_ =
- std::make_shared(config_, log_, networkServer_.get());
+ std::make_shared(config_, log_, networkServer_);
CHK_ERR("allocator proxy init", allocatorProxy_->init());
log_->get_file_log()->info("allocator proxy initialized.");
- protocol_ = std::make_shared(config_, log_, networkServer_.get(),
- allocatorProxy_.get());
+ protocol_ = std::make_shared(config_, log_, networkServer_,
+ allocatorProxy_);
CHK_ERR("protocol init", protocol_->init());
log_->get_file_log()->info("protocol initialized.");
diff --git a/rpmp/pmpool/DataServer.h b/rpmp/pmpool/DataServer.h
index 14e70ec1..dde9c8f8 100644
--- a/rpmp/pmpool/DataServer.h
+++ b/rpmp/pmpool/DataServer.h
@@ -3,15 +3,15 @@
* Path: /mnt/spark-pmof/tool/rpmp/pmpool
* Created Date: Thursday, November 7th 2019, 3:48:52 pm
* Author: root
- *
+ *
* Copyright (c) 2019 Intel
*/
#ifndef PMPOOL_DATASERVER_H_
#define PMPOOL_DATASERVER_H_
-#include
#include
+#include
#include
@@ -25,18 +25,20 @@ class Log;
/**
* @brief DataServer is designed as distributed remote memory pool.
- * DataServer on every node communicated with each other to guarantee data consistency.
- *
+ * DataServer on every node communicated with each other to guarantee data
+ * consistency.
+ *
*/
class DataServer {
public:
DataServer() = delete;
- explicit DataServer(Config* config, Log* log);
+ explicit DataServer(std::shared_ptr config, std::shared_ptr log);
int init();
void wait();
+
private:
- Config* config_;
- Log* log_;
+ std::shared_ptr config_;
+ std::shared_ptr log_;
std::shared_ptr networkServer_;
std::shared_ptr allocatorProxy_;
std::shared_ptr protocol_;
diff --git a/rpmp/pmpool/Event.cc b/rpmp/pmpool/Event.cc
index 2415c58d..a118ead3 100644
--- a/rpmp/pmpool/Event.cc
+++ b/rpmp/pmpool/Event.cc
@@ -21,6 +21,7 @@ Request::Request(char *data, uint64_t size, Connection *con) : size_(size) {
}
Request::~Request() {
+ const std::lock_guard lock(data_lock_);
if (data_ != nullptr) {
std::free(data_);
data_ = nullptr;
@@ -30,87 +31,102 @@ Request::~Request() {
RequestContext &Request::get_rc() { return requestContext_; }
void Request::encode() {
+ const std::lock_guard lock(data_lock_);
OpType rt = requestContext_.type;
assert(rt == ALLOC || rt == FREE || rt == WRITE || rt == READ);
- requestMsg_.type = requestContext_.type;
- requestMsg_.rid = requestContext_.rid;
- requestMsg_.address = requestContext_.address;
- requestMsg_.src_address = requestContext_.src_address;
- requestMsg_.src_rkey = requestContext_.src_rkey;
- requestMsg_.size = requestContext_.size;
- requestMsg_.key = requestContext_.key;
-
- size_ = sizeof(requestMsg_);
- data_ = static_cast(std::malloc(size_));
- memcpy(data_, &requestMsg_, size_);
+ size_ = sizeof(RequestMsg);
+ data_ = static_cast(std::malloc(sizeof(RequestMsg)));
+ RequestMsg *requestMsg = (RequestMsg *)data_;
+ requestMsg->type = requestContext_.type;
+ requestMsg->rid = requestContext_.rid;
+ requestMsg->address = requestContext_.address;
+ requestMsg->src_address = requestContext_.src_address;
+ requestMsg->src_rkey = requestContext_.src_rkey;
+ requestMsg->size = requestContext_.size;
+ requestMsg->key = requestContext_.key;
}
void Request::decode() {
- assert(size_ == sizeof(requestMsg_));
- memcpy(&requestMsg_, data_, size_);
- requestContext_.type = (OpType)requestMsg_.type;
- requestContext_.rid = requestMsg_.rid;
- requestContext_.address = requestMsg_.address;
- requestContext_.src_address = requestMsg_.src_address;
- requestContext_.src_rkey = requestMsg_.src_rkey;
- requestContext_.size = requestMsg_.size;
- requestContext_.key = requestMsg_.key;
+ const std::lock_guard lock(data_lock_);
+ assert(size_ == sizeof(RequestMsg));
+ RequestMsg *requestMsg = (RequestMsg *)data_;
+ requestContext_.type = (OpType)requestMsg->type;
+ requestContext_.rid = requestMsg->rid;
+ requestContext_.address = requestMsg->address;
+ requestContext_.src_address = requestMsg->src_address;
+ requestContext_.src_rkey = requestMsg->src_rkey;
+ requestContext_.size = requestMsg->size;
+ requestContext_.key = requestMsg->key;
}
-RequestReply::RequestReply(RequestReplyContext requestReplyContext)
+RequestReply::RequestReply(
+ std::shared_ptr requestReplyContext)
: data_(nullptr), size_(0), requestReplyContext_(requestReplyContext) {}
RequestReply::RequestReply(char *data, uint64_t size, Connection *con)
: size_(size) {
data_ = static_cast(std::malloc(size_));
memcpy(data_, data, size_);
- requestReplyContext_.con = con;
+ requestReplyContext_ = std::make_shared();
+ requestReplyContext_->con = con;
}
RequestReply::~RequestReply() {
+ const std::lock_guard lock(data_lock_);
if (data_ != nullptr) {
std::free(data_);
data_ = nullptr;
}
}
-RequestReplyContext &RequestReply::get_rrc() { return requestReplyContext_; }
+std::shared_ptr RequestReply::get_rrc() {
+ return requestReplyContext_;
+}
void RequestReply::encode() {
- requestReplyMsg_.type = (OpType)requestReplyContext_.type;
- requestReplyMsg_.success = requestReplyContext_.success;
- requestReplyMsg_.rid = requestReplyContext_.rid;
- requestReplyMsg_.address = requestReplyContext_.address;
- requestReplyMsg_.size = requestReplyContext_.size;
- requestReplyMsg_.key = requestReplyContext_.key;
- auto msg_size = sizeof(requestReplyMsg_);
+ const std::lock_guard lock(data_lock_);
+ RequestReplyMsg requestReplyMsg;
+ requestReplyMsg.type = (OpType)requestReplyContext_->type;
+ requestReplyMsg.success = requestReplyContext_->success;
+ requestReplyMsg.rid = requestReplyContext_->rid;
+ requestReplyMsg.address = requestReplyContext_->address;
+ requestReplyMsg.size = requestReplyContext_->size;
+ requestReplyMsg.key = requestReplyContext_->key;
+ auto msg_size = sizeof(requestReplyMsg);
size_ = msg_size;
/// copy data from block metadata list
uint32_t bml_size = 0;
- if (!requestReplyContext_.bml.empty()) {
- bml_size = sizeof(block_meta) * requestReplyContext_.bml.size();
+ if (!requestReplyContext_->bml.empty()) {
+ bml_size = sizeof(block_meta) * requestReplyContext_->bml.size();
size_ += bml_size;
}
data_ = static_cast(std::malloc(size_));
- memcpy(data_, &requestReplyMsg_, msg_size);
+ memcpy(data_, &requestReplyMsg, msg_size);
if (bml_size != 0) {
- memcpy(data_ + msg_size, &requestReplyContext_.bml[0], bml_size);
+ memcpy(data_ + msg_size, &requestReplyContext_->bml[0], bml_size);
}
}
void RequestReply::decode() {
- memcpy(&requestReplyMsg_, data_, size_);
- requestReplyContext_.type = (OpType)requestReplyMsg_.type;
- requestReplyContext_.success = requestReplyMsg_.success;
- requestReplyContext_.rid = requestReplyMsg_.rid;
- requestReplyContext_.address = requestReplyMsg_.address;
- requestReplyContext_.size = requestReplyMsg_.size;
- requestReplyContext_.key = requestReplyMsg_.key;
- if (size_ > sizeof(requestReplyMsg_)) {
- auto bml_size = size_ - sizeof(requestReplyMsg_);
- requestReplyContext_.bml.resize(bml_size / sizeof(block_meta));
- memcpy(&requestReplyContext_.bml[0], data_ + sizeof(requestReplyMsg_),
+ const std::lock_guard lock(data_lock_);
+ // memcpy(&requestReplyMsg_, data_, size_);
+ if (data_ == nullptr) {
+ std::string err_msg = "Decode with null data";
+ std::cerr << err_msg << std::endl;
+ throw;
+ }
+ RequestReplyMsg *requestReplyMsg = (RequestReplyMsg *)data_;
+ requestReplyContext_->type = (OpType)requestReplyMsg->type;
+ requestReplyContext_->success = requestReplyMsg->success;
+ requestReplyContext_->rid = requestReplyMsg->rid;
+ requestReplyContext_->address = requestReplyMsg->address;
+ requestReplyContext_->size = requestReplyMsg->size;
+ requestReplyContext_->key = requestReplyMsg->key;
+ if (size_ > sizeof(RequestReplyMsg)) {
+ auto bml_size = size_ - sizeof(RequestReplyMsg);
+ requestReplyContext_->bml.resize(bml_size / sizeof(block_meta));
+ memcpy(&requestReplyContext_->bml[0], data_ + sizeof(RequestReplyMsg),
bml_size);
}
}
diff --git a/rpmp/pmpool/Event.h b/rpmp/pmpool/Event.h
index 136b73f3..78432ca1 100644
--- a/rpmp/pmpool/Event.h
+++ b/rpmp/pmpool/Event.h
@@ -68,7 +68,7 @@ struct RequestReplyContext {
uint64_t key;
Connection* con;
Chunk* ck;
- vector bml;
+ vector bml;
};
template
@@ -88,19 +88,21 @@ inline void decode_(T* t, char* data, uint64_t size) {
class RequestReply {
public:
RequestReply() = delete;
- explicit RequestReply(RequestReplyContext requestReplyContext);
+ explicit RequestReply(
+ std::shared_ptr requestReplyContext);
RequestReply(char* data, uint64_t size, Connection* con);
~RequestReply();
- RequestReplyContext& get_rrc();
+ std::shared_ptr get_rrc();
void decode();
void encode();
private:
+ std::mutex data_lock_;
friend Protocol;
- char* data_;
- uint64_t size_;
- RequestReplyMsg requestReplyMsg_;
- RequestReplyContext requestReplyContext_;
+ char* data_ = nullptr;
+ uint64_t size_ = 0;
+ // RequestReplyMsg requestReplyMsg_;
+ std::shared_ptr requestReplyContext_;
};
typedef promise Promise;
@@ -126,13 +128,17 @@ class Request {
RequestContext& get_rc();
void encode();
void decode();
+ //#ifdef DEBUG
+ char* getData() { return data_; }
+ uint64_t getSize() { return size_; }
+ //#endif
private:
+ std::mutex data_lock_;
friend RequestHandler;
friend ClientRecvCallback;
char* data_;
uint64_t size_;
- RequestMsg requestMsg_;
RequestContext requestContext_;
};
diff --git a/rpmp/pmpool/Log.h b/rpmp/pmpool/Log.h
index d3701ce9..55ec9758 100644
--- a/rpmp/pmpool/Log.h
+++ b/rpmp/pmpool/Log.h
@@ -12,10 +12,10 @@
#include
-#include "Config.h"
-#include "spdlog/spdlog.h"
+#include "pmpool/Config.h"
#include "spdlog/sinks/basic_file_sink.h"
#include "spdlog/sinks/stdout_color_sinks.h"
+#include "spdlog/spdlog.h"
class Log {
public:
diff --git a/rpmp/pmpool/NetworkServer.cc b/rpmp/pmpool/NetworkServer.cc
index 2f731a79..62f0a396 100644
--- a/rpmp/pmpool/NetworkServer.cc
+++ b/rpmp/pmpool/NetworkServer.cc
@@ -9,13 +9,14 @@
#include "pmpool/NetworkServer.h"
-#include "Base.h"
-#include "Config.h"
-#include "Event.h"
-#include "Log.h"
-#include "buffer/CircularBuffer.h"
-
-NetworkServer::NetworkServer(Config *config, Log *log)
+#include "pmpool/Base.h"
+#include "pmpool/Config.h"
+#include "pmpool/Event.h"
+#include "pmpool/Log.h"
+#include "pmpool/buffer/CircularBuffer.h"
+
+NetworkServer::NetworkServer(std::shared_ptr config,
+ std::shared_ptr log)
: config_(config), log_(log) {
time = 0;
}
@@ -44,8 +45,8 @@ int NetworkServer::start() {
CHK_ERR("hpnl server listen", server_->listen(config_->get_ip().c_str(),
config_->get_port().c_str()));
- circularBuffer_ =
- std::make_shared(1024 * 1024, 4096, true, this);
+ circularBuffer_ = std::make_shared(1024 * 1024, 4096, true,
+ shared_from_this());
return 0;
}
@@ -59,7 +60,7 @@ void NetworkServer::unregister_rma_buffer(int buffer_id) {
server_->unreg_rma_buffer(buffer_id);
}
-void NetworkServer::get_dram_buffer(RequestReplyContext *rrc) {
+void NetworkServer::get_dram_buffer(std::shared_ptr rrc) {
char *buffer = circularBuffer_->get(rrc->size);
rrc->dest_address = (uint64_t)buffer;
@@ -76,13 +77,15 @@ void NetworkServer::get_dram_buffer(RequestReplyContext *rrc) {
rrc->ck = ck;
}
-void NetworkServer::reclaim_dram_buffer(RequestReplyContext *rrc) {
+void NetworkServer::reclaim_dram_buffer(
+ std::shared_ptr rrc) {
char *buffer_tmp = reinterpret_cast(rrc->dest_address);
circularBuffer_->put(buffer_tmp, rrc->size);
delete rrc->ck;
}
-void NetworkServer::get_pmem_buffer(RequestReplyContext *rrc, Chunk *base_ck) {
+void NetworkServer::get_pmem_buffer(std::shared_ptr rrc,
+ Chunk *base_ck) {
Chunk *ck = new Chunk();
ck->buffer = reinterpret_cast(rrc->dest_address);
ck->capacity = rrc->size;
@@ -92,13 +95,18 @@ void NetworkServer::get_pmem_buffer(RequestReplyContext *rrc, Chunk *base_ck) {
rrc->ck = ck;
}
-void NetworkServer::reclaim_pmem_buffer(RequestReplyContext *rrc) {
+void NetworkServer::reclaim_pmem_buffer(
+ std::shared_ptr rrc) {
if (rrc->ck != nullptr) {
delete rrc->ck;
}
}
-ChunkMgr *NetworkServer::get_chunk_mgr() { return chunkMgr_.get(); }
+uint64_t NetworkServer::get_rkey() {
+ return circularBuffer_->get_rma_chunk()->mr->key;
+}
+
+std::shared_ptr NetworkServer::get_chunk_mgr() { return chunkMgr_; }
void NetworkServer::set_recv_callback(Callback *callback) {
server_->set_recv_callback(callback);
@@ -123,12 +131,20 @@ void NetworkServer::send(char *data, uint64_t size, Connection *con) {
con->send(ck);
}
-void NetworkServer::read(RequestReply *rr) {
- RequestReplyContext rrc = rr->get_rrc();
- rrc.con->read(rrc.ck, 0, rrc.size, rrc.src_address, rrc.src_rkey);
+void NetworkServer::read(std::shared_ptr rr) {
+ auto rrc = rr->get_rrc();
+#ifdef DEBUG
+ printf("[NetworkServer::read] dest is %ld-%d, src is %ld-%d\n",
+ rrc->ck->buffer, rrc->ck->size, rrc->src_address, rrc->size);
+#endif
+ rrc->con->read(rrc->ck, 0, rrc->size, rrc->src_address, rrc->src_rkey);
}
-void NetworkServer::write(RequestReply *rr) {
- RequestReplyContext rrc = rr->get_rrc();
- rrc.con->write(rrc.ck, 0, rrc.size, rrc.src_address, rrc.src_rkey);
+void NetworkServer::write(std::shared_ptr rr) {
+ auto rrc = rr->get_rrc();
+#ifdef DEBUG
+ printf("[NetworkServer::write] src is %ld-%d, dest is %ld-%d\n",
+ rrc->ck->buffer, rrc->ck->size, rrc->src_address, rrc->size);
+#endif
+ rrc->con->write(rrc->ck, 0, rrc->size, rrc->src_address, rrc->src_rkey);
}
diff --git a/rpmp/pmpool/NetworkServer.h b/rpmp/pmpool/NetworkServer.h
index d125a83c..4bcdd7f0 100644
--- a/rpmp/pmpool/NetworkServer.h
+++ b/rpmp/pmpool/NetworkServer.h
@@ -17,7 +17,7 @@
#include
#include
-#include "RmaBufferRegister.h"
+#include "pmpool/RmaBufferRegister.h"
class CircularBuffer;
class Config;
@@ -30,10 +30,11 @@ class Log;
* asynchronous network library. RPMP currently supports RDMA iWarp and RoCE V2
* protocol.
*/
-class NetworkServer : public RmaBufferRegister {
+class NetworkServer : public RmaBufferRegister,
+ public std::enable_shared_from_this {
public:
NetworkServer() = delete;
- NetworkServer(Config *config, Log *log_);
+ NetworkServer(std::shared_ptr config, std::shared_ptr log_);
~NetworkServer();
int init();
int start();
@@ -47,19 +48,22 @@ class NetworkServer : public RmaBufferRegister {
void unregister_rma_buffer(int buffer_id) override;
/// get DRAM buffer from circular buffer pool.
- void get_dram_buffer(RequestReplyContext *rrc);
+ void get_dram_buffer(std::shared_ptr rrc);
/// reclaim DRAM buffer from circular buffer pool.
- void reclaim_dram_buffer(RequestReplyContext *rrc);
+ void reclaim_dram_buffer(std::shared_ptr rrc);
+
+ /// get rdma registered memory key for client.
+ uint64_t get_rkey();
/// get Persistent Memory buffer from circular buffer pool
- void get_pmem_buffer(RequestReplyContext *rrc, Chunk *ck);
+ void get_pmem_buffer(std::shared_ptr rrc, Chunk *ck);
/// reclaim Persistent Memory buffer form circular buffer pool
- void reclaim_pmem_buffer(RequestReplyContext *rrc);
+ void reclaim_pmem_buffer(std::shared_ptr rrc);
/// return the pointer of chunk manager.
- ChunkMgr *get_chunk_mgr();
+ std::shared_ptr get_chunk_mgr();
/// since the network implementation is asynchronous,
/// we need to define callback better before starting network service.
@@ -69,12 +73,12 @@ class NetworkServer : public RmaBufferRegister {
void set_write_callback(Callback *callback);
void send(char *data, uint64_t size, Connection *con);
- void read(RequestReply *rrc);
- void write(RequestReply *rrc);
+ void read(std::shared_ptr rrc);
+ void write(std::shared_ptr rrc);
private:
- Config *config_;
- Log* log_;
+ std::shared_ptr config_;
+ std::shared_ptr log_;
std::shared_ptr server_;
std::shared_ptr chunkMgr_;
std::shared_ptr circularBuffer_;
diff --git a/rpmp/pmpool/PmemAllocator.h b/rpmp/pmpool/PmemAllocator.h
index b6d279ce..ccfecd68 100644
--- a/rpmp/pmpool/PmemAllocator.h
+++ b/rpmp/pmpool/PmemAllocator.h
@@ -20,10 +20,10 @@
#include
#include
-#include "Allocator.h"
-#include "DataServer.h"
-#include "Log.h"
-#include "NetworkServer.h"
+#include "pmpool/Allocator.h"
+#include "pmpool/DataServer.h"
+#include "pmpool/Log.h"
+#include "pmpool/NetworkServer.h"
using std::shared_ptr;
using std::unordered_map;
@@ -68,8 +68,9 @@ enum types { BLOCK_ENTRY_TYPE, DATA_TYPE, MAX_TYPE };
class PmemObjAllocator : public Allocator {
public:
PmemObjAllocator() = delete;
- explicit PmemObjAllocator(Log *log, DiskInfo *diskInfos,
- NetworkServer *server, int wid)
+ explicit PmemObjAllocator(std::shared_ptr log,
+ std::shared_ptr diskInfos,
+ std::shared_ptr server, int wid)
: log_(log), diskInfo_(diskInfos), server_(server), wid_(wid) {}
~PmemObjAllocator() { close(); }
@@ -305,6 +306,7 @@ class PmemObjAllocator : public Allocator {
err_msg);
return -1;
}
+
pmemContext_.poid = pmemobj_root(pmemContext_.pop, sizeof(struct Base));
pmemContext_.base = (struct Base *)pmemobj_direct(pmemContext_.poid);
pmemContext_.base->head = OID_NULL;
@@ -315,9 +317,11 @@ class PmemObjAllocator : public Allocator {
base_ck = server_->register_rma_buffer(
reinterpret_cast(pmemContext_.pop), diskInfo_->size);
assert(base_ck != nullptr);
+ auto addr = reinterpret_cast(pmemContext_.pop);
log_->get_console_log()->info(
"successfully registered Persistent Memory(" + diskInfo_->path +
- ") as RDMA region");
+ ") as RDMA region, size is {0}",
+ diskInfo_->size);
}
return 0;
}
@@ -374,17 +378,19 @@ class PmemObjAllocator : public Allocator {
int free_meta() {
std::lock_guard l(mtx);
index_map.clear();
+ return 0;
}
private:
- Log *log_;
- DiskInfo *diskInfo_;
- NetworkServer *server_;
+ std::shared_ptr log_;
+ std::shared_ptr diskInfo_;
+ std::shared_ptr server_;
int wid_;
PmemContext pmemContext_;
std::mutex mtx;
unordered_map index_map;
uint64_t total = 0;
+ uint64_t disk_size = 0;
char str[1048576];
Chunk *base_ck;
};
diff --git a/rpmp/pmpool/Protocol.cc b/rpmp/pmpool/Protocol.cc
index cfa88f81..da59880f 100644
--- a/rpmp/pmpool/Protocol.cc
+++ b/rpmp/pmpool/Protocol.cc
@@ -11,65 +11,83 @@
#include
-#include "AllocatorProxy.h"
-#include "Config.h"
-#include "Digest.h"
-#include "Event.h"
-#include "Log.h"
-#include "NetworkServer.h"
-
-RecvCallback::RecvCallback(Protocol *protocol, ChunkMgr *chunkMgr)
+#include "pmpool/AllocatorProxy.h"
+#include "pmpool/Config.h"
+#include "pmpool/Digest.h"
+#include "pmpool/Event.h"
+#include "pmpool/Log.h"
+#include "pmpool/NetworkServer.h"
+
+RecvCallback::RecvCallback(std::shared_ptr protocol,
+ std::shared_ptr chunkMgr)
: protocol_(protocol), chunkMgr_(chunkMgr) {}
void RecvCallback::operator()(void *buffer_id, void *buffer_size) {
auto buffer_id_ = *static_cast(buffer_id);
Chunk *ck = chunkMgr_->get(buffer_id_);
assert(*static_cast(buffer_size) == ck->size);
- Request *request = new Request(reinterpret_cast(ck->buffer), ck->size,
- reinterpret_cast(ck->con));
+ auto request =
+ std::make_shared(reinterpret_cast(ck->buffer), ck->size,
+ reinterpret_cast(ck->con));
request->decode();
- protocol_->enqueue_recv_msg(request);
+ RequestMsg *requestMsg = (RequestMsg *)(request->getData());
+ if (requestMsg->type != 0) {
+ protocol_->enqueue_recv_msg(request);
+ } else {
+ std::cout << "[RecvCallback::RecvCallback][" << requestMsg->type
+ << "] size is " << ck->size << std::endl;
+ for (int i = 0; i < ck->size; i++) {
+ printf("%X ", *(request->getData() + i));
+ }
+ printf("\n");
+ }
chunkMgr_->reclaim(ck, static_cast(ck->con));
}
-ReadCallback::ReadCallback(Protocol *protocol) : protocol_(protocol) {}
+ReadCallback::ReadCallback(std::shared_ptr protocol)
+ : protocol_(protocol) {}
void ReadCallback::operator()(void *buffer_id, void *buffer_size) {
auto buffer_id_ = *static_cast(buffer_id);
protocol_->enqueue_rma_msg(buffer_id_);
}
-SendCallback::SendCallback(ChunkMgr *chunkMgr) : chunkMgr_(chunkMgr) {}
+SendCallback::SendCallback(
+ std::shared_ptr chunkMgr,
+ std::unordered_map> rrcMap)
+ : chunkMgr_(chunkMgr), rrcMap_(rrcMap) {}
void SendCallback::operator()(void *buffer_id, void *buffer_size) {
auto buffer_id_ = *static_cast(buffer_id);
auto ck = chunkMgr_->get(buffer_id_);
/// free the memory of class RequestReply
- auto reqeustReply = static_cast(ck->ptr);
- delete reqeustReply;
+ rrcMap_.erase(buffer_id_);
chunkMgr_->reclaim(ck, static_cast(ck->con));
}
-WriteCallback::WriteCallback(Protocol *protocol) : protocol_(protocol) {}
+WriteCallback::WriteCallback(std::shared_ptr protocol)
+ : protocol_(protocol) {}
void WriteCallback::operator()(void *buffer_id, void *buffer_size) {
auto buffer_id_ = *static_cast(buffer_id);
protocol_->enqueue_rma_msg(buffer_id_);
}
-RecvWorker::RecvWorker(Protocol *protocol, int index)
+RecvWorker::RecvWorker(std::shared_ptr protocol, int index)
: protocol_(protocol), index_(index) {
init = false;
}
int RecvWorker::entry() {
if (!init) {
- set_affinity(index_);
+ if (index_ != -1) {
+ set_affinity(index_);
+ }
init = true;
}
- Request *request;
+ std::shared_ptr request;
bool res = pendingRecvRequestQueue_.wait_dequeue_timed(
request, std::chrono::milliseconds(1000));
if (res) {
@@ -80,21 +98,23 @@ int RecvWorker::entry() {
void RecvWorker::abort() {}
-void RecvWorker::addTask(Request *request) {
+void RecvWorker::addTask(std::shared_ptr request) {
pendingRecvRequestQueue_.enqueue(request);
}
-ReadWorker::ReadWorker(Protocol *protocol, int index)
+ReadWorker::ReadWorker(std::shared_ptr protocol, int index)
: protocol_(protocol), index_(index) {
init = false;
}
int ReadWorker::entry() {
if (!init) {
- set_affinity(index_);
+ if (index_ != -1) {
+ set_affinity(index_);
+ }
init = true;
}
- RequestReply *requestReply;
+ std::shared_ptr requestReply;
bool res = pendingReadRequestQueue_.wait_dequeue_timed(
requestReply, std::chrono::milliseconds(1000));
if (res) {
@@ -105,16 +125,18 @@ int ReadWorker::entry() {
void ReadWorker::abort() {}
-void ReadWorker::addTask(RequestReply *rr) {
+void ReadWorker::addTask(std::shared_ptr rr) {
pendingReadRequestQueue_.enqueue(rr);
}
-FinalizeWorker::FinalizeWorker(Protocol *protocol) : protocol_(protocol) {}
+FinalizeWorker::FinalizeWorker(std::shared_ptr protocol)
+ : protocol_(protocol) {}
int FinalizeWorker::entry() {
- RequestReply *requestReply;
+ std::shared_ptr requestReply;
bool res = pendingRequestReplyQueue_.wait_dequeue_timed(
requestReply, std::chrono::milliseconds(1000));
+ assert(res);
if (res) {
protocol_->handle_finalize_msg(requestReply);
}
@@ -123,17 +145,22 @@ int FinalizeWorker::entry() {
void FinalizeWorker::abort() {}
-void FinalizeWorker::addTask(RequestReply *requestReply) {
+void FinalizeWorker::addTask(std::shared_ptr requestReply) {
pendingRequestReplyQueue_.enqueue(requestReply);
}
-Protocol::Protocol(Config *config, Log *log, NetworkServer *server,
- AllocatorProxy *allocatorProxy)
+Protocol::Protocol(std::shared_ptr config, std::shared_ptr log,
+ std::shared_ptr server,
+ std::shared_ptr allocatorProxy)
: config_(config),
log_(log),
networkServer_(server),
allocatorProxy_(allocatorProxy) {
time = 0;
+ char *buffer = new char[2147483648]();
+ memset(buffer, 0, 2147483648);
+ address = allocatorProxy_->allocate_and_write(2147483648, buffer, 0);
+ delete[] buffer;
}
Protocol::~Protocol() {
@@ -150,26 +177,28 @@ Protocol::~Protocol() {
}
int Protocol::init() {
- recvCallback_ =
- std::make_shared(this, networkServer_->get_chunk_mgr());
+ recvCallback_ = std::make_shared(
+ shared_from_this(), networkServer_->get_chunk_mgr());
sendCallback_ =
- std::make_shared(networkServer_->get_chunk_mgr());
- readCallback_ = std::make_shared(this);
- writeCallback_ = std::make_shared(this);
+ std::make_shared(networkServer_->get_chunk_mgr(), rrcMap_);
+ readCallback_ = std::make_shared(shared_from_this());
+ writeCallback_ = std::make_shared(shared_from_this());
for (int i = 0; i < config_->get_pool_size(); i++) {
- auto recvWorker = new RecvWorker(this, config_->get_affinities_()[i] - 1);
+ auto recvWorker = std::make_shared(
+ shared_from_this(), config_->get_affinities_()[i] - 1);
recvWorker->start();
- recvWorkers_.push_back(std::shared_ptr