Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false
*/
private[this] val blockInfoWrappers = new ConcurrentHashMap[BlockId, BlockInfoWrapper]

// Cache mappings to avoid O(n) scans in remove operations.
private[this] val rddToBlockIds =
new ConcurrentHashMap[Int, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]]()
private[this] val broadcastToBlockIds =
new ConcurrentHashMap[Long, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]]()
private[this] val sessionToBlockIds =
new ConcurrentHashMap[String, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]]()

/**
* Record invisible rdd blocks stored in the block manager, entries will be removed when blocks
* are marked as visible or blocks are removed by [[removeBlock()]].
Expand Down Expand Up @@ -446,6 +454,7 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false
}

if (previous == null) {
addToMapping(blockId)
// New block lock it for writing.
val result = lockForWriting(blockId, blocking = false)
assert(result.isDefined)
Expand Down Expand Up @@ -536,6 +545,23 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false
blockInfoWrappers.entrySet().iterator().asScala.map(kv => kv.getKey -> kv.getValue.info)
}

/**
* Return all blocks belonging to the given RDD.
*/
def rddBlockIds(rddId: Int): Seq[BlockId] = getBlockIdsFromMapping(rddToBlockIds, rddId)

/**
* Return all blocks belonging to the given broadcast.
*/
def broadcastBlockIds(broadcastId: Long): Seq[BlockId] =
getBlockIdsFromMapping(broadcastToBlockIds, broadcastId)

/**
* Return cache blocks that might be related to cached local relations.
*/
def sessionBlockIds(sessionUUID: String): Seq[BlockId] =
getBlockIdsFromMapping(sessionToBlockIds, sessionUUID)

/**
* Removes the given block and releases the write lock on it.
*
Expand All @@ -552,6 +578,7 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false
} else {
invisibleRDDBlocks.synchronized {
blockInfoWrappers.remove(blockId)
removeFromMapping(blockId)
blockId.asRDDId.foreach(invisibleRDDBlocks.remove)
}
info.readerCount = 0
Expand All @@ -574,11 +601,75 @@ private[storage] class BlockInfoManager(trackingCacheVisibility: Boolean = false
}
}
blockInfoWrappers.clear()
rddToBlockIds.clear()
broadcastToBlockIds.clear()
sessionToBlockIds.clear()
readLocksByTask.clear()
writeLocksByTask.clear()
invisibleRDDBlocks.synchronized {
invisibleRDDBlocks.clear()
}
}

/**
* Return all blocks in the cache mapping for a given key.
*/
private def getBlockIdsFromMapping[K](
map: ConcurrentHashMap[K, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]],
key: K): Seq[BlockId] = {
Option(map.get(key)).map(_.asScala.toSeq).getOrElse(Seq.empty)
}

/**
* Add a block ID to the corresponding cache mapping based on its type.
*/
private def addToMapping(blockId: BlockId): Unit = {
blockId match {
case rddBlockId: RDDBlockId =>
rddToBlockIds
.computeIfAbsent(rddBlockId.rddId, _ => ConcurrentHashMap.newKeySet())
.add(blockId)
case broadcastBlockId: BroadcastBlockId =>
broadcastToBlockIds
.computeIfAbsent(broadcastBlockId.broadcastId, _ => ConcurrentHashMap.newKeySet())
.add(blockId)
case cacheId: CacheId =>
sessionToBlockIds
.computeIfAbsent(cacheId.sessionUUID, _ => ConcurrentHashMap.newKeySet())
.add(blockId)
case _ => // Do nothing for other block types
}
}

/**
* Remove a block ID from the corresponding cache mapping based on its type.
*/
private def removeFromMapping(blockId: BlockId): Unit = {
def doRemove[K](
map: ConcurrentHashMap[K, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]],
key: K,
block: BlockId): Unit = {
map.compute(key,
(_, set) => {
if (null != set) {
set.remove(block)
if (set.isEmpty) null else set
} else {
// missing
null
}
}
)
}

blockId match {
case rddBlockId: RDDBlockId =>
doRemove(rddToBlockIds, rddBlockId.rddId, rddBlockId)
case broadcastBlockId: BroadcastBlockId =>
doRemove(broadcastToBlockIds, broadcastBlockId.broadcastId, broadcastBlockId)
case cacheId: CacheId =>
doRemove(sessionToBlockIds, cacheId.sessionUUID, cacheId)
case _ => // Do nothing for other block types
}
}
}
11 changes: 3 additions & 8 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2028,9 +2028,8 @@ private[spark] class BlockManager(
* @return The number of blocks removed.
*/
def removeRdd(rddId: Int): Int = {
// TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks.
logInfo(log"Removing RDD ${MDC(RDD_ID, rddId)}")
val blocksToRemove = blockInfoManager.entries.flatMap(_._1.asRDDId).filter(_.rddId == rddId)
val blocksToRemove = blockInfoManager.rddBlockIds(rddId)
blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) }
blocksToRemove.size
}
Expand Down Expand Up @@ -2064,9 +2063,7 @@ private[spark] class BlockManager(
*/
def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = {
logDebug(s"Removing broadcast $broadcastId")
val blocksToRemove = blockInfoManager.entries.map(_._1).collect {
case bid @ BroadcastBlockId(`broadcastId`, _) => bid
}
val blocksToRemove = blockInfoManager.broadcastBlockIds(broadcastId)
blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) }
blocksToRemove.size
}
Expand All @@ -2078,9 +2075,7 @@ private[spark] class BlockManager(
*/
def removeCache(sessionUUID: String): Int = {
logDebug(s"Removing cache of spark session with UUID: $sessionUUID")
val blocksToRemove = blockInfoManager.entries.map(_._1).collect {
case cid: CacheId if cid.sessionUUID == sessionUUID => cid
}
val blocksToRemove = blockInfoManager.sessionBlockIds(sessionUUID)
blocksToRemove.foreach { blockId => removeBlock(blockId) }
blocksToRemove.size
}
Expand Down