diff --git a/.scalafmt.conf b/.scalafmt.conf new file mode 100644 index 00000000..356e6edb --- /dev/null +++ b/.scalafmt.conf @@ -0,0 +1,11 @@ +version = 2.4.2 +align = none +align.openParenDefnSite = false +align.openParenCallSite = false +align.tokens = [] +optIn = { + configStyleArguments = false +} +danglingParentheses = false +docstrings = JavaDoc +maxColumn = 98 \ No newline at end of file diff --git a/core/pom.xml b/core/pom.xml index b400270b..9ffb0098 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -13,7 +13,7 @@ jar com.intel.spark-pmof.java - java + spark-pmof-java 1.0 @@ -40,6 +40,11 @@ hpnl 0.5 + + com.intel.rpmp + rpmp + 0.1 + org.xerial sqlite-jdbc diff --git a/core/src/main/java/org/apache/spark/storage/pmof/RemotePersistentMemoryPool.java b/core/src/main/java/org/apache/spark/storage/pmof/RemotePersistentMemoryPool.java new file mode 100644 index 00000000..9c177d56 --- /dev/null +++ b/core/src/main/java/org/apache/spark/storage/pmof/RemotePersistentMemoryPool.java @@ -0,0 +1,72 @@ +package org.apache.spark.storage.pmof; + +import com.intel.rpmp.PmPoolClient; +import java.io.IOException; +import java.nio.ByteBuffer; + +public class RemotePersistentMemoryPool { + private static String remote_host; + private static String remote_port_str; + + private RemotePersistentMemoryPool(String remote_address, String remote_port) throws IOException { + pmPoolClient = new PmPoolClient(remote_address, remote_port); + } + + public static RemotePersistentMemoryPool getInstance(String remote_address, String remote_port) throws IOException { + synchronized (RemotePersistentMemoryPool.class) { + if (instance == null) { + if (instance == null) { + remote_host = remote_address; + remote_port_str = remote_port; + instance = new RemotePersistentMemoryPool(remote_address, remote_port); + } + } + } + return instance; + } + + public static int close() { + synchronized (RemotePersistentMemoryPool.class) { + if (instance != null) + return instance.dispose(); + else + return 0; + } + } + + public static String getHost() { + return remote_host; + } + + public static int getPort() { + return Integer.parseInt(remote_port_str); + } + + public int read(long address, long size, ByteBuffer byteBuffer) { + return pmPoolClient.read(address, size, byteBuffer); + } + + public long put(String key, ByteBuffer data, long size) { + return pmPoolClient.put(key, data, size); + } + + public long get(String key, long size, ByteBuffer data) { + return pmPoolClient.get(key, size, data); + } + + public long[] getMeta(String key) { + return pmPoolClient.getMeta(key); + } + + public int del(String key) throws IOException { + return pmPoolClient.del(key); + } + + public int dispose() { + pmPoolClient.dispose(); + return 0; + } + + private static PmPoolClient pmPoolClient; + private static RemotePersistentMemoryPool instance; +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/network/pmof/Client.scala b/core/src/main/scala/org/apache/spark/network/pmof/Client.scala index 022c729a..d56e073a 100644 --- a/core/src/main/scala/org/apache/spark/network/pmof/Client.scala +++ b/core/src/main/scala/org/apache/spark/network/pmof/Client.scala @@ -4,9 +4,10 @@ import java.nio.ByteBuffer import java.util.concurrent.ConcurrentHashMap import com.intel.hpnl.core.{Connection, EqService} +import org.apache.spark.internal.Logging import org.apache.spark.shuffle.pmof.PmofShuffleManager -class Client(clientFactory: ClientFactory, val shuffleManager: PmofShuffleManager, con: Connection) { +class Client(clientFactory: ClientFactory, val shuffleManager: PmofShuffleManager, con: Connection) extends Logging { final val outstandingReceiveFetches: ConcurrentHashMap[Long, ReceivedCallback] = new ConcurrentHashMap[Long, ReceivedCallback]() final val outstandingReadFetches: ConcurrentHashMap[Int, ReadCallback] = diff --git a/core/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala b/core/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala index db2c9161..1a332f2e 100644 --- a/core/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/pmof/ClientFactory.scala @@ -5,10 +5,11 @@ import java.nio.ByteBuffer import java.util.concurrent.ConcurrentHashMap import com.intel.hpnl.core._ +import org.apache.spark.internal.Logging import org.apache.spark.shuffle.pmof.PmofShuffleManager import org.apache.spark.util.configuration.pmof.PmofConf -class ClientFactory(pmofConf: PmofConf) { +class ClientFactory(pmofConf: PmofConf) extends Logging { final val eqService = new EqService(pmofConf.clientWorkerNums, pmofConf.clientBufferNums, false).init() private[this] final val cqService = new CqService(eqService).init() private[this] final val clientMap = new ConcurrentHashMap[InetSocketAddress, Client]() @@ -28,6 +29,7 @@ class ClientFactory(pmofConf: PmofConf) { var client = clientMap.get(socketAddress) if (client == null) { ClientFactory.this.synchronized { + logInfo(s"createClient target is ${address}:${port}") client = clientMap.get(socketAddress) if (client == null) { val con = eqService.connect(address, port.toString, 0) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala new file mode 100644 index 00000000..31f70c11 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -0,0 +1,232 @@ +package org.apache.spark.scheduler + +import java.nio.ByteBuffer +import java.io.{Externalizable, ObjectInput, ObjectOutput} + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.roaringbitmap.RoaringBitmap + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.config +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils + +/** + * A [[MapStatus]] implementation that tracks the size of each block. Size for each block is + * represented using a single byte. + * + * @param loc location where the task is being executed. + * @param compressedSizes size of the blocks, indexed by reduce partition id. + */ +private[spark] trait MapStatus { + + /** Location where this task was run. */ + def location: BlockManagerId + + /** + * Estimated size for the reduce block, in bytes. + * + * If a block is non-empty, then this method MUST return a non-zero size. This invariant is + * necessary for correctness, since block fetchers are allowed to skip zero-size blocks. + */ + def getSizeForBlock(reduceId: Int): Long + +} + +private[spark] object MapStatus { + + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { + if (uncompressedSizes.length > 2000) { + HighlyCompressedMapStatus(loc, uncompressedSizes) + } else { + new CompressedMapStatus(loc, uncompressedSizes) + } + } + + private[this] val LOG_BASE = 1.1 + + /** + * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. + * We do this by encoding the log base 1.1 of the size as an integer, which can support + * sizes up to 35 GB with at most 10% error. + */ + def compressSize(size: Long): Byte = { + if (size == 0) { + 0 + } else if (size <= 1L) { + 1 + } else { + math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte + } + } + + /** + * Decompress an 8-bit encoded block size, using the reverse operation of compressSize. + */ + def decompressSize(compressedSize: Byte): Long = { + if (compressedSize == 0) { + 0 + } else { + math.pow(LOG_BASE, compressedSize & 0xFF).toLong + } + } +} + +/** + * A [[MapStatus]] implementation that tracks the size of each block. Size for each block is + * represented using a single byte. + * + * @param loc location where the task is being executed. + * @param compressedSizes size of the blocks, indexed by reduce partition id. + */ +private[spark] class CompressedMapStatus( + private[this] var loc: BlockManagerId, + private[this] var compressedSizes: Array[Byte]) + extends MapStatus + with Externalizable { + + protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + + def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { + this(loc, uncompressedSizes.map(MapStatus.compressSize)) + } + + override def location: BlockManagerId = loc + + override def getSizeForBlock(reduceId: Int): Long = { + MapStatus.decompressSize(compressedSizes(reduceId)) + } + + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { + loc.writeExternal(out) + out.writeInt(compressedSizes.length) + out.write(compressedSizes) + } + + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { + loc = BlockManagerId(in) + val len = in.readInt() + compressedSizes = new Array[Byte](len) + in.readFully(compressedSizes) + } +} + +/** + * A [[MapStatus]] implementation that stores the accurate size of huge blocks, which are larger + * than spark.shuffle.accurateBlockThreshold. It stores the average size of other non-empty blocks, + * plus a bitmap for tracking which blocks are empty. + * + * @param loc location where the task is being executed + * @param numNonEmptyBlocks the number of non-empty blocks + * @param emptyBlocks a bitmap tracking which blocks are empty + * @param avgSize average size of the non-empty and non-huge blocks + * @param hugeBlockSizes sizes of huge blocks by their reduceId. + */ +private[spark] class HighlyCompressedMapStatus private ( + private[this] var loc: BlockManagerId, + private[this] var numNonEmptyBlocks: Int, + private[this] var emptyBlocks: RoaringBitmap, + private[this] var avgSize: Long, + private var hugeBlockSizes: Map[Int, Byte]) + extends MapStatus + with Externalizable { + + // loc could be null when the default constructor is called during deserialization + require( + loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, + "Average size can only be zero for map stages that produced no output") + + protected def this() = this(null, -1, null, -1, null) // For deserialization only + + override def location: BlockManagerId = loc + + override def getSizeForBlock(reduceId: Int): Long = { + assert(hugeBlockSizes != null) + if (emptyBlocks.contains(reduceId)) { + 0 + } else { + hugeBlockSizes.get(reduceId) match { + case Some(size) => MapStatus.decompressSize(size) + case None => avgSize + } + } + } + + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { + loc.writeExternal(out) + emptyBlocks.writeExternal(out) + out.writeLong(avgSize) + out.writeInt(hugeBlockSizes.size) + hugeBlockSizes.foreach { kv => + out.writeInt(kv._1) + out.writeByte(kv._2) + } + } + + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { + loc = BlockManagerId(in) + emptyBlocks = new RoaringBitmap() + emptyBlocks.readExternal(in) + avgSize = in.readLong() + val count = in.readInt() + val hugeBlockSizesArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]() + (0 until count).foreach { _ => + val block = in.readInt() + val size = in.readByte() + hugeBlockSizesArray += Tuple2(block, size) + } + hugeBlockSizes = hugeBlockSizesArray.toMap + } +} + +private[spark] object HighlyCompressedMapStatus { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + // We must keep track of which blocks are empty so that we don't report a zero-sized + // block as being non-empty (or vice-versa) when using the average block size. + var i = 0 + var numNonEmptyBlocks: Int = 0 + var numSmallBlocks: Int = 0 + var totalSmallBlockSize: Long = 0 + // From a compression standpoint, it shouldn't matter whether we track empty or non-empty + // blocks. From a performance standpoint, we benefit from tracking empty blocks because + // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions. + val emptyBlocks = new RoaringBitmap() + val totalNumBlocks = uncompressedSizes.length + val threshold = Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD)) + .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get) + val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]() + while (i < totalNumBlocks) { + val size = uncompressedSizes(i) + if (size > 0) { + numNonEmptyBlocks += 1 + // Huge blocks are not included in the calculation for average size, thus size for smaller + // blocks is more accurate. + if (size < threshold) { + totalSmallBlockSize += size + numSmallBlocks += 1 + } else { + hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i))) + } + } else { + emptyBlocks.add(i) + } + i += 1 + } + val avgSize = if (numSmallBlocks > 0) { + totalSmallBlockSize / numSmallBlocks + } else { + 0 + } + emptyBlocks.trim() + emptyBlocks.runOptimize() + new HighlyCompressedMapStatus( + loc, + numNonEmptyBlocks, + emptyBlocks, + avgSize, + hugeBlockSizesArray.toMap) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/pmof/UnCompressedMapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/pmof/UnCompressedMapStatus.scala new file mode 100644 index 00000000..4591a673 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/pmof/UnCompressedMapStatus.scala @@ -0,0 +1,66 @@ +package org.apache.spark.scheduler.pmof + +import java.nio.ByteBuffer +import java.io.{Externalizable, ObjectInput, ObjectOutput} + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.roaringbitmap.RoaringBitmap + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.config +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.internal.Logging + +private[spark] class UnCompressedMapStatus( + private[this] var loc: BlockManagerId, + private[this] var data: Array[Byte]) + extends MapStatus + with Externalizable + with Logging { + + protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + val step = 8 + + def this(loc: BlockManagerId, sizes: Array[Long]) { + this(loc, sizes.flatMap(UnCompressedMapStatus.longToBytes)) + } + + override def location: BlockManagerId = loc + + override def getSizeForBlock(reduceId: Int): Long = { + val start = reduceId * step + UnCompressedMapStatus.bytesToLong(data.slice(start, start + step)) + } + + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { + loc.writeExternal(out) + out.writeInt(data.length) + out.write(data) + } + + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { + loc = BlockManagerId(in) + val len = in.readInt() + data = new Array[Byte](len) + in.readFully(data) + } +} + +object UnCompressedMapStatus { + def longToBytes(x: Long): Array[Byte] = { + val buffer = ByteBuffer.allocate(8) + buffer.putLong(x) + buffer.array() + } + + def bytesToLong(bytes: Array[Byte]): Long = { + val buffer = ByteBuffer.allocate(8); + buffer.put(bytes) + buffer.flip() + buffer.getLong() + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala index edc0d993..74753d65 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/MetadataResolver.scala @@ -9,6 +9,7 @@ import java.util.zip.{Deflater, DeflaterOutputStream, Inflater, InflaterInputStr import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{ByteBufferInputStream, ByteBufferOutputStream, Input, Output} +import org.apache.spark.internal.Logging import org.apache.spark.SparkEnv import org.apache.spark.network.pmof._ import org.apache.spark.shuffle.IndexShuffleBlockResolver @@ -27,7 +28,7 @@ import scala.util.control.{Breaks, ControlThrowable} * and can be used by executor to send metadata to driver * @param pmofConf */ -class MetadataResolver(pmofConf: PmofConf) { +class MetadataResolver(pmofConf: PmofConf) extends Logging { private[this] val blockManager = SparkEnv.get.blockManager private[this] val blockMap: ConcurrentHashMap[String, ShuffleBuffer] = new ConcurrentHashMap[String, ShuffleBuffer]() private[this] val blockInfoMap = mutable.HashMap.empty[String, ArrayBuffer[ShuffleBlockInfo]] @@ -44,7 +45,7 @@ class MetadataResolver(pmofConf: PmofConf) { * @param rkey */ def pushPmemBlockInfo(shuffleId: Int, mapId: Int, dataAddressMap: mutable.HashMap[Int, Array[(Long, Int)]], rkey: Long): Unit = { - val buffer: Array[Byte] = new Array[Byte](pmofConf.reduce_serializer_buffer_size.toInt) + val buffer: Array[Byte] = new Array[Byte](pmofConf.map_serializer_buffer_size.toInt) var output = new Output(buffer) val bufferArray = new ArrayBuffer[ByteBuffer]() @@ -60,7 +61,7 @@ class MetadataResolver(pmofConf: PmofConf) { blockBuffer.flip() bufferArray += blockBuffer output.close() - val new_buffer = new Array[Byte](pmofConf.reduce_serializer_buffer_size.toInt) + val new_buffer = new Array[Byte](pmofConf.map_serializer_buffer_size.toInt) output = new Output(new_buffer) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala index 04e7b03e..e50e5f61 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriter.scala @@ -21,6 +21,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.network.pmof.PmofTransferService import org.apache.spark.scheduler.MapStatus +import org.apache.spark.scheduler.pmof.UnCompressedMapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} import org.apache.spark.storage._ import org.apache.spark.util.collection.pmof.PmemExternalSorter @@ -32,16 +33,18 @@ import org.apache.spark.storage.BlockManager import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -private[spark] class PmemShuffleWriter[K, V, C](shuffleBlockResolver: PmemShuffleBlockResolver, - metadataResolver: MetadataResolver, - blockManager: BlockManager, - serializerManager: SerializerManager, - handle: BaseShuffleHandle[K, V, C], - mapId: Int, - context: TaskContext, - conf: SparkConf, - pmofConf: PmofConf) - extends ShuffleWriter[K, V] with Logging { +private[spark] class PmemShuffleWriter[K, V, C]( + shuffleBlockResolver: PmemShuffleBlockResolver, + metadataResolver: MetadataResolver, + blockManager: BlockManager, + serializerManager: SerializerManager, + handle: BaseShuffleHandle[K, V, C], + mapId: Int, + context: TaskContext, + conf: SparkConf, + pmofConf: PmofConf) + extends ShuffleWriter[K, V] + with Logging { private[this] val dep = handle.dependency private[this] var mapStatus: MapStatus = _ private[this] val stageId = dep.shuffleId @@ -60,27 +63,35 @@ private[spark] class PmemShuffleWriter[K, V, C](shuffleBlockResolver: PmemShuffl private var stopping = false /** - * Call PMDK to write data to persistent memory - * Original Spark writer will do write and mergesort in this function, - * while by using pmdk, we can do that once since pmdk supports transaction. - */ + * Call PMDK to write data to persistent memory + * Original Spark writer will do write and mergesort in this function, + * while by using pmdk, we can do that once since pmdk supports transaction. + */ override def write(records: Iterator[Product2[K, V]]): Unit = { - val PmemBlockOutputStreamArray = (0 until numPartitions).toArray.map(partitionId => - new PmemBlockOutputStream( - context.taskMetrics(), - ShuffleBlockId(stageId, mapId, partitionId), - serializerManager, - dep.serializer, - conf, - pmofConf, - numMaps, - numPartitions)) + logInfo(" write start") + val PmemBlockOutputStreamArray = (0 until numPartitions).toArray.map( + partitionId => + new PmemBlockOutputStream( + context.taskMetrics(), + ShuffleBlockId(stageId, mapId, partitionId), + serializerManager, + dep.serializer, + conf, + pmofConf, + numMaps, + numPartitions)) if (dep.mapSideCombine) { // do aggregation if (dep.aggregator.isDefined) { - sorter = new PmemExternalSorter[K, V, C](context, handle, pmofConf, dep.aggregator, Some(dep.partitioner), - dep.keyOrdering, dep.serializer) - sorter.setPartitionByteBufferArray(PmemBlockOutputStreamArray) + sorter = new PmemExternalSorter[K, V, C]( + context, + handle, + pmofConf, + dep.aggregator, + Some(dep.partitioner), + dep.keyOrdering, + dep.serializer) + sorter.setPartitionByteBufferArray(PmemBlockOutputStreamArray) sorter.insertAll(records) sorter.forceSpillToPmem() } else { @@ -107,31 +118,50 @@ private[spark] class PmemShuffleWriter[K, V, C](shuffleBlockResolver: PmemShuffl spilledPartition += 1 } val pmemBlockInfoMap = mutable.HashMap.empty[Int, Array[(Long, Int)]] - var output_str : String = "" + var output_str: String = "" + var rKey: Int = 0 for (i <- spillPartitionArray) { - if (pmofConf.enableRdma) { - pmemBlockInfoMap(i) = PmemBlockOutputStreamArray(i).getPartitionMeta().map { info => (info._1, info._2) } + if (pmofConf.enableRdma && !pmofConf.enableRemotePmem) { + pmemBlockInfoMap(i) = PmemBlockOutputStreamArray(i) + .getPartitionMeta() + .map(info => { + if (rKey == 0) { + rKey = info._3 + } + //logInfo(s"${ShuffleBlockId(stageId, mapId, i)} [${rKey}]${info._1}:${info._2}") + (info._1, info._2) + }) } - partitionLengths(i) = PmemBlockOutputStreamArray(i).size - output_str += "\tPartition " + i + ": " + partitionLengths(i) + ", records: " + PmemBlockOutputStreamArray(i).records + "\n" + partitionLengths(i) = PmemBlockOutputStreamArray(i).getSize + output_str += "\tPartition " + i + ": " + partitionLengths(i) + ", records: " + PmemBlockOutputStreamArray( + i).records + "\n" } + //logWarning(output_str) for (i <- 0 until numPartitions) { PmemBlockOutputStreamArray(i).close() } val shuffleServerId = blockManager.shuffleServerId - if (pmofConf.enableRdma) { - val rkey = PmemBlockOutputStreamArray(0).getRkey() - metadataResolver.pushPmemBlockInfo(stageId, mapId, pmemBlockInfoMap, rkey) + if (pmofConf.enableRemotePmem) { + mapStatus = new UnCompressedMapStatus(shuffleServerId, partitionLengths) + //mapStatus = MapStatus(shuffleServerId, partitionLengths) + } else if (!pmofConf.enableRdma) { + mapStatus = MapStatus(shuffleServerId, partitionLengths) + } else { + metadataResolver.pushPmemBlockInfo(stageId, mapId, pmemBlockInfoMap, rKey) val blockManagerId: BlockManagerId = - BlockManagerId(shuffleServerId.executorId, PmofTransferService.shuffleNodesMap(shuffleServerId.host), - PmofTransferService.getTransferServiceInstance(pmofConf, blockManager).port, shuffleServerId.topologyInfo) + BlockManagerId( + shuffleServerId.executorId, + PmofTransferService.shuffleNodesMap(shuffleServerId.host), + PmofTransferService.getTransferServiceInstance(pmofConf, blockManager).port, + shuffleServerId.topologyInfo) mapStatus = MapStatus(blockManagerId, partitionLengths) - } else { - mapStatus = MapStatus(shuffleServerId, partitionLengths) } + logWarning( + s"shuffle_${stageId}_${mapId}_0 size is ${partitionLengths(0)}, decompressed length is ${mapStatus + .getSizeForBlock(0)}") } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala index 4e61c4e5..2bb963ce 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/PmofShuffleManager.scala @@ -16,22 +16,32 @@ private[spark] class PmofShuffleManager(conf: SparkConf) extends ShuffleManager } private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() - private[this] val pmofConf = new PmofConf(conf) + private[this] val pmofConf = PmofConf.getConf(conf) var metadataResolver: MetadataResolver = _ - override def registerShuffle[K, V, C](shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + override def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { val env: SparkEnv = SparkEnv.get metadataResolver = MetadataResolver.getMetadataResolver(pmofConf) if (pmofConf.enableRdma) { - PmofTransferService.getTransferServiceInstance(pmofConf: PmofConf, env.blockManager, this, isDriver = true) + PmofTransferService.getTransferServiceInstance( + pmofConf: PmofConf, + env.blockManager, + this, + isDriver = true) } new BaseShuffleHandle(shuffleId, numMaps, dependency) } - override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] = { + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext): ShuffleWriter[K, V] = { assert(handle.isInstanceOf[BaseShuffleHandle[_, _, _]]) val env: SparkEnv = SparkEnv.get @@ -47,28 +57,69 @@ private[spark] class PmofShuffleManager(conf: SparkConf) extends ShuffleManager } if (pmofConf.enablePmem) { - new PmemShuffleWriter(shuffleBlockResolver.asInstanceOf[PmemShuffleBlockResolver], metadataResolver, blockManager, serializerManager, - handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context, env.conf, pmofConf) + new PmemShuffleWriter( + shuffleBlockResolver.asInstanceOf[PmemShuffleBlockResolver], + metadataResolver, + blockManager, + serializerManager, + handle.asInstanceOf[BaseShuffleHandle[K, V, _]], + mapId, + context, + env.conf, + pmofConf) } else { - new BaseShuffleWriter(shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], metadataResolver, blockManager, serializerManager, - handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context, pmofConf) + new BaseShuffleWriter( + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + metadataResolver, + blockManager, + serializerManager, + handle.asInstanceOf[BaseShuffleHandle[K, V, _]], + mapId, + context, + pmofConf) } } - override def getReader[K, C](handle: _root_.org.apache.spark.shuffle.ShuffleHandle, startPartition: Int, endPartition: Int, context: _root_.org.apache.spark.TaskContext): _root_.org.apache.spark.shuffle.ShuffleReader[K, C] = { - if (pmofConf.enableRdma) { - new RdmaShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startPartition, endPartition, context, pmofConf) + override def getReader[K, C]( + handle: _root_.org.apache.spark.shuffle.ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: _root_.org.apache.spark.TaskContext) + : _root_.org.apache.spark.shuffle.ShuffleReader[K, C] = { + val env: SparkEnv = SparkEnv.get + if (pmofConf.enableRemotePmem) { + new RpmpShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + startPartition, + endPartition, + context, + pmofConf) + + } else if (pmofConf.enableRdma) { + metadataResolver = MetadataResolver.getMetadataResolver(pmofConf) + PmofTransferService.getTransferServiceInstance(pmofConf, env.blockManager, this) + new RdmaShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + startPartition, + endPartition, + context, + pmofConf) } else { new BaseShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context, pmofConf) + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + startPartition, + endPartition, + context, + pmofConf) } } override def unregisterShuffle(shuffleId: Int): Boolean = { Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps => (0 until numMaps).foreach { mapId => - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver].removeDataByMap(shuffleId, mapId) + shuffleBlockResolver + .asInstanceOf[IndexShuffleBlockResolver] + .removeDataByMap(shuffleId, mapId) } } true diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala index 92af5417..3a044f13 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/RdmaShuffleReader.scala @@ -29,6 +29,7 @@ private[spark] class RdmaShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], private[this] val dep = handle.dependency private[this] val serializerInstance: SerializerInstance = dep.serializer.newInstance() private[this] val enable_pmem: Boolean = SparkEnv.get.conf.getBoolean("spark.shuffle.pmof.enable_pmem", defaultValue = true) + private[this] val enable_rpmp: Boolean = SparkEnv.get.conf.getBoolean("spark.shuffle.pmof.enable_remote_pmem", defaultValue = true) /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { @@ -86,7 +87,7 @@ private[spark] class RdmaShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C], // Sort the output if there is a sort ordering defined. dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => - if (enable_pmem) { + if (enable_pmem && !enable_rpmp) { val sorter = new PmemExternalSorter[K, C, C](context, handle, pmofConf, ordering = Some(keyOrd), serializer = dep.serializer) sorter.insertAll(aggregatedIter) CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) diff --git a/core/src/main/scala/org/apache/spark/shuffle/pmof/RpmpShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/pmof/RpmpShuffleReader.scala new file mode 100644 index 00000000..97ffdaea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/pmof/RpmpShuffleReader.scala @@ -0,0 +1,109 @@ +package org.apache.spark.shuffle.pmof + +import org.apache.spark._ +import org.apache.spark.internal.{Logging, config} +import org.apache.spark.network.pmof.PmofTransferService +import org.apache.spark.serializer.{SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.storage.BlockManager +import org.apache.spark.storage.pmof._ +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.collection.pmof.PmemExternalSorter +import org.apache.spark.util.configuration.pmof.PmofConf + +/** + * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by + * requesting them from other nodes' block stores. + */ +private[spark] class RpmpShuffleReader[K, C]( + handle: BaseShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext, + pmofConf: PmofConf, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) + extends ShuffleReader[K, C] + with Logging { + + private[this] val dep = handle.dependency + private[this] val serializerInstance: SerializerInstance = dep.serializer.newInstance() + private[this] val enable_pmem: Boolean = + SparkEnv.get.conf.getBoolean("spark.shuffle.pmof.enable_pmem", defaultValue = true) + private[this] val enable_rpmp: Boolean = + SparkEnv.get.conf.getBoolean("spark.shuffle.pmof.enable_remote_pmem", defaultValue = true) + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val wrappedStreams: RpmpShuffleBlockFetcherIterator = new RpmpShuffleBlockFetcherIterator( + context, + blockManager, + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, + SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", defaultValue = true), + pmofConf) + + // Create a key/value iterator for each stream + val recordIter = wrappedStreams.flatMap { + case (_, wrappedStream) => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](recordIter.map { + record => + readMetrics.incRecordsRead(1) + record + }, context.taskMetrics().mergeShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + if (dep.mapSideCombine) { + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + } else { + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) + } + } else { + require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] + } + + // Sort the output if there is a sort ordering defined. + dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + val sorter = + new ExternalSorter[K, C, C]( + context, + ordering = Some(keyOrd), + serializer = dep.serializer) + sorter.insertAll(aggregatedIter) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]]( + sorter.iterator, + sorter.stop()) + case None => + aggregatedIter + } + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/NettyByteBufferPool.scala b/core/src/main/scala/org/apache/spark/storage/pmof/NettyByteBufferPool.scala index dfdd0d5c..5419c3ac 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/NettyByteBufferPool.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/NettyByteBufferPool.scala @@ -5,23 +5,19 @@ import io.netty.buffer.{ByteBuf, PooledByteBufAllocator, UnpooledByteBufAllocato import scala.collection.mutable.Stack import java.lang.RuntimeException import org.apache.spark.internal.Logging +import scala.collection.mutable.Map object NettyByteBufferPool extends Logging { private val allocatedBufRenCnt: AtomicLong = new AtomicLong(0) - private val allocatedBytes: AtomicLong = new AtomicLong(0) - private val peakAllocatedBytes: AtomicLong = new AtomicLong(0) - private val unpooledAllocatedBytes: AtomicLong = new AtomicLong(0) - private var fixedBufferSize: Long = 0 - private val allocatedBufferPool: Stack[ByteBuf] = Stack[ByteBuf]() + private val allocatedBytes: AtomicLong = new AtomicLong(0) + private val peakAllocatedBytes: AtomicLong = new AtomicLong(0) + private val unpooledAllocatedBytes: AtomicLong = new AtomicLong(0) + private val bufferMap: Map[ByteBuf, Long] = Map() + private val allocatedBufferPool: Stack[ByteBuf] = Stack[ByteBuf]() private var reachRead = false private val allocator = UnpooledByteBufAllocator.DEFAULT def allocateNewBuffer(bufSize: Int): ByteBuf = synchronized { - if (fixedBufferSize == 0) { - fixedBufferSize = bufSize - } else if (bufSize > fixedBufferSize) { - throw new RuntimeException(s"allocateNewBuffer, expected size is ${fixedBufferSize}, actual size is ${bufSize}") - } allocatedBufRenCnt.getAndIncrement() allocatedBytes.getAndAdd(bufSize) if (allocatedBytes.get > peakAllocatedBytes.get) { @@ -33,9 +29,11 @@ object NettyByteBufferPool extends Logging { } else { allocator.directBuffer(bufSize, bufSize) }*/ - allocator.directBuffer(bufSize, bufSize) + val byteBuf = allocator.directBuffer(bufSize, bufSize) + bufferMap += (byteBuf -> bufSize) + byteBuf } catch { - case e : Throwable => + case e: Throwable => logError(s"allocateNewBuffer size is ${bufSize}") throw e } @@ -43,7 +41,8 @@ object NettyByteBufferPool extends Logging { def releaseBuffer(buf: ByteBuf): Unit = synchronized { allocatedBufRenCnt.getAndDecrement() - allocatedBytes.getAndAdd(0 - fixedBufferSize) + val bufSize = bufferMap(buf) + allocatedBytes.getAndAdd(bufSize) buf.clear() //allocatedBufferPool.push(buf) buf.release(buf.refCnt()) diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/NioManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/pmof/NioManagedBuffer.scala new file mode 100644 index 00000000..8118232e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/pmof/NioManagedBuffer.scala @@ -0,0 +1,75 @@ +package org.apache.spark.storage.pmof + +import java.io.IOException +import java.io.InputStream +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger + +import io.netty.buffer.ByteBuf +import io.netty.buffer.ByteBufInputStream +import io.netty.buffer.Unpooled +import org.apache.commons.lang3.builder.ToStringBuilder +import org.apache.commons.lang3.builder.ToStringStyle +import org.apache.spark.network.buffer.ManagedBuffer + +/** + * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}. + */ +class NioManagedBuffer(bufSize: Int) extends ManagedBuffer { + private val buf: ByteBuf = NettyByteBufferPool.allocateNewBuffer(bufSize) + private val byteBuffer: ByteBuffer = buf.nioBuffer(0, bufSize) + private val refCount = new AtomicInteger(1) + private var in: InputStream = _ + private var nettyObj: ByteBuf = _ + + def getByteBuf: ByteBuf = buf + + def resize(size: Int): Unit = { + byteBuffer.limit(size) + } + + override def size: Long = { + byteBuffer.remaining() + } + + override def nioByteBuffer: ByteBuffer = { + byteBuffer + } + + override def createInputStream: InputStream = { + nettyObj = Unpooled.wrappedBuffer(byteBuffer) + in = new ByteBufInputStream(nettyObj) + in + } + + override def retain: ManagedBuffer = { + refCount.incrementAndGet() + return this + } + + override def release: ManagedBuffer = { + if (refCount.decrementAndGet() == 0) { + if (in != null) { + in.close() + } + if (nettyObj != null) { + nettyObj.release() + } + NettyByteBufferPool.releaseBuffer(buf) + } + return this + } + + override def convertToNetty: Object = { + if (nettyObj == null) { + nettyObj = Unpooled.wrappedBuffer(byteBuffer) + } + nettyObj + } + + override def toString: String = { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("buf", buf) + .toString(); + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockInputStream.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockInputStream.scala index bde6ad44..3fb19546 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockInputStream.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockInputStream.scala @@ -2,14 +2,22 @@ package org.apache.spark.storage.pmof import com.esotericsoftware.kryo.KryoException import org.apache.spark.SparkEnv -import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerInstance, SerializerManager} +import org.apache.spark.serializer.{ + DeserializationStream, + Serializer, + SerializerInstance, + SerializerManager +} import org.apache.spark.storage.BlockId -class PmemBlockInputStream[K, C](pmemBlockOutputStream: PmemBlockOutputStream, serializer: Serializer) { +class PmemBlockInputStream[K, C]( + pmemBlockOutputStream: PmemBlockOutputStream, + serializer: Serializer) { val blockId: BlockId = pmemBlockOutputStream.getBlockId() val serializerManager: SerializerManager = SparkEnv.get.serializerManager val serInstance: SerializerInstance = serializer.newInstance() - val persistentMemoryWriter: PersistentMemoryHandler = PersistentMemoryHandler.getPersistentMemoryHandler + val persistentMemoryWriter: PersistentMemoryHandler = + PersistentMemoryHandler.getPersistentMemoryHandler var pmemInputStream: PmemInputStream = new PmemInputStream(persistentMemoryWriter, blockId.name) val wrappedStream = serializerManager.wrapStream(blockId, pmemInputStream) var inObjStream: DeserializationStream = serInstance.deserializeStream(wrappedStream) @@ -30,7 +38,7 @@ class PmemBlockInputStream[K, C](pmemBlockOutputStream: PmemBlockOutputStream, s close() return null } - try{ + try { val k = inObjStream.readObject().asInstanceOf[K] val c = inObjStream.readObject().asInstanceOf[C] indexInBatch += 1 @@ -39,8 +47,7 @@ class PmemBlockInputStream[K, C](pmemBlockOutputStream: PmemBlockOutputStream, s } (k, c) } catch { - case ex: KryoException => { - } + case ex: KryoException => {} sys.exit(0) } } diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockOutputStream.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockOutputStream.scala index ee303b0b..65748e69 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockOutputStream.scala @@ -1,28 +1,30 @@ package org.apache.spark.storage.pmof -import org.apache.spark.storage._ -import org.apache.spark.serializer._ +import java.io.File + +import org.apache.spark.SparkConf import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging -import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.serializer._ +import org.apache.spark.storage._ import org.apache.spark.util.Utils -import java.io.{File, OutputStream} - import org.apache.spark.util.configuration.pmof.PmofConf import scala.collection.mutable.ArrayBuffer -class PmemBlockId (stageId: Int, tmpId: Int) extends ShuffleBlockId(stageId, 0, tmpId) { +class PmemBlockId(stageId: Int, tmpId: Int) extends ShuffleBlockId(stageId, 0, tmpId) { override def name: String = "reduce_spill_" + stageId + "_" + tmpId + override def isShuffle: Boolean = false } object PmemBlockId { private var tempId: Int = 0 + def getTempBlockId(stageId: Int): PmemBlockId = synchronized { val cur_tempId = tempId tempId += 1 - new PmemBlockId (stageId, cur_tempId) + new PmemBlockId(stageId, cur_tempId) } } @@ -34,8 +36,16 @@ private[spark] class PmemBlockOutputStream( conf: SparkConf, pmofConf: PmofConf, numMaps: Int = 0, - numPartitions: Int = 1 -) extends DiskBlockObjectWriter(new File(Utils.getConfiguredLocalDirs(conf).toList(0) + "/null"), null, null, 0, true, null, null) with Logging { + numPartitions: Int = 1) + extends DiskBlockObjectWriter( + new File(Utils.getConfiguredLocalDirs(conf).toList(0) + "/null"), + null, + null, + 0, + true, + null, + null) + with Logging { var size: Int = 0 var records: Int = 0 @@ -44,15 +54,31 @@ private[spark] class PmemBlockOutputStream( var spilled: Boolean = false var partitionMeta: Array[(Long, Int, Int)] = _ val root_dir = Utils.getConfiguredLocalDirs(conf).toList(0) - - val persistentMemoryWriter: PersistentMemoryHandler = PersistentMemoryHandler.getPersistentMemoryHandler(pmofConf, - root_dir, pmofConf.path_list, blockId.name, pmofConf.maxPoolSize) + var persistentMemoryWriter: PersistentMemoryHandler = _ + var remotePersistentMemoryPool: RemotePersistentMemoryPool = _ + + if (!pmofConf.enableRemotePmem) { + persistentMemoryWriter = PersistentMemoryHandler.getPersistentMemoryHandler( + pmofConf, + root_dir, + pmofConf.path_list, + blockId.name, + pmofConf.maxPoolSize) + } else { + remotePersistentMemoryPool = + RemotePersistentMemoryPool.getInstance(pmofConf.rpmpHost, pmofConf.rpmpPort) + } //disable metadata updating by default //persistentMemoryWriter.updateShuffleMeta(blockId.name) val pmemOutputStream: PmemOutputStream = new PmemOutputStream( - persistentMemoryWriter, numPartitions, blockId.name, numMaps, (pmofConf.spill_throttle.toInt + 1024)) + persistentMemoryWriter, + remotePersistentMemoryPool, + numPartitions, + blockId.name, + numMaps, + (pmofConf.spill_throttle.toInt + 1024)) val serInstance = serializer.newInstance() val bs = serializerManager.wrapStream(blockId, pmemOutputStream) var objStream: SerializationStream = serInstance.serializeStream(bs) @@ -62,7 +88,7 @@ private[spark] class PmemBlockOutputStream( objStream.writeValue(value) records += 1 recordsPerBlock += 1 - if (blockId.isShuffle == true) { + if (blockId.isShuffle == true) { taskMetrics.shuffleWriteMetrics.incRecordsWritten(1) } maybeSpill() @@ -92,7 +118,7 @@ private[spark] class PmemBlockOutputStream( if (bufSize > 0) { recordsArray += recordsPerBlock recordsPerBlock = 0 - size += bufSize + size = bufSize if (blockId.isShuffle == true) { val writeMetrics = taskMetrics.shuffleWriteMetrics @@ -111,10 +137,31 @@ private[spark] class PmemBlockOutputStream( spilled } + def getPartitionBlockInfo(res_array: Array[Long]): Array[(Long, Int, Int)] = { + var i = -3 + var blockInfo = Array.ofDim[(Long, Int)]((res_array.length) / 3) + blockInfo.map { x => + i += 3; + (res_array(i), res_array(i + 1).toInt, res_array(i + 2).toInt) + } + } + def getPartitionMeta(): Array[(Long, Int, Int)] = { if (partitionMeta == null) { var i = -1 - partitionMeta = persistentMemoryWriter.getPartitionBlockInfo(blockId.name).map{ x=> i+=1; (x._1, x._2, recordsArray(i))} + partitionMeta = if (!pmofConf.enableRemotePmem) { + persistentMemoryWriter + .getPartitionBlockInfo(blockId.name) + .map(x => { + i += 1 + (x._1, x._2, getRkey().toInt) + }) + } else { + getPartitionBlockInfo(remotePersistentMemoryPool.getMeta(blockId.name)).map(x => { + i += 1 + (x._1, x._2, x._3) + }) + } } partitionMeta } @@ -128,7 +175,7 @@ private[spark] class PmemBlockOutputStream( } def getTotalRecords(): Long = { - records + records } def getSize(): Long = { diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala b/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala index 644a8333..4a896b45 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala @@ -5,14 +5,17 @@ import java.nio.ByteBuffer import io.netty.buffer.{ByteBuf, PooledByteBufAllocator} import org.apache.spark.internal.Logging +import java.io.IOException class PmemOutputStream( - persistentMemoryWriter: PersistentMemoryHandler, - numPartitions: Int, - blockId: String, - numMaps: Int, - bufferSize: Int - ) extends OutputStream with Logging { + persistentMemoryWriter: PersistentMemoryHandler, + remotePersistentMemoryPool: RemotePersistentMemoryPool, + numPartitions: Int, + blockId: String, + numMaps: Int, + bufferSize: Int) + extends OutputStream + with Logging { var set_clean = true var is_closed = false @@ -34,7 +37,21 @@ class PmemOutputStream( override def flush(): Unit = { if (bufferRemainingSize > 0) { - persistentMemoryWriter.setPartition(numPartitions, blockId, byteBuffer, bufferRemainingSize, set_clean) + if (persistentMemoryWriter != null) { + persistentMemoryWriter.setPartition( + numPartitions, + blockId, + byteBuffer, + bufferRemainingSize, + set_clean) + } else { + logDebug( + s"[put Remote Block] target is ${RemotePersistentMemoryPool.getHost}:${RemotePersistentMemoryPool.getPort}, " + + s"${blockId} ${bufferRemainingSize}") + if (remotePersistentMemoryPool.put(blockId, byteBuffer, bufferRemainingSize) == -1) { + throw new IOException("RPMem put failed with time out.") + } + } bufferFlushedSize += bufferRemainingSize bufferRemainingSize = 0 } @@ -48,7 +65,7 @@ class PmemOutputStream( } def remainingSize(): Int = { - bufferRemainingSize + bufferRemainingSize } def reset(): Unit = { @@ -61,7 +78,7 @@ class PmemOutputStream( if (!is_closed) { flush() reset() - NettyByteBufferPool.releaseBuffer(buf) + NettyByteBufferPool.releaseBuffer(buf) is_closed = true } } diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/RdmaShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/pmof/RdmaShuffleBlockFetcherIterator.scala index d3c3e064..334d0ab4 100644 --- a/core/src/main/scala/org/apache/spark/storage/pmof/RdmaShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/pmof/RdmaShuffleBlockFetcherIterator.scala @@ -28,6 +28,7 @@ import org.apache.spark.network.pmof._ import org.apache.spark.network.shuffle.{ShuffleClient, TempFileManager} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage._ +import org.apache.spark.util.configuration.pmof.PmofConf import org.apache.spark.{SparkException, TaskContext} import scala.collection.mutable @@ -36,88 +37,100 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future /** - * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block - * manager. For remote blocks, it fetches them using the provided BlockTransferService. - * - * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks - * in a pipelined fashion as they are received. - * - * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid - * using too much memory. - * - * @param context [[TaskContext]], used for metrics update - * @param shuffleClient [[ShuffleClient]] for fetching remote blocks - * @param blockManager [[BlockManager]] for reading local blocks - * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. - * For each block we also require the size (in bytes as a long field) in - * order to throttle the memory usage. - * @param streamWrapper A function to wrap the returned input stream. - * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. - * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. - * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point - * for a given remote host:port. - * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. - * @param detectCorrupt whether to detect any corruption in fetched blocks. - */ -private[spark] -final class RdmaShuffleBlockFetcherIterator(context: TaskContext, - shuffleClient: ShuffleClient, - blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - streamWrapper: (BlockId, InputStream) => InputStream, - maxBytesInFlight: Long, - maxReqsInFlight: Int, - maxBlocksInFlightPerAddress: Int, - maxReqSizeShuffleToMem: Long, - detectCorrupt: Boolean) - extends Iterator[(BlockId, InputStream)] with TempFileManager with Logging { + * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block + * manager. For remote blocks, it fetches them using the provided BlockTransferService. + * + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks + * in a pipelined fashion as they are received. + * + * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid + * using too much memory. + * + * @param context [[TaskContext]], used for metrics update + * @param shuffleClient [[ShuffleClient]] for fetching remote blocks + * @param blockManager [[BlockManager]] for reading local blocks + * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. + * For each block we also require the size (in bytes as a long field) in + * order to throttle the memory usage. + * @param streamWrapper A function to wrap the returned input stream. + * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. + * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. + * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point + * for a given remote host:port. + * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. + * @param detectCorrupt whether to detect any corruption in fetched blocks. + */ +private[spark] final class RdmaShuffleBlockFetcherIterator( + context: TaskContext, + shuffleClient: ShuffleClient, + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + maxReqSizeShuffleToMem: Long, + detectCorrupt: Boolean) + extends Iterator[(BlockId, InputStream)] + with TempFileManager + with Logging { import RdmaShuffleBlockFetcherIterator._ /** Local blocks to fetch, excluding zero-sized blocks. */ private[this] val localBlocks = new ArrayBuffer[BlockId]() + /** - * A queue to hold our results. This turns the asynchronous model provided by - * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). - */ + * A queue to hold our results. This turns the asynchronous model provided by + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). + */ private[this] val results = new LinkedBlockingQueue[FetchResult] - private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() + private[this] val shuffleMetrics = + context.taskMetrics().createTempShuffleReadMetrics() + /** - * A set to store the files used for shuffling remote huge blocks. Files in this set will be - * deleted when cleanup. This is a layer of defensiveness against disk file leaks. - */ + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ @GuardedBy("this") private[this] val shuffleFilesSet = mutable.HashSet[File]() - private[this] val remoteRdmaRequestQueue = new LinkedBlockingQueue[RdmaRequest]() + private[this] val remoteRdmaRequestQueue = + new LinkedBlockingQueue[RdmaRequest]() + /** - * Total number of blocks to fetch. This can be smaller than the total number of blocks - * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. - * - * This should equal localBlocks.size + remoteBlocks.size. - */ + * Total number of blocks to fetch. This can be smaller than the total number of blocks + * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. + * + * This should equal localBlocks.size + remoteBlocks.size. + */ private[this] var numBlocksToFetch = 0 + /** - * The number of blocks processed by the caller. The iterator is exhausted when - * [[numBlocksProcessed]] == [[numBlocksToFetch]]. - */ + * The number of blocks processed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ private[this] var numBlocksProcessed = 0 private[this] val numRemoteBlockToFetch = new AtomicInteger(0) private[this] val numRemoteBlockProcessing = new AtomicInteger(0) private[this] val numRemoteBlockProcessed = new AtomicInteger(0) + /** - * Current [[FetchResult]] being processed. We track this so we can release the current buffer - * in case of a runtime exception when processing the current buffer. - */ + * Current [[FetchResult]] being processed. We track this so we can release the current buffer + * in case of a runtime exception when processing the current buffer. + */ @volatile private[this] var currentResult: SuccessFetchResult = _ + /** Current bytes in flight from our requests */ private[this] val bytesInFlight = new AtomicLong(0) + /** Current number of requests in flight */ private[this] val reqsInFlight = new AtomicInteger(0) + /** - * Whether the iterator is still active. If isZombie is true, the callback interface will no - * longer place fetched blocks into [[results]]. - */ + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ @GuardedBy("this") private[this] var isZombie = false @@ -125,13 +138,17 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext, def initialize(): Unit = { context.addTaskCompletionListener(_ => cleanup()) - - val remoteBlocksByAddress = blocksByAddress.filter(_._1.executorId != blockManager.blockManagerId.executorId) - for ((address, blockInfos) <- blocksByAddress) { - if (address.executorId == blockManager.blockManagerId.executorId) { - localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) - numBlocksToFetch += localBlocks.size + val (localBlocksByAddress, remoteBlocksByAddress) = + if (PmofConf.getConf.enableRemotePmem) { + (Nil, blocksByAddress) + } else { + ( + blocksByAddress.filter(_._1.executorId == blockManager.blockManagerId.executorId), + blocksByAddress.filter(_._1.executorId != blockManager.blockManagerId.executorId)) } + for ((address, blockInfos) <- localBlocksByAddress) { + localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + numBlocksToFetch += localBlocks.size } startFetch(remoteBlocksByAddress) @@ -156,7 +173,12 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext, for (i <- 0 until num) { val current = blockInfoArray(i) if (current.getShuffleBlockId != last.getShuffleBlockId) { - remoteRdmaRequestQueue.put(new RdmaRequest(blockManagerId, last.getShuffleBlockId, blockInfoArray.slice(startIndex, i), reqSize)) + remoteRdmaRequestQueue.put( + new RdmaRequest( + blockManagerId, + last.getShuffleBlockId, + blockInfoArray.slice(startIndex, i), + reqSize)) startIndex = i reqSize = 0 Future { @@ -166,15 +188,18 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext, last = current reqSize += current.getLength } - remoteRdmaRequestQueue.put(new RdmaRequest(blockManagerId, last.getShuffleBlockId, blockInfoArray.slice(startIndex, num), reqSize)) + remoteRdmaRequestQueue.put( + new RdmaRequest( + blockManagerId, + last.getShuffleBlockId, + blockInfoArray.slice(startIndex, num), + reqSize)) Future { fetchRemoteBlocks() } } - override def onFailure(e: Throwable): Unit = { - - } + override def onFailure(e: Throwable): Unit = {} } numBlocksToFetch += blockIds.length @@ -185,10 +210,10 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext, } /** - * Fetch the local blocks while we are fetching remote blocks. This is ok because - * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we - * track in-memory are the ManagedBuffer references themselves. - */ + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we + * track in-memory are the ManagedBuffer references themselves. + */ private[this] def fetchLocalBlocks() { val iter = localBlocks.iterator while (iter.hasNext) { @@ -198,7 +223,13 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext, shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, isNetworkReqDone = false)) + results.put( + SuccessFetchResult( + blockId, + blockManager.blockManagerId, + 0, + buf, + isNetworkReqDone = false)) } catch { case e: Exception => // If we see an exception, stop immediately. @@ -210,8 +241,8 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext, } /** - * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. - */ + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ private[this] def cleanup() { synchronized { isZombie = true @@ -265,13 +296,13 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext, } /** - * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers - * underlying each InputStream will be freed by the cleanup() method registered with the - * TaskCompletionListener. However, callers should close() these InputStreams - * as soon as they are no longer needed, in order to release memory as early as possible. - * - * Throws a FetchFailedException if the next block could not be fetched. - */ + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers + * underlying each InputStream will be freed by the cleanup() method registered with the + * TaskCompletionListener. However, callers should close() these InputStreams + * as soon as they are no longer needed, in order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. + */ override def next(): (BlockId, InputStream) = { if (!hasNext) { throw new NoSuchElementException @@ -307,18 +338,20 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext, reqsInFlight.decrementAndGet } - logDebug("numRemoteBlockToFetch " + numRemoteBlockToFetch + " numRemoteBlockProcessing " + numRemoteBlockProcessing + " numRemoteBlockProcessed " + numRemoteBlockProcessed) - - val in = try { - buf.createInputStream() - } catch { - // The exception could only be throwed by local shuffle block - case e: IOException => - assert(buf.isInstanceOf[FileSegmentManagedBuffer]) - logError("Failed to create input stream from local block", e) - buf.release() - throwFetchFailedException(blockId, address, e) - } + logDebug( + "numRemoteBlockToFetch " + numRemoteBlockToFetch + " numRemoteBlockProcessing " + numRemoteBlockProcessing + " numRemoteBlockProcessed " + numRemoteBlockProcessed) + + val in = + try { + buf.createInputStream() + } catch { + // The exception could only be throwed by local shuffle block + case e: IOException => + assert(buf.isInstanceOf[FileSegmentManagedBuffer]) + logError("Failed to create input stream from local block", e) + buf.release() + throwFetchFailedException(blockId, address, e) + } input = streamWrapper(blockId, in) // Only copy the stream if it's wrapped by compression or encryption, also the size of @@ -342,10 +375,43 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext, if (rdmaRequest == null) { return } - if (!isRemoteBlockFetchable(rdmaRequest)) { - remoteRdmaRequestQueue.put(rdmaRequest) + if (PmofConf.getConf.enableRemotePmem) { + val shuffleBlockInfos = rdmaRequest.shuffleBlockInfos + val blockManagerId = rdmaRequest.blockManagerId + val remotePersistentMemoryPool = + RemotePersistentMemoryPool.getInstance(blockManagerId.host, blockManagerId.port.toString) + + rdmaRequest.shuffleBlockInfos.foreach(shuffleBlockInfo => { + logDebug( + s"[fetch Remote Blocks] target is ${blockManagerId.host}:${blockManagerId.port}," + + s" ${rdmaRequest.shuffleBlockIdName}" + + s" [${shuffleBlockInfo.getRkey}]${shuffleBlockInfo.getAddress}-${shuffleBlockInfo.getLength}") + val inputByteBuffer = new NioManagedBuffer(shuffleBlockInfo.getLength) + if (remotePersistentMemoryPool.read( + shuffleBlockInfo.getAddress, + shuffleBlockInfo.getLength, + inputByteBuffer.nioByteBuffer) == 0) { + results.put( + SuccessFetchResult( + BlockId(rdmaRequest.shuffleBlockIdName), + rdmaRequest.blockManagerId, + rdmaRequest.reqSize, + inputByteBuffer, + isNetworkReqDone = true)) + } else { + results.put( + FailureFetchResult( + BlockId(rdmaRequest.shuffleBlockIdName), + rdmaRequest.blockManagerId, + new IllegalStateException("RemotePersistentMemoryPool read failed."))) + } + }) } else { - sendRequest(rdmaRequest) + if (!isRemoteBlockFetchable(rdmaRequest)) { + remoteRdmaRequestQueue.put(rdmaRequest) + } else { + sendRequest(rdmaRequest) + } } } @@ -366,7 +432,13 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext, RdmaShuffleBlockFetcherIterator.this.synchronized { blockNums -= 1 if (blockNums == 0) { - results.put(SuccessFetchResult(BlockId(shuffleBlockIdName), blockManagerId, rdmaRequest.reqSize, shuffleBuffer, isNetworkReqDone = true)) + results.put( + SuccessFetchResult( + BlockId(shuffleBlockIdName), + blockManagerId, + rdmaRequest.reqSize, + shuffleBuffer, + isNetworkReqDone = true)) f(shuffleBuffer.getRdmaBufferId) } } @@ -378,17 +450,31 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext, } } - val client = pmofTransferService.getClient(blockManagerId.host, blockManagerId.port) - val shuffleBuffer = new ShuffleBuffer(rdmaRequest.reqSize, client.getEqService, true) - val rdmaBuffer = client.getEqService.regRmaBufferByAddress(shuffleBuffer.nioByteBuffer(), - shuffleBuffer.getAddress, shuffleBuffer.getLength.toInt) + val client = + pmofTransferService.getClient(blockManagerId.host, blockManagerId.port) + val shuffleBuffer = + new ShuffleBuffer(rdmaRequest.reqSize, client.getEqService, true) + val rdmaBuffer = client.getEqService.regRmaBufferByAddress( + shuffleBuffer.nioByteBuffer(), + shuffleBuffer.getAddress, + shuffleBuffer.getLength.toInt) shuffleBuffer.setRdmaBufferId(rdmaBuffer.getBufferId) var offset = 0 for (i <- 0 until blockNums) { - pmofTransferService.fetchBlock(blockManagerId.host, blockManagerId.port, - shuffleBlockInfos(i).getAddress, shuffleBlockInfos(i).getLength, - shuffleBlockInfos(i).getRkey, offset, shuffleBuffer, client, blockFetchingReadCallback) + logInfo( + s"[fetch Remote Blocks] target is ${blockManagerId.host}:${blockManagerId.port}, ${shuffleBlockIdName} [${shuffleBlockInfos( + i).getRkey}]${shuffleBlockInfos(i).getAddress}-${shuffleBlockInfos(i).getLength}") + pmofTransferService.fetchBlock( + blockManagerId.host, + blockManagerId.port, + shuffleBlockInfos(i).getAddress, + shuffleBlockInfos(i).getLength, + shuffleBlockInfos(i).getRkey, + offset, + shuffleBuffer, + client, + blockFetchingReadCallback) offset += shuffleBlockInfos(i).getLength } } @@ -399,26 +485,34 @@ final class RdmaShuffleBlockFetcherIterator(context: TaskContext, override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { + private def throwFetchFailedException( + blockId: BlockId, + address: BlockManagerId, + e: Throwable) = { blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) case _ => throw new SparkException( - "Failed to get block " + blockId + ", which is not a shuffle block", e) + "Failed to get block " + blockId + ", which is not a shuffle block", + e) } } } -private class RdmaRequest(val blockManagerId: BlockManagerId, val shuffleBlockIdName: String, val shuffleBlockInfos: ArrayBuffer[ShuffleBlockInfo], val reqSize: Int) {} +private class RdmaRequest( + val blockManagerId: BlockManagerId, + val shuffleBlockIdName: String, + val shuffleBlockInfos: ArrayBuffer[ShuffleBlockInfo], + val reqSize: Int) {} /** - * Helper class that ensures a ManagedBuffer is released upon InputStream.close() - */ + * Helper class that ensures a ManagedBuffer is released upon InputStream.close() + */ private class RDMABufferReleasingInputStream( - private val delegate: InputStream, - private val iterator: RdmaShuffleBlockFetcherIterator) - extends InputStream { + private val delegate: InputStream, + private val iterator: RdmaShuffleBlockFetcherIterator) + extends InputStream { private[this] var closed = false override def read(): Int = delegate.read() @@ -441,64 +535,65 @@ private class RDMABufferReleasingInputStream( override def read(b: Array[Byte]): Int = delegate.read(b) - override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len) + override def read(b: Array[Byte], off: Int, len: Int): Int = + delegate.read(b, off, len) override def reset(): Unit = delegate.reset() } -private[storage] -object RdmaShuffleBlockFetcherIterator { +private[storage] object RdmaShuffleBlockFetcherIterator { /** - * Result of a fetch from a remote block. - */ + * Result of a fetch from a remote block. + */ private[storage] sealed trait FetchResult { val blockId: BlockId val address: BlockManagerId } /** - * A request to fetch blocks from a remote BlockManager. - * - * @param address remote BlockManager to fetch from. - * @param blocks Sequence of tuple, where the first element is the block id, - * and the second element is the estimated size, used to calculate bytesInFlight. - */ + * A request to fetch blocks from a remote BlockManager. + * + * @param address remote BlockManager to fetch from. + * @param blocks Sequence of tuple, where the first element is the block id, + * and the second element is the estimated size, used to calculate bytesInFlight. + */ case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) { val size: Long = blocks.map(_._2).sum } /** - * Result of a fetch from a remote block successfully. - * - * @param blockId block id - * @param address BlockManager that the block was fetched from. - * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. - * @param buf `ManagedBuffer` for the content. - * @param isNetworkReqDone Is this the last network request for this host in this fetch request. - */ + * Result of a fetch from a remote block successfully. + * + * @param blockId block id + * @param address BlockManager that the block was fetched from. + * @param size estimated size of the block, used to calculate bytesInFlight. + * Note that this is NOT the exact bytes. + * @param buf `ManagedBuffer` for the content. + * @param isNetworkReqDone Is this the last network request for this host in this fetch request. + */ private[storage] case class SuccessFetchResult( - blockId: BlockId, - address: BlockManagerId, - size: Long, - buf: ManagedBuffer, - isNetworkReqDone: Boolean) extends FetchResult { + blockId: BlockId, + address: BlockManagerId, + size: Long, + buf: ManagedBuffer, + isNetworkReqDone: Boolean) + extends FetchResult { require(buf != null) require(size >= 0) } /** - * Result of a fetch from a remote block unsuccessfully. - * - * @param blockId block id - * @param address BlockManager that the block was attempted to be fetched from - * @param e the failure exception - */ + * Result of a fetch from a remote block unsuccessfully. + * + * @param blockId block id + * @param address BlockManager that the block was attempted to be fetched from + * @param e the failure exception + */ private[storage] case class FailureFetchResult( - blockId: BlockId, - address: BlockManagerId, - e: Throwable) - extends FetchResult + blockId: BlockId, + address: BlockManagerId, + e: Throwable) + extends FetchResult } diff --git a/core/src/main/scala/org/apache/spark/storage/pmof/RpmpShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/pmof/RpmpShuffleBlockFetcherIterator.scala new file mode 100644 index 00000000..00ea3a0b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/pmof/RpmpShuffleBlockFetcherIterator.scala @@ -0,0 +1,377 @@ +package org.apache.spark.storage.pmof + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.{File, IOException, InputStream} +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} + +import javax.annotation.concurrent.GuardedBy +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.pmof._ +import org.apache.spark.network.shuffle.{ShuffleClient, TempFileManager} +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.storage._ +import org.apache.spark.util.configuration.pmof.PmofConf +import org.apache.spark.{SparkException, TaskContext} + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future + +import org.apache.spark.util.Utils +import org.apache.spark.util.io.ChunkedByteBufferOutputStream +import java.nio.ByteBuffer +import io.netty.buffer.{ByteBuf, ByteBufInputStream, ByteBufOutputStream} +import io.netty.buffer.Unpooled + +/** + * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block + * manager. For remote blocks, it fetches them using the provided BlockTransferService. + * + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks + * in a pipelined fashion as they are received. + * + * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid + * using too much memory. + * + * @param context [[TaskContext]], used for metrics update + * @param blockManager [[BlockManager]] for reading local blocks + * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. + * For each block we also require the size (in bytes as a long field) in + * order to throttle the memory usage. + * @param streamWrapper A function to wrap the returned input stream. + * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. + * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. + * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point + * for a given remote host:port. + * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. + * @param detectCorrupt whether to detect any corruption in fetched blocks. + */ +private[spark] final class RpmpShuffleBlockFetcherIterator( + context: TaskContext, + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + streamWrapper: (BlockId, InputStream) => InputStream, + maxBytesInFlight: Long, + maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, + maxReqSizeShuffleToMem: Long, + detectCorrupt: Boolean, + pmofConf: PmofConf) + extends Iterator[(BlockId, InputStream)] + with TempFileManager + with Logging { + + import RpmpShuffleBlockFetcherIterator._ + + /** Local blocks to fetch, excluding zero-sized blocks. */ + private[this] val localBlocks = new ArrayBuffer[BlockId]() + + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[FetchResult] + private[this] val shuffleMetrics = + context.taskMetrics().createTempShuffleReadMetrics() + + /** + * A set to store the files used for shuffling remote huge blocks. Files in this set will be + * deleted when cleanup. This is a layer of defensiveness against disk file leaks. + */ + @GuardedBy("this") + private[this] val shuffleFilesSet = mutable.HashSet[File]() + + /** + * Total number of blocks to fetch. This can be smaller than the total number of blocks + * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. + * + * This should equal localBlocks.size + remoteBlocks.size. + */ + private[this] var numBlocksToFetch = 0 + + /** + * The number of blocks processed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ + private[this] var numBlocksProcessed = 0 + + private[this] val numRemoteBlockToFetch = new AtomicInteger(0) + private[this] val numRemoteBlockProcessing = new AtomicInteger(0) + private[this] val numRemoteBlockProcessed = new AtomicInteger(0) + + /** + * Current [[FetchResult]] being processed. We track this so we can release the current buffer + * in case of a runtime exception when processing the current buffer. + */ + @volatile private[this] var currentResult: SuccessFetchResult = _ + + /** Current bytes in flight from our requests */ + private[this] val bytesInFlight = new AtomicLong(0) + + /** Current number of requests in flight */ + private[this] val reqsInFlight = new AtomicInteger(0) + + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @GuardedBy("this") + private[this] var isZombie = false + + private[this] var blocksByAddressSize = 0 + private[this] var blocksByAddressCurrentId = 0 + private[this] var address: BlockManagerId = _ + private[this] var blockInfos: Seq[(BlockId, Long)] = _ + private[this] var iterator: Iterator[(BlockId, Long)] = _ + + initialize() + + def initialize(): Unit = { + context.addTaskCompletionListener(_ => cleanup()) + blocksByAddressSize = blocksByAddress.size + if (blocksByAddressCurrentId < blocksByAddressSize) { + val res = blocksByAddress(blocksByAddressCurrentId) + address = res._1 + blockInfos = res._2 + iterator = blockInfos.iterator + blocksByAddressCurrentId += 1 + } + } + + val remotePersistentMemoryPool = + RemotePersistentMemoryPool.getInstance(pmofConf.rpmpHost, pmofConf.rpmpPort) + + /** + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers + * underlying each InputStream will be freed by the cleanup() method registered with the + * TaskCompletionListener. However, callers should close() these InputStreams + * as soon as they are no longer needed, in order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. + */ + private[this] var next_called: Boolean = true + private[this] var has_next: Boolean = false + override def next(): (BlockId, InputStream) = { + next_called = true + var input: InputStream = null + val (blockId, size) = iterator.next() + var buf = new NioManagedBuffer(size.toInt) + val startFetchWait = System.currentTimeMillis() + val readed_len = remotePersistentMemoryPool.get(blockId.name, size, buf.nioByteBuffer) + if (readed_len != -1) { + val stopFetchWait = System.currentTimeMillis() + shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) + shuffleMetrics.incRemoteBytesRead(buf.size) + shuffleMetrics.incRemoteBlocksFetched(1) + + val in = buf.createInputStream() + input = streamWrapper(blockId, in) + if (detectCorrupt && !input.eq(in)) { + val originalInput = input + val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) + try { + Utils.copyStream(input, out) + out.close() + input = out.toChunkedByteBuffer.toInputStream(dispose = true) + logDebug(s" buf ${blockId}-${size} decompress succeeded.") + } catch { + case e: IOException => + val tmp_byte_buffer = buf.nioByteBuffer + logWarning( + s" buf ${blockId}-${size} decompress corrupted, input is ${input}, raw_data is ${tmp_byte_buffer}, remaining is ${tmp_byte_buffer.remaining}") + val tmp: Array[Byte] = Array() + buf.nioByteBuffer.rewind + buf.nioByteBuffer.get(tmp, 0, size.toInt) + logWarning(s"content is ${convertBytesToHex(tmp)}.") + throw e + + } finally { + originalInput.close() + in.close() + buf.release() + } + } + } else { + throw new IOException(s"remotePersistentMemoryPool.get(${blockId}, ${size}) failed."); + } + (blockId, new RpmpBufferReleasingInputStream(input, null)) + } + + override def hasNext: Boolean = { + if (!next_called) { + return has_next + } + next_called = false + if (iterator.hasNext) { + has_next = true + } else { + if (blocksByAddressCurrentId >= blocksByAddressSize) { + has_next = false + return has_next + } + val res = blocksByAddress(blocksByAddressCurrentId) + address = res._1 + blockInfos = res._2 + iterator = blockInfos.iterator + blocksByAddressCurrentId += 1 + has_next = true + } + return has_next + } + + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + synchronized { + isZombie = true + } + } + + override def registerTempFileToClean(file: File): Boolean = synchronized { + if (isZombie) { + false + } else { + shuffleFilesSet += file + true + } + } + + private def throwFetchFailedException( + blockId: BlockId, + address: BlockManagerId, + e: Throwable) = { + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block", + e) + } + } + override def createTempFile(): File = { + blockManager.diskBlockManager.createTempLocalBlock()._2 + } + def convertBytesToHex(bytes: Seq[Byte]): String = { + val sb = new StringBuilder + for (b <- bytes) { + sb.append(String.format("%02x", Byte.box(b))) + } + sb.toString + } + +} + +/** + * Helper class that ensures a ManagedBuffer is released upon InputStream.close() + */ +private class RpmpBufferReleasingInputStream( + private val delegate: InputStream, + private val parent: NioManagedBuffer) + extends InputStream { + private[this] var closed = false + + override def read(): Int = delegate.read() + + override def close(): Unit = { + if (!closed) { + delegate.close() + if (parent != null) { + parent.release() + } + closed = true + } + } + + override def available(): Int = delegate.available() + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = delegate.skip(n) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = delegate.read(b) + + override def read(b: Array[Byte], off: Int, len: Int): Int = + delegate.read(b, off, len) + + override def reset(): Unit = delegate.reset() +} + +private[storage] object RpmpShuffleBlockFetcherIterator { + + /** + * Result of a fetch from a remote block. + */ + private[storage] sealed trait FetchResult { + val blockId: BlockId + val address: BlockManagerId + } + + /** + * A request to fetch blocks from a remote BlockManager. + * + * @param address remote BlockManager to fetch from. + * @param blocks Sequence of tuple, where the first element is the block id, + * and the second element is the estimated size, used to calculate bytesInFlight. + */ + case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) { + val size: Long = blocks.map(_._2).sum + } + + /** + * Result of a fetch from a remote block successfully. + * + * @param blockId block id + * @param address BlockManager that the block was fetched from. + * @param size estimated size of the block, used to calculate bytesInFlight. + * Note that this is NOT the exact bytes. + * @param buf `ManagedBuffer` for the content. + * @param isNetworkReqDone Is this the last network request for this host in this fetch request. + */ + private[storage] case class SuccessFetchResult( + blockId: BlockId, + address: BlockManagerId, + size: Long, + buf: ManagedBuffer, + isNetworkReqDone: Boolean) + extends FetchResult { + require(buf != null) + require(size >= 0) + } + + /** + * Result of a fetch from a remote block unsuccessfully. + * + * @param blockId block id + * @param address BlockManager that the block was attempted to be fetched from + * @param e the failure exception + */ + private[storage] case class FailureFetchResult( + blockId: BlockId, + address: BlockManagerId, + e: Throwable) + extends FetchResult + +} diff --git a/core/src/main/scala/org/apache/spark/util/configuration/pmof/PmofConf.scala b/core/src/main/scala/org/apache/spark/util/configuration/pmof/PmofConf.scala index f3551704..3e560c72 100644 --- a/core/src/main/scala/org/apache/spark/util/configuration/pmof/PmofConf.scala +++ b/core/src/main/scala/org/apache/spark/util/configuration/pmof/PmofConf.scala @@ -27,4 +27,22 @@ class PmofConf(conf: SparkConf) { val shuffleBlockSize: Int = conf.getInt("spark.shuffle.pmof.shuffle_block_size", defaultValue = 2048) val pmemCapacity: Long = conf.getLong("spark.shuffle.pmof.pmem_capacity", defaultValue = 264239054848L) val pmemCoreMap = conf.get("spark.shuffle.pmof.dev_core_set", defaultValue = "/dev/dax0.0:0-17,36-53").split(";").map(_.trim).map(_.split(":")).map(arr => arr(0) -> arr(1)).toMap + val enableRemotePmem: Boolean = conf.getBoolean("spark.shuffle.pmof.enable_remote_pmem", defaultValue = false); + val rpmpHost: String = conf.get("spark.rpmp.rhost", defaultValue = "172.168.0.40") + val rpmpPort: String = conf.get("spark.rpmp.rport", defaultValue = "61010") } + +object PmofConf { + var ins: PmofConf = null + def getConf(conf: SparkConf): PmofConf = if (ins == null) { + ins = new PmofConf(conf) + ins + } else { + ins + } + def getConf: PmofConf = if (ins == null) { + throw new IllegalStateException("PmofConf is not initialized yet") + } else { + ins + } +} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterSuite.scala index 4c0af8d8..91e2f9fa 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterSuite.scala @@ -49,7 +49,7 @@ class PmemShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with conf.set("spark.shuffle.pmof.pmem_list", "/dev/dax0.0") shuffleBlockResolver = new PmemShuffleBlockResolver(conf) serializer = new JavaSerializer(conf) - pmofConf = new PmofConf(conf) + pmofConf = PmofConf.getConf(conf) taskMetrics = new TaskMetrics() serializerManager = new SerializerManager(serializer, conf) diff --git a/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterWithSortSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterWithSortSuite.scala index e01e06cd..410292f6 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterWithSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/pmof/PmemShuffleWriterWithSortSuite.scala @@ -53,7 +53,7 @@ class PmemShuffleWriterWithSortSuite extends SparkFunSuite with SharedSparkConte conf.set("spark.shuffle.pmof.pmem_list", "/dev/dax0.0") shuffleBlockResolver = new PmemShuffleBlockResolver(conf) serializer = new JavaSerializer(conf) - pmofConf = new PmofConf(conf) + pmofConf = PmofConf.getConf(conf) taskMetrics = new TaskMetrics() serializerManager = new SerializerManager(serializer, conf) diff --git a/rpmp/benchmark/CMakeLists.txt b/rpmp/benchmark/CMakeLists.txt index 468d6b15..a670b1e9 100644 --- a/rpmp/benchmark/CMakeLists.txt +++ b/rpmp/benchmark/CMakeLists.txt @@ -1,17 +1,26 @@ -add_executable(local_allocate local_allocate.cc) -target_link_libraries(local_allocate pmpool) +#add_executable(local_allocate local_allocate.cc) +#target_link_libraries(local_allocate pmpool) +# +#add_executable(remote_allocate remote_allocate.cc) +#target_link_libraries(remote_allocate pmpool_client_jni) +# +#add_executable(remote_write remote_write.cc) +#target_link_libraries(remote_write pmpool_client_jni) +# +#add_executable(remote_allocate_write remote_allocate_write.cc) +#target_link_libraries(remote_allocate_write pmpool_client_jni) +# +#add_executable(circularbuffer circularbuffer.cc) +#target_link_libraries(circularbuffer pmpool_client_jni) +# +#add_executable(remote_read remote_read.cc) +#target_link_libraries(remote_read pmpool_client_jni) +# +#add_executable(remote_put remote_put.cc) +#target_link_libraries(remote_put pmpool_client_jni) -add_executable(remote_allocate remote_allocate.cc) -target_link_libraries(remote_allocate pmpool) +add_executable(put_and_read put_and_read.cc) +target_link_libraries(put_and_read pmpool_client_jni) -add_executable(remote_write remote_write.cc) -target_link_libraries(remote_write pmpool) - -add_executable(remote_allocate_write remote_allocate_write.cc) -target_link_libraries(remote_allocate_write pmpool) - -add_executable(circularbuffer circularbuffer.cc) -target_link_libraries(circularbuffer pmpool) - -add_executable(remote_read remote_read.cc) -target_link_libraries(remote_read pmpool) +add_executable(put_and_get put_and_get.cc) +target_link_libraries(put_and_get pmpool_client_jni) \ No newline at end of file diff --git a/rpmp/benchmark/Config.h b/rpmp/benchmark/Config.h new file mode 100644 index 00000000..dc5659ef --- /dev/null +++ b/rpmp/benchmark/Config.h @@ -0,0 +1,98 @@ +/* + * Filename: /mnt/spark-pmof/tool/rpmp/pmpool/Config.h + * Path: /mnt/spark-pmof/tool/rpmp/pmpool + * Created Date: Thursday, November 7th 2019, 3:48:52 pm + * Author: root + * + * Copyright (c) 2019 Intel + */ + +#ifndef PMPOOL_CONFIG_H_ +#define PMPOOL_CONFIG_H_ + +#include +#include +#include +#include + +#include + +using boost::program_options::error; +using boost::program_options::options_description; +using boost::program_options::value; +using boost::program_options::variables_map; +using std::string; +using std::vector; + +/** + * @brief This class represents the current RPMP configuration. + * + */ +class Config { + public: + int init(int argc, char **argv) { + try { + options_description desc{"Options"}; + desc.add_options()("help,h", "Help screen")( + "address,a", value()->default_value("172.168.0.40"), + "set the rdma server address")( + "port,p", value()->default_value("12346"), + "set the rdma server port")( + "log,l", value()->default_value("/tmp/rpmp.log"), + "set rpmp log file path")("map_id,m", value()->default_value(0), + "map id")( + "req_num,r", value()->default_value(2048), "number of requests")( + "threads,t", value()->default_value(8), "number of threads"); + + variables_map vm; + store(parse_command_line(argc, argv, desc), vm); + notify(vm); + + if (vm.count("help")) { + std::cout << desc << '\n'; + return -1; + } + set_ip(vm["address"].as()); + set_port(vm["port"].as()); + set_log_path(vm["log"].as()); + set_map_id(vm["map_id"].as()); + set_num_reqs(vm["req_num"].as()); + set_num_threads(vm["threads"].as()); + } catch (const error &ex) { + std::cerr << ex.what() << '\n'; + } + return 0; + } + + int get_map_id() { return map_id_; } + void set_map_id(int map_id) { map_id_ = map_id; } + + int get_num_reqs() { return num_reqs_; } + void set_num_reqs(int num_reqs) { num_reqs_ = num_reqs; } + + int get_num_threads() { return num_threads_; } + void set_num_threads(int num_threads) { num_threads_ = num_threads; } + + string get_ip() { return ip_; } + void set_ip(string ip) { ip_ = ip; } + + string get_port() { return port_; } + void set_port(string port) { port_ = port; } + + string get_log_path() { return log_path_; } + void set_log_path(string log_path) { log_path_ = log_path; } + + string get_log_level() { return log_level_; } + void set_log_level(string log_level) { log_level_ = log_level; } + + private: + string ip_; + string port_; + string log_path_; + string log_level_; + int map_id_ = 0; + int num_threads_ = 8; + int num_reqs_ = 2048; +}; + +#endif // PMPOOL_CONFIG_H_ diff --git a/rpmp/benchmark/local_allocate.cc b/rpmp/benchmark/local_allocate.cc index 8faadd68..50afbc0c 100644 --- a/rpmp/benchmark/local_allocate.cc +++ b/rpmp/benchmark/local_allocate.cc @@ -29,7 +29,7 @@ std::mutex mtx; uint64_t count = 0; char str[1048576]; -void func(AllocatorProxy *proxy, int index) { +void func(std::shared_ptr proxy, int index) { while (true) { std::unique_lock lk(mtx); uint64_t count_ = count++; @@ -47,7 +47,7 @@ int main() { std::shared_ptr config = std::make_shared(); config->init(0, nullptr); std::shared_ptr log = std::make_shared(config.get()); - auto allocatorProxy = new AllocatorProxy(config.get(), log.get(), nullptr); + auto allocatorProxy = std::make_shared(config, log, nullptr); allocatorProxy->init(); std::vector threads; memset(str, '0', 1048576); diff --git a/rpmp/benchmark/put_and_get.cc b/rpmp/benchmark/put_and_get.cc new file mode 100644 index 00000000..473a2664 --- /dev/null +++ b/rpmp/benchmark/put_and_get.cc @@ -0,0 +1,141 @@ +/* + * Filename: /mnt/spark-pmof/tool/rpmp/benchmark/allocate_perf.cc + * Path: /mnt/spark-pmof/tool/rpmp/benchmark + * Created Date: Friday, December 20th 2019, 8:29:23 am + * Author: root + * + * Copyright (c) 2019 Intel + */ + +#include +#include +#include // NOLINT +#include "Config.h" +#include "pmpool/Base.h" +#include "pmpool/client/PmPoolClient.h" + +uint64_t timestamp_now() { + return std::chrono::high_resolution_clock::now().time_since_epoch() / + std::chrono::milliseconds(1); +} + +char str[1048576]; +int numReqs = 2048; + +bool comp(char* str, char* str_read, uint64_t size) { + auto res = memcmp(str, str_read, size); + if (res != 0) { + fprintf(stderr, + "** strcmp is %d, read res is not aligned with wrote. **\nreaded " + "content is \n", + res); + for (int i = 0; i < 100; i++) { + fprintf(stderr, "%X ", *(str_read + i)); + } + fprintf(stderr, " ...\nwrote content is \n"); + for (int i = 0; i < 100; i++) { + fprintf(stderr, "%X ", *(str + i)); + } + fprintf(stderr, " ...\n"); + } + return res == 0; +} + +void get(int map_id, int start, int end, std::shared_ptr client) { + int count = start; + while (count < end) { + std::string key = + "block_" + std::to_string(map_id) + "_" + std::to_string(count++); + char str_read[1048576]; + client->begin_tx(); + client->get(key, str_read, 1048576); + client->end_tx(); + if (comp(str, str_read, 1048576) == false) { + throw; + } + } +} + +void put(int map_id, int start, int end, std::shared_ptr client) { + int count = start; + while (count < end) { + std::string key = + "block_" + std::to_string(map_id) + "_" + std::to_string(count++); + client->begin_tx(); + client->put(key, str, 1048576); + client->end_tx(); + } +} + +int main(int argc, char** argv) { + /// initialize Config class + std::shared_ptr config = std::make_shared(); + CHK_ERR("config init", config->init(argc, argv)); + + char temp[] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', + 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', + 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f'}; + for (int i = 0; i < 1048576 / 32; i++) { + memcpy(str + i * 32, temp, 32); + } + + int threads = config->get_num_threads(); + int map_id = config->get_map_id(); + numReqs = config->get_num_reqs(); + std::string host = config->get_ip(); + std::string port = config->get_port(); + + std::cout << "=================== Put and get =======================" + << std::endl; + std::cout << "RPMP server is " << host << ":" << port << std::endl; + std::cout << "Total Num Requests is " << numReqs << std::endl; + std::cout << "Total Num Threads is " << threads << std::endl; + std::cout << "Block key pattern is " + << "block_" << map_id << "_*" << std::endl; + + auto client = std::make_shared(host, port); + client->init(); + std::cout << "start put." << std::endl; + int start = 0; + int step = numReqs / threads; + std::vector> threads_1; + uint64_t begin = timestamp_now(); + for (int i = 0; i < threads; i++) { + auto t = + std::make_shared(put, map_id, start, start + step, client); + threads_1.push_back(t); + start += step; + } + for (auto thread : threads_1) { + thread->join(); + } + uint64_t end = timestamp_now(); + std::cout << "[block_" << map_id << "_*]" + << "pmemkv put test: 1048576 " + << " bytes test, consumes " << (end - begin) / 1000.0 + << "s, throughput is " << numReqs / ((end - begin) / 1000.0) + << "MB/s" << std::endl; + + std::cout << "start get." << std::endl; + std::vector> threads_2; + begin = timestamp_now(); + start = 0; + for (int i = 0; i < threads; i++) { + auto t = + std::make_shared(get, map_id, start, start + step, client); + threads_2.push_back(t); + start += step; + } + for (auto thread : threads_2) { + thread->join(); + } + end = timestamp_now(); + std::cout << "[block_" << map_id << "_*]" + << "pmemkv get test: 1048576 " + << " bytes test, consumes " << (end - begin) / 1000.0 + << "s, throughput is " << numReqs / ((end - begin) / 1000.0) + << "MB/s" << std::endl; + + client.reset(); + return 0; +} diff --git a/rpmp/benchmark/put_and_read.cc b/rpmp/benchmark/put_and_read.cc new file mode 100644 index 00000000..05dff4bb --- /dev/null +++ b/rpmp/benchmark/put_and_read.cc @@ -0,0 +1,139 @@ +/* + * Filename: /mnt/spark-pmof/tool/rpmp/benchmark/allocate_perf.cc + * Path: /mnt/spark-pmof/tool/rpmp/benchmark + * Created Date: Friday, December 20th 2019, 8:29:23 am + * Author: root + * + * Copyright (c) 2019 Intel + */ + +#include +#include +#include // NOLINT +#include "Config.h" +#include "pmpool/Base.h" +#include "pmpool/client/PmPoolClient.h" + +uint64_t timestamp_now() { + return std::chrono::high_resolution_clock::now().time_since_epoch() / + std::chrono::milliseconds(1); +} + +char str[1048576]; +int numReqs = 2048; + +bool comp(char* str, char* str_read, uint64_t size) { + auto res = memcmp(str, str_read, size); + if (res != 0) { + fprintf(stderr, + "strcmp is %d, read res is not aligned with wrote. readed " + "content is \n", + res); + for (int i = 0; i < size; i++) { + fprintf(stderr, "%X ", *(str_read + i)); + } + fprintf(stderr, "\n wrote content is \n"); + for (int i = 0; i < size; i++) { + fprintf(stderr, "%X ", *(str + i)); + } + fprintf(stderr, "\n"); + } + return res == 0; +} + +void get(std::vector addresses, + std::shared_ptr client) { + for (auto bm : addresses) { + char str_read[1048576]; + client->read(bm.address, str_read, bm.size); + comp(str, str_read, bm.size); + } +} + +void put(int map_id, int start, int end, std::shared_ptr client, + std::vector* addresses) { + int count = start; + while (count < end) { + std::string key = + "block_" + std::to_string(map_id) + "_" + std::to_string(count++); + client->begin_tx(); + client->put(key, str, 1048576); + auto res = client->getMeta(key); + for (auto bm : res) { + (*addresses).push_back(bm); + } + client->end_tx(); + } +} + +int main(int argc, char** argv) { + /// initialize Config class + std::shared_ptr config = std::make_shared(); + CHK_ERR("config init", config->init(argc, argv)); + + char temp[] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', + 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', + 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f'}; + for (int i = 0; i < 1048576 / 32; i++) { + memcpy(str + i * 32, temp, 32); + } + + int threads = config->get_num_threads(); + int map_id = config->get_map_id(); + numReqs = config->get_num_reqs(); + std::string host = config->get_ip(); + std::string port = config->get_port(); + + std::cout << "=================== Put and get =======================" + << std::endl; + std::cout << "RPMP server is " << host << ":" << port << std::endl; + std::cout << "Total Num Requests is " << numReqs << std::endl; + std::cout << "Total Num Threads is " << threads << std::endl; + std::cout << "Block key pattern is " + << "block_" << map_id << "_*" << std::endl; + + auto client = std::make_shared(host, port); + client->init(); + std::cout << "start put." << std::endl; + int start = 0; + int step = numReqs / threads; + std::vector> addresses_list; + addresses_list.resize(threads); + std::vector> threads_1; + uint64_t begin = timestamp_now(); + for (int i = 0; i < threads; i++) { + auto t = std::make_shared(put, map_id, start, start + step, + client, &addresses_list[i]); + threads_1.push_back(t); + start += step; + } + for (auto thread : threads_1) { + thread->join(); + } + uint64_t end = timestamp_now(); + std::cout << "[block_" << map_id << "_*]" + << "pmemkv put test: 1048576 " + << " bytes test, consumes " << (end - begin) / 1000.0 + << "s, throughput is " << numReqs / ((end - begin) / 1000.0) + << "MB/s" << std::endl; + + std::cout << "start get." << std::endl; + std::vector> threads_2; + begin = timestamp_now(); + for (int i = 0; i < threads; i++) { + auto t = std::make_shared(get, addresses_list[i], client); + threads_2.push_back(t); + } + for (auto thread : threads_2) { + thread->join(); + } + end = timestamp_now(); + std::cout << "[block_" << map_id << "_*]" + << "pmemkv get test: 1048576 " + << " bytes test, consumes " << (end - begin) / 1000.0 + << "s, throughput is " << numReqs / ((end - begin) / 1000.0) + << "MB/s" << std::endl; + + client.reset(); + return 0; +} diff --git a/rpmp/benchmark/remote_allocate_write.cc b/rpmp/benchmark/remote_allocate_write.cc index 2a1f5da7..a57d10a2 100644 --- a/rpmp/benchmark/remote_allocate_write.cc +++ b/rpmp/benchmark/remote_allocate_write.cc @@ -8,8 +8,10 @@ */ #include -#include // NOLINT #include +#include // NOLINT +#include "pmpool/Base.h" +#include "pmpool/Config.h" #include "pmpool/client/PmPoolClient.h" uint64_t timestamp_now() { @@ -20,12 +22,12 @@ uint64_t timestamp_now() { std::atomic count = {0}; std::mutex mtx; char str[1048576]; -std::vector clients; +std::vector> clients; std::map> addresses; void func1(int i) { while (true) { - uint64_t count_ = count++; + auto count_ = count++; if (count_ < 20480) { clients[i]->begin_tx(); if (addresses.count(i) != 0) { @@ -43,8 +45,11 @@ void func1(int i) { } } -int main() { - std::vector threads; +int main(int argc, char **argv) { + /// initialize Config class + std::shared_ptr config = std::make_shared(); + CHK_ERR("config init", config->init(argc, argv)); + std::vector> threads; memset(str, '0', 1048576); int num = 0; @@ -52,7 +57,8 @@ int main() { num = 0; count = 0; for (int i = 0; i < 4; i++) { - PmPoolClient *client = new PmPoolClient("172.168.0.40", "12346"); + auto client = + std::make_shared(config->get_ip(), config->get_port()); client->begin_tx(); client->init(); client->end_tx(); @@ -61,12 +67,11 @@ int main() { } uint64_t start = timestamp_now(); for (int i = 0; i < num; i++) { - auto t = new std::thread(func1, i); + auto t = std::make_shared(func1, i); threads.push_back(t); } for (int i = 0; i < num; i++) { threads[i]->join(); - delete threads[i]; } uint64_t end = timestamp_now(); std::cout << "pmemkv put test: 1048576 " @@ -84,7 +89,6 @@ int main() { std::cout << "freed." << std::endl; for (int i = 0; i < num; i++) { clients[i]->wait(); - delete clients[i]; } return 0; } diff --git a/rpmp/benchmark/remote_put.cc b/rpmp/benchmark/remote_put.cc new file mode 100644 index 00000000..bb35696a --- /dev/null +++ b/rpmp/benchmark/remote_put.cc @@ -0,0 +1,108 @@ +/* + * Filename: /mnt/spark-pmof/tool/rpmp/benchmark/allocate_perf.cc + * Path: /mnt/spark-pmof/tool/rpmp/benchmark + * Created Date: Friday, December 20th 2019, 8:29:23 am + * Author: root + * + * Copyright (c) 2019 Intel + */ + +#include +#include +#include // NOLINT +#include "pmpool/Base.h" +#include "pmpool/Config.h" +#include "pmpool/client/PmPoolClient.h" + +uint64_t timestamp_now() { + return std::chrono::high_resolution_clock::now().time_since_epoch() / + std::chrono::milliseconds(1); +} + +int count = 0; +std::mutex mtx; +std::vector keys; +char str[1048576]; +int numReqs = 2048; + +void func1(std::shared_ptr client) { + while (true) { + std::unique_lock lk(mtx); + uint64_t count_ = count++; + lk.unlock(); + if (count_ < numReqs) { + char str_read[1048576]; + client->begin_tx(); + client->put(keys[count_], str, 1048576); + auto res = client->getMeta(keys[count_]); + // printf("put and get Meta of key %s\n", keys[count_].c_str()); + for (auto bm : res) { + client->read(bm.address, str_read, bm.size); + // printf("read of key %s, info is [%d]%ld-%d\n", keys[count_].c_str(), + // bm.r_key, bm.address, bm.size); + auto res = memcmp(str, str_read, 1048576); + if (res != 0) { + fprintf(stderr, + "strcmp is %d, read res is not aligned with wrote. readed " + "content is \n", + res); + for (int i = 0; i < 1048576; i++) { + fprintf(stderr, "%X ", *(str_read + i)); + } + fprintf(stderr, "\n wrote content is \n"); + for (int i = 0; i < 1048576; i++) { + fprintf(stderr, "%X ", *(str + i)); + } + fprintf(stderr, "\n"); + } + } + client->end_tx(); + } else { + break; + } + } +} + +int main(int argc, char** argv) { + /// initialize Config class + std::shared_ptr config = std::make_shared(); + CHK_ERR("config init", config->init(argc, argv)); + auto client = + std::make_shared(config->get_ip(), config->get_port()); + char temp[] = {'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', + 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', + 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f'}; + for (int i = 0; i < 1048576 / 32; i++) { + memcpy(str + i * 32, temp, 32); + } + for (int i = 0; i < 20480; i++) { + keys.emplace_back("block_" + std::to_string(i)); + } + client->init(); + + int threads = 4; + std::cout << "start put." << std::endl; + std::vector> threads_1; + uint64_t start = timestamp_now(); + for (int i = 0; i < threads; i++) { + auto t = std::make_shared(func1, client); + threads_1.push_back(t); + } + for (auto thread : threads_1) { + thread->join(); + } + uint64_t end = timestamp_now(); + std::cout << "pmemkv put test: 1048576 " + << " bytes test, consumes " << (end - start) / 1000.0 + << "s, throughput is " << numReqs / ((end - start) / 1000.0) + << "MB/s" << std::endl; + client.reset(); + /*for (int i = 0; i < 20480; i++) { + client->begin_tx(); + client->del(keys[i]); + client->end_tx(); + } + std::cout << "freed." << std::endl; + client->wait();*/ + return 0; +} diff --git a/rpmp/include/spdlog b/rpmp/include/spdlog index cca004ef..5ca5cbd4 160000 --- a/rpmp/include/spdlog +++ b/rpmp/include/spdlog @@ -1 +1 @@ -Subproject commit cca004efe4e66136a5f9f37e007d28a23bb729e0 +Subproject commit 5ca5cbd44739576ee90d3d9c21ba8cbf7e11ff5c diff --git a/rpmp/main.cc b/rpmp/main.cc index c8fe9d16..2f97fbd8 100644 --- a/rpmp/main.cc +++ b/rpmp/main.cc @@ -28,7 +28,7 @@ int ServerMain(int argc, char **argv) { std::shared_ptr log = std::make_shared(config.get()); /// initialize DataServer class std::shared_ptr dataServer = - std::make_shared(config.get(), log.get()); + std::make_shared(config, log); log->get_file_log()->info("start to initialize data server."); CHK_ERR("data server init", dataServer->init()); log->get_file_log()->info("data server initailized."); diff --git a/rpmp/pmpool/AllocatorProxy.h b/rpmp/pmpool/AllocatorProxy.h index cddc5830..4a67d402 100644 --- a/rpmp/pmpool/AllocatorProxy.h +++ b/rpmp/pmpool/AllocatorProxy.h @@ -11,22 +11,23 @@ #define PMPOOL_ALLOCATORPROXY_H_ #include +#include #include #include -#include #include +#include #include "Allocator.h" +#include "Base.h" #include "Config.h" #include "DataServer.h" #include "Log.h" #include "PmemAllocator.h" -#include "Base.h" using std::atomic; using std::make_shared; -using std::unordered_map; using std::string; +using std::unordered_map; using std::vector; /** @@ -37,24 +38,21 @@ using std::vector; class AllocatorProxy { public: AllocatorProxy() = delete; - AllocatorProxy(Config *config, Log *log, NetworkServer *networkServer) + AllocatorProxy(std::shared_ptr config, std::shared_ptr log, + std::shared_ptr networkServer) : config_(config), log_(log) { vector paths = config_->get_pool_paths(); vector sizes = config_->get_pool_sizes(); assert(paths.size() == sizes.size()); for (int i = 0; i < paths.size(); i++) { - DiskInfo *diskInfo = new DiskInfo(paths[i], sizes[i]); + auto diskInfo = std::make_shared(paths[i], sizes[i]); diskInfos_.push_back(diskInfo); allocators_.push_back( - new PmemObjAllocator(log_, diskInfo, networkServer, i)); + std::make_shared(log_, diskInfo, networkServer, i)); } } ~AllocatorProxy() { - for (int i = 0; i < config_->get_pool_paths().size(); i++) { - delete allocators_[i]; - delete diskInfos_[i]; - } allocators_.clear(); diskInfos_.clear(); } @@ -76,6 +74,7 @@ class AllocatorProxy { addr = allocators_[index % diskInfos_.size()]->allocate_and_write( size, content); } + return addr; } int write(uint64_t address, const char *content, uint64_t size) { @@ -112,14 +111,16 @@ class AllocatorProxy { return allocators_[wid]->get_rma_chunk(); } - void cache_chunk(uint64_t key, uint64_t address, uint64_t size) { - block_meta bm = {address, size}; + void cache_chunk(uint64_t key, uint64_t address, uint64_t size, int r_key) { + block_meta bm = {address, size, r_key}; cache_chunk(key, bm); } void cache_chunk(uint64_t key, block_meta bm) { if (kv_meta_map.count(key)) { + printf("key exists\n"); kv_meta_map[key].push_back(bm); + // kv_meta_map.erase(key); } else { vector bml; bml.push_back(bm); @@ -135,16 +136,16 @@ class AllocatorProxy { } void del_chunk(uint64_t key) { - if (kv_meta_map.count(key)) { + if (kv_meta_map.count(key)) { kv_meta_map.erase(key); } } private: - Config *config_; - Log *log_; - vector allocators_; - vector diskInfos_; + std::shared_ptr config_; + std::shared_ptr log_; + vector> allocators_; + vector> diskInfos_; atomic buffer_id_{0}; unordered_map> kv_meta_map; }; diff --git a/rpmp/pmpool/Base.h b/rpmp/pmpool/Base.h index 685c0d60..4cfa38ed 100644 --- a/rpmp/pmpool/Base.h +++ b/rpmp/pmpool/Base.h @@ -44,11 +44,19 @@ struct RequestReplyMsg { }; struct block_meta { - block_meta() : block_meta(0, 0) {} + block_meta() : block_meta(0, 0, 0) {} block_meta(uint64_t _address, uint64_t _size) : address(_address), size(_size) {} + block_meta(uint64_t _address, uint64_t _size, int _r_key) + : address(_address), size(_size), r_key(_r_key) {} + void set_rKey(int _r_key) { r_key = _r_key; } + std::string ToString() { + return std::to_string(r_key) + "-" + std::to_string(address) + ":" + + std::to_string(size); + } uint64_t address; uint64_t size; + int r_key; }; #endif // PMPOOL_BASE_H_ diff --git a/rpmp/pmpool/CMakeLists.txt b/rpmp/pmpool/CMakeLists.txt index 64c82108..f5dd3656 100644 --- a/rpmp/pmpool/CMakeLists.txt +++ b/rpmp/pmpool/CMakeLists.txt @@ -1,4 +1,8 @@ -add_library(pmpool SHARED DataServer.cc Protocol.cc Event.cc NetworkServer.cc hash/xxhash.cc client/PmPoolClient.cc client/NetworkClient.cc client/native/com_intel_rpmp_PmPoolClient.cc) +add_library(pmpool_client_jni SHARED Event.cc client/PmPoolClient.cc client/NetworkClient.cc client/native/com_intel_rpmp_PmPoolClient.cc) +target_link_libraries(pmpool_client_jni LINK_PUBLIC ${Boost_LIBRARIES} hpnl) +set_target_properties(pmpool_client_jni PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") + +add_library(pmpool SHARED DataServer.cc Protocol.cc Event.cc NetworkServer.cc hash/xxhash.cc client/PmPoolClient.cc client/NetworkClient.cc) target_link_libraries(pmpool LINK_PUBLIC ${Boost_LIBRARIES} hpnl pmemobj) set_target_properties(pmpool PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") diff --git a/rpmp/pmpool/Config.h b/rpmp/pmpool/Config.h index 2a1b9526..d1edacbb 100644 --- a/rpmp/pmpool/Config.h +++ b/rpmp/pmpool/Config.h @@ -17,12 +17,14 @@ #include -using boost::program_options::error; +using namespace boost::program_options; +using namespace std; +/*using boost::program_options::error; using boost::program_options::options_description; using boost::program_options::value; using boost::program_options::variables_map; using std::string; -using std::vector; +using std::vector;*/ /** * @brief This class represents the current RPMP configuration. @@ -44,15 +46,21 @@ class Config { "set network buffer number")("network_worker,nw", value()->default_value(1), "set network wroker number")( - "paths,ps", value>(), "set memory pool path")( - "sizes,ss", value>(), "set memory pool size")( + "paths,ps", value>()->multitoken(), + "set memory pool path")("sizes,ss", + value>()->multitoken(), + "set memory pool size")( + "task_set, t", value>()->multitoken(), + "set affinity for each device")( "log,l", value()->default_value("/tmp/rpmp.log"), "set rpmp log file path")("log_level,ll", value()->default_value("warn"), "set log level"); + command_line_parser parser{argc, argv}; + parsed_options parsed_options = parser.options(desc).run(); variables_map vm; - store(parse_command_line(argc, argv, desc), vm); + store(parsed_options, vm); notify(vm); if (vm.count("help")) { @@ -64,18 +72,31 @@ class Config { set_network_buffer_size(vm["network_buffer_size"].as()); set_network_buffer_num(vm["network_buffer_num"].as()); set_network_worker_num(vm["network_worker"].as()); - pool_paths_.push_back("/dev/dax0.0"); - pool_paths_.push_back("/dev/dax0.1"); - pool_paths_.push_back("/dev/dax1.0"); - pool_paths_.push_back("/dev/dax1.1"); - sizes_.push_back(126833655808L); - sizes_.push_back(126833655808L); - sizes_.push_back(126833655808L); - sizes_.push_back(126833655808L); - affinities_.push_back(2); - affinities_.push_back(41); - affinities_.push_back(22); - affinities_.push_back(60); + // pool_paths_.push_back("/dev/dax0.0"); + if (vm.count("sizes")) { + set_pool_sizes(vm["sizes"].as>()); + } + if (vm.count("paths")) { + set_pool_paths(vm["paths"].as>()); + } else { + std::cerr << "No input device!!" << std::endl; + throw; + } + if (pool_paths_.size() != sizes_.size()) { + if (sizes_.size() < pool_paths_.size() && !sizes_.empty()) { + auto first = sizes_[0]; + sizes_.resize(pool_paths_.size(), first); + } else if (sizes_.size() > pool_paths_.size()) { + sizes_.resize(pool_paths_.size()); + } else { + throw 1; + } + } + if (vm.count("task_set")) { + set_affinities_(vm["task_set"].as>()); + } else { + affinities_.resize(pool_paths_.size(), -1); + } set_log_path(vm["log"].as()); set_log_level(vm["log_level"].as()); } catch (const error &ex) { @@ -115,7 +136,18 @@ class Config { int get_pool_size() { return sizes_.size(); } - std::vector get_affinities_() { return affinities_; } + void set_affinities_(vector affinities) { + if (affinities.size() < pool_paths_.size()) { + affinities_.resize(pool_paths_.size(), -1); + } else { + for (int i = 0; i < pool_paths_.size(); i++) { + affinities_.push_back(affinities[i]); + std::cout << pool_paths_[i] << " task_set to " << affinities[i] + << std::endl; + } + } + } + std::vector get_affinities_() { return affinities_; } string get_log_path() { return log_path_; } void set_log_path(string log_path) { log_path_ = log_path; } @@ -131,7 +163,7 @@ class Config { int network_worker_num_; vector pool_paths_; vector sizes_; - vector affinities_; + vector affinities_; string log_path_; string log_level_; }; diff --git a/rpmp/pmpool/DataServer.cc b/rpmp/pmpool/DataServer.cc index 47859c50..638a1acb 100644 --- a/rpmp/pmpool/DataServer.cc +++ b/rpmp/pmpool/DataServer.cc @@ -9,14 +9,15 @@ #include "pmpool/DataServer.h" -#include "AllocatorProxy.h" -#include "Config.h" -#include "Digest.h" -#include "NetworkServer.h" -#include "Protocol.h" -#include "Log.h" +#include "pmpool/AllocatorProxy.h" +#include "pmpool/Config.h" +#include "pmpool/Digest.h" +#include "pmpool/Log.h" +#include "pmpool/NetworkServer.h" +#include "pmpool/Protocol.h" -DataServer::DataServer(Config *config, Log *log) : config_(config), log_(log) {} +DataServer::DataServer(std::shared_ptr config, std::shared_ptr log) + : config_(config), log_(log) {} int DataServer::init() { networkServer_ = std::make_shared(config_, log_); @@ -24,12 +25,12 @@ int DataServer::init() { log_->get_file_log()->info("network server initialized."); allocatorProxy_ = - std::make_shared(config_, log_, networkServer_.get()); + std::make_shared(config_, log_, networkServer_); CHK_ERR("allocator proxy init", allocatorProxy_->init()); log_->get_file_log()->info("allocator proxy initialized."); - protocol_ = std::make_shared(config_, log_, networkServer_.get(), - allocatorProxy_.get()); + protocol_ = std::make_shared(config_, log_, networkServer_, + allocatorProxy_); CHK_ERR("protocol init", protocol_->init()); log_->get_file_log()->info("protocol initialized."); diff --git a/rpmp/pmpool/DataServer.h b/rpmp/pmpool/DataServer.h index 14e70ec1..dde9c8f8 100644 --- a/rpmp/pmpool/DataServer.h +++ b/rpmp/pmpool/DataServer.h @@ -3,15 +3,15 @@ * Path: /mnt/spark-pmof/tool/rpmp/pmpool * Created Date: Thursday, November 7th 2019, 3:48:52 pm * Author: root - * + * * Copyright (c) 2019 Intel */ #ifndef PMPOOL_DATASERVER_H_ #define PMPOOL_DATASERVER_H_ -#include #include +#include #include @@ -25,18 +25,20 @@ class Log; /** * @brief DataServer is designed as distributed remote memory pool. - * DataServer on every node communicated with each other to guarantee data consistency. - * + * DataServer on every node communicated with each other to guarantee data + * consistency. + * */ class DataServer { public: DataServer() = delete; - explicit DataServer(Config* config, Log* log); + explicit DataServer(std::shared_ptr config, std::shared_ptr log); int init(); void wait(); + private: - Config* config_; - Log* log_; + std::shared_ptr config_; + std::shared_ptr log_; std::shared_ptr networkServer_; std::shared_ptr allocatorProxy_; std::shared_ptr protocol_; diff --git a/rpmp/pmpool/Event.cc b/rpmp/pmpool/Event.cc index 2415c58d..a118ead3 100644 --- a/rpmp/pmpool/Event.cc +++ b/rpmp/pmpool/Event.cc @@ -21,6 +21,7 @@ Request::Request(char *data, uint64_t size, Connection *con) : size_(size) { } Request::~Request() { + const std::lock_guard lock(data_lock_); if (data_ != nullptr) { std::free(data_); data_ = nullptr; @@ -30,87 +31,102 @@ Request::~Request() { RequestContext &Request::get_rc() { return requestContext_; } void Request::encode() { + const std::lock_guard lock(data_lock_); OpType rt = requestContext_.type; assert(rt == ALLOC || rt == FREE || rt == WRITE || rt == READ); - requestMsg_.type = requestContext_.type; - requestMsg_.rid = requestContext_.rid; - requestMsg_.address = requestContext_.address; - requestMsg_.src_address = requestContext_.src_address; - requestMsg_.src_rkey = requestContext_.src_rkey; - requestMsg_.size = requestContext_.size; - requestMsg_.key = requestContext_.key; - - size_ = sizeof(requestMsg_); - data_ = static_cast(std::malloc(size_)); - memcpy(data_, &requestMsg_, size_); + size_ = sizeof(RequestMsg); + data_ = static_cast(std::malloc(sizeof(RequestMsg))); + RequestMsg *requestMsg = (RequestMsg *)data_; + requestMsg->type = requestContext_.type; + requestMsg->rid = requestContext_.rid; + requestMsg->address = requestContext_.address; + requestMsg->src_address = requestContext_.src_address; + requestMsg->src_rkey = requestContext_.src_rkey; + requestMsg->size = requestContext_.size; + requestMsg->key = requestContext_.key; } void Request::decode() { - assert(size_ == sizeof(requestMsg_)); - memcpy(&requestMsg_, data_, size_); - requestContext_.type = (OpType)requestMsg_.type; - requestContext_.rid = requestMsg_.rid; - requestContext_.address = requestMsg_.address; - requestContext_.src_address = requestMsg_.src_address; - requestContext_.src_rkey = requestMsg_.src_rkey; - requestContext_.size = requestMsg_.size; - requestContext_.key = requestMsg_.key; + const std::lock_guard lock(data_lock_); + assert(size_ == sizeof(RequestMsg)); + RequestMsg *requestMsg = (RequestMsg *)data_; + requestContext_.type = (OpType)requestMsg->type; + requestContext_.rid = requestMsg->rid; + requestContext_.address = requestMsg->address; + requestContext_.src_address = requestMsg->src_address; + requestContext_.src_rkey = requestMsg->src_rkey; + requestContext_.size = requestMsg->size; + requestContext_.key = requestMsg->key; } -RequestReply::RequestReply(RequestReplyContext requestReplyContext) +RequestReply::RequestReply( + std::shared_ptr requestReplyContext) : data_(nullptr), size_(0), requestReplyContext_(requestReplyContext) {} RequestReply::RequestReply(char *data, uint64_t size, Connection *con) : size_(size) { data_ = static_cast(std::malloc(size_)); memcpy(data_, data, size_); - requestReplyContext_.con = con; + requestReplyContext_ = std::make_shared(); + requestReplyContext_->con = con; } RequestReply::~RequestReply() { + const std::lock_guard lock(data_lock_); if (data_ != nullptr) { std::free(data_); data_ = nullptr; } } -RequestReplyContext &RequestReply::get_rrc() { return requestReplyContext_; } +std::shared_ptr RequestReply::get_rrc() { + return requestReplyContext_; +} void RequestReply::encode() { - requestReplyMsg_.type = (OpType)requestReplyContext_.type; - requestReplyMsg_.success = requestReplyContext_.success; - requestReplyMsg_.rid = requestReplyContext_.rid; - requestReplyMsg_.address = requestReplyContext_.address; - requestReplyMsg_.size = requestReplyContext_.size; - requestReplyMsg_.key = requestReplyContext_.key; - auto msg_size = sizeof(requestReplyMsg_); + const std::lock_guard lock(data_lock_); + RequestReplyMsg requestReplyMsg; + requestReplyMsg.type = (OpType)requestReplyContext_->type; + requestReplyMsg.success = requestReplyContext_->success; + requestReplyMsg.rid = requestReplyContext_->rid; + requestReplyMsg.address = requestReplyContext_->address; + requestReplyMsg.size = requestReplyContext_->size; + requestReplyMsg.key = requestReplyContext_->key; + auto msg_size = sizeof(requestReplyMsg); size_ = msg_size; /// copy data from block metadata list uint32_t bml_size = 0; - if (!requestReplyContext_.bml.empty()) { - bml_size = sizeof(block_meta) * requestReplyContext_.bml.size(); + if (!requestReplyContext_->bml.empty()) { + bml_size = sizeof(block_meta) * requestReplyContext_->bml.size(); size_ += bml_size; } data_ = static_cast(std::malloc(size_)); - memcpy(data_, &requestReplyMsg_, msg_size); + memcpy(data_, &requestReplyMsg, msg_size); if (bml_size != 0) { - memcpy(data_ + msg_size, &requestReplyContext_.bml[0], bml_size); + memcpy(data_ + msg_size, &requestReplyContext_->bml[0], bml_size); } } void RequestReply::decode() { - memcpy(&requestReplyMsg_, data_, size_); - requestReplyContext_.type = (OpType)requestReplyMsg_.type; - requestReplyContext_.success = requestReplyMsg_.success; - requestReplyContext_.rid = requestReplyMsg_.rid; - requestReplyContext_.address = requestReplyMsg_.address; - requestReplyContext_.size = requestReplyMsg_.size; - requestReplyContext_.key = requestReplyMsg_.key; - if (size_ > sizeof(requestReplyMsg_)) { - auto bml_size = size_ - sizeof(requestReplyMsg_); - requestReplyContext_.bml.resize(bml_size / sizeof(block_meta)); - memcpy(&requestReplyContext_.bml[0], data_ + sizeof(requestReplyMsg_), + const std::lock_guard lock(data_lock_); + // memcpy(&requestReplyMsg_, data_, size_); + if (data_ == nullptr) { + std::string err_msg = "Decode with null data"; + std::cerr << err_msg << std::endl; + throw; + } + RequestReplyMsg *requestReplyMsg = (RequestReplyMsg *)data_; + requestReplyContext_->type = (OpType)requestReplyMsg->type; + requestReplyContext_->success = requestReplyMsg->success; + requestReplyContext_->rid = requestReplyMsg->rid; + requestReplyContext_->address = requestReplyMsg->address; + requestReplyContext_->size = requestReplyMsg->size; + requestReplyContext_->key = requestReplyMsg->key; + if (size_ > sizeof(RequestReplyMsg)) { + auto bml_size = size_ - sizeof(RequestReplyMsg); + requestReplyContext_->bml.resize(bml_size / sizeof(block_meta)); + memcpy(&requestReplyContext_->bml[0], data_ + sizeof(RequestReplyMsg), bml_size); } } diff --git a/rpmp/pmpool/Event.h b/rpmp/pmpool/Event.h index 136b73f3..78432ca1 100644 --- a/rpmp/pmpool/Event.h +++ b/rpmp/pmpool/Event.h @@ -68,7 +68,7 @@ struct RequestReplyContext { uint64_t key; Connection* con; Chunk* ck; - vector bml; + vector bml; }; template @@ -88,19 +88,21 @@ inline void decode_(T* t, char* data, uint64_t size) { class RequestReply { public: RequestReply() = delete; - explicit RequestReply(RequestReplyContext requestReplyContext); + explicit RequestReply( + std::shared_ptr requestReplyContext); RequestReply(char* data, uint64_t size, Connection* con); ~RequestReply(); - RequestReplyContext& get_rrc(); + std::shared_ptr get_rrc(); void decode(); void encode(); private: + std::mutex data_lock_; friend Protocol; - char* data_; - uint64_t size_; - RequestReplyMsg requestReplyMsg_; - RequestReplyContext requestReplyContext_; + char* data_ = nullptr; + uint64_t size_ = 0; + // RequestReplyMsg requestReplyMsg_; + std::shared_ptr requestReplyContext_; }; typedef promise Promise; @@ -126,13 +128,17 @@ class Request { RequestContext& get_rc(); void encode(); void decode(); + //#ifdef DEBUG + char* getData() { return data_; } + uint64_t getSize() { return size_; } + //#endif private: + std::mutex data_lock_; friend RequestHandler; friend ClientRecvCallback; char* data_; uint64_t size_; - RequestMsg requestMsg_; RequestContext requestContext_; }; diff --git a/rpmp/pmpool/Log.h b/rpmp/pmpool/Log.h index d3701ce9..55ec9758 100644 --- a/rpmp/pmpool/Log.h +++ b/rpmp/pmpool/Log.h @@ -12,10 +12,10 @@ #include -#include "Config.h" -#include "spdlog/spdlog.h" +#include "pmpool/Config.h" #include "spdlog/sinks/basic_file_sink.h" #include "spdlog/sinks/stdout_color_sinks.h" +#include "spdlog/spdlog.h" class Log { public: diff --git a/rpmp/pmpool/NetworkServer.cc b/rpmp/pmpool/NetworkServer.cc index 2f731a79..62f0a396 100644 --- a/rpmp/pmpool/NetworkServer.cc +++ b/rpmp/pmpool/NetworkServer.cc @@ -9,13 +9,14 @@ #include "pmpool/NetworkServer.h" -#include "Base.h" -#include "Config.h" -#include "Event.h" -#include "Log.h" -#include "buffer/CircularBuffer.h" - -NetworkServer::NetworkServer(Config *config, Log *log) +#include "pmpool/Base.h" +#include "pmpool/Config.h" +#include "pmpool/Event.h" +#include "pmpool/Log.h" +#include "pmpool/buffer/CircularBuffer.h" + +NetworkServer::NetworkServer(std::shared_ptr config, + std::shared_ptr log) : config_(config), log_(log) { time = 0; } @@ -44,8 +45,8 @@ int NetworkServer::start() { CHK_ERR("hpnl server listen", server_->listen(config_->get_ip().c_str(), config_->get_port().c_str())); - circularBuffer_ = - std::make_shared(1024 * 1024, 4096, true, this); + circularBuffer_ = std::make_shared(1024 * 1024, 4096, true, + shared_from_this()); return 0; } @@ -59,7 +60,7 @@ void NetworkServer::unregister_rma_buffer(int buffer_id) { server_->unreg_rma_buffer(buffer_id); } -void NetworkServer::get_dram_buffer(RequestReplyContext *rrc) { +void NetworkServer::get_dram_buffer(std::shared_ptr rrc) { char *buffer = circularBuffer_->get(rrc->size); rrc->dest_address = (uint64_t)buffer; @@ -76,13 +77,15 @@ void NetworkServer::get_dram_buffer(RequestReplyContext *rrc) { rrc->ck = ck; } -void NetworkServer::reclaim_dram_buffer(RequestReplyContext *rrc) { +void NetworkServer::reclaim_dram_buffer( + std::shared_ptr rrc) { char *buffer_tmp = reinterpret_cast(rrc->dest_address); circularBuffer_->put(buffer_tmp, rrc->size); delete rrc->ck; } -void NetworkServer::get_pmem_buffer(RequestReplyContext *rrc, Chunk *base_ck) { +void NetworkServer::get_pmem_buffer(std::shared_ptr rrc, + Chunk *base_ck) { Chunk *ck = new Chunk(); ck->buffer = reinterpret_cast(rrc->dest_address); ck->capacity = rrc->size; @@ -92,13 +95,18 @@ void NetworkServer::get_pmem_buffer(RequestReplyContext *rrc, Chunk *base_ck) { rrc->ck = ck; } -void NetworkServer::reclaim_pmem_buffer(RequestReplyContext *rrc) { +void NetworkServer::reclaim_pmem_buffer( + std::shared_ptr rrc) { if (rrc->ck != nullptr) { delete rrc->ck; } } -ChunkMgr *NetworkServer::get_chunk_mgr() { return chunkMgr_.get(); } +uint64_t NetworkServer::get_rkey() { + return circularBuffer_->get_rma_chunk()->mr->key; +} + +std::shared_ptr NetworkServer::get_chunk_mgr() { return chunkMgr_; } void NetworkServer::set_recv_callback(Callback *callback) { server_->set_recv_callback(callback); @@ -123,12 +131,20 @@ void NetworkServer::send(char *data, uint64_t size, Connection *con) { con->send(ck); } -void NetworkServer::read(RequestReply *rr) { - RequestReplyContext rrc = rr->get_rrc(); - rrc.con->read(rrc.ck, 0, rrc.size, rrc.src_address, rrc.src_rkey); +void NetworkServer::read(std::shared_ptr rr) { + auto rrc = rr->get_rrc(); +#ifdef DEBUG + printf("[NetworkServer::read] dest is %ld-%d, src is %ld-%d\n", + rrc->ck->buffer, rrc->ck->size, rrc->src_address, rrc->size); +#endif + rrc->con->read(rrc->ck, 0, rrc->size, rrc->src_address, rrc->src_rkey); } -void NetworkServer::write(RequestReply *rr) { - RequestReplyContext rrc = rr->get_rrc(); - rrc.con->write(rrc.ck, 0, rrc.size, rrc.src_address, rrc.src_rkey); +void NetworkServer::write(std::shared_ptr rr) { + auto rrc = rr->get_rrc(); +#ifdef DEBUG + printf("[NetworkServer::write] src is %ld-%d, dest is %ld-%d\n", + rrc->ck->buffer, rrc->ck->size, rrc->src_address, rrc->size); +#endif + rrc->con->write(rrc->ck, 0, rrc->size, rrc->src_address, rrc->src_rkey); } diff --git a/rpmp/pmpool/NetworkServer.h b/rpmp/pmpool/NetworkServer.h index d125a83c..4bcdd7f0 100644 --- a/rpmp/pmpool/NetworkServer.h +++ b/rpmp/pmpool/NetworkServer.h @@ -17,7 +17,7 @@ #include #include -#include "RmaBufferRegister.h" +#include "pmpool/RmaBufferRegister.h" class CircularBuffer; class Config; @@ -30,10 +30,11 @@ class Log; * asynchronous network library. RPMP currently supports RDMA iWarp and RoCE V2 * protocol. */ -class NetworkServer : public RmaBufferRegister { +class NetworkServer : public RmaBufferRegister, + public std::enable_shared_from_this { public: NetworkServer() = delete; - NetworkServer(Config *config, Log *log_); + NetworkServer(std::shared_ptr config, std::shared_ptr log_); ~NetworkServer(); int init(); int start(); @@ -47,19 +48,22 @@ class NetworkServer : public RmaBufferRegister { void unregister_rma_buffer(int buffer_id) override; /// get DRAM buffer from circular buffer pool. - void get_dram_buffer(RequestReplyContext *rrc); + void get_dram_buffer(std::shared_ptr rrc); /// reclaim DRAM buffer from circular buffer pool. - void reclaim_dram_buffer(RequestReplyContext *rrc); + void reclaim_dram_buffer(std::shared_ptr rrc); + + /// get rdma registered memory key for client. + uint64_t get_rkey(); /// get Persistent Memory buffer from circular buffer pool - void get_pmem_buffer(RequestReplyContext *rrc, Chunk *ck); + void get_pmem_buffer(std::shared_ptr rrc, Chunk *ck); /// reclaim Persistent Memory buffer form circular buffer pool - void reclaim_pmem_buffer(RequestReplyContext *rrc); + void reclaim_pmem_buffer(std::shared_ptr rrc); /// return the pointer of chunk manager. - ChunkMgr *get_chunk_mgr(); + std::shared_ptr get_chunk_mgr(); /// since the network implementation is asynchronous, /// we need to define callback better before starting network service. @@ -69,12 +73,12 @@ class NetworkServer : public RmaBufferRegister { void set_write_callback(Callback *callback); void send(char *data, uint64_t size, Connection *con); - void read(RequestReply *rrc); - void write(RequestReply *rrc); + void read(std::shared_ptr rrc); + void write(std::shared_ptr rrc); private: - Config *config_; - Log* log_; + std::shared_ptr config_; + std::shared_ptr log_; std::shared_ptr server_; std::shared_ptr chunkMgr_; std::shared_ptr circularBuffer_; diff --git a/rpmp/pmpool/PmemAllocator.h b/rpmp/pmpool/PmemAllocator.h index b6d279ce..ccfecd68 100644 --- a/rpmp/pmpool/PmemAllocator.h +++ b/rpmp/pmpool/PmemAllocator.h @@ -20,10 +20,10 @@ #include #include -#include "Allocator.h" -#include "DataServer.h" -#include "Log.h" -#include "NetworkServer.h" +#include "pmpool/Allocator.h" +#include "pmpool/DataServer.h" +#include "pmpool/Log.h" +#include "pmpool/NetworkServer.h" using std::shared_ptr; using std::unordered_map; @@ -68,8 +68,9 @@ enum types { BLOCK_ENTRY_TYPE, DATA_TYPE, MAX_TYPE }; class PmemObjAllocator : public Allocator { public: PmemObjAllocator() = delete; - explicit PmemObjAllocator(Log *log, DiskInfo *diskInfos, - NetworkServer *server, int wid) + explicit PmemObjAllocator(std::shared_ptr log, + std::shared_ptr diskInfos, + std::shared_ptr server, int wid) : log_(log), diskInfo_(diskInfos), server_(server), wid_(wid) {} ~PmemObjAllocator() { close(); } @@ -305,6 +306,7 @@ class PmemObjAllocator : public Allocator { err_msg); return -1; } + pmemContext_.poid = pmemobj_root(pmemContext_.pop, sizeof(struct Base)); pmemContext_.base = (struct Base *)pmemobj_direct(pmemContext_.poid); pmemContext_.base->head = OID_NULL; @@ -315,9 +317,11 @@ class PmemObjAllocator : public Allocator { base_ck = server_->register_rma_buffer( reinterpret_cast(pmemContext_.pop), diskInfo_->size); assert(base_ck != nullptr); + auto addr = reinterpret_cast(pmemContext_.pop); log_->get_console_log()->info( "successfully registered Persistent Memory(" + diskInfo_->path + - ") as RDMA region"); + ") as RDMA region, size is {0}", + diskInfo_->size); } return 0; } @@ -374,17 +378,19 @@ class PmemObjAllocator : public Allocator { int free_meta() { std::lock_guard l(mtx); index_map.clear(); + return 0; } private: - Log *log_; - DiskInfo *diskInfo_; - NetworkServer *server_; + std::shared_ptr log_; + std::shared_ptr diskInfo_; + std::shared_ptr server_; int wid_; PmemContext pmemContext_; std::mutex mtx; unordered_map index_map; uint64_t total = 0; + uint64_t disk_size = 0; char str[1048576]; Chunk *base_ck; }; diff --git a/rpmp/pmpool/Protocol.cc b/rpmp/pmpool/Protocol.cc index cfa88f81..da59880f 100644 --- a/rpmp/pmpool/Protocol.cc +++ b/rpmp/pmpool/Protocol.cc @@ -11,65 +11,83 @@ #include -#include "AllocatorProxy.h" -#include "Config.h" -#include "Digest.h" -#include "Event.h" -#include "Log.h" -#include "NetworkServer.h" - -RecvCallback::RecvCallback(Protocol *protocol, ChunkMgr *chunkMgr) +#include "pmpool/AllocatorProxy.h" +#include "pmpool/Config.h" +#include "pmpool/Digest.h" +#include "pmpool/Event.h" +#include "pmpool/Log.h" +#include "pmpool/NetworkServer.h" + +RecvCallback::RecvCallback(std::shared_ptr protocol, + std::shared_ptr chunkMgr) : protocol_(protocol), chunkMgr_(chunkMgr) {} void RecvCallback::operator()(void *buffer_id, void *buffer_size) { auto buffer_id_ = *static_cast(buffer_id); Chunk *ck = chunkMgr_->get(buffer_id_); assert(*static_cast(buffer_size) == ck->size); - Request *request = new Request(reinterpret_cast(ck->buffer), ck->size, - reinterpret_cast(ck->con)); + auto request = + std::make_shared(reinterpret_cast(ck->buffer), ck->size, + reinterpret_cast(ck->con)); request->decode(); - protocol_->enqueue_recv_msg(request); + RequestMsg *requestMsg = (RequestMsg *)(request->getData()); + if (requestMsg->type != 0) { + protocol_->enqueue_recv_msg(request); + } else { + std::cout << "[RecvCallback::RecvCallback][" << requestMsg->type + << "] size is " << ck->size << std::endl; + for (int i = 0; i < ck->size; i++) { + printf("%X ", *(request->getData() + i)); + } + printf("\n"); + } chunkMgr_->reclaim(ck, static_cast(ck->con)); } -ReadCallback::ReadCallback(Protocol *protocol) : protocol_(protocol) {} +ReadCallback::ReadCallback(std::shared_ptr protocol) + : protocol_(protocol) {} void ReadCallback::operator()(void *buffer_id, void *buffer_size) { auto buffer_id_ = *static_cast(buffer_id); protocol_->enqueue_rma_msg(buffer_id_); } -SendCallback::SendCallback(ChunkMgr *chunkMgr) : chunkMgr_(chunkMgr) {} +SendCallback::SendCallback( + std::shared_ptr chunkMgr, + std::unordered_map> rrcMap) + : chunkMgr_(chunkMgr), rrcMap_(rrcMap) {} void SendCallback::operator()(void *buffer_id, void *buffer_size) { auto buffer_id_ = *static_cast(buffer_id); auto ck = chunkMgr_->get(buffer_id_); /// free the memory of class RequestReply - auto reqeustReply = static_cast(ck->ptr); - delete reqeustReply; + rrcMap_.erase(buffer_id_); chunkMgr_->reclaim(ck, static_cast(ck->con)); } -WriteCallback::WriteCallback(Protocol *protocol) : protocol_(protocol) {} +WriteCallback::WriteCallback(std::shared_ptr protocol) + : protocol_(protocol) {} void WriteCallback::operator()(void *buffer_id, void *buffer_size) { auto buffer_id_ = *static_cast(buffer_id); protocol_->enqueue_rma_msg(buffer_id_); } -RecvWorker::RecvWorker(Protocol *protocol, int index) +RecvWorker::RecvWorker(std::shared_ptr protocol, int index) : protocol_(protocol), index_(index) { init = false; } int RecvWorker::entry() { if (!init) { - set_affinity(index_); + if (index_ != -1) { + set_affinity(index_); + } init = true; } - Request *request; + std::shared_ptr request; bool res = pendingRecvRequestQueue_.wait_dequeue_timed( request, std::chrono::milliseconds(1000)); if (res) { @@ -80,21 +98,23 @@ int RecvWorker::entry() { void RecvWorker::abort() {} -void RecvWorker::addTask(Request *request) { +void RecvWorker::addTask(std::shared_ptr request) { pendingRecvRequestQueue_.enqueue(request); } -ReadWorker::ReadWorker(Protocol *protocol, int index) +ReadWorker::ReadWorker(std::shared_ptr protocol, int index) : protocol_(protocol), index_(index) { init = false; } int ReadWorker::entry() { if (!init) { - set_affinity(index_); + if (index_ != -1) { + set_affinity(index_); + } init = true; } - RequestReply *requestReply; + std::shared_ptr requestReply; bool res = pendingReadRequestQueue_.wait_dequeue_timed( requestReply, std::chrono::milliseconds(1000)); if (res) { @@ -105,16 +125,18 @@ int ReadWorker::entry() { void ReadWorker::abort() {} -void ReadWorker::addTask(RequestReply *rr) { +void ReadWorker::addTask(std::shared_ptr rr) { pendingReadRequestQueue_.enqueue(rr); } -FinalizeWorker::FinalizeWorker(Protocol *protocol) : protocol_(protocol) {} +FinalizeWorker::FinalizeWorker(std::shared_ptr protocol) + : protocol_(protocol) {} int FinalizeWorker::entry() { - RequestReply *requestReply; + std::shared_ptr requestReply; bool res = pendingRequestReplyQueue_.wait_dequeue_timed( requestReply, std::chrono::milliseconds(1000)); + assert(res); if (res) { protocol_->handle_finalize_msg(requestReply); } @@ -123,17 +145,22 @@ int FinalizeWorker::entry() { void FinalizeWorker::abort() {} -void FinalizeWorker::addTask(RequestReply *requestReply) { +void FinalizeWorker::addTask(std::shared_ptr requestReply) { pendingRequestReplyQueue_.enqueue(requestReply); } -Protocol::Protocol(Config *config, Log *log, NetworkServer *server, - AllocatorProxy *allocatorProxy) +Protocol::Protocol(std::shared_ptr config, std::shared_ptr log, + std::shared_ptr server, + std::shared_ptr allocatorProxy) : config_(config), log_(log), networkServer_(server), allocatorProxy_(allocatorProxy) { time = 0; + char *buffer = new char[2147483648](); + memset(buffer, 0, 2147483648); + address = allocatorProxy_->allocate_and_write(2147483648, buffer, 0); + delete[] buffer; } Protocol::~Protocol() { @@ -150,26 +177,28 @@ Protocol::~Protocol() { } int Protocol::init() { - recvCallback_ = - std::make_shared(this, networkServer_->get_chunk_mgr()); + recvCallback_ = std::make_shared( + shared_from_this(), networkServer_->get_chunk_mgr()); sendCallback_ = - std::make_shared(networkServer_->get_chunk_mgr()); - readCallback_ = std::make_shared(this); - writeCallback_ = std::make_shared(this); + std::make_shared(networkServer_->get_chunk_mgr(), rrcMap_); + readCallback_ = std::make_shared(shared_from_this()); + writeCallback_ = std::make_shared(shared_from_this()); for (int i = 0; i < config_->get_pool_size(); i++) { - auto recvWorker = new RecvWorker(this, config_->get_affinities_()[i] - 1); + auto recvWorker = std::make_shared( + shared_from_this(), config_->get_affinities_()[i] - 1); recvWorker->start(); - recvWorkers_.push_back(std::shared_ptr(recvWorker)); + recvWorkers_.push_back(std::move(recvWorker)); } - finalizeWorker_ = make_shared(this); + finalizeWorker_ = make_shared(shared_from_this()); finalizeWorker_->start(); for (int i = 0; i < config_->get_pool_size(); i++) { - auto readWorker = new ReadWorker(this, config_->get_affinities_()[i]); + auto readWorker = std::make_shared( + shared_from_this(), config_->get_affinities_()[i]); readWorker->start(); - readWorkers_.push_back(std::shared_ptr(readWorker)); + readWorkers_.push_back(std::move(readWorker)); } networkServer_->set_recv_callback(recvCallback_.get()); @@ -179,7 +208,7 @@ int Protocol::init() { return 0; } -void Protocol::enqueue_recv_msg(Request *request) { +void Protocol::enqueue_recv_msg(std::shared_ptr request) { RequestContext rc = request->get_rc(); if (rc.address != 0) { auto wid = GET_WID(rc.address); @@ -189,187 +218,273 @@ void Protocol::enqueue_recv_msg(Request *request) { } } -void Protocol::handle_recv_msg(Request *request) { +void Protocol::handle_recv_msg(std::shared_ptr request) { + num_requests_++; + if (num_requests_ % 10000 == 0) { + log_->get_file_log()->info( + "Protocol::handle_recv_msg handled requests number is {0}.", + num_requests_); + log_->get_console_log()->info( + "Protocol::handle_recv_msg handled requests number is {0}.", + num_requests_); + } RequestContext rc = request->get_rc(); - RequestReplyContext rrc; + auto rrc = std::make_shared(); switch (rc.type) { case ALLOC: { + log_->get_console_log()->info( + "Protocol::handle_recv_msg Allocate request."); uint64_t addr = allocatorProxy_->allocate_and_write( rc.size, nullptr, rc.rid % config_->get_pool_size()); auto wid = GET_WID(addr); assert(wid == rc.rid % config_->get_pool_size()); - rrc.type = ALLOC_REPLY; - rrc.success = 0; - rrc.rid = rc.rid; - rrc.address = addr; - rrc.size = rc.size; - rrc.con = rc.con; - RequestReply *requestReply = new RequestReply(rrc); - rrc.ck->ptr = requestReply; + rrc->type = ALLOC_REPLY; + rrc->success = 0; + rrc->rid = rc.rid; + rrc->address = addr; + rrc->size = rc.size; + rrc->con = rc.con; + networkServer_->get_dram_buffer(rrc); + std::shared_ptr requestReply = + std::make_shared(rrc); + rrc->ck->ptr = requestReply.get(); enqueue_finalize_msg(requestReply); break; } case FREE: { - rrc.type = FREE_REPLY; - rrc.success = allocatorProxy_->release(rc.address); - rrc.rid = rc.rid; - rrc.address = rc.address; - rrc.size = rc.size; - rrc.con = rc.con; - RequestReply *requestReply = new RequestReply(rrc); - rrc.ck->ptr = requestReply; + log_->get_console_log()->info("Protocol::handle_recv_msg Free request."); + rrc->type = FREE_REPLY; + rrc->success = allocatorProxy_->release(rc.address); + rrc->rid = rc.rid; + rrc->address = rc.address; + rrc->size = rc.size; + rrc->con = rc.con; + std::shared_ptr requestReply = + std::make_shared(rrc); + rrc->ck->ptr = requestReply.get(); enqueue_finalize_msg(requestReply); break; } case WRITE: { - rrc.type = WRITE_REPLY; - rrc.success = 0; - rrc.rid = rc.rid; - rrc.address = rc.address; - rrc.src_address = rc.src_address; - rrc.src_rkey = rc.src_rkey; - rrc.size = rc.size; - rrc.con = rc.con; - networkServer_->get_dram_buffer(&rrc); - RequestReply *requestReply = new RequestReply(rrc); - rrc.ck->ptr = requestReply; + log_->get_console_log()->info("Protocol::handle_recv_msg Write request."); + rrc->type = WRITE_REPLY; + rrc->success = 0; + rrc->rid = rc.rid; + rrc->address = rc.address; + rrc->src_address = rc.src_address; + rrc->src_rkey = rc.src_rkey; + rrc->size = rc.size; + rrc->con = rc.con; + networkServer_->get_dram_buffer(rrc); + std::shared_ptr requestReply = + std::make_shared(rrc); + rrc->ck->ptr = requestReply.get(); std::unique_lock lk(rrcMtx_); - rrcMap_[rrc.ck->buffer_id] = requestReply; + rrcMap_[rrc->ck->buffer_id] = requestReply; lk.unlock(); networkServer_->read(requestReply); break; } case READ: { - rrc.type = READ_REPLY; - rrc.success = 0; - rrc.rid = rc.rid; - rrc.address = rc.address; - rrc.src_address = rc.src_address; - rrc.src_rkey = rc.src_rkey; - rrc.size = rc.size; - rrc.con = rc.con; - rrc.dest_address = allocatorProxy_->get_virtual_address(rrc.address); - rrc.ck = nullptr; - Chunk *base_ck = allocatorProxy_->get_rma_chunk(rrc.address); - networkServer_->get_pmem_buffer(&rrc, base_ck); - RequestReply *requestReply = new RequestReply(rrc); - rrc.ck->ptr = requestReply; +#ifdef DEBUG + std::cout << "[Protocol::handle_recv_msg][READ], info is " << rc.address + << "-" << rc.size << std::endl; +#endif + rrc->type = READ_REPLY; + rrc->success = 0; + rrc->rid = rc.rid; + rrc->address = rc.address; + rrc->src_address = rc.src_address; + rrc->src_rkey = rc.src_rkey; + rrc->size = rc.size; + rrc->con = rc.con; + rrc->dest_address = allocatorProxy_->get_virtual_address(rrc->address); + rrc->ck = nullptr; + Chunk *base_ck = allocatorProxy_->get_rma_chunk(rrc->address); + networkServer_->get_pmem_buffer(rrc, base_ck); + std::shared_ptr requestReply = + std::make_shared(rrc); + rrc->ck->ptr = requestReply.get(); std::unique_lock lk(rrcMtx_); - rrcMap_[rrc.ck->buffer_id] = requestReply; + rrcMap_[rrc->ck->buffer_id] = requestReply; lk.unlock(); networkServer_->write(requestReply); break; } case PUT: { - rrc.type = PUT_REPLY; - rrc.success = 0; - rrc.rid = rc.rid; - rrc.address = rc.address; - rrc.src_address = rc.src_address; - rrc.src_rkey = rc.src_rkey; - rrc.size = rc.size; - rrc.key = rc.key; - rrc.con = rc.con; - networkServer_->get_dram_buffer(&rrc); - RequestReply *requestReply = new RequestReply(rrc); - rrc.ck->ptr = requestReply; + // log_->get_console_log()->info("Protocol::handle_recv_msg put + // request."); + rrc->type = PUT_REPLY; + rrc->success = 0; + rrc->rid = rc.rid; + rrc->address = rc.address; + rrc->src_address = rc.src_address; + rrc->src_rkey = rc.src_rkey; + rrc->size = rc.size; + rrc->key = rc.key; + rrc->con = rc.con; + networkServer_->get_dram_buffer(rrc); + std::shared_ptr requestReply = + std::make_shared(rrc); + rrc->ck->ptr = requestReply.get(); // rrc->ck set by get_dram_buffer std::unique_lock lk(rrcMtx_); - rrcMap_[rrc.ck->buffer_id] = requestReply; + rrcMap_[rrc->ck->buffer_id] = requestReply; lk.unlock(); networkServer_->read(requestReply); break; } + case GET: { + rrc->type = GET_REPLY; + rrc->success = 0; + rrc->rid = rc.rid; + rrc->src_address = rc.src_address; + rrc->src_rkey = rc.src_rkey; + rrc->size = rc.size; + rrc->key = rc.key; + rrc->con = rc.con; + rrc->ck = nullptr; + auto bml = allocatorProxy_->get_cached_chunk(rrc->key); + uint64_t wrote_size = 0; + if (bml.size() == 1) { + rrc->address = bml[0].address; + rrc->dest_address = allocatorProxy_->get_virtual_address(rrc->address); + Chunk *base_ck = allocatorProxy_->get_rma_chunk(rrc->address); + networkServer_->get_pmem_buffer(rrc, base_ck); + + } else { + throw; + networkServer_->get_dram_buffer(rrc); + rrc->address = rrc->dest_address; + for (auto bm : bml) { + if ((wrote_size + bm.size) <= rrc->size) { + auto partition_data = reinterpret_cast( + allocatorProxy_->get_virtual_address(bm.address)); + auto dest_address = reinterpret_cast(rrc->dest_address); + if ((wrote_size + bm.size) < rrc->size) { + printf("[GET]key is %ld, rrc->size is %ld, bm.size is %ld\n", + rrc->key, rrc->size, bm.size); + } + memcpy((dest_address + wrote_size), partition_data, bm.size); + wrote_size += bm.size; + } + } + rrc->size = wrote_size; + } + std::shared_ptr requestReply = + std::make_shared(rrc); + rrc->ck->ptr = requestReply.get(); + std::unique_lock lk(rrcMtx_); + rrcMap_[rrc->ck->buffer_id] = requestReply; + lk.unlock(); + networkServer_->write(requestReply); + break; + } case GET_META: { - rrc.type = GET_META_REPLY; - rrc.success = 0; - rrc.rid = rc.rid; - rrc.size = rc.size; - rrc.key = rc.key; - rrc.con = rc.con; - RequestReply *requestReply = new RequestReply(rrc); - rrc.ck->ptr = requestReply; + // log_->get_console_log()->info( + // "Protocol::handle_recv_msg getMeta request."); + rrc->type = GET_META_REPLY; + rrc->success = 0; + rrc->rid = rc.rid; + rrc->size = rc.size; + rrc->key = rc.key; + rrc->con = rc.con; + std::shared_ptr requestReply = + std::make_shared(rrc); enqueue_finalize_msg(requestReply); + break; } case DELETE: { - rrc.type = DELETE_REPLY; - rrc.key = rc.key; - rrc.con = rc.con; - rrc.rid = rc.rid; - rrc.success = 0; + log_->get_console_log()->info( + "Protocol::handle_recv_msg Delete request."); + rrc->type = DELETE_REPLY; + rrc->key = rc.key; + rrc->con = rc.con; + rrc->rid = rc.rid; + rrc->success = 0; } default: { break; } } - - delete request; } -void Protocol::enqueue_finalize_msg(RequestReply *requestReply) { +void Protocol::enqueue_finalize_msg( + std::shared_ptr requestReply) { finalizeWorker_->addTask(requestReply); } -void Protocol::handle_finalize_msg(RequestReply *requestReply) { - RequestReplyContext rrc = requestReply->get_rrc(); - if (rrc.type == PUT_REPLY) { - allocatorProxy_->cache_chunk(rrc.key, rrc.address, rrc.size); - } else if (rrc.type == GET_META_REPLY) { - auto bml = allocatorProxy_->get_cached_chunk(rrc.key); - requestReply->requestReplyContext_.bml = bml; - } else if (rrc.type == DELETE_REPLY) { - auto bml = allocatorProxy_->get_cached_chunk(rrc.key); +void Protocol::handle_finalize_msg(std::shared_ptr requestReply) { + std::shared_ptr rrc = requestReply->get_rrc(); + if (rrc->type == PUT_REPLY) { + allocatorProxy_->cache_chunk(rrc->key, rrc->address, rrc->size, + networkServer_->get_rkey()); + } else if (rrc->type == GET_META_REPLY) { + auto bml = allocatorProxy_->get_cached_chunk(rrc->key); + requestReply->requestReplyContext_->bml = bml; + } else if (rrc->type == DELETE_REPLY) { + auto bml = allocatorProxy_->get_cached_chunk(rrc->key); for (auto bm : bml) { - rrc.success = allocatorProxy_->release(bm.address); - if (rrc.success) { + rrc->success = allocatorProxy_->release(bm.address); + if (rrc->success) { break; } } - allocatorProxy_->del_chunk(rrc.key); + allocatorProxy_->del_chunk(rrc->key); + } else if (rrc->type == GET_REPLY) { } else { } requestReply->encode(); networkServer_->send(reinterpret_cast(requestReply->data_), - requestReply->size_, rrc.con); + requestReply->size_, rrc->con); } void Protocol::enqueue_rma_msg(uint64_t buffer_id) { std::unique_lock lk(rrcMtx_); - RequestReply *requestReply = rrcMap_[buffer_id]; + auto requestReply = rrcMap_[buffer_id]; lk.unlock(); - RequestReplyContext rrc = requestReply->get_rrc(); - if (rrc.address != 0) { - auto wid = GET_WID(rrc.address); + std::shared_ptr rrc = requestReply->get_rrc(); + if (rrc->address != 0) { + auto wid = GET_WID(rrc->address); readWorkers_[wid]->addTask(requestReply); } else { - readWorkers_[rrc.rid % config_->get_pool_size()]->addTask(requestReply); + readWorkers_[rrc->rid % config_->get_pool_size()]->addTask(requestReply); } } -void Protocol::handle_rma_msg(RequestReply *requestReply) { - RequestReplyContext &rrc = requestReply->get_rrc(); - switch (rrc.type) { +void Protocol::handle_rma_msg(std::shared_ptr requestReply) { + std::shared_ptr rrc = requestReply->get_rrc(); + switch (rrc->type) { case WRITE_REPLY: { - char *buffer = static_cast(rrc.ck->buffer); - if (rrc.address == 0) { - rrc.address = allocatorProxy_->allocate_and_write( - rrc.size, buffer, rrc.rid % config_->get_pool_size()); + char *buffer = static_cast(rrc->ck->buffer); + if (rrc->address == 0) { + rrc->address = allocatorProxy_->allocate_and_write( + rrc->size, buffer, rrc->rid % config_->get_pool_size()); } else { - allocatorProxy_->write(rrc.address, buffer, rrc.size); + allocatorProxy_->write(rrc->address, buffer, rrc->size); } - networkServer_->reclaim_dram_buffer(&rrc); + networkServer_->reclaim_dram_buffer(rrc); break; } case READ_REPLY: { - networkServer_->reclaim_pmem_buffer(&rrc); + networkServer_->reclaim_pmem_buffer(rrc); break; } case PUT_REPLY: { - char *buffer = static_cast(rrc.ck->buffer); - assert(rrc.address == 0); - rrc.address = allocatorProxy_->allocate_and_write( - rrc.size, buffer, rrc.rid % config_->get_pool_size()); - networkServer_->reclaim_dram_buffer(&rrc); + char *buffer = static_cast(rrc->ck->buffer); + assert(rrc->address == 0); +#ifdef DEBUG + std::cout << "[handle_rma_msg]" << rrc->src_rkey << "-" + << rrc->src_address << ":" << rrc->size << std::endl; +#endif + rrc->address = allocatorProxy_->allocate_and_write( + rrc->size, buffer, rrc->rid % config_->get_pool_size()); + networkServer_->reclaim_dram_buffer(rrc); + break; + } + case GET_REPLY: { + // networkServer_->reclaim_dram_buffer(rrc); + networkServer_->reclaim_pmem_buffer(rrc); break; } default: { break; } diff --git a/rpmp/pmpool/Protocol.h b/rpmp/pmpool/Protocol.h index b2cfb679..9284ca80 100644 --- a/rpmp/pmpool/Protocol.h +++ b/rpmp/pmpool/Protocol.h @@ -22,10 +22,10 @@ #include #include -#include "Event.h" -#include "ThreadWrapper.h" -#include "queue/blockingconcurrentqueue.h" -#include "queue/concurrentqueue.h" +#include "pmpool/Event.h" +#include "pmpool/ThreadWrapper.h" +#include "pmpool/queue/blockingconcurrentqueue.h" +#include "pmpool/queue/concurrentqueue.h" class Digest; class AllocatorProxy; @@ -50,92 +50,98 @@ struct MessageHeader { class RecvCallback : public Callback { public: RecvCallback() = delete; - RecvCallback(Protocol *protocol, ChunkMgr *chunkMgr); + RecvCallback(std::shared_ptr protocol, + std::shared_ptr chunkMgr); ~RecvCallback() override = default; void operator()(void *buffer_id, void *buffer_size) override; private: - Protocol *protocol_; - ChunkMgr *chunkMgr_; + std::shared_ptr protocol_; + std::shared_ptr chunkMgr_; }; class SendCallback : public Callback { public: SendCallback() = delete; - explicit SendCallback(ChunkMgr *chunkMgr); + explicit SendCallback( + std::shared_ptr chunkMgr, + std::unordered_map> rrcMap); ~SendCallback() override = default; void operator()(void *buffer_id, void *buffer_size) override; private: - ChunkMgr *chunkMgr_; + std::shared_ptr chunkMgr_; + std::unordered_map> rrcMap_; }; class ReadCallback : public Callback { public: ReadCallback() = delete; - explicit ReadCallback(Protocol *protocol); + explicit ReadCallback(std::shared_ptr protocol); ~ReadCallback() override = default; void operator()(void *buffer_id, void *buffer_size) override; private: - Protocol *protocol_; + std::shared_ptr protocol_; }; class WriteCallback : public Callback { public: WriteCallback() = delete; - explicit WriteCallback(Protocol *protocol); + explicit WriteCallback(std::shared_ptr protocol); ~WriteCallback() override = default; void operator()(void *buffer_id, void *buffer_size) override; private: - Protocol *protocol_; + std::shared_ptr protocol_; }; class RecvWorker : public ThreadWrapper { public: RecvWorker() = delete; - RecvWorker(Protocol *protocol, int index); + RecvWorker(std::shared_ptr protocol, int index); ~RecvWorker() override = default; int entry() override; void abort() override; - void addTask(Request *request); + void addTask(std::shared_ptr request); private: - Protocol *protocol_; + std::shared_ptr protocol_; int index_; bool init; - BlockingConcurrentQueue pendingRecvRequestQueue_; + BlockingConcurrentQueue> pendingRecvRequestQueue_; }; class ReadWorker : public ThreadWrapper { public: ReadWorker() = delete; - ReadWorker(Protocol *protocol, int index); + ReadWorker(std::shared_ptr protocol, int index); ~ReadWorker() override = default; int entry() override; void abort() override; - void addTask(RequestReply *requestReply); + void addTask(std::shared_ptr requestReply); private: - Protocol *protocol_; + std::shared_ptr protocol_; int index_; bool init; - BlockingConcurrentQueue pendingReadRequestQueue_; + BlockingConcurrentQueue> + pendingReadRequestQueue_; }; class FinalizeWorker : public ThreadWrapper { public: FinalizeWorker() = delete; - explicit FinalizeWorker(Protocol *protocol); + explicit FinalizeWorker(std::shared_ptr protocol); ~FinalizeWorker() override = default; int entry() override; void abort() override; - void addTask(RequestReply *requestReply); + void addTask(std::shared_ptr requestReply); private: - Protocol *protocol_; - BlockingConcurrentQueue pendingRequestReplyQueue_; + std::shared_ptr protocol_; + BlockingConcurrentQueue> + pendingRequestReplyQueue_; }; /** @@ -146,33 +152,35 @@ class FinalizeWorker : public ThreadWrapper { * finalize queue-> to handle finalization event. * rma queue-> to handle remote memory access event. */ -class Protocol { +class Protocol : public std::enable_shared_from_this { public: Protocol() = delete; - Protocol(Config *config, Log *log, NetworkServer *server, - AllocatorProxy *allocatorProxy); + Protocol(std::shared_ptr config, std::shared_ptr log, + std::shared_ptr server, + std::shared_ptr allocatorProxy); ~Protocol(); int init(); friend class RecvCallback; friend class RecvWorker; - void enqueue_recv_msg(Request *request); - void handle_recv_msg(Request *request); + void enqueue_recv_msg(std::shared_ptr request); + void handle_recv_msg(std::shared_ptr request); - void enqueue_finalize_msg(RequestReply *requestReply); - void handle_finalize_msg(RequestReply *requestReply); + void enqueue_finalize_msg(std::shared_ptr requestReply); + void handle_finalize_msg(std::shared_ptr requestReply); void enqueue_rma_msg(uint64_t buffer_id); - void handle_rma_msg(RequestReply *requestReply); + void handle_rma_msg(std::shared_ptr requestReply); public: - Config *config_; - Log *log_; + std::shared_ptr config_; + std::shared_ptr log_; + uint64_t num_requests_ = 0; private: - NetworkServer *networkServer_; - AllocatorProxy *allocatorProxy_; + std::shared_ptr networkServer_; + std::shared_ptr allocatorProxy_; std::shared_ptr recvCallback_; std::shared_ptr sendCallback_; @@ -187,8 +195,9 @@ class Protocol { std::vector> readWorkers_; std::mutex rrcMtx_; - std::unordered_map rrcMap_; + std::unordered_map> rrcMap_; uint64_t time; + long address = 0; }; #endif // PMPOOL_PROTOCOL_H_ diff --git a/rpmp/pmpool/ThreadWrapper.h b/rpmp/pmpool/ThreadWrapper.h index 769cabea..476f382b 100644 --- a/rpmp/pmpool/ThreadWrapper.h +++ b/rpmp/pmpool/ThreadWrapper.h @@ -13,10 +13,10 @@ #include #include -#include // NOLINT +#include // NOLINT #include -#include // NOLINT -#include // NOLINT +#include // NOLINT +#include // NOLINT class ThreadWrapper { public: @@ -41,7 +41,9 @@ class ThreadWrapper { #ifdef __linux__ cpu_set_t cpuset; CPU_ZERO(&cpuset); - CPU_SET(cpu, &cpuset); + for (int i = 0; i < 10; i++) { + CPU_SET(cpu + i, &cpuset); + } int res = pthread_setaffinity_np(thread.native_handle(), sizeof(cpu_set_t), &cpuset); if (res) { diff --git a/rpmp/pmpool/buffer/CircularBuffer.h b/rpmp/pmpool/buffer/CircularBuffer.h index a41ce3ab..1f3991b5 100644 --- a/rpmp/pmpool/buffer/CircularBuffer.h +++ b/rpmp/pmpool/buffer/CircularBuffer.h @@ -25,9 +25,9 @@ #include // NOLINT #include -#include "../Common.h" -#include "../NetworkServer.h" -#include "../RmaBufferRegister.h" +#include "pmpool/Common.h" +#include "pmpool/NetworkServer.h" +#include "pmpool/RmaBufferRegister.h" #define p2align(x, a) (((x) + (a)-1) & ~((a)-1)) @@ -36,12 +36,24 @@ class CircularBuffer { CircularBuffer() = delete; CircularBuffer(const CircularBuffer &) = delete; CircularBuffer(uint64_t buffer_size, uint32_t buffer_num, - bool is_server = false, RmaBufferRegister *rbr = nullptr) + bool is_server = false) : buffer_size_(buffer_size), buffer_num_(buffer_num), - rbr_(rbr), read_(0), write_(0) { + init(); + } + CircularBuffer(uint64_t buffer_size, uint32_t buffer_num, bool is_server, + std::shared_ptr rbr) + : rbr_(rbr), + buffer_size_(buffer_size), + buffer_num_(buffer_num), + read_(0), + write_(0) { + init(); + } + + void init() { uint64_t total = buffer_num_ * buffer_size_; buffer_ = static_cast(mmap(0, buffer_num_ * buffer_size_, PROT_READ | PROT_WRITE, @@ -57,16 +69,29 @@ class CircularBuffer { if (rbr_) { ck_ = rbr_->register_rma_buffer(buffer_, buffer_num_ * buffer_size_); +#ifdef DEBUG + printf("[CircularBuffer::Register_RMA_Buffer] range is %ld - %ld\n", + (uint64_t)buffer_, + (uint64_t)(buffer_ + buffer_num_ * buffer_size_)); +#endif } - for (int i = 0; i < buffer_num; i++) { + for (int i = 0; i < buffer_num_; i++) { bits.push_back(0); } } + ~CircularBuffer() { + if (ck_ != nullptr) { + rbr_->unregister_rma_buffer(ck_->buffer_id); + } munmap(buffer_, buffer_num_ * buffer_size_); buffer_ = nullptr; +#ifdef DEBUG + std::cout << "CircularBuffer destructed" << std::endl; +#endif } + char *get(uint64_t bytes) { uint64_t offset = 0; bool res = get(bytes, &offset); @@ -75,6 +100,7 @@ class CircularBuffer { } return buffer_ + offset * buffer_size_; } + void put(const char *data, uint64_t bytes) { assert((data - buffer_) % buffer_size_ == 0); uint64_t offset = (data - buffer_) / buffer_size_; @@ -203,8 +229,8 @@ class CircularBuffer { char *tmp_; uint64_t buffer_size_; uint64_t buffer_num_; - RmaBufferRegister *rbr_; - Chunk *ck_; + std::shared_ptr rbr_; + Chunk *ck_ = nullptr; std::vector bits; uint64_t read_; uint64_t write_; diff --git a/rpmp/pmpool/client/NetworkClient.cc b/rpmp/pmpool/client/NetworkClient.cc index 1e32cba9..06290f43 100644 --- a/rpmp/pmpool/client/NetworkClient.cc +++ b/rpmp/pmpool/client/NetworkClient.cc @@ -13,89 +13,208 @@ #include #include -#include "../Event.h" -#include "../buffer/CircularBuffer.h" +#include "pmpool/Event.h" +#include "pmpool/buffer/CircularBuffer.h" +using namespace std::chrono_literals; uint64_t timestamp_now() { return std::chrono::high_resolution_clock::now().time_since_epoch() / std::chrono::milliseconds(1); } -RequestHandler::RequestHandler(NetworkClient *networkClient) +RequestHandler::RequestHandler(std::shared_ptr networkClient) : networkClient_(networkClient) {} -void RequestHandler::addTask(Request *request) { handleRequest(request); } +RequestHandler::~RequestHandler() { +#ifdef DEBUG + std::cout << "RequestHandler destructed" << std::endl; +#endif +} + +void RequestHandler::reset() { + this->stop(); + this->join(); + networkClient_.reset(); +#ifdef DEBUG + std::cout << "Callback map is " + << (callback_map.empty() ? "empty" : "not empty") << std::endl; + std::cout << "inflight map is " << (inflight_.empty() ? "empty" : "not empty") + << std::endl; +#endif +} -void RequestHandler::addTask(Request *request, std::function func) { +void RequestHandler::addTask(std::shared_ptr request) { + pendingRequestQueue_.enqueue(request); +} + +void RequestHandler::addTask(std::shared_ptr request, + std::function func) { callback_map[request->get_rc().rid] = func; - handleRequest(request); + pendingRequestQueue_.enqueue(request); } -void RequestHandler::wait() { - unique_lock lk(h_mtx); - while (!op_finished) { - cv.wait(lk); +int RequestHandler::entry() { + std::shared_ptr request; + bool res = pendingRequestQueue_.wait_dequeue_timed( + request, std::chrono::milliseconds(1000)); + if (res) { + handleRequest(request); + } + return 0; +} + +std::shared_ptr +RequestHandler::inflight_insert_or_get(std::shared_ptr request) { + const std::lock_guard lock(inflight_mtx_); + auto rid = request->requestContext_.rid; + if (inflight_.find(rid) == inflight_.end()) { + auto ctx = std::make_shared(); + inflight_.emplace(rid, ctx); + return ctx; + } else { + auto ctx = inflight_[rid]; + return ctx; + } +} + +void RequestHandler::inflight_erase(std::shared_ptr request) { + const std::lock_guard lock(inflight_mtx_); + inflight_.erase(request->requestContext_.rid); +} + +uint64_t RequestHandler::wait(std::shared_ptr request) { + auto ctx = inflight_insert_or_get(request); + unique_lock lk(ctx->mtx_reply); + while (!ctx->cv_reply.wait_for(lk, 5ms, [ctx, request] { + auto current = std::chrono::steady_clock::now(); + auto elapse = current - ctx->start; + if (elapse > 30s) { // tried 10s and found 8 process * 8 threads request + // will still go timeout, need to fix + ctx->op_failed = true; + fprintf(stderr, "Request [TYPE %ld][Key %ld] spent %ld s, time out\n", + request->requestContext_.type, request->requestContext_.key, + std::chrono::duration_cast(elapse).count()); + return true; + } + return ctx->op_finished; + })) { + } + uint64_t res = 0; + if (ctx->op_failed) { + res = -1; + } + // res = ctx->requestReplyContext->size; + inflight_erase(request); + return res; +} + +std::shared_ptr RequestHandler::get( + std::shared_ptr request) { + auto ctx = inflight_insert_or_get(request); + unique_lock lk(ctx->mtx_reply); + while (!ctx->cv_reply.wait_for(lk, 5ms, [ctx, request] { + auto current = std::chrono::steady_clock::now(); + auto elapse = current - ctx->start; + if (elapse > 30s) { // tried 10s and found 8 process * 8 threads request + // will still go timeout, need to fix + ctx->op_failed = true; + fprintf(stderr, "Request [TYPE %ld] spent %ld s, time out\n", + request->requestContext_.type, + std::chrono::duration_cast(elapse).count()); + return true; + } + return ctx->op_finished; + })) { + } + auto res = std::move(ctx->requestReplyContext); + if (ctx->op_failed) { + throw; } + inflight_erase(request); + return res; } -void RequestHandler::notify(RequestReply *requestReply) { - unique_lock lk(h_mtx); - requestReplyContext = requestReply->get_rrc(); - op_finished = true; - if (callback_map.count(requestReplyContext.rid) != 0) { - callback_map[requestReplyContext.rid](); - callback_map.erase(requestReplyContext.rid); +void RequestHandler::notify(std::shared_ptr requestReply) { + const std::lock_guard lock(inflight_mtx_); + auto rid = requestReply->get_rrc()->rid; + auto ctx = inflight_[rid]; + ctx->op_finished = true; + auto rrc = requestReply->get_rrc(); + if (expectedReturnType != rrc->type) { + std::string err_msg = "expected return type is " + + std::to_string(expectedReturnType) + + ", current rrc.type is " + std::to_string(rrc->type); + std::cout << err_msg << std::endl; + return; + } + ctx->requestReplyContext = std::move(rrc); + if (callback_map.count(ctx->requestReplyContext->rid) != 0) { + callback_map[ctx->requestReplyContext->rid](); + callback_map.erase(ctx->requestReplyContext->rid); } else { - cv.notify_one(); - lk.unlock(); + ctx->cv_reply.notify_one(); } } -void RequestHandler::handleRequest(Request *request) { - op_finished = false; +void RequestHandler::handleRequest(std::shared_ptr request) { + auto ctx = inflight_insert_or_get(request); OpType rt = request->get_rc().type; switch (rt) { case ALLOC: { + expectedReturnType = ALLOC_REPLY; request->encode(); networkClient_->send(reinterpret_cast(request->data_), request->size_); break; } case FREE: { + expectedReturnType = FREE_REPLY; request->encode(); networkClient_->send(reinterpret_cast(request->data_), request->size_); break; } case WRITE: { + expectedReturnType = WRITE_REPLY; request->encode(); networkClient_->send(reinterpret_cast(request->data_), request->size_); break; } case READ: { + expectedReturnType = READ_REPLY; request->encode(); networkClient_->send(reinterpret_cast(request->data_), request->size_); break; } case PUT: { + expectedReturnType = PUT_REPLY; request->encode(); networkClient_->send(reinterpret_cast(request->data_), request->size_); + break; + } + case GET: { + expectedReturnType = GET_REPLY; + request->encode(); + networkClient_->send(reinterpret_cast(request->data_), + request->size_); + break; } case GET_META: { + expectedReturnType = GET_META_REPLY; request->encode(); networkClient_->send(reinterpret_cast(request->data_), request->size_); + break; } default: {} } } -RequestReplyContext &RequestHandler::get() { return requestReplyContext; } - -ClientConnectedCallback::ClientConnectedCallback(NetworkClient *networkClient) { +ClientConnectedCallback::ClientConnectedCallback( + std::shared_ptr networkClient) { networkClient_ = networkClient; } @@ -104,8 +223,9 @@ void ClientConnectedCallback::operator()(void *param_1, void *param_2) { networkClient_->connected(con); } -ClientRecvCallback::ClientRecvCallback(ChunkMgr *chunkMgr, - RequestHandler *requestHandler) +ClientRecvCallback::ClientRecvCallback( + std::shared_ptr chunkMgr, + std::shared_ptr requestHandler) : chunkMgr_(chunkMgr), requestHandler_(requestHandler) {} void ClientRecvCallback::operator()(void *param_1, void *param_2) { @@ -136,25 +256,38 @@ void ClientRecvCallback::operator()(void *param_1, void *param_2) { // con->send(new_ck); // test end - RequestReply requestReply(reinterpret_cast(ck->buffer), ck->size, - reinterpret_cast(ck->con)); - requestReply.decode(); - RequestReplyContext rrc = requestReply.get_rrc(); - switch (rrc.type) { + auto requestReply = std::make_shared( + reinterpret_cast(ck->buffer), ck->size, + reinterpret_cast(ck->con)); + requestReply->decode(); + auto rrc = requestReply->get_rrc(); + switch (rrc->type) { case ALLOC_REPLY: { - requestHandler_->notify(&requestReply); + requestHandler_->notify(requestReply); break; } case FREE_REPLY: { - requestHandler_->notify(&requestReply); + requestHandler_->notify(requestReply); break; } case WRITE_REPLY: { - requestHandler_->notify(&requestReply); + requestHandler_->notify(requestReply); break; } case READ_REPLY: { - requestHandler_->notify(&requestReply); + requestHandler_->notify(requestReply); + break; + } + case PUT_REPLY: { + requestHandler_->notify(requestReply); + break; + } + case GET_REPLY: { + requestHandler_->notify(requestReply); + break; + } + case GET_META_REPLY: { + requestHandler_->notify(requestReply); break; } default: {} @@ -179,30 +312,32 @@ NetworkClient::NetworkClient(const string &remote_address, connected_(false) {} NetworkClient::~NetworkClient() { - delete shutdownCallback; - delete connectedCallback; - delete sendCallback; - delete recvCallback; +#ifdef DUBUG + std::cout << "NetworkClient destructed" << std::endl; +#endif } -int NetworkClient::init(RequestHandler *requestHandler) { - client_ = new Client(worker_num_, buffer_num_per_con_); +int NetworkClient::init(std::shared_ptr requestHandler) { + client_ = std::make_shared(worker_num_, buffer_num_per_con_); if ((client_->init()) != 0) { return -1; } - chunkMgr_ = new ChunkPool(client_, buffer_size_, init_buffer_num_); + chunkMgr_ = std::make_shared(client_.get(), buffer_size_, + init_buffer_num_); - client_->set_chunk_mgr(chunkMgr_); + client_->set_chunk_mgr(chunkMgr_.get()); - shutdownCallback = new ClientShutdownCallback(); - connectedCallback = new ClientConnectedCallback(this); - recvCallback = new ClientRecvCallback(chunkMgr_, requestHandler); - sendCallback = new ClientSendCallback(chunkMgr_); + shutdownCallback = std::make_shared(); + connectedCallback = + std::make_shared(shared_from_this()); + recvCallback = + std::make_shared(chunkMgr_, requestHandler); + sendCallback = std::make_shared(chunkMgr_); - client_->set_shutdown_callback(shutdownCallback); - client_->set_connected_callback(connectedCallback); - client_->set_recv_callback(recvCallback); - client_->set_send_callback(sendCallback); + client_->set_shutdown_callback(shutdownCallback.get()); + client_->set_connected_callback(connectedCallback.get()); + client_->set_recv_callback(recvCallback.get()); + client_->set_send_callback(sendCallback.get()); client_->start(); int res = client_->connect(remote_address_.c_str(), remote_port_.c_str()); @@ -211,13 +346,32 @@ int NetworkClient::init(RequestHandler *requestHandler) { con_v.wait(lk); } - circularBuffer_ = make_shared(1024 * 1024, 512, false, this); + circularBuffer_ = + make_shared(1024 * 1024, 512, false, shared_from_this()); + return 0; } void NetworkClient::shutdown() { client_->shutdown(); } void NetworkClient::wait() { client_->wait(); } +void NetworkClient::reset() { + circularBuffer_.reset(); + shutdownCallback.reset(); + connectedCallback.reset(); + recvCallback.reset(); + sendCallback.reset(); + if (con_ != nullptr) { + con_->shutdown(); + } + if (client_) { + client_->shutdown(); + client_.reset(); + } +} + +std::shared_ptr NetworkClient::get_chunkMgr() { return chunkMgr_; } + Chunk *NetworkClient::register_rma_buffer(char *rma_buffer, uint64_t size) { return client_->reg_rma_buffer(rma_buffer, size, buffer_id_++); } @@ -254,9 +408,18 @@ void NetworkClient::send(char *data, uint64_t size) { auto ck = chunkMgr_->get(con_); std::memcpy(reinterpret_cast(ck->buffer), data, size); ck->size = size; +#ifdef DEBUG + RequestMsg *requestMsg = (RequestMsg *)(data); + std::cout << "[NetworkClient::send][" << requestMsg->type << "] size is " + << size << std::endl; + for (int i = 0; i < size; i++) { + printf("%X ", *(data + i)); + } + printf("\n"); +#endif con_->send(ck); } -void NetworkClient::read(Request *request) { +void NetworkClient::read(std::shared_ptr request) { RequestContext rc = request->get_rc(); } diff --git a/rpmp/pmpool/client/NetworkClient.h b/rpmp/pmpool/client/NetworkClient.h index f9e4f54c..1f8ebc8f 100644 --- a/rpmp/pmpool/client/NetworkClient.h +++ b/rpmp/pmpool/client/NetworkClient.h @@ -22,11 +22,11 @@ #include #include -#include "../Event.h" -#include "../RmaBufferRegister.h" -#include "../ThreadWrapper.h" -#include "../queue/blockingconcurrentqueue.h" -#include "../queue/concurrentqueue.h" +#include "pmpool/Event.h" +#include "pmpool/RmaBufferRegister.h" +#include "pmpool/ThreadWrapper.h" +#include "pmpool/queue/blockingconcurrentqueue.h" +#include "pmpool/queue/concurrentqueue.h" using moodycamel::BlockingConcurrentQueue; using std::atomic; @@ -48,31 +48,46 @@ class ChunkMgr; typedef promise Promise; typedef future Future; -class RequestHandler { +class RequestHandler : public ThreadWrapper { public: - explicit RequestHandler(NetworkClient *networkClient); - ~RequestHandler() = default; - void addTask(Request *request); - void addTask(Request *request, std::function func); - void notify(RequestReply *requestReply); - void wait(); - RequestReplyContext &get(); - - private: - void handleRequest(Request *request); + explicit RequestHandler(std::shared_ptr networkClient); + ~RequestHandler(); + void addTask(std::shared_ptr request); + void addTask(std::shared_ptr request, std::function func); + void reset(); + int entry() override; + void abort() override {} + void notify(std::shared_ptr requestReply); + uint64_t wait(std::shared_ptr request); + std::shared_ptr get(std::shared_ptr request); private: - NetworkClient *networkClient_; - BlockingConcurrentQueue pendingRequestQueue_; - std::mutex h_mtx; + std::shared_ptr networkClient_; + BlockingConcurrentQueue> pendingRequestQueue_; unordered_map> callback_map; uint64_t total_num = 0; uint64_t begin = 0; uint64_t end = 0; uint64_t time = 0; - bool op_finished = false; - std::condition_variable cv; - RequestReplyContext requestReplyContext; + struct InflightRequestContext { + std::mutex mtx_reply; + std::condition_variable cv_reply; + std::mutex mtx_returned; + std::chrono::time_point start; + bool op_finished = false; + bool op_failed = false; + InflightRequestContext() { start = std::chrono::steady_clock::now(); } + std::shared_ptr requestReplyContext; + }; + std::unordered_map> + inflight_; + std::mutex inflight_mtx_; + long expectedReturnType; + + std::shared_ptr inflight_insert_or_get( + std::shared_ptr); + void inflight_erase(std::shared_ptr request); + void handleRequest(std::shared_ptr request); }; class ClientShutdownCallback : public Callback { @@ -84,23 +99,25 @@ class ClientShutdownCallback : public Callback { class ClientConnectedCallback : public Callback { public: - explicit ClientConnectedCallback(NetworkClient *networkClient); + explicit ClientConnectedCallback( + std::shared_ptr networkClient); ~ClientConnectedCallback() = default; void operator()(void *param_1, void *param_2); private: - NetworkClient *networkClient_; + std::shared_ptr networkClient_; }; class ClientRecvCallback : public Callback { public: - ClientRecvCallback(ChunkMgr *chunkMgr, RequestHandler *requestHandler); + ClientRecvCallback(std::shared_ptr chunkMgr, + std::shared_ptr requestHandler); ~ClientRecvCallback() = default; void operator()(void *param_1, void *param_2); private: - ChunkMgr *chunkMgr_; - RequestHandler *requestHandler_; + std::shared_ptr chunkMgr_; + std::shared_ptr requestHandler_; uint64_t count_ = 0; uint64_t time = 0; uint64_t start = 0; @@ -110,7 +127,8 @@ class ClientRecvCallback : public Callback { class ClientSendCallback : public Callback { public: - explicit ClientSendCallback(ChunkMgr *chunkMgr) : chunkMgr_(chunkMgr) {} + explicit ClientSendCallback(std::shared_ptr chunkMgr) + : chunkMgr_(chunkMgr) {} ~ClientSendCallback() = default; void operator()(void *param_1, void *param_2) { auto buffer_id_ = *static_cast(param_1); @@ -119,10 +137,11 @@ class ClientSendCallback : public Callback { } private: - ChunkMgr *chunkMgr_; + std::shared_ptr chunkMgr_; }; -class NetworkClient : public RmaBufferRegister { +class NetworkClient : public RmaBufferRegister, + public std::enable_shared_from_this { public: friend ClientConnectedCallback; NetworkClient() = delete; @@ -131,7 +150,7 @@ class NetworkClient : public RmaBufferRegister { int worker_num, int buffer_num_per_con, int buffer_size, int init_buffer_num); ~NetworkClient(); - int init(RequestHandler *requesthandler); + int init(std::shared_ptr requesthandler); void shutdown(); void wait(); Chunk *register_rma_buffer(char *rma_buffer, uint64_t size) override; @@ -141,7 +160,9 @@ class NetworkClient : public RmaBufferRegister { uint64_t get_rkey(); void connected(Connection *con); void send(char *data, uint64_t size); - void read(Request *request); + void read(std::shared_ptr request); + std::shared_ptr get_chunkMgr(); + void reset(); private: string remote_address_; @@ -150,13 +171,13 @@ class NetworkClient : public RmaBufferRegister { int buffer_num_per_con_; int buffer_size_; int init_buffer_num_; - Client *client_; - ChunkMgr *chunkMgr_; + std::shared_ptr client_; + std::shared_ptr chunkMgr_; Connection *con_; - ClientShutdownCallback *shutdownCallback; - ClientConnectedCallback *connectedCallback; - ClientRecvCallback *recvCallback; - ClientSendCallback *sendCallback; + std::shared_ptr shutdownCallback; + std::shared_ptr connectedCallback; + std::shared_ptr recvCallback; + std::shared_ptr sendCallback; mutex con_mtx; bool connected_; condition_variable con_v; diff --git a/rpmp/pmpool/client/PmPoolClient.cc b/rpmp/pmpool/client/PmPoolClient.cc index f1da09a2..5d7494c1 100644 --- a/rpmp/pmpool/client/PmPoolClient.cc +++ b/rpmp/pmpool/client/PmPoolClient.cc @@ -19,12 +19,23 @@ PmPoolClient::PmPoolClient(const string &remote_address, tx_finished = true; op_finished = false; networkClient_ = make_shared(remote_address, remote_port); - requestHandler_ = make_shared(networkClient_.get()); + requestHandler_ = make_shared(networkClient_); } -PmPoolClient::~PmPoolClient() {} +PmPoolClient::~PmPoolClient() { + requestHandler_->reset(); + networkClient_->reset(); -int PmPoolClient::init() { networkClient_->init(requestHandler_.get()); } +#ifdef DEBUG + std::cout << "PmPoolClient destructed" << std::endl; +#endif +} + +int PmPoolClient::init() { + auto res = networkClient_->init(requestHandler_); + requestHandler_->start(); + return res; +} void PmPoolClient::begin_tx() { std::unique_lock lk(tx_mtx); @@ -39,10 +50,9 @@ uint64_t PmPoolClient::alloc(uint64_t size) { rc.type = ALLOC; rc.rid = rid_++; rc.size = size; - Request request(rc); - requestHandler_->addTask(&request); - requestHandler_->wait(); - return requestHandler_->get().address; + auto request = std::make_shared(rc); + requestHandler_->addTask(request); + return requestHandler_->get(request)->address; } int PmPoolClient::free(uint64_t address) { @@ -50,10 +60,9 @@ int PmPoolClient::free(uint64_t address) { rc.type = FREE; rc.rid = rid_++; rc.address = address; - Request request(rc); - requestHandler_->addTask(&request); - requestHandler_->wait(); - return requestHandler_->get().success; + auto request = std::make_shared(rc); + requestHandler_->addTask(request); + return requestHandler_->get(request)->success; } void PmPoolClient::shutdown() { networkClient_->shutdown(); } @@ -69,10 +78,9 @@ int PmPoolClient::write(uint64_t address, const char *data, uint64_t size) { // allocate memory for RMA read from client. rc.src_address = networkClient_->get_dram_buffer(data, rc.size); rc.src_rkey = networkClient_->get_rkey(); - Request request(rc); - requestHandler_->addTask(&request); - requestHandler_->wait(); - auto res = requestHandler_->get().success; + auto request = std::make_shared(rc); + requestHandler_->addTask(request); + auto res = requestHandler_->get(request)->success; networkClient_->reclaim_dram_buffer(rc.src_address, rc.size); return res; } @@ -86,10 +94,9 @@ uint64_t PmPoolClient::write(const char *data, uint64_t size) { // allocate memory for RMA read from client. rc.src_address = networkClient_->get_dram_buffer(data, rc.size); rc.src_rkey = networkClient_->get_rkey(); - Request request(rc); - requestHandler_->addTask(&request); - requestHandler_->wait(); - auto res = requestHandler_->get().address; + auto request = std::make_shared(rc); + requestHandler_->addTask(request); + auto res = requestHandler_->get(request)->address; networkClient_->reclaim_dram_buffer(rc.src_address, rc.size); return res; } @@ -103,10 +110,9 @@ int PmPoolClient::read(uint64_t address, char *data, uint64_t size) { // allocate memory for RMA read from client. rc.src_address = networkClient_->get_dram_buffer(nullptr, rc.size); rc.src_rkey = networkClient_->get_rkey(); - Request request(rc); - requestHandler_->addTask(&request); - requestHandler_->wait(); - auto res = requestHandler_->get().success; + auto request = std::make_shared(rc); + requestHandler_->addTask(request); + auto res = requestHandler_->get(request)->success; if (!res) { memcpy(data, reinterpret_cast(rc.src_address), size); } @@ -124,9 +130,9 @@ int PmPoolClient::read(uint64_t address, char *data, uint64_t size, // allocate memory for RMA read from client. rc.src_address = networkClient_->get_dram_buffer(nullptr, rc.size); rc.src_rkey = networkClient_->get_rkey(); - Request request(rc); - requestHandler_->addTask(&request, [&] { - auto res = requestHandler_->get().success; + auto request = std::make_shared(rc); + requestHandler_->addTask(request, [&] { + auto res = requestHandler_->get(request)->success; if (res) { memcpy(data, reinterpret_cast(rc.src_address), size); } @@ -154,16 +160,60 @@ uint64_t PmPoolClient::put(const string &key, const char *value, // allocate memory for RMA read from client. rc.src_address = networkClient_->get_dram_buffer(value, rc.size); rc.src_rkey = networkClient_->get_rkey(); +#ifdef DEBUG + std::cout << "[PmPoolClient::put] " << rc.src_rkey << "-" << rc.src_address + << ":" << rc.size << std::endl; +#endif rc.key = key_uint; - Request request(rc); - requestHandler_->addTask(&request); - requestHandler_->wait(); - auto address = requestHandler_->get().address; + auto request = std::make_shared(rc); + requestHandler_->addTask(request); + auto res = requestHandler_->wait(request); networkClient_->reclaim_dram_buffer(rc.src_address, rc.size); - return address; +#ifdef DEBUG + fprintf(stderr, "[PUT]key is %s, length is %ld, content is \n", key.c_str(), + size); + for (int i = 0; i < 100; i++) { + fprintf(stderr, "%X ", *(value + i)); + } + fprintf(stderr, " ...\n"); +#endif + return res; } -vector PmPoolClient::get(const string &key) { +uint64_t PmPoolClient::get(const string &key, char *value, uint64_t size) { + uint64_t key_uint; + Digest::computeKeyHash(key, &key_uint); + RequestContext rc = {}; + rc.type = GET; + rc.rid = rid_++; + rc.size = size; + rc.address = 0; + // allocate memory for RMA read from client. + rc.src_address = networkClient_->get_dram_buffer(nullptr, rc.size); + rc.src_rkey = networkClient_->get_rkey(); +#ifdef DEBUG + std::cout << "[PmPoolClient::get] " << rc.src_rkey << "-" << rc.src_address + << ":" << rc.size << std::endl; +#endif + rc.key = key_uint; + auto request = std::make_shared(rc); + requestHandler_->addTask(request); + auto res = requestHandler_->wait(request); + memcpy(value, reinterpret_cast(rc.src_address), rc.size); + networkClient_->reclaim_dram_buffer(rc.src_address, rc.size); + +#ifdef DEBUG + fprintf(stderr, "[GET]key is %s, length is %ld, content is \n", key.c_str(), + size); + for (int i = 0; i < 100; i++) { + fprintf(stderr, "%X ", *(value + i)); + } + fprintf(stderr, " ...\n"); +#endif + return res; +} + +vector PmPoolClient::getMeta(const string &key) { uint64_t key_uint; Digest::computeKeyHash(key, &key_uint); RequestContext rc = {}; @@ -171,11 +221,17 @@ vector PmPoolClient::get(const string &key) { rc.rid = rid_++; rc.address = 0; rc.key = key_uint; - Request request(rc); - requestHandler_->addTask(&request); - requestHandler_->wait(); - auto bml = requestHandler_->get().bml; - return bml; + auto request = std::make_shared(rc); + requestHandler_->addTask(request); + auto rrc = requestHandler_->get(request); + if (rrc->type == GET_META_REPLY) { + return rrc->bml; + } else { + std::string err_msg = + "GetMeta function got " + std::to_string(rrc->type) + " msg."; + std::cerr << err_msg << std::endl; + throw; + } } int PmPoolClient::del(const string &key) { @@ -185,9 +241,8 @@ int PmPoolClient::del(const string &key) { rc.type = DELETE; rc.rid = rid_++; rc.key = key_uint; - Request request(rc); - requestHandler_->addTask(&request); - requestHandler_->wait(); - auto res = requestHandler_->get().success; + auto request = std::make_shared(rc); + requestHandler_->addTask(request); + auto res = requestHandler_->get(request)->success; return res; } diff --git a/rpmp/pmpool/client/PmPoolClient.h b/rpmp/pmpool/client/PmPoolClient.h index b9544f35..0c458cf3 100644 --- a/rpmp/pmpool/client/PmPoolClient.h +++ b/rpmp/pmpool/client/PmPoolClient.h @@ -29,9 +29,9 @@ #include #include -#include "../Base.h" -#include "../Common.h" -#include "../ThreadWrapper.h" +#include "pmpool/Base.h" +#include "pmpool/Common.h" +#include "pmpool/ThreadWrapper.h" class NetworkClient; class RequestHandler; @@ -80,7 +80,8 @@ class PmPoolClient { /// key-value storage interface uint64_t put(const string &key, const char *value, uint64_t size); - vector get(const string &key); + uint64_t get(const string &key, char *value, uint64_t size); + vector getMeta(const string &key); int del(const string &key); void shutdown(); diff --git a/rpmp/pmpool/client/java/rpmp/pom.xml b/rpmp/pmpool/client/java/rpmp/pom.xml index 65310b2a..ab6a9b0d 100644 --- a/rpmp/pmpool/client/java/rpmp/pom.xml +++ b/rpmp/pmpool/client/java/rpmp/pom.xml @@ -16,9 +16,16 @@ UTF-8 1.7 1.7 + ../../../../build/lib + + org.slf4j + slf4j-api + 1.6.1 + + junit junit @@ -28,6 +35,14 @@ + + + ${cpp.build.dir} + + **/libpmpool_client_jni.so + + + diff --git a/rpmp/pmpool/client/java/rpmp/src/main/java/com/intel/rpmp/JniUtils.java b/rpmp/pmpool/client/java/rpmp/src/main/java/com/intel/rpmp/JniUtils.java new file mode 100644 index 00000000..daab29e7 --- /dev/null +++ b/rpmp/pmpool/client/java/rpmp/src/main/java/com/intel/rpmp/JniUtils.java @@ -0,0 +1,70 @@ +package com.intel.rpmp; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Helper class for JNI related operations. */ +public class JniUtils { + private static final String LIBRARY_NAME = "pmpool_client_jni"; + private static boolean isLoaded = false; + private static volatile JniUtils INSTANCE; + private static final Logger LOG = LoggerFactory.getLogger(JniUtils.class); + + public static JniUtils getInstance() throws IOException { + if (INSTANCE == null) { + synchronized (JniUtils.class) { + if (INSTANCE == null) { + try { + INSTANCE = new JniUtils(); + } catch (IllegalAccessException ex) { + throw new IOException("IllegalAccess"); + } + } + } + } + + return INSTANCE; + } + + private JniUtils() throws IOException, IllegalAccessException { + try { + loadLibraryFromJar(); + } catch (IOException ex) { + System.loadLibrary(LIBRARY_NAME); + } + } + + static void loadLibraryFromJar() throws IOException, IllegalAccessException { + synchronized (JniUtils.class) { + if (!isLoaded) { + final String libraryToLoad = System.mapLibraryName(LIBRARY_NAME); + final File libraryFile = + moveFileFromJarToTemp(System.getProperty("java.io.tmpdir"), libraryToLoad); + LOG.info("library path is " + libraryFile.getAbsolutePath()); + System.load(libraryFile.getAbsolutePath()); + isLoaded = true; + } + } + } + + private static File moveFileFromJarToTemp(final String tmpDir, String libraryToLoad) + throws IOException { + final File temp = File.createTempFile(tmpDir, libraryToLoad); + try (final InputStream is = + JniUtils.class.getClassLoader().getResourceAsStream(libraryToLoad)) { + if (is == null) { + throw new FileNotFoundException(libraryToLoad); + } else { + Files.copy(is, temp.toPath(), StandardCopyOption.REPLACE_EXISTING); + } + } + return temp; + } +} diff --git a/rpmp/pmpool/client/java/rpmp/src/main/java/com/intel/rpmp/PmPoolClient.java b/rpmp/pmpool/client/java/rpmp/src/main/java/com/intel/rpmp/PmPoolClient.java index 12245816..a89953be 100644 --- a/rpmp/pmpool/client/java/rpmp/src/main/java/com/intel/rpmp/PmPoolClient.java +++ b/rpmp/pmpool/client/java/rpmp/src/main/java/com/intel/rpmp/PmPoolClient.java @@ -4,116 +4,132 @@ import java.lang.reflect.Constructor; import java.nio.ByteBuffer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + /** * PmPoolClient * */ public class PmPoolClient { - static { - System.loadLibrary("pmpool"); - } - - public PmPoolClient(String remote_address, String remote_port) { - objectId = newPmPoolClient_(remote_address, remote_port); - } - - public long alloc(long size) { - return alloc_(size, objectId); - } - - public int free(long address) { - return free_(address, objectId); - } - - public int write(long address, String data, long size) { - return write_(address, data, size, objectId); - } - - public long write(String data, long size) { - return alloc_and_write_(data, size, objectId); - } - - public long write(ByteBuffer data, long size) { - return alloc_and_write_(data, size, objectId); - } - - public int read(long address, long size, ByteBuffer byteBuffer) { - return read_(address, size, byteBuffer, objectId); - } - - public long put(String key, ByteBuffer data, long size) { - return put(key, data, size, objectId); - } - - public long[] getMeta(String key) { - return getMeta(key, objectId); + private static final Logger LOG = LoggerFactory.getLogger(PmPoolClient.class); + + public PmPoolClient(String remote_address, String remote_port) throws IOException { + JniUtils.getInstance(); + LOG.info("create PmPoolClient instance, remote address is " + remote_address + + ", remote port is " + remote_port); + objectId = nativeOpenPmPoolClient(remote_address, remote_port); + } + + public long alloc(long size) { + return nativeAlloc(size, objectId); + } + + public int free(long address) { + return nativeFree(address, objectId); + } + + public int write(long address, String data, long size) { + return nativeWrite(address, data, size, objectId); + } + + public long write(String data, long size) { + return nativeAllocAndWriteWithString(data, size, objectId); + } + + public long write(ByteBuffer data, long size) { + return nativeAllocAndWriteWithByteBuffer(data, size, objectId); + } + + public int read(long address, long size, ByteBuffer byteBuffer) { + return nativeRead(address, size, byteBuffer, objectId); + } + + public long put(String key, ByteBuffer data, long size) { + return nativePut(key, data, size, objectId); + } + + public long get(String key, long size, ByteBuffer data) { + return nativeGet(key, size, data, objectId); + } + + public long[] getMeta(String key) { + long[] res = nativeGetMeta(key, objectId); + if (res == null) { + return new long[0]; + } else { + return res; } - - public int del(String key) { - return del(key, objectId); + } + + public int del(String key) throws IOException { + throw new IOException("Delete " + key); + // return nativeRemove(key, objectId); + } + + public void shutdown() { + nativeShutdown(objectId); + } + + public void waitToStop() { + nativeWaitToStop(objectId); + } + + public void dispose() { + nativeDispose(objectId); + } + + private ByteBuffer convertToByteBuffer(long address, int length) throws IOException { + Class classDirectByteBuffer; + try { + classDirectByteBuffer = Class.forName("java.nio.DirectByteBuffer"); + } catch (ClassNotFoundException e) { + throw new IOException("java.nio.DirectByteBuffer class not found"); } - - public void shutdown() { - shutdown_(objectId); + Constructor constructor; + try { + constructor = classDirectByteBuffer.getDeclaredConstructor(long.class, int.class); + } catch (NoSuchMethodException e) { + throw new IOException("java.nio.DirectByteBuffer constructor not found"); } - - public void waitToStop() { - waitToStop_(objectId); + constructor.setAccessible(true); + ByteBuffer byteBuffer; + try { + byteBuffer = (ByteBuffer) constructor.newInstance(address, length); + } catch (Exception e) { + throw new IOException("java.nio.DirectByteBuffer exception: " + e.toString()); } - public void dispose() { - dispose_(objectId); - } + return byteBuffer; + } - private ByteBuffer convertToByteBuffer(long address, int length) throws IOException { - Class classDirectByteBuffer; - try { - classDirectByteBuffer = Class.forName("java.nio.DirectByteBuffer"); - } catch (ClassNotFoundException e) { - throw new IOException("java.nio.DirectByteBuffer class not found"); - } - Constructor constructor; - try { - constructor = classDirectByteBuffer.getDeclaredConstructor(long.class, int.class); - } catch (NoSuchMethodException e) { - throw new IOException("java.nio.DirectByteBuffer constructor not found"); - } - constructor.setAccessible(true); - ByteBuffer byteBuffer; - try { - byteBuffer = (ByteBuffer) constructor.newInstance(address, length); - } catch (Exception e) { - throw new IOException("java.nio.DirectByteBuffer exception: " + e.toString()); - } - - return byteBuffer; - } + private native long nativeOpenPmPoolClient(String remote_address, String remote_port); - private native long newPmPoolClient_(String remote_address, String remote_port); + private native long nativeAlloc(long size, long objectId); - private native long alloc_(long size, long objectId); + private native int nativeFree(long address, long objectId); - private native int free_(long address, long objectId); + private native int nativeWrite(long address, String data, long size, long objectId); - private native int write_(long address, String data, long size, long objectId); + private native long nativeAllocAndWriteWithString(String data, long size, long objectId); - private native long alloc_and_write_(String data, long size, long objectId); + private native long nativeAllocAndWriteWithByteBuffer(ByteBuffer data, long size, long objectId); - private native long alloc_and_write_(ByteBuffer data, long size, long objectId); + private native long nativePut(String key, ByteBuffer data, long size, long objectId); - private native long put(String key, ByteBuffer data, long size, long objectId); + private native long nativeGet(String key, long size, ByteBuffer data, long objectId); - private native long[] getMeta(String key, long objectId); + private native long[] nativeGetMeta(String key, long objectId); - private native int del(String key, long objectId); + private native int nativeRemove(String key, long objectId); - private native int read_(long address, long size, ByteBuffer byteBuffer, long objectId); + private native int nativeRead(long address, long size, ByteBuffer byteBuffer, long objectId); - private native void shutdown_(long objectId); + private native void nativeShutdown(long objectId); - private native void waitToStop_(long objectId); + private native void nativeWaitToStop(long objectId); - private native void dispose_(long objectId); + private native void nativeDispose(long objectId); - private long objectId; + private long objectId; } diff --git a/rpmp/pmpool/client/java/rpmp/src/test/java/com/intel/rpmp/PmPoolClientTest.java b/rpmp/pmpool/client/java/rpmp/src/test/java/com/intel/rpmp/PmPoolClientTest.java index 2f4669ad..cc97b5c0 100644 --- a/rpmp/pmpool/client/java/rpmp/src/test/java/com/intel/rpmp/PmPoolClientTest.java +++ b/rpmp/pmpool/client/java/rpmp/src/test/java/com/intel/rpmp/PmPoolClientTest.java @@ -16,7 +16,11 @@ public class PmPoolClientTest { @Before public void setup() { - pmPoolClient = new PmPoolClient("172.168.0.40", "12346"); + try { + pmPoolClient = new PmPoolClient("172.168.0.209", "12346"); + } catch (Exception e) { + assertTrue(false); + } } @After diff --git a/rpmp/pmpool/client/java/rpmp/target/classes/com/intel/rpmp/PmPoolClient.class b/rpmp/pmpool/client/java/rpmp/target/classes/com/intel/rpmp/PmPoolClient.class index 2cdedac0..f0a1585f 100644 Binary files a/rpmp/pmpool/client/java/rpmp/target/classes/com/intel/rpmp/PmPoolClient.class and b/rpmp/pmpool/client/java/rpmp/target/classes/com/intel/rpmp/PmPoolClient.class differ diff --git a/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/compile/default-compile/createdFiles.lst b/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/compile/default-compile/createdFiles.lst index e69de29b..7ae962a6 100644 --- a/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/compile/default-compile/createdFiles.lst +++ b/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/compile/default-compile/createdFiles.lst @@ -0,0 +1,2 @@ +com/intel/rpmp/JniUtils.class +com/intel/rpmp/PmPoolClient.class diff --git a/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/compile/default-compile/inputFiles.lst b/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/compile/default-compile/inputFiles.lst index ab698a7d..4e161099 100644 --- a/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/compile/default-compile/inputFiles.lst +++ b/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/compile/default-compile/inputFiles.lst @@ -1 +1,2 @@ -/mnt/spark-pmof/tool/rpmp/pmpool/client/java/rpmp/src/main/java/com/intel/rpmp/PmPoolClient.java +/mnt/spark-pmof/Spark-PMoF/rpmp/pmpool/client/java/rpmp/src/main/java/com/intel/rpmp/JniUtils.java +/mnt/spark-pmof/Spark-PMoF/rpmp/pmpool/client/java/rpmp/src/main/java/com/intel/rpmp/PmPoolClient.java diff --git a/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/testCompile/default-testCompile/createdFiles.lst b/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/testCompile/default-testCompile/createdFiles.lst index e69de29b..9ca782e7 100644 --- a/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/testCompile/default-testCompile/createdFiles.lst +++ b/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/testCompile/default-testCompile/createdFiles.lst @@ -0,0 +1 @@ +com/intel/rpmp/PmPoolClientTest.class diff --git a/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/testCompile/default-testCompile/inputFiles.lst b/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/testCompile/default-testCompile/inputFiles.lst index dfd02fe7..fdcf343d 100644 --- a/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/testCompile/default-testCompile/inputFiles.lst +++ b/rpmp/pmpool/client/java/rpmp/target/maven-status/maven-compiler-plugin/testCompile/default-testCompile/inputFiles.lst @@ -1 +1 @@ -/mnt/spark-pmof/tool/rpmp/pmpool/client/java/rpmp/src/test/java/com/intel/rpmp/PmPoolClientTest.java +/mnt/spark-pmof/Spark-PMoF/rpmp/pmpool/client/java/rpmp/src/test/java/com/intel/rpmp/PmPoolClientTest.java diff --git a/rpmp/pmpool/client/java/rpmp/target/surefire-reports/TEST-com.intel.rpmp.PmPoolClientTest.xml b/rpmp/pmpool/client/java/rpmp/target/surefire-reports/TEST-com.intel.rpmp.PmPoolClientTest.xml deleted file mode 100644 index 33c46b59..00000000 --- a/rpmp/pmpool/client/java/rpmp/target/surefire-reports/TEST-com.intel.rpmp.PmPoolClientTest.xml +++ /dev/null @@ -1,64 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/rpmp/pmpool/client/java/rpmp/target/surefire-reports/com.intel.rpmp.PmPoolClientTest.txt b/rpmp/pmpool/client/java/rpmp/target/surefire-reports/com.intel.rpmp.PmPoolClientTest.txt deleted file mode 100644 index 53952175..00000000 --- a/rpmp/pmpool/client/java/rpmp/target/surefire-reports/com.intel.rpmp.PmPoolClientTest.txt +++ /dev/null @@ -1,4 +0,0 @@ -------------------------------------------------------------------------------- -Test set: com.intel.rpmp.PmPoolClientTest -------------------------------------------------------------------------------- -Tests run: 4, Failures: 0, Errors: 0, Skipped: 0, Time elapsed: 2.63 s - in com.intel.rpmp.PmPoolClientTest diff --git a/rpmp/pmpool/client/java/rpmp/target/test-classes/com/intel/rpmp/PmPoolClientTest.class b/rpmp/pmpool/client/java/rpmp/target/test-classes/com/intel/rpmp/PmPoolClientTest.class index d196a320..09622d5b 100644 Binary files a/rpmp/pmpool/client/java/rpmp/target/test-classes/com/intel/rpmp/PmPoolClientTest.class and b/rpmp/pmpool/client/java/rpmp/target/test-classes/com/intel/rpmp/PmPoolClientTest.class differ diff --git a/rpmp/pmpool/client/native/com_intel_rpmp_PmPoolClient.cc b/rpmp/pmpool/client/native/com_intel_rpmp_PmPoolClient.cc index 3aa641ee..47501ffe 100644 --- a/rpmp/pmpool/client/native/com_intel_rpmp_PmPoolClient.cc +++ b/rpmp/pmpool/client/native/com_intel_rpmp_PmPoolClient.cc @@ -8,15 +8,59 @@ */ #include +#include #include "pmpool/client/PmPoolClient.h" -#include "pmpool/client/native/com_intel_rpmp_PmPoolClient.h" +#include "pmpool/client/native/concurrent_map.h" + +static jint JNI_VERSION = JNI_VERSION_1_8; +static jclass io_exception_class; +static jclass illegal_argument_exception_class; +static arrow::jni::ConcurrentMap> handler_holder_; + +jclass CreateGlobalClassReference(JNIEnv *env, const char *class_name) { + jclass local_class = env->FindClass(class_name); + jclass global_class = (jclass)env->NewGlobalRef(local_class); + env->DeleteLocalRef(local_class); + if (global_class == nullptr) { + std::string error_message = + "Unable to createGlobalClassReference for" + std::string(class_name); + env->ThrowNew(illegal_argument_exception_class, error_message.c_str()); + } + return global_class; +} + +#ifdef __cplusplus +extern "C" { +#endif + +jint JNI_OnLoad(JavaVM *vm, void *reserved) { + JNIEnv *env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION) != JNI_OK) { + return JNI_ERR; + } + + io_exception_class = CreateGlobalClassReference(env, "Ljava/io/IOException;"); + illegal_argument_exception_class = + CreateGlobalClassReference(env, "Ljava/lang/IllegalArgumentException;"); + + return JNI_VERSION; +} -JNIEXPORT jlong JNICALL Java_com_intel_rpmp_PmPoolClient_newPmPoolClient_1( +std::shared_ptr GetClient(JNIEnv *env, jlong id) { + auto handler = handler_holder_.Lookup(id); + if (!handler) { + std::string error_message = "invalid handler id " + std::to_string(id); + env->ThrowNew(illegal_argument_exception_class, error_message.c_str()); + } + return handler; +} + +JNIEXPORT jlong JNICALL Java_com_intel_rpmp_PmPoolClient_nativeOpenPmPoolClient( JNIEnv *env, jobject obj, jstring address, jstring port) { const char *remote_address = env->GetStringUTFChars(address, 0); const char *remote_port = env->GetStringUTFChars(port, 0); - PmPoolClient *client = new PmPoolClient(remote_address, remote_port); + auto client = std::make_shared(remote_address, remote_port); client->begin_tx(); client->init(); client->end_tx(); @@ -24,35 +68,33 @@ JNIEXPORT jlong JNICALL Java_com_intel_rpmp_PmPoolClient_newPmPoolClient_1( env->ReleaseStringUTFChars(address, remote_address); env->ReleaseStringUTFChars(port, remote_port); - return reinterpret_cast(client); + return handler_holder_.Insert(std::move(client)); } -JNIEXPORT jlong JNICALL Java_com_intel_rpmp_PmPoolClient_alloc_1( +JNIEXPORT jlong JNICALL Java_com_intel_rpmp_PmPoolClient_nativeAlloc( JNIEnv *env, jobject obj, jlong size, jlong objectId) { - PmPoolClient *client = reinterpret_cast(objectId); + auto client = GetClient(env, objectId); client->begin_tx(); uint64_t address = client->alloc(size); client->end_tx(); return address; } -JNIEXPORT jint JNICALL Java_com_intel_rpmp_PmPoolClient_free_1(JNIEnv *env, - jobject obj, - jlong address, - jlong objectId) { - PmPoolClient *client = reinterpret_cast(objectId); +JNIEXPORT jint JNICALL Java_com_intel_rpmp_PmPoolClient_nativeFree( + JNIEnv *env, jobject obj, jlong address, jlong objectId) { + auto client = GetClient(env, objectId); client->begin_tx(); int success = client->free(address); client->end_tx(); return success; } -JNIEXPORT jint JNICALL Java_com_intel_rpmp_PmPoolClient_write_1( +JNIEXPORT jint JNICALL Java_com_intel_rpmp_PmPoolClient_nativeWrite( JNIEnv *env, jobject obj, jlong address, jstring data, jlong size, jlong objectId) { const char *raw_data = env->GetStringUTFChars(data, 0); - PmPoolClient *client = reinterpret_cast(objectId); + auto client = GetClient(env, objectId); client->begin_tx(); int success = client->write(address, raw_data, size); client->end_tx(); @@ -62,11 +104,11 @@ JNIEXPORT jint JNICALL Java_com_intel_rpmp_PmPoolClient_write_1( return success; } JNIEXPORT jlong JNICALL -Java_com_intel_rpmp_PmPoolClient_alloc_1and_1write_1__Ljava_lang_String_2JJ( +Java_com_intel_rpmp_PmPoolClient_nativeAllocAndWriteWithString( JNIEnv *env, jobject obj, jstring data, jlong size, jlong objectId) { const char *raw_data = env->GetStringUTFChars(data, 0); - PmPoolClient *client = reinterpret_cast(objectId); + auto client = GetClient(env, objectId); client->begin_tx(); uint64_t address = client->write(raw_data, size); client->end_tx(); @@ -76,21 +118,22 @@ Java_com_intel_rpmp_PmPoolClient_alloc_1and_1write_1__Ljava_lang_String_2JJ( } JNIEXPORT jlong JNICALL -Java_com_intel_rpmp_PmPoolClient_alloc_1and_1write_1__Ljava_nio_ByteBuffer_2JJ( +Java_com_intel_rpmp_PmPoolClient_nativeAllocateAndWriteWithByteBuffer( JNIEnv *env, jobject obj, jobject data, jlong size, jlong objectId) { char *raw_data = static_cast((*env).GetDirectBufferAddress(data)); - PmPoolClient *client = reinterpret_cast(objectId); + auto client = GetClient(env, objectId); client->begin_tx(); uint64_t address = client->write(raw_data, size); client->end_tx(); + return address; } -JNIEXPORT jlong JNICALL -Java_com_intel_rpmp_PmPoolClient_put(JNIEnv *env, jobject obj, jstring key, - jobject data, jlong size, jlong objectId) { +JNIEXPORT jlong JNICALL Java_com_intel_rpmp_PmPoolClient_nativePut( + JNIEnv *env, jobject obj, jstring key, jobject data, jlong size, + jlong objectId) { char *raw_data = static_cast((*env).GetDirectBufferAddress(data)); const char *raw_key = env->GetStringUTFChars(key, 0); - PmPoolClient *client = reinterpret_cast(objectId); + auto client = GetClient(env, objectId); client->begin_tx(); auto address = client->put(raw_key, raw_data, size); client->end_tx(); @@ -98,39 +141,49 @@ Java_com_intel_rpmp_PmPoolClient_put(JNIEnv *env, jobject obj, jstring key, return address; } -JNIEXPORT jlongArray JNICALL Java_com_intel_rpmp_PmPoolClient_getMeta( +JNIEXPORT jlong JNICALL Java_com_intel_rpmp_PmPoolClient_nativeGet( + JNIEnv *env, jobject obj, jstring key, jlong size, jobject data, + jlong objectId) { + char *raw_data = static_cast((*env).GetDirectBufferAddress(data)); + const char *raw_key = env->GetStringUTFChars(key, 0); + auto client = GetClient(env, objectId); + client->begin_tx(); + auto address = client->get(raw_key, raw_data, size); + client->end_tx(); + env->ReleaseStringUTFChars(key, raw_key); + return address; +} + +JNIEXPORT jlongArray JNICALL Java_com_intel_rpmp_PmPoolClient_nativeGetMeta( JNIEnv *env, jobject obj, jstring key, jlong objectId) { const char *raw_key = env->GetStringUTFChars(key, 0); - PmPoolClient *client = reinterpret_cast(objectId); + auto client = GetClient(env, objectId); client->begin_tx(); - auto bml = client->get(raw_key); + auto bml = client->getMeta(raw_key); client->end_tx(); env->ReleaseStringUTFChars(key, raw_key); - int longCArraySize = bml.size() * 2; - jlongArray longJavaArray = env->NewLongArray(longCArraySize); - uint64_t *longCArray = - static_cast(std::malloc(longCArraySize * sizeof(uint64_t))); - if (longJavaArray == nullptr) { + int longCArraySize = bml.size() * 3; + if (longCArraySize == 0) { return nullptr; } + auto longCArray = new uint64_t[longCArraySize](); int i = 0; for (auto bm : bml) { longCArray[i++] = bm.address; longCArray[i++] = bm.size; + longCArray[i++] = bm.r_key; } + jlongArray longJavaArray = env->NewLongArray(longCArraySize); env->SetLongArrayRegion(longJavaArray, 0, longCArraySize, reinterpret_cast(longCArray)); - std::free(longCArray); - env->ReleaseStringUTFChars(key, raw_key); + delete[] longCArray; return longJavaArray; } -JNIEXPORT jint JNICALL Java_com_intel_rpmp_PmPoolClient_del(JNIEnv *env, - jobject obj, - jstring key, - jlong objectId) { +JNIEXPORT jint JNICALL Java_com_intel_rpmp_PmPoolClient_nativeRemove( + JNIEnv *env, jobject obj, jstring key, jlong objectId) { const char *raw_key = env->GetStringUTFChars(key, 0); - PmPoolClient *client = reinterpret_cast(objectId); + auto client = GetClient(env, objectId); client->begin_tx(); int res = client->del(raw_key); client->end_tx(); @@ -138,31 +191,34 @@ JNIEXPORT jint JNICALL Java_com_intel_rpmp_PmPoolClient_del(JNIEnv *env, return res; } -JNIEXPORT jint JNICALL Java_com_intel_rpmp_PmPoolClient_read_1( +JNIEXPORT jint JNICALL Java_com_intel_rpmp_PmPoolClient_nativeRead( JNIEnv *env, jobject obj, jlong address, jlong size, jobject data, jlong objectId) { char *raw_data = static_cast((*env).GetDirectBufferAddress(data)); - PmPoolClient *client = reinterpret_cast(objectId); + auto client = GetClient(env, objectId); client->begin_tx(); int success = client->read(address, raw_data, size); client->end_tx(); return success; } -JNIEXPORT void JNICALL Java_com_intel_rpmp_PmPoolClient_shutdown_1( +JNIEXPORT void JNICALL Java_com_intel_rpmp_PmPoolClient_nativeShutdown( JNIEnv *env, jobject obj, jlong objectId) { - PmPoolClient *client = reinterpret_cast(objectId); + auto client = GetClient(env, objectId); client->shutdown(); } -JNIEXPORT void JNICALL Java_com_intel_rpmp_PmPoolClient_waitToStop_1( +JNIEXPORT void JNICALL Java_com_intel_rpmp_PmPoolClient_nativeWaitToStop( JNIEnv *env, jobject obj, jlong objectId) { - PmPoolClient *client = reinterpret_cast(objectId); + auto client = GetClient(env, objectId); client->wait(); } -JNIEXPORT void JNICALL Java_com_intel_rpmp_PmPoolClient_dispose_1( +JNIEXPORT void JNICALL Java_com_intel_rpmp_PmPoolClient_nativeDispose( JNIEnv *env, jobject obj, jlong objectId) { - PmPoolClient *client = reinterpret_cast(objectId); - delete client; + handler_holder_.Erase(objectId); +} + +#ifdef __cplusplus } +#endif diff --git a/rpmp/pmpool/client/native/com_intel_rpmp_PmPoolClient.h b/rpmp/pmpool/client/native/com_intel_rpmp_PmPoolClient.h deleted file mode 100644 index 1ba2636b..00000000 --- a/rpmp/pmpool/client/native/com_intel_rpmp_PmPoolClient.h +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Filename: - * /mnt/spark-pmof/Spark-PMoF/rpmp/pmpool/client/native/com_intel_rpmp_PmPoolClient.h - * Path: /mnt/spark-pmof/Spark-PMoF/rpmp/pmpool/client/native - * Created Date: Thursday, March 5th 2020, 10:44:12 am - * Author: root - * - * Copyright (c) 2020 Intel - */ - -#include -/* Header for class com_intel_rpmp_PmPoolClient */ - -#ifndef PMPOOL_CLIENT_NATIVE_COM_INTEL_RPMP_PMPOOLCLIENT_H_ -#define PMPOOL_CLIENT_NATIVE_COM_INTEL_RPMP_PMPOOLCLIENT_H_ -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: com_intel_rpmp_PmPoolClient - * Method: newPmPoolClient_ - * Signature: (Ljava/lang/String;Ljava/lang/String;)J - */ -JNIEXPORT jlong JNICALL Java_com_intel_rpmp_PmPoolClient_newPmPoolClient_1( - JNIEnv *, jobject, jstring, jstring); - -/* - * Class: com_intel_rpmp_PmPoolClient - * Method: alloc_ - * Signature: (JJ)J - */ -JNIEXPORT jlong JNICALL Java_com_intel_rpmp_PmPoolClient_alloc_1(JNIEnv *, - jobject, jlong, - jlong); - -/* - * Class: com_intel_rpmp_PmPoolClient - * Method: free_ - * Signature: (JJ)I - */ -JNIEXPORT jint JNICALL Java_com_intel_rpmp_PmPoolClient_free_1(JNIEnv *, - jobject, jlong, - jlong); - -/* - * Class: com_intel_rpmp_PmPoolClient - * Method: write_ - * Signature: (JLjava/lang/String;JJ)I - */ -JNIEXPORT jint JNICALL Java_com_intel_rpmp_PmPoolClient_write_1(JNIEnv *, - jobject, jlong, - jstring, jlong, - jlong); - -/* - * Class: com_intel_rpmp_PmPoolClient - * Method: alloc_and_write_ - * Signature: (Ljava/lang/String;JJ)J - */ -JNIEXPORT jlong JNICALL -Java_com_intel_rpmp_PmPoolClient_alloc_1and_1write_1__Ljava_lang_String_2JJ( - JNIEnv *, jobject, jstring, jlong, jlong); - -/* - * Class: com_intel_rpmp_PmPoolClient - * Method: alloc_and_write_ - * Signature: (Ljava/nio/ByteBuffer;JJ)J - */ -JNIEXPORT jlong JNICALL -Java_com_intel_rpmp_PmPoolClient_alloc_1and_1write_1__Ljava_nio_ByteBuffer_2JJ( - JNIEnv *, jobject, jobject, jlong, jlong); - -/* - * Class: com_intel_rpmp_PmPoolClient - * Method: put - * Signature: (Ljava/lang/String;Ljava/nio/ByteBuffer;JJ)J - */ -JNIEXPORT jlong JNICALL Java_com_intel_rpmp_PmPoolClient_put(JNIEnv *, jobject, - jstring, jobject, - jlong, jlong); - -/* - * Class: com_intel_rpmp_PmPoolClient - * Method: getMeta - * Signature: (Ljava/lang/String;J)[J - */ -JNIEXPORT jlongArray JNICALL Java_com_intel_rpmp_PmPoolClient_getMeta(JNIEnv *, - jobject, - jstring, - jlong); - -/* - * Class: com_intel_rpmp_PmPoolClient - * Method: del - * Signature: (Ljava/lang/String;J)I - */ -JNIEXPORT jint JNICALL Java_com_intel_rpmp_PmPoolClient_del(JNIEnv *, jobject, - jstring, jlong); - -/* - * Class: com_intel_rpmp_PmPoolClient - * Method: read_ - * Signature: (JJLjava/nio/ByteBuffer;J)I - */ -JNIEXPORT jint JNICALL Java_com_intel_rpmp_PmPoolClient_read_1(JNIEnv *, - jobject, jlong, - jlong, jobject, - jlong); - -/* - * Class: com_intel_rpmp_PmPoolClient - * Method: shutdown_ - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_com_intel_rpmp_PmPoolClient_shutdown_1(JNIEnv *, - jobject, - jlong); - -/* - * Class: com_intel_rpmp_PmPoolClient - * Method: waitToStop_ - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_com_intel_rpmp_PmPoolClient_waitToStop_1(JNIEnv *, - jobject, - jlong); - -/* - * Class: com_intel_rpmp_PmPoolClient - * Method: dispose_ - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_com_intel_rpmp_PmPoolClient_dispose_1(JNIEnv *, - jobject, - jlong); - -#ifdef __cplusplus -} -#endif -#endif // PMPOOL_CLIENT_NATIVE_COM_INTEL_RPMP_PMPOOLCLIENT_H_ diff --git a/rpmp/pmpool/client/native/concurrent_map.h b/rpmp/pmpool/client/native/concurrent_map.h new file mode 100644 index 00000000..c74659ed --- /dev/null +++ b/rpmp/pmpool/client/native/concurrent_map.h @@ -0,0 +1,67 @@ +#ifndef JNI_ID_TO_MODULE_MAP_H +#define JNI_ID_TO_MODULE_MAP_H + +#include +#include +#include +#include +#include + +namespace arrow { +namespace jni { + +/** + * An utility class that map module id to module pointers. + * @tparam Holder class of the object to hold. + */ +template +class ConcurrentMap { + public: + ConcurrentMap() : module_id_(init_module_id_) {} + + jlong Insert(Holder holder) { + std::lock_guard lock(mtx_); + jlong result = module_id_++; + map_.insert(std::pair(result, holder)); + return result; + } + + void Erase(jlong module_id) { + std::lock_guard lock(mtx_); + map_.erase(module_id); + } + + Holder Lookup(jlong module_id) { + std::lock_guard lock(mtx_); + auto it = map_.find(module_id); + if (it != map_.end()) { + return it->second; + } + return nullptr; + } + + void Clear() { + std::lock_guard lock(mtx_); + map_.clear(); + } + + size_t Size() { + std::lock_guard lock(mtx_); + return map_.size(); + } + + private: + // Initialize the module id starting value to a number greater than zero + // to allow for easier debugging of uninitialized java variables. + static constexpr int init_module_id_ = 4; + + int64_t module_id_; + std::mutex mtx_; + // map from module ids returned to Java and module pointers + std::unordered_map map_; +}; + +} // namespace jni +} // namespace arrow + +#endif // JNI_ID_TO_MODULE_MAP_H diff --git a/rpmp/pmpool/queue/blockingconcurrentqueue.h b/rpmp/pmpool/queue/blockingconcurrentqueue.h index c855f9df..d4e848cc 100644 --- a/rpmp/pmpool/queue/blockingconcurrentqueue.h +++ b/rpmp/pmpool/queue/blockingconcurrentqueue.h @@ -6,12 +6,12 @@ #pragma once -#include "concurrentqueue.h" -#include #include -#include #include #include +#include +#include +#include "concurrentqueue.h" #if defined(_WIN32) // Avoid including windows.h in a header; we only need a handful of @@ -20,11 +20,16 @@ // I know this is an ugly hack but it still beats polluting the global // namespace with thousands of generic names or adding a .cpp for nothing. extern "C" { - struct _SECURITY_ATTRIBUTES; - __declspec(dllimport) void* __stdcall CreateSemaphoreW(_SECURITY_ATTRIBUTES* lpSemaphoreAttributes, long lInitialCount, long lMaximumCount, const wchar_t* lpName); - __declspec(dllimport) int __stdcall CloseHandle(void* hObject); - __declspec(dllimport) unsigned long __stdcall WaitForSingleObject(void* hHandle, unsigned long dwMilliseconds); - __declspec(dllimport) int __stdcall ReleaseSemaphore(void* hSemaphore, long lReleaseCount, long* lpPreviousCount); +struct _SECURITY_ATTRIBUTES; +__declspec(dllimport) void *__stdcall CreateSemaphoreW( + _SECURITY_ATTRIBUTES *lpSemaphoreAttributes, long lInitialCount, + long lMaximumCount, const wchar_t *lpName); +__declspec(dllimport) int __stdcall CloseHandle(void *hObject); +__declspec(dllimport) unsigned long __stdcall WaitForSingleObject( + void *hHandle, unsigned long dwMilliseconds); +__declspec(dllimport) int __stdcall ReleaseSemaphore(void *hSemaphore, + long lReleaseCount, + long *lpPreviousCount); } #elif defined(__MACH__) #include @@ -32,950 +37,898 @@ extern "C" { #include #endif -namespace moodycamel -{ -namespace details -{ - // Code in the mpmc_sema namespace below is an adaptation of Jeff Preshing's - // portable + lightweight semaphore implementations, originally from - // https://github.com/preshing/cpp11-on-multicore/blob/master/common/sema.h - // LICENSE: - // Copyright (c) 2015 Jeff Preshing - // - // This software is provided 'as-is', without any express or implied - // warranty. In no event will the authors be held liable for any damages - // arising from the use of this software. - // - // Permission is granted to anyone to use this software for any purpose, - // including commercial applications, and to alter it and redistribute it - // freely, subject to the following restrictions: - // - // 1. The origin of this software must not be misrepresented; you must not - // claim that you wrote the original software. If you use this software - // in a product, an acknowledgement in the product documentation would be - // appreciated but is not required. - // 2. Altered source versions must be plainly marked as such, and must not be - // misrepresented as being the original software. - // 3. This notice may not be removed or altered from any source distribution. - namespace mpmc_sema - { +namespace moodycamel { +namespace details { +// Code in the mpmc_sema namespace below is an adaptation of Jeff Preshing's +// portable + lightweight semaphore implementations, originally from +// https://github.com/preshing/cpp11-on-multicore/blob/master/common/sema.h +// LICENSE: +// Copyright (c) 2015 Jeff Preshing +// +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. +// +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: +// +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgement in the product documentation would be +// appreciated but is not required. +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. +// 3. This notice may not be removed or altered from any source distribution. +namespace mpmc_sema { #if defined(_WIN32) - class Semaphore - { - private: - void* m_hSema; - - Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; - Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; - - public: - Semaphore(int initialCount = 0) - { - assert(initialCount >= 0); - const long maxLong = 0x7fffffff; - m_hSema = CreateSemaphoreW(nullptr, initialCount, maxLong, nullptr); - } - - ~Semaphore() - { - CloseHandle(m_hSema); - } - - void wait() - { - const unsigned long infinite = 0xffffffff; - WaitForSingleObject(m_hSema, infinite); - } - - bool try_wait() - { - const unsigned long RC_WAIT_TIMEOUT = 0x00000102; - return WaitForSingleObject(m_hSema, 0) != RC_WAIT_TIMEOUT; - } - - bool timed_wait(std::uint64_t usecs) - { - const unsigned long RC_WAIT_TIMEOUT = 0x00000102; - return WaitForSingleObject(m_hSema, (unsigned long)(usecs / 1000)) != RC_WAIT_TIMEOUT; - } - - void signal(int count = 1) - { - ReleaseSemaphore(m_hSema, count, nullptr); - } - }; +class Semaphore { + private: + void *m_hSema; + + Semaphore(const Semaphore &other) MOODYCAMEL_DELETE_FUNCTION; + Semaphore &operator=(const Semaphore &other) MOODYCAMEL_DELETE_FUNCTION; + + public: + Semaphore(int initialCount = 0) { + assert(initialCount >= 0); + const long maxLong = 0x7fffffff; + m_hSema = CreateSemaphoreW(nullptr, initialCount, maxLong, nullptr); + } + + ~Semaphore() { CloseHandle(m_hSema); } + + void wait() { + const unsigned long infinite = 0xffffffff; + WaitForSingleObject(m_hSema, infinite); + } + + bool try_wait() { + const unsigned long RC_WAIT_TIMEOUT = 0x00000102; + return WaitForSingleObject(m_hSema, 0) != RC_WAIT_TIMEOUT; + } + + bool timed_wait(std::uint64_t usecs) { + const unsigned long RC_WAIT_TIMEOUT = 0x00000102; + return WaitForSingleObject(m_hSema, (unsigned long)(usecs / 1000)) != + RC_WAIT_TIMEOUT; + } + + void signal(int count = 1) { ReleaseSemaphore(m_hSema, count, nullptr); } +}; #elif defined(__MACH__) - //--------------------------------------------------------- - // Semaphore (Apple iOS and OSX) - // Can't use POSIX semaphores due to http://lists.apple.com/archives/darwin-kernel/2009/Apr/msg00010.html - //--------------------------------------------------------- - class Semaphore - { - private: - semaphore_t m_sema; - - Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; - Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; - - public: - Semaphore(int initialCount = 0) - { - assert(initialCount >= 0); - semaphore_create(mach_task_self(), &m_sema, SYNC_POLICY_FIFO, initialCount); - } - - ~Semaphore() - { - semaphore_destroy(mach_task_self(), m_sema); - } - - void wait() - { - semaphore_wait(m_sema); - } - - bool try_wait() - { - return timed_wait(0); - } - - bool timed_wait(std::uint64_t timeout_usecs) - { - mach_timespec_t ts; - ts.tv_sec = static_cast(timeout_usecs / 1000000); - ts.tv_nsec = (timeout_usecs % 1000000) * 1000; - - // added in OSX 10.10: https://developer.apple.com/library/prerelease/mac/documentation/General/Reference/APIDiffsMacOSX10_10SeedDiff/modules/Darwin.html - kern_return_t rc = semaphore_timedwait(m_sema, ts); - - return rc != KERN_OPERATION_TIMED_OUT && rc != KERN_ABORTED; - } - - void signal() - { - semaphore_signal(m_sema); - } - - void signal(int count) - { - while (count-- > 0) - { - semaphore_signal(m_sema); - } - } - }; +//--------------------------------------------------------- +// Semaphore (Apple iOS and OSX) +// Can't use POSIX semaphores due to +// http://lists.apple.com/archives/darwin-kernel/2009/Apr/msg00010.html +//--------------------------------------------------------- +class Semaphore { + private: + semaphore_t m_sema; + + Semaphore(const Semaphore &other) MOODYCAMEL_DELETE_FUNCTION; + Semaphore &operator=(const Semaphore &other) MOODYCAMEL_DELETE_FUNCTION; + + public: + Semaphore(int initialCount = 0) { + assert(initialCount >= 0); + semaphore_create(mach_task_self(), &m_sema, SYNC_POLICY_FIFO, initialCount); + } + + ~Semaphore() { semaphore_destroy(mach_task_self(), m_sema); } + + void wait() { semaphore_wait(m_sema); } + + bool try_wait() { return timed_wait(0); } + + bool timed_wait(std::uint64_t timeout_usecs) { + mach_timespec_t ts; + ts.tv_sec = static_cast(timeout_usecs / 1000000); + ts.tv_nsec = (timeout_usecs % 1000000) * 1000; + + // added in OSX 10.10: + // https://developer.apple.com/library/prerelease/mac/documentation/General/Reference/APIDiffsMacOSX10_10SeedDiff/modules/Darwin.html + kern_return_t rc = semaphore_timedwait(m_sema, ts); + + return rc != KERN_OPERATION_TIMED_OUT && rc != KERN_ABORTED; + } + + void signal() { semaphore_signal(m_sema); } + + void signal(int count) { + while (count-- > 0) { + semaphore_signal(m_sema); + } + } +}; #elif defined(__unix__) - //--------------------------------------------------------- - // Semaphore (POSIX, Linux) - //--------------------------------------------------------- - class Semaphore - { - private: - sem_t m_sema; - - Semaphore(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; - Semaphore& operator=(const Semaphore& other) MOODYCAMEL_DELETE_FUNCTION; - - public: - Semaphore(int initialCount = 0) - { - assert(initialCount >= 0); - sem_init(&m_sema, 0, initialCount); - } - - ~Semaphore() - { - sem_destroy(&m_sema); - } - - void wait() - { - // http://stackoverflow.com/questions/2013181/gdb-causes-sem-wait-to-fail-with-eintr-error - int rc; - do { - rc = sem_wait(&m_sema); - } while (rc == -1 && errno == EINTR); - } - - bool try_wait() - { - int rc; - do { - rc = sem_trywait(&m_sema); - } while (rc == -1 && errno == EINTR); - return !(rc == -1 && errno == EAGAIN); - } - - bool timed_wait(std::uint64_t usecs) - { - struct timespec ts; - const int usecs_in_1_sec = 1000000; - const int nsecs_in_1_sec = 1000000000; - clock_gettime(CLOCK_REALTIME, &ts); - ts.tv_sec += usecs / usecs_in_1_sec; - ts.tv_nsec += (usecs % usecs_in_1_sec) * 1000; - // sem_timedwait bombs if you have more than 1e9 in tv_nsec - // so we have to clean things up before passing it in - if (ts.tv_nsec >= nsecs_in_1_sec) { - ts.tv_nsec -= nsecs_in_1_sec; - ++ts.tv_sec; - } - - int rc; - do { - rc = sem_timedwait(&m_sema, &ts); - } while (rc == -1 && errno == EINTR); - return !(rc == -1 && errno == ETIMEDOUT); - } - - void signal() - { - sem_post(&m_sema); - } - - void signal(int count) - { - while (count-- > 0) - { - sem_post(&m_sema); - } - } - }; +//--------------------------------------------------------- +// Semaphore (POSIX, Linux) +//--------------------------------------------------------- +class Semaphore { + private: + sem_t m_sema; + + Semaphore(const Semaphore &other) MOODYCAMEL_DELETE_FUNCTION; + Semaphore &operator=(const Semaphore &other) MOODYCAMEL_DELETE_FUNCTION; + + public: + Semaphore(int initialCount = 0) { + assert(initialCount >= 0); + sem_init(&m_sema, 0, initialCount); + } + + ~Semaphore() { sem_destroy(&m_sema); } + + void wait() { + // http://stackoverflow.com/questions/2013181/gdb-causes-sem-wait-to-fail-with-eintr-error + int rc; + do { + rc = sem_wait(&m_sema); + } while (rc == -1 && errno == EINTR); + } + + bool try_wait() { + int rc; + do { + rc = sem_trywait(&m_sema); + } while (rc == -1 && errno == EINTR); + return !(rc == -1 && errno == EAGAIN); + } + + bool timed_wait(std::uint64_t usecs) { + struct timespec ts; + const int usecs_in_1_sec = 1000000; + const int nsecs_in_1_sec = 1000000000; + clock_gettime(CLOCK_REALTIME, &ts); + ts.tv_sec += usecs / usecs_in_1_sec; + ts.tv_nsec += (usecs % usecs_in_1_sec) * 1000; + // sem_timedwait bombs if you have more than 1e9 in tv_nsec + // so we have to clean things up before passing it in + if (ts.tv_nsec >= nsecs_in_1_sec) { + ts.tv_nsec -= nsecs_in_1_sec; + ++ts.tv_sec; + } + + int rc; + do { + rc = sem_timedwait(&m_sema, &ts); + } while (rc == -1 && errno == EINTR); + return !(rc == -1 && errno == ETIMEDOUT); + } + + void signal() { sem_post(&m_sema); } + + void signal(int count) { + while (count-- > 0) { + sem_post(&m_sema); + } + } +}; #else #error Unsupported platform! (No semaphore wrapper available) #endif - //--------------------------------------------------------- - // LightweightSemaphore - //--------------------------------------------------------- - class LightweightSemaphore - { - public: - typedef std::make_signed::type ssize_t; - - private: - std::atomic m_count; - Semaphore m_sema; - - bool waitWithPartialSpinning(std::int64_t timeout_usecs = -1) - { - ssize_t oldCount; - // Is there a better way to set the initial spin count? - // If we lower it to 1000, testBenaphore becomes 15x slower on my Core i7-5930K Windows PC, - // as threads start hitting the kernel semaphore. - int spin = 10000; - while (--spin >= 0) - { - oldCount = m_count.load(std::memory_order_relaxed); - if ((oldCount > 0) && m_count.compare_exchange_strong(oldCount, oldCount - 1, std::memory_order_acquire, std::memory_order_relaxed)) - return true; - std::atomic_signal_fence(std::memory_order_acquire); // Prevent the compiler from collapsing the loop. - } - oldCount = m_count.fetch_sub(1, std::memory_order_acquire); - if (oldCount > 0) - return true; - if (timeout_usecs < 0) - { - m_sema.wait(); - return true; - } - if (m_sema.timed_wait((std::uint64_t)timeout_usecs)) - return true; - // At this point, we've timed out waiting for the semaphore, but the - // count is still decremented indicating we may still be waiting on - // it. So we have to re-adjust the count, but only if the semaphore - // wasn't signaled enough times for us too since then. If it was, we - // need to release the semaphore too. - while (true) - { - oldCount = m_count.load(std::memory_order_acquire); - if (oldCount >= 0 && m_sema.try_wait()) - return true; - if (oldCount < 0 && m_count.compare_exchange_strong(oldCount, oldCount + 1, std::memory_order_relaxed, std::memory_order_relaxed)) - return false; - } - } - - ssize_t waitManyWithPartialSpinning(ssize_t max, std::int64_t timeout_usecs = -1) - { - assert(max > 0); - ssize_t oldCount; - int spin = 10000; - while (--spin >= 0) - { - oldCount = m_count.load(std::memory_order_relaxed); - if (oldCount > 0) - { - ssize_t newCount = oldCount > max ? oldCount - max : 0; - if (m_count.compare_exchange_strong(oldCount, newCount, std::memory_order_acquire, std::memory_order_relaxed)) - return oldCount - newCount; - } - std::atomic_signal_fence(std::memory_order_acquire); - } - oldCount = m_count.fetch_sub(1, std::memory_order_acquire); - if (oldCount <= 0) - { - if (timeout_usecs < 0) - m_sema.wait(); - else if (!m_sema.timed_wait((std::uint64_t)timeout_usecs)) - { - while (true) - { - oldCount = m_count.load(std::memory_order_acquire); - if (oldCount >= 0 && m_sema.try_wait()) - break; - if (oldCount < 0 && m_count.compare_exchange_strong(oldCount, oldCount + 1, std::memory_order_relaxed, std::memory_order_relaxed)) - return 0; - } - } - } - if (max > 1) - return 1 + tryWaitMany(max - 1); - return 1; - } - - public: - LightweightSemaphore(ssize_t initialCount = 0) : m_count(initialCount) - { - assert(initialCount >= 0); - } - - bool tryWait() - { - ssize_t oldCount = m_count.load(std::memory_order_relaxed); - while (oldCount > 0) - { - if (m_count.compare_exchange_weak(oldCount, oldCount - 1, std::memory_order_acquire, std::memory_order_relaxed)) - return true; - } - return false; - } - - void wait() - { - if (!tryWait()) - waitWithPartialSpinning(); - } - - bool wait(std::int64_t timeout_usecs) - { - return tryWait() || waitWithPartialSpinning(timeout_usecs); - } - - // Acquires between 0 and (greedily) max, inclusive - ssize_t tryWaitMany(ssize_t max) - { - assert(max >= 0); - ssize_t oldCount = m_count.load(std::memory_order_relaxed); - while (oldCount > 0) - { - ssize_t newCount = oldCount > max ? oldCount - max : 0; - if (m_count.compare_exchange_weak(oldCount, newCount, std::memory_order_acquire, std::memory_order_relaxed)) - return oldCount - newCount; - } - return 0; - } - - // Acquires at least one, and (greedily) at most max - ssize_t waitMany(ssize_t max, std::int64_t timeout_usecs) - { - assert(max >= 0); - ssize_t result = tryWaitMany(max); - if (result == 0 && max > 0) - result = waitManyWithPartialSpinning(max, timeout_usecs); - return result; - } - - ssize_t waitMany(ssize_t max) - { - ssize_t result = waitMany(max, -1); - assert(result > 0); - return result; - } - - void signal(ssize_t count = 1) - { - assert(count >= 0); - ssize_t oldCount = m_count.fetch_add(count, std::memory_order_release); - ssize_t toRelease = -oldCount < count ? -oldCount : count; - if (toRelease > 0) - { - m_sema.signal((int)toRelease); - } - } - - ssize_t availableApprox() const - { - ssize_t count = m_count.load(std::memory_order_relaxed); - return count > 0 ? count : 0; - } - }; - } // end namespace mpmc_sema -} // end namespace details - - -// This is a blocking version of the queue. It has an almost identical interface to -// the normal non-blocking version, with the addition of various wait_dequeue() methods -// and the removal of producer-specific dequeue methods. -template -class BlockingConcurrentQueue -{ -private: - typedef ::moodycamel::ConcurrentQueue ConcurrentQueue; - typedef details::mpmc_sema::LightweightSemaphore LightweightSemaphore; - -public: - typedef typename ConcurrentQueue::producer_token_t producer_token_t; - typedef typename ConcurrentQueue::consumer_token_t consumer_token_t; - - typedef typename ConcurrentQueue::index_t index_t; - typedef typename ConcurrentQueue::size_t size_t; - typedef typename std::make_signed::type ssize_t; - - static const size_t BLOCK_SIZE = ConcurrentQueue::BLOCK_SIZE; - static const size_t EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD = ConcurrentQueue::EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD; - static const size_t EXPLICIT_INITIAL_INDEX_SIZE = ConcurrentQueue::EXPLICIT_INITIAL_INDEX_SIZE; - static const size_t IMPLICIT_INITIAL_INDEX_SIZE = ConcurrentQueue::IMPLICIT_INITIAL_INDEX_SIZE; - static const size_t INITIAL_IMPLICIT_PRODUCER_HASH_SIZE = ConcurrentQueue::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE; - static const std::uint32_t EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE = ConcurrentQueue::EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE; - static const size_t MAX_SUBQUEUE_SIZE = ConcurrentQueue::MAX_SUBQUEUE_SIZE; - -public: - // Creates a queue with at least `capacity` element slots; note that the - // actual number of elements that can be inserted without additional memory - // allocation depends on the number of producers and the block size (e.g. if - // the block size is equal to `capacity`, only a single block will be allocated - // up-front, which means only a single producer will be able to enqueue elements - // without an extra allocation -- blocks aren't shared between producers). - // This method is not thread safe -- it is up to the user to ensure that the - // queue is fully constructed before it starts being used by other threads (this - // includes making the memory effects of construction visible, possibly with a - // memory barrier). - explicit BlockingConcurrentQueue(size_t capacity = 6 * BLOCK_SIZE) - : inner(capacity), sema(create(), &BlockingConcurrentQueue::template destroy) - { - assert(reinterpret_cast((BlockingConcurrentQueue*)1) == &((BlockingConcurrentQueue*)1)->inner && "BlockingConcurrentQueue must have ConcurrentQueue as its first member"); - if (!sema) { - MOODYCAMEL_THROW(std::bad_alloc()); - } - } - - BlockingConcurrentQueue(size_t minCapacity, size_t maxExplicitProducers, size_t maxImplicitProducers) - : inner(minCapacity, maxExplicitProducers, maxImplicitProducers), sema(create(), &BlockingConcurrentQueue::template destroy) - { - assert(reinterpret_cast((BlockingConcurrentQueue*)1) == &((BlockingConcurrentQueue*)1)->inner && "BlockingConcurrentQueue must have ConcurrentQueue as its first member"); - if (!sema) { - MOODYCAMEL_THROW(std::bad_alloc()); - } - } - - // Disable copying and copy assignment - BlockingConcurrentQueue(BlockingConcurrentQueue const&) MOODYCAMEL_DELETE_FUNCTION; - BlockingConcurrentQueue& operator=(BlockingConcurrentQueue const&) MOODYCAMEL_DELETE_FUNCTION; - - // Moving is supported, but note that it is *not* a thread-safe operation. - // Nobody can use the queue while it's being moved, and the memory effects - // of that move must be propagated to other threads before they can use it. - // Note: When a queue is moved, its tokens are still valid but can only be - // used with the destination queue (i.e. semantically they are moved along - // with the queue itself). - BlockingConcurrentQueue(BlockingConcurrentQueue&& other) MOODYCAMEL_NOEXCEPT - : inner(std::move(other.inner)), sema(std::move(other.sema)) - { } - - inline BlockingConcurrentQueue& operator=(BlockingConcurrentQueue&& other) MOODYCAMEL_NOEXCEPT - { - return swap_internal(other); - } - - // Swaps this queue's state with the other's. Not thread-safe. - // Swapping two queues does not invalidate their tokens, however - // the tokens that were created for one queue must be used with - // only the swapped queue (i.e. the tokens are tied to the - // queue's movable state, not the object itself). - inline void swap(BlockingConcurrentQueue& other) MOODYCAMEL_NOEXCEPT - { - swap_internal(other); - } - -private: - BlockingConcurrentQueue& swap_internal(BlockingConcurrentQueue& other) - { - if (this == &other) { - return *this; - } - - inner.swap(other.inner); - sema.swap(other.sema); - return *this; - } - -public: - // Enqueues a single item (by copying it). - // Allocates memory if required. Only fails if memory allocation fails (or implicit - // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, - // or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Thread-safe. - inline bool enqueue(T const& item) - { - if ((details::likely)(inner.enqueue(item))) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues a single item (by moving it, if possible). - // Allocates memory if required. Only fails if memory allocation fails (or implicit - // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, - // or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Thread-safe. - inline bool enqueue(T&& item) - { - if ((details::likely)(inner.enqueue(std::move(item)))) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues a single item (by copying it) using an explicit producer token. - // Allocates memory if required. Only fails if memory allocation fails (or - // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Thread-safe. - inline bool enqueue(producer_token_t const& token, T const& item) - { - if ((details::likely)(inner.enqueue(token, item))) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues a single item (by moving it, if possible) using an explicit producer token. - // Allocates memory if required. Only fails if memory allocation fails (or - // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Thread-safe. - inline bool enqueue(producer_token_t const& token, T&& item) - { - if ((details::likely)(inner.enqueue(token, std::move(item)))) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues several items. - // Allocates memory if required. Only fails if memory allocation fails (or - // implicit production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE - // is 0, or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Note: Use std::make_move_iterator if the elements should be moved instead of copied. - // Thread-safe. - template - inline bool enqueue_bulk(It itemFirst, size_t count) - { - if ((details::likely)(inner.enqueue_bulk(std::forward(itemFirst), count))) { - sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); - return true; - } - return false; - } - - // Enqueues several items using an explicit producer token. - // Allocates memory if required. Only fails if memory allocation fails - // (or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). - // Note: Use std::make_move_iterator if the elements should be moved - // instead of copied. - // Thread-safe. - template - inline bool enqueue_bulk(producer_token_t const& token, It itemFirst, size_t count) - { - if ((details::likely)(inner.enqueue_bulk(token, std::forward(itemFirst), count))) { - sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); - return true; - } - return false; - } - - // Enqueues a single item (by copying it). - // Does not allocate memory. Fails if not enough room to enqueue (or implicit - // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE - // is 0). - // Thread-safe. - inline bool try_enqueue(T const& item) - { - if (inner.try_enqueue(item)) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues a single item (by moving it, if possible). - // Does not allocate memory (except for one-time implicit producer). - // Fails if not enough room to enqueue (or implicit production is - // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). - // Thread-safe. - inline bool try_enqueue(T&& item) - { - if (inner.try_enqueue(std::move(item))) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues a single item (by copying it) using an explicit producer token. - // Does not allocate memory. Fails if not enough room to enqueue. - // Thread-safe. - inline bool try_enqueue(producer_token_t const& token, T const& item) - { - if (inner.try_enqueue(token, item)) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues a single item (by moving it, if possible) using an explicit producer token. - // Does not allocate memory. Fails if not enough room to enqueue. - // Thread-safe. - inline bool try_enqueue(producer_token_t const& token, T&& item) - { - if (inner.try_enqueue(token, std::move(item))) { - sema->signal(); - return true; - } - return false; - } - - // Enqueues several items. - // Does not allocate memory (except for one-time implicit producer). - // Fails if not enough room to enqueue (or implicit production is - // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). - // Note: Use std::make_move_iterator if the elements should be moved - // instead of copied. - // Thread-safe. - template - inline bool try_enqueue_bulk(It itemFirst, size_t count) - { - if (inner.try_enqueue_bulk(std::forward(itemFirst), count)) { - sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); - return true; - } - return false; - } - - // Enqueues several items using an explicit producer token. - // Does not allocate memory. Fails if not enough room to enqueue. - // Note: Use std::make_move_iterator if the elements should be moved - // instead of copied. - // Thread-safe. - template - inline bool try_enqueue_bulk(producer_token_t const& token, It itemFirst, size_t count) - { - if (inner.try_enqueue_bulk(token, std::forward(itemFirst), count)) { - sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); - return true; - } - return false; - } - - - // Attempts to dequeue from the queue. - // Returns false if all producer streams appeared empty at the time they - // were checked (so, the queue is likely but not guaranteed to be empty). - // Never allocates. Thread-safe. - template - inline bool try_dequeue(U& item) - { - if (sema->tryWait()) { - while (!inner.try_dequeue(item)) { - continue; - } - return true; - } - return false; - } - - // Attempts to dequeue from the queue using an explicit consumer token. - // Returns false if all producer streams appeared empty at the time they - // were checked (so, the queue is likely but not guaranteed to be empty). - // Never allocates. Thread-safe. - template - inline bool try_dequeue(consumer_token_t& token, U& item) - { - if (sema->tryWait()) { - while (!inner.try_dequeue(token, item)) { - continue; - } - return true; - } - return false; - } - - // Attempts to dequeue several elements from the queue. - // Returns the number of items actually dequeued. - // Returns 0 if all producer streams appeared empty at the time they - // were checked (so, the queue is likely but not guaranteed to be empty). - // Never allocates. Thread-safe. - template - inline size_t try_dequeue_bulk(It itemFirst, size_t max) - { - size_t count = 0; - max = (size_t)sema->tryWaitMany((LightweightSemaphore::ssize_t)(ssize_t)max); - while (count != max) { - count += inner.template try_dequeue_bulk(itemFirst, max - count); - } - return count; - } - - // Attempts to dequeue several elements from the queue using an explicit consumer token. - // Returns the number of items actually dequeued. - // Returns 0 if all producer streams appeared empty at the time they - // were checked (so, the queue is likely but not guaranteed to be empty). - // Never allocates. Thread-safe. - template - inline size_t try_dequeue_bulk(consumer_token_t& token, It itemFirst, size_t max) - { - size_t count = 0; - max = (size_t)sema->tryWaitMany((LightweightSemaphore::ssize_t)(ssize_t)max); - while (count != max) { - count += inner.template try_dequeue_bulk(token, itemFirst, max - count); - } - return count; - } - - - - // Blocks the current thread until there's something to dequeue, then - // dequeues it. - // Never allocates. Thread-safe. - template - inline void wait_dequeue(U& item) - { - sema->wait(); - while (!inner.try_dequeue(item)) { - continue; - } - } - - // Blocks the current thread until either there's something to dequeue - // or the timeout (specified in microseconds) expires. Returns false - // without setting `item` if the timeout expires, otherwise assigns - // to `item` and returns true. - // Using a negative timeout indicates an indefinite timeout, - // and is thus functionally equivalent to calling wait_dequeue. - // Never allocates. Thread-safe. - template - inline bool wait_dequeue_timed(U& item, std::int64_t timeout_usecs) - { - if (!sema->wait(timeout_usecs)) { - return false; - } - while (!inner.try_dequeue(item)) { - continue; - } - return true; - } - - // Blocks the current thread until either there's something to dequeue - // or the timeout expires. Returns false without setting `item` if the - // timeout expires, otherwise assigns to `item` and returns true. - // Never allocates. Thread-safe. - template - inline bool wait_dequeue_timed(U& item, std::chrono::duration const& timeout) - { - return wait_dequeue_timed(item, std::chrono::duration_cast(timeout).count()); - } - - // Blocks the current thread until there's something to dequeue, then - // dequeues it using an explicit consumer token. - // Never allocates. Thread-safe. - template - inline void wait_dequeue(consumer_token_t& token, U& item) - { - sema->wait(); - while (!inner.try_dequeue(token, item)) { - continue; - } - } - - // Blocks the current thread until either there's something to dequeue - // or the timeout (specified in microseconds) expires. Returns false - // without setting `item` if the timeout expires, otherwise assigns - // to `item` and returns true. - // Using a negative timeout indicates an indefinite timeout, - // and is thus functionally equivalent to calling wait_dequeue. - // Never allocates. Thread-safe. - template - inline bool wait_dequeue_timed(consumer_token_t& token, U& item, std::int64_t timeout_usecs) - { - if (!sema->wait(timeout_usecs)) { - return false; - } - while (!inner.try_dequeue(token, item)) { - continue; - } - return true; - } - - // Blocks the current thread until either there's something to dequeue - // or the timeout expires. Returns false without setting `item` if the - // timeout expires, otherwise assigns to `item` and returns true. - // Never allocates. Thread-safe. - template - inline bool wait_dequeue_timed(consumer_token_t& token, U& item, std::chrono::duration const& timeout) - { - return wait_dequeue_timed(token, item, std::chrono::duration_cast(timeout).count()); - } - - // Attempts to dequeue several elements from the queue. - // Returns the number of items actually dequeued, which will - // always be at least one (this method blocks until the queue - // is non-empty) and at most max. - // Never allocates. Thread-safe. - template - inline size_t wait_dequeue_bulk(It itemFirst, size_t max) - { - size_t count = 0; - max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max); - while (count != max) { - count += inner.template try_dequeue_bulk(itemFirst, max - count); - } - return count; - } - - // Attempts to dequeue several elements from the queue. - // Returns the number of items actually dequeued, which can - // be 0 if the timeout expires while waiting for elements, - // and at most max. - // Using a negative timeout indicates an indefinite timeout, - // and is thus functionally equivalent to calling wait_dequeue_bulk. - // Never allocates. Thread-safe. - template - inline size_t wait_dequeue_bulk_timed(It itemFirst, size_t max, std::int64_t timeout_usecs) - { - size_t count = 0; - max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max, timeout_usecs); - while (count != max) { - count += inner.template try_dequeue_bulk(itemFirst, max - count); - } - return count; - } - - // Attempts to dequeue several elements from the queue. - // Returns the number of items actually dequeued, which can - // be 0 if the timeout expires while waiting for elements, - // and at most max. - // Never allocates. Thread-safe. - template - inline size_t wait_dequeue_bulk_timed(It itemFirst, size_t max, std::chrono::duration const& timeout) - { - return wait_dequeue_bulk_timed(itemFirst, max, std::chrono::duration_cast(timeout).count()); - } - - // Attempts to dequeue several elements from the queue using an explicit consumer token. - // Returns the number of items actually dequeued, which will - // always be at least one (this method blocks until the queue - // is non-empty) and at most max. - // Never allocates. Thread-safe. - template - inline size_t wait_dequeue_bulk(consumer_token_t& token, It itemFirst, size_t max) - { - size_t count = 0; - max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max); - while (count != max) { - count += inner.template try_dequeue_bulk(token, itemFirst, max - count); - } - return count; - } - - // Attempts to dequeue several elements from the queue using an explicit consumer token. - // Returns the number of items actually dequeued, which can - // be 0 if the timeout expires while waiting for elements, - // and at most max. - // Using a negative timeout indicates an indefinite timeout, - // and is thus functionally equivalent to calling wait_dequeue_bulk. - // Never allocates. Thread-safe. - template - inline size_t wait_dequeue_bulk_timed(consumer_token_t& token, It itemFirst, size_t max, std::int64_t timeout_usecs) - { - size_t count = 0; - max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max, timeout_usecs); - while (count != max) { - count += inner.template try_dequeue_bulk(token, itemFirst, max - count); - } - return count; - } - - // Attempts to dequeue several elements from the queue using an explicit consumer token. - // Returns the number of items actually dequeued, which can - // be 0 if the timeout expires while waiting for elements, - // and at most max. - // Never allocates. Thread-safe. - template - inline size_t wait_dequeue_bulk_timed(consumer_token_t& token, It itemFirst, size_t max, std::chrono::duration const& timeout) - { - return wait_dequeue_bulk_timed(token, itemFirst, max, std::chrono::duration_cast(timeout).count()); - } - - - // Returns an estimate of the total number of elements currently in the queue. This - // estimate is only accurate if the queue has completely stabilized before it is called - // (i.e. all enqueue and dequeue operations have completed and their memory effects are - // visible on the calling thread, and no further operations start while this method is - // being called). - // Thread-safe. - inline size_t size_approx() const - { - return (size_t)sema->availableApprox(); - } - - - // Returns true if the underlying atomic variables used by - // the queue are lock-free (they should be on most platforms). - // Thread-safe. - static bool is_lock_free() - { - return ConcurrentQueue::is_lock_free(); - } - - -private: - template - static inline U* create() - { - auto p = (Traits::malloc)(sizeof(U)); - return p != nullptr ? new (p) U : nullptr; - } - - template - static inline U* create(A1&& a1) - { - auto p = (Traits::malloc)(sizeof(U)); - return p != nullptr ? new (p) U(std::forward(a1)) : nullptr; - } - - template - static inline void destroy(U* p) - { - if (p != nullptr) { - p->~U(); - } - (Traits::free)(p); - } - -private: - ConcurrentQueue inner; - std::unique_ptr sema; +//--------------------------------------------------------- +// LightweightSemaphore +//--------------------------------------------------------- +class LightweightSemaphore { + public: + typedef std::make_signed::type ssize_t; + + private: + std::atomic m_count; + Semaphore m_sema; + + bool waitWithPartialSpinning(std::int64_t timeout_usecs = -1) { + ssize_t oldCount; + // Is there a better way to set the initial spin count? + // If we lower it to 1000, testBenaphore becomes 15x slower on my Core + // i7-5930K Windows PC, as threads start hitting the kernel semaphore. + int spin = 10000; + while (--spin >= 0) { + oldCount = m_count.load(std::memory_order_relaxed); + if ((oldCount > 0) && + m_count.compare_exchange_strong(oldCount, oldCount - 1, + std::memory_order_acquire, + std::memory_order_relaxed)) + return true; + std::atomic_signal_fence(std::memory_order_acquire); // Prevent the + // compiler from + // collapsing the + // loop. + } + oldCount = m_count.fetch_sub(1, std::memory_order_acquire); + if (oldCount > 0) return true; + if (timeout_usecs < 0) { + m_sema.wait(); + return true; + } + if (m_sema.timed_wait((std::uint64_t)timeout_usecs)) return true; + // At this point, we've timed out waiting for the semaphore, but the + // count is still decremented indicating we may still be waiting on + // it. So we have to re-adjust the count, but only if the semaphore + // wasn't signaled enough times for us too since then. If it was, we + // need to release the semaphore too. + while (true) { + oldCount = m_count.load(std::memory_order_acquire); + if (oldCount >= 0 && m_sema.try_wait()) return true; + if (oldCount < 0 && m_count.compare_exchange_strong( + oldCount, oldCount + 1, std::memory_order_relaxed, + std::memory_order_relaxed)) + return false; + } + } + + ssize_t waitManyWithPartialSpinning(ssize_t max, + std::int64_t timeout_usecs = -1) { + assert(max > 0); + ssize_t oldCount; + int spin = 10000; + while (--spin >= 0) { + oldCount = m_count.load(std::memory_order_relaxed); + if (oldCount > 0) { + ssize_t newCount = oldCount > max ? oldCount - max : 0; + if (m_count.compare_exchange_strong(oldCount, newCount, + std::memory_order_acquire, + std::memory_order_relaxed)) + return oldCount - newCount; + } + std::atomic_signal_fence(std::memory_order_acquire); + } + oldCount = m_count.fetch_sub(1, std::memory_order_acquire); + if (oldCount <= 0) { + if (timeout_usecs < 0) + m_sema.wait(); + else if (!m_sema.timed_wait((std::uint64_t)timeout_usecs)) { + while (true) { + oldCount = m_count.load(std::memory_order_acquire); + if (oldCount >= 0 && m_sema.try_wait()) break; + if (oldCount < 0 && + m_count.compare_exchange_strong(oldCount, oldCount + 1, + std::memory_order_relaxed, + std::memory_order_relaxed)) + return 0; + } + } + } + if (max > 1) return 1 + tryWaitMany(max - 1); + return 1; + } + + public: + LightweightSemaphore(ssize_t initialCount = 0) : m_count(initialCount) { + assert(initialCount >= 0); + } + + bool tryWait() { + ssize_t oldCount = m_count.load(std::memory_order_relaxed); + while (oldCount > 0) { + if (m_count.compare_exchange_weak(oldCount, oldCount - 1, + std::memory_order_acquire, + std::memory_order_relaxed)) + return true; + } + return false; + } + + void wait() { + if (!tryWait()) waitWithPartialSpinning(); + } + + bool wait(std::int64_t timeout_usecs) { + return tryWait() || waitWithPartialSpinning(timeout_usecs); + } + + // Acquires between 0 and (greedily) max, inclusive + ssize_t tryWaitMany(ssize_t max) { + assert(max >= 0); + ssize_t oldCount = m_count.load(std::memory_order_relaxed); + while (oldCount > 0) { + ssize_t newCount = oldCount > max ? oldCount - max : 0; + if (m_count.compare_exchange_weak(oldCount, newCount, + std::memory_order_acquire, + std::memory_order_relaxed)) + return oldCount - newCount; + } + return 0; + } + + // Acquires at least one, and (greedily) at most max + ssize_t waitMany(ssize_t max, std::int64_t timeout_usecs) { + assert(max >= 0); + ssize_t result = tryWaitMany(max); + if (result == 0 && max > 0) + result = waitManyWithPartialSpinning(max, timeout_usecs); + return result; + } + + ssize_t waitMany(ssize_t max) { + ssize_t result = waitMany(max, -1); + assert(result > 0); + return result; + } + + void signal(ssize_t count = 1) { + assert(count >= 0); + ssize_t oldCount = m_count.fetch_add(count, std::memory_order_release); + ssize_t toRelease = -oldCount < count ? -oldCount : count; + if (toRelease > 0) { + m_sema.signal((int)toRelease); + } + } + + ssize_t availableApprox() const { + ssize_t count = m_count.load(std::memory_order_relaxed); + return count > 0 ? count : 0; + } }; +} // end namespace mpmc_sema +} // end namespace details + +// This is a blocking version of the queue. It has an almost identical interface +// to the normal non-blocking version, with the addition of various +// wait_dequeue() methods and the removal of producer-specific dequeue methods. +template +class BlockingConcurrentQueue { + private: + typedef ::moodycamel::ConcurrentQueue ConcurrentQueue; + typedef details::mpmc_sema::LightweightSemaphore LightweightSemaphore; + + public: + typedef typename ConcurrentQueue::producer_token_t producer_token_t; + typedef typename ConcurrentQueue::consumer_token_t consumer_token_t; + + typedef typename ConcurrentQueue::index_t index_t; + typedef typename ConcurrentQueue::size_t size_t; + typedef typename std::make_signed::type ssize_t; + + static const size_t BLOCK_SIZE = ConcurrentQueue::BLOCK_SIZE; + static const size_t EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD = + ConcurrentQueue::EXPLICIT_BLOCK_EMPTY_COUNTER_THRESHOLD; + static const size_t EXPLICIT_INITIAL_INDEX_SIZE = + ConcurrentQueue::EXPLICIT_INITIAL_INDEX_SIZE; + static const size_t IMPLICIT_INITIAL_INDEX_SIZE = + ConcurrentQueue::IMPLICIT_INITIAL_INDEX_SIZE; + static const size_t INITIAL_IMPLICIT_PRODUCER_HASH_SIZE = + ConcurrentQueue::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE; + static const std::uint32_t EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE = + ConcurrentQueue::EXPLICIT_CONSUMER_CONSUMPTION_QUOTA_BEFORE_ROTATE; + static const size_t MAX_SUBQUEUE_SIZE = ConcurrentQueue::MAX_SUBQUEUE_SIZE; + + public: + // Creates a queue with at least `capacity` element slots; note that the + // actual number of elements that can be inserted without additional memory + // allocation depends on the number of producers and the block size (e.g. if + // the block size is equal to `capacity`, only a single block will be + // allocated up-front, which means only a single producer will be able to + // enqueue elements without an extra allocation -- blocks aren't shared + // between producers). This method is not thread safe -- it is up to the user + // to ensure that the queue is fully constructed before it starts being used + // by other threads (this includes making the memory effects of construction + // visible, possibly with a memory barrier). + explicit BlockingConcurrentQueue(size_t capacity = 6 * BLOCK_SIZE) + : inner(capacity), + sema(create(), + &BlockingConcurrentQueue::template destroy) { + assert(reinterpret_cast((BlockingConcurrentQueue *)1) == + &((BlockingConcurrentQueue *)1)->inner && + "BlockingConcurrentQueue must have ConcurrentQueue as its first " + "member"); + if (!sema) { + MOODYCAMEL_THROW(std::bad_alloc()); + } + } + + BlockingConcurrentQueue(size_t minCapacity, size_t maxExplicitProducers, + size_t maxImplicitProducers) + : inner(minCapacity, maxExplicitProducers, maxImplicitProducers), + sema(create(), + &BlockingConcurrentQueue::template destroy) { + assert(reinterpret_cast((BlockingConcurrentQueue *)1) == + &((BlockingConcurrentQueue *)1)->inner && + "BlockingConcurrentQueue must have ConcurrentQueue as its first " + "member"); + if (!sema) { + MOODYCAMEL_THROW(std::bad_alloc()); + } + } + + // Disable copying and copy assignment + BlockingConcurrentQueue(BlockingConcurrentQueue const &) + MOODYCAMEL_DELETE_FUNCTION; + BlockingConcurrentQueue &operator=(BlockingConcurrentQueue const &) + MOODYCAMEL_DELETE_FUNCTION; + + // Moving is supported, but note that it is *not* a thread-safe operation. + // Nobody can use the queue while it's being moved, and the memory effects + // of that move must be propagated to other threads before they can use it. + // Note: When a queue is moved, its tokens are still valid but can only be + // used with the destination queue (i.e. semantically they are moved along + // with the queue itself). + BlockingConcurrentQueue(BlockingConcurrentQueue &&other) MOODYCAMEL_NOEXCEPT + : inner(std::move(other.inner)), + sema(std::move(other.sema)) {} + + inline BlockingConcurrentQueue &operator=(BlockingConcurrentQueue &&other) + MOODYCAMEL_NOEXCEPT { + return swap_internal(other); + } + + // Swaps this queue's state with the other's. Not thread-safe. + // Swapping two queues does not invalidate their tokens, however + // the tokens that were created for one queue must be used with + // only the swapped queue (i.e. the tokens are tied to the + // queue's movable state, not the object itself). + inline void swap(BlockingConcurrentQueue &other) MOODYCAMEL_NOEXCEPT { + swap_internal(other); + } + + private: + BlockingConcurrentQueue &swap_internal(BlockingConcurrentQueue &other) { + if (this == &other) { + return *this; + } + inner.swap(other.inner); + sema.swap(other.sema); + return *this; + } + + public: + // Enqueues a single item (by copying it). + // Allocates memory if required. Only fails if memory allocation fails (or + // implicit production is disabled because + // Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, or + // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(T const &item) { + if ((details::likely)(inner.enqueue(item))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by moving it, if possible). + // Allocates memory if required. Only fails if memory allocation fails (or + // implicit production is disabled because + // Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, or + // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(T &&item) { + if ((details::likely)(inner.enqueue(std::move(item)))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by copying it) using an explicit producer token. + // Allocates memory if required. Only fails if memory allocation fails (or + // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Thread-safe. + inline bool enqueue(producer_token_t const &token, T const &item) { + if ((details::likely)(inner.enqueue(token, item))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by moving it, if possible) using an explicit + // producer token. Allocates memory if required. Only fails if memory + // allocation fails (or Traits::MAX_SUBQUEUE_SIZE has been defined and would + // be surpassed). Thread-safe. + inline bool enqueue(producer_token_t const &token, T &&item) { + if ((details::likely)(inner.enqueue(token, std::move(item)))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues several items. + // Allocates memory if required. Only fails if memory allocation fails (or + // implicit production is disabled because + // Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0, or + // Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). Note: + // Use std::make_move_iterator if the elements should be moved instead of + // copied. Thread-safe. + template + inline bool enqueue_bulk(It itemFirst, size_t count) { + if ((details::likely)( + inner.enqueue_bulk(std::forward(itemFirst), count))) { + sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); + return true; + } + return false; + } + + // Enqueues several items using an explicit producer token. + // Allocates memory if required. Only fails if memory allocation fails + // (or Traits::MAX_SUBQUEUE_SIZE has been defined and would be surpassed). + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + inline bool enqueue_bulk(producer_token_t const &token, It itemFirst, + size_t count) { + if ((details::likely)( + inner.enqueue_bulk(token, std::forward(itemFirst), count))) { + sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); + return true; + } + return false; + } + + // Enqueues a single item (by copying it). + // Does not allocate memory. Fails if not enough room to enqueue (or implicit + // production is disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE + // is 0). + // Thread-safe. + inline bool try_enqueue(T const &item) { + if (inner.try_enqueue(item)) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by moving it, if possible). + // Does not allocate memory (except for one-time implicit producer). + // Fails if not enough room to enqueue (or implicit production is + // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). + // Thread-safe. + inline bool try_enqueue(T &&item) { + if (inner.try_enqueue(std::move(item))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by copying it) using an explicit producer token. + // Does not allocate memory. Fails if not enough room to enqueue. + // Thread-safe. + inline bool try_enqueue(producer_token_t const &token, T const &item) { + if (inner.try_enqueue(token, item)) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues a single item (by moving it, if possible) using an explicit + // producer token. Does not allocate memory. Fails if not enough room to + // enqueue. Thread-safe. + inline bool try_enqueue(producer_token_t const &token, T &&item) { + if (inner.try_enqueue(token, std::move(item))) { + sema->signal(); + return true; + } + return false; + } + + // Enqueues several items. + // Does not allocate memory (except for one-time implicit producer). + // Fails if not enough room to enqueue (or implicit production is + // disabled because Traits::INITIAL_IMPLICIT_PRODUCER_HASH_SIZE is 0). + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + inline bool try_enqueue_bulk(It itemFirst, size_t count) { + if (inner.try_enqueue_bulk(std::forward(itemFirst), count)) { + sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); + return true; + } + return false; + } + + // Enqueues several items using an explicit producer token. + // Does not allocate memory. Fails if not enough room to enqueue. + // Note: Use std::make_move_iterator if the elements should be moved + // instead of copied. + // Thread-safe. + template + inline bool try_enqueue_bulk(producer_token_t const &token, It itemFirst, + size_t count) { + if (inner.try_enqueue_bulk(token, std::forward(itemFirst), count)) { + sema->signal((LightweightSemaphore::ssize_t)(ssize_t)count); + return true; + } + return false; + } + + // Attempts to dequeue from the queue. + // Returns false if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline bool try_dequeue(U &item) { + if (sema->tryWait()) { + while (!inner.try_dequeue(item)) { + continue; + } + return true; + } + return false; + } + + // Attempts to dequeue from the queue using an explicit consumer token. + // Returns false if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline bool try_dequeue(consumer_token_t &token, U &item) { + if (sema->tryWait()) { + while (!inner.try_dequeue(token, item)) { + continue; + } + return true; + } + return false; + } + + // Attempts to dequeue several elements from the queue. + // Returns the number of items actually dequeued. + // Returns 0 if all producer streams appeared empty at the time they + // were checked (so, the queue is likely but not guaranteed to be empty). + // Never allocates. Thread-safe. + template + inline size_t try_dequeue_bulk(It itemFirst, size_t max) { + size_t count = 0; + max = + (size_t)sema->tryWaitMany((LightweightSemaphore::ssize_t)(ssize_t)max); + while (count != max) { + count += inner.template try_dequeue_bulk(itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue using an explicit + // consumer token. Returns the number of items actually dequeued. Returns 0 if + // all producer streams appeared empty at the time they were checked (so, the + // queue is likely but not guaranteed to be empty). Never allocates. + // Thread-safe. + template + inline size_t try_dequeue_bulk(consumer_token_t &token, It itemFirst, + size_t max) { + size_t count = 0; + max = + (size_t)sema->tryWaitMany((LightweightSemaphore::ssize_t)(ssize_t)max); + while (count != max) { + count += + inner.template try_dequeue_bulk(token, itemFirst, max - count); + } + return count; + } + + // Blocks the current thread until there's something to dequeue, then + // dequeues it. + // Never allocates. Thread-safe. + template + inline void wait_dequeue(U &item) { + sema->wait(); + while (!inner.try_dequeue(item)) { + continue; + } + } + + // Blocks the current thread until either there's something to dequeue + // or the timeout (specified in microseconds) expires. Returns false + // without setting `item` if the timeout expires, otherwise assigns + // to `item` and returns true. + // Using a negative timeout indicates an indefinite timeout, + // and is thus functionally equivalent to calling wait_dequeue. + // Never allocates. Thread-safe. + template + inline bool wait_dequeue_timed(U &item, std::int64_t timeout_usecs) { + if (!sema->wait(timeout_usecs)) { + return false; + } + while (!inner.try_dequeue(item)) { + continue; + } + return true; + } + + // Blocks the current thread until either there's something to dequeue + // or the timeout expires. Returns false without setting `item` if the + // timeout expires, otherwise assigns to `item` and returns true. + // Never allocates. Thread-safe. + template + inline bool wait_dequeue_timed( + U &item, std::chrono::duration const &timeout) { + return wait_dequeue_timed( + item, + std::chrono::duration_cast(timeout).count()); + } + + // Blocks the current thread until there's something to dequeue, then + // dequeues it using an explicit consumer token. + // Never allocates. Thread-safe. + template + inline void wait_dequeue(consumer_token_t &token, U &item) { + sema->wait(); + while (!inner.try_dequeue(token, item)) { + continue; + } + } + + // Blocks the current thread until either there's something to dequeue + // or the timeout (specified in microseconds) expires. Returns false + // without setting `item` if the timeout expires, otherwise assigns + // to `item` and returns true. + // Using a negative timeout indicates an indefinite timeout, + // and is thus functionally equivalent to calling wait_dequeue. + // Never allocates. Thread-safe. + template + inline bool wait_dequeue_timed(consumer_token_t &token, U &item, + std::int64_t timeout_usecs) { + if (!sema->wait(timeout_usecs)) { + return false; + } + while (!inner.try_dequeue(token, item)) { + continue; + } + return true; + } + + // Blocks the current thread until either there's something to dequeue + // or the timeout expires. Returns false without setting `item` if the + // timeout expires, otherwise assigns to `item` and returns true. + // Never allocates. Thread-safe. + template + inline bool wait_dequeue_timed( + consumer_token_t &token, U &item, + std::chrono::duration const &timeout) { + return wait_dequeue_timed( + token, item, + std::chrono::duration_cast(timeout).count()); + } + + // Attempts to dequeue several elements from the queue. + // Returns the number of items actually dequeued, which will + // always be at least one (this method blocks until the queue + // is non-empty) and at most max. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk(It itemFirst, size_t max) { + size_t count = 0; + max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max); + while (count != max) { + count += inner.template try_dequeue_bulk(itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue. + // Returns the number of items actually dequeued, which can + // be 0 if the timeout expires while waiting for elements, + // and at most max. + // Using a negative timeout indicates an indefinite timeout, + // and is thus functionally equivalent to calling wait_dequeue_bulk. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk_timed(It itemFirst, size_t max, + std::int64_t timeout_usecs) { + size_t count = 0; + max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max, + timeout_usecs); + while (count != max) { + count += inner.template try_dequeue_bulk(itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue. + // Returns the number of items actually dequeued, which can + // be 0 if the timeout expires while waiting for elements, + // and at most max. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk_timed( + It itemFirst, size_t max, + std::chrono::duration const &timeout) { + return wait_dequeue_bulk_timed( + itemFirst, max, + std::chrono::duration_cast(timeout).count()); + } + + // Attempts to dequeue several elements from the queue using an explicit + // consumer token. Returns the number of items actually dequeued, which will + // always be at least one (this method blocks until the queue + // is non-empty) and at most max. + // Never allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk(consumer_token_t &token, It itemFirst, + size_t max) { + size_t count = 0; + max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max); + while (count != max) { + count += + inner.template try_dequeue_bulk(token, itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue using an explicit + // consumer token. Returns the number of items actually dequeued, which can be + // 0 if the timeout expires while waiting for elements, and at most max. Using + // a negative timeout indicates an indefinite timeout, and is thus + // functionally equivalent to calling wait_dequeue_bulk. Never allocates. + // Thread-safe. + template + inline size_t wait_dequeue_bulk_timed(consumer_token_t &token, It itemFirst, + size_t max, + std::int64_t timeout_usecs) { + size_t count = 0; + max = (size_t)sema->waitMany((LightweightSemaphore::ssize_t)(ssize_t)max, + timeout_usecs); + while (count != max) { + count += + inner.template try_dequeue_bulk(token, itemFirst, max - count); + } + return count; + } + + // Attempts to dequeue several elements from the queue using an explicit + // consumer token. Returns the number of items actually dequeued, which can be + // 0 if the timeout expires while waiting for elements, and at most max. Never + // allocates. Thread-safe. + template + inline size_t wait_dequeue_bulk_timed( + consumer_token_t &token, It itemFirst, size_t max, + std::chrono::duration const &timeout) { + return wait_dequeue_bulk_timed( + token, itemFirst, max, + std::chrono::duration_cast(timeout).count()); + } + + // Returns an estimate of the total number of elements currently in the queue. + // This estimate is only accurate if the queue has completely stabilized + // before it is called (i.e. all enqueue and dequeue operations have completed + // and their memory effects are visible on the calling thread, and no further + // operations start while this method is being called). Thread-safe. + inline size_t size_approx() const { return (size_t)sema->availableApprox(); } + + // Returns true if the underlying atomic variables used by + // the queue are lock-free (they should be on most platforms). + // Thread-safe. + static bool is_lock_free() { return ConcurrentQueue::is_lock_free(); } + + private: + template + static inline U *create() { + auto p = (Traits::malloc)(sizeof(U)); + return p != nullptr ? new (p) U : nullptr; + } + + template + static inline U *create(A1 &&a1) { + auto p = (Traits::malloc)(sizeof(U)); + return p != nullptr ? new (p) U(std::forward(a1)) : nullptr; + } + + template + static inline void destroy(U *p) { + if (p != nullptr) { + p->~U(); + } + (Traits::free)(p); + } + + private: + ConcurrentQueue inner; + std::unique_ptr sema; +}; -template -inline void swap(BlockingConcurrentQueue& a, BlockingConcurrentQueue& b) MOODYCAMEL_NOEXCEPT -{ - a.swap(b); +template +inline void swap(BlockingConcurrentQueue &a, + BlockingConcurrentQueue &b) MOODYCAMEL_NOEXCEPT { + a.swap(b); } -} // end namespace moodycamel +} // end namespace moodycamel