Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 77 additions & 6 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,65 @@ private[spark] class BlockManager(
securityManager.getIOEncryptionKey())
private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)

// SPARK-53446 Performance optimization: Cache mappings to avoid O(n) scans in remove operations
private[this] val rddToBlockIds =
new ConcurrentHashMap[Int, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]]()
private[this] val broadcastToBlockIds =
new ConcurrentHashMap[Long, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]]()
private[this] val sessionToBlockIds =
new ConcurrentHashMap[String, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]]()

/**
* Add a block ID to the appropriate cache mapping based on its type.
*/
private def addToCache(blockId: BlockId): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Move these helper methods towards the bottom.

blockId match {
case rddBlockId: RDDBlockId =>
rddToBlockIds
.computeIfAbsent(rddBlockId.rddId, _ => ConcurrentHashMap.newKeySet())
.add(blockId)
case broadcastBlockId: BroadcastBlockId =>
broadcastToBlockIds
.computeIfAbsent(broadcastBlockId.broadcastId, _ => ConcurrentHashMap.newKeySet())
.add(blockId)
case cacheId: CacheId =>
sessionToBlockIds
.computeIfAbsent(cacheId.sessionUUID, _ => ConcurrentHashMap.newKeySet())
.add(blockId)
case _ => // Do nothing for other block types
}
}

/**
* Remove a block ID from the appropriate cache mapping based on its type.
*/
private def removeFromCache(blockId: BlockId): Unit = {
blockId match {
case rddBlockId: RDDBlockId =>
Option(rddToBlockIds.get(rddBlockId.rddId)).foreach { blockSet =>
blockSet.remove(blockId)
if (blockSet.isEmpty) {
rddToBlockIds.remove(rddBlockId.rddId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a race condition here:

  1. thread 1 calls addToCache, gets the map for its RDD id
  2. thread 2 calls removeFromCache, gets the map for the same RDD id, remove the last block id, and then removes this RDD id from the cache
  3. thread 1 adds the block id, but it's noop as this map entire is dangling now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @cloud-fan, use compute instead.
For example, something like this:

// Use import java.util.{Set => JSet} and change change the 'type' of value for the sets to JSet[BlockId]

def removeFromCache(
<snip>

    def doRemove[K](map: ConcurrentHashMap[K, JSet[BlockId]], key: K, block: BlockId): Unit = {
      map.compute(key,
        (_, set) => {
          if (null != set) {
            set.remove(block)
            if (set.isEmpty) null else set
          } else {
            // missing
            null
          }
        }
      )
    }

<snip>

case rddBlockId: RDDBlockId =>
  doRemove(rddToBlockIds, rddBlockId.rddId, blockId)
case broadcastBlockId: BroadcastBlockId =>
  doRemove(broadcastToBlockIds, broadcastBlockId.broadcastId, blockId)

// and so on

}
}
case broadcastBlockId: BroadcastBlockId =>
Option(broadcastToBlockIds.get(broadcastBlockId.broadcastId)).foreach { blockSet =>
blockSet.remove(blockId)
if (blockSet.isEmpty) {
broadcastToBlockIds.remove(broadcastBlockId.broadcastId)
}
}
case cacheId: CacheId =>
Option(sessionToBlockIds.get(cacheId.sessionUUID)).foreach { blockSet =>
blockSet.remove(blockId)
if (blockSet.isEmpty) {
sessionToBlockIds.remove(cacheId.sessionUUID)
}
}
case _ => // Do nothing for other block types
}
}

var hostLocalDirManager: Option[HostLocalDirManager] = None

@inline final private def isDecommissioning() = {
Expand Down Expand Up @@ -1560,6 +1619,7 @@ private[spark] class BlockManager(
exceptionWasThrown = false
if (res.isEmpty) {
// the block was successfully stored
addToCache(blockId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is doPut the only entry point that can add blocks?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For adding a new block, yes; it goes through doPut

if (keepReadLock) {
blockInfoManager.downgradeLock(blockId)
} else {
Expand Down Expand Up @@ -2028,9 +2088,13 @@ private[spark] class BlockManager(
* @return The number of blocks removed.
*/
def removeRdd(rddId: Int): Int = {
// TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks.
logInfo(log"Removing RDD ${MDC(RDD_ID, rddId)}")
val blocksToRemove = blockInfoManager.entries.flatMap(_._1.asRDDId).filter(_.rddId == rddId)
val blocksToRemove = Option(rddToBlockIds.get(rddId)) match {
case Some(blockSet) =>
blockSet.asScala.toSeq
case None =>
Seq.empty
}
Comment on lines +2092 to +2097
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove it proactively from the map. This also makes the size returned consistent with what is actually removed (in case of race conditions).
Same for the other cases as well

Suggested change
val blocksToRemove = Option(rddToBlockIds.get(rddId)) match {
case Some(blockSet) =>
blockSet.asScala.toSeq
case None =>
Seq.empty
}
val blocksToRemove = Option(rddToBlockIds.remove(rddId)).
map(_.asScala.toSeq).getOrElse(Seq.empty)

blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) }
blocksToRemove.size
}
Expand Down Expand Up @@ -2064,8 +2128,11 @@ private[spark] class BlockManager(
*/
def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = {
logDebug(s"Removing broadcast $broadcastId")
val blocksToRemove = blockInfoManager.entries.map(_._1).collect {
case bid @ BroadcastBlockId(`broadcastId`, _) => bid
val blocksToRemove = Option(broadcastToBlockIds.get(broadcastId)) match {
case Some(blockSet) =>
blockSet.asScala.toSeq
case None =>
Seq.empty
}
blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) }
blocksToRemove.size
Expand All @@ -2078,8 +2145,11 @@ private[spark] class BlockManager(
*/
def removeCache(sessionUUID: String): Int = {
logDebug(s"Removing cache of spark session with UUID: $sessionUUID")
val blocksToRemove = blockInfoManager.entries.map(_._1).collect {
case cid: CacheId if cid.sessionUUID == sessionUUID => cid
val blocksToRemove = Option(sessionToBlockIds.get(sessionUUID)) match {
case Some(blockSet) =>
blockSet.asScala.toSeq
case None =>
Seq.empty
}
blocksToRemove.foreach { blockId => removeBlock(blockId) }
blocksToRemove.size
Expand Down Expand Up @@ -2122,6 +2192,7 @@ private[spark] class BlockManager(
}

blockInfoManager.removeBlock(blockId)
removeFromCache(blockId)
hasRemoveBlock = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do this in the finally block as well.

if (tellMaster) {
// Only update storage level from the captured block status before deleting, so that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2504,6 +2504,98 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe
assert(acc.value === 6)
}

test("cache optimization: remove operations with cached mappings") {
val store = makeBlockManager(8000)

// Test removeRdd with cached mappings
val rddId = 123
val rddBlock1 = RDDBlockId(rddId, 0)
val rddBlock2 = RDDBlockId(rddId, 1)
store.putSingle(rddBlock1, "rdd_data_1", StorageLevel.MEMORY_ONLY)
store.putSingle(rddBlock2, "rdd_data_2", StorageLevel.MEMORY_ONLY)

assert(store.hasLocalBlock(rddBlock1))
assert(store.hasLocalBlock(rddBlock2))

val removedRddCount = store.removeRdd(rddId)
assert(removedRddCount === 2)
assert(!store.hasLocalBlock(rddBlock1))
assert(!store.hasLocalBlock(rddBlock2))

// Test removeBroadcast with cached mappings
val broadcastId = 456L
val broadcastBlock1 = BroadcastBlockId(broadcastId, "piece0")
val broadcastBlock2 = BroadcastBlockId(broadcastId, "piece1")
store.putSingle(broadcastBlock1, "broadcast_data_1", StorageLevel.MEMORY_ONLY)
store.putSingle(broadcastBlock2, "broadcast_data_2", StorageLevel.MEMORY_ONLY)

assert(store.hasLocalBlock(broadcastBlock1))
assert(store.hasLocalBlock(broadcastBlock2))

val removedBroadcastCount = store.removeBroadcast(broadcastId, tellMaster = false)
assert(removedBroadcastCount === 2)
assert(!store.hasLocalBlock(broadcastBlock1))
assert(!store.hasLocalBlock(broadcastBlock2))

// Test removeCache with cached mappings
val sessionUUID = "test-session-uuid"
val cacheBlock1 = CacheId(sessionUUID, "hash1")
val cacheBlock2 = CacheId(sessionUUID, "hash2")
store.putSingle(cacheBlock1, "cache_data_1", StorageLevel.MEMORY_ONLY)
store.putSingle(cacheBlock2, "cache_data_2", StorageLevel.MEMORY_ONLY)

assert(store.hasLocalBlock(cacheBlock1))
assert(store.hasLocalBlock(cacheBlock2))

val removedCacheCount = store.removeCache(sessionUUID)
assert(removedCacheCount === 2)
assert(!store.hasLocalBlock(cacheBlock1))
assert(!store.hasLocalBlock(cacheBlock2))

// Test removing non-existent items (should return 0)
assert(store.removeRdd(999) === 0)
assert(store.removeBroadcast(999L, tellMaster = false) === 0)
assert(store.removeCache("non-existent-uuid") === 0)
}

test("cache optimization: performance improvement verification") {
val store = makeBlockManager(800000)

// Create many blocks to test performance difference
val numNormalBlocks = 100000
val normalBroadcastId = 10000L
val numBlocks = 1000
val broadcastId = 12345L

// Add many Broadcast blocks
for (i <- 0 until numNormalBlocks) {
val blockId = BroadcastBlockId(normalBroadcastId, s"piece_$i")
store.putSingle(blockId, s"data_$i", StorageLevel.MEMORY_ONLY)
}
for (i <- 0 until numBlocks) {
val blockId = BroadcastBlockId(broadcastId, s"piece_$i")
store.putSingle(blockId, s"data_$i", StorageLevel.MEMORY_ONLY)
}

// Verify all blocks exist
assert((0 until numBlocks)
.forall(i => store.hasLocalBlock(BroadcastBlockId(broadcastId, s"piece_$i"))))

// Time the removal operation (should be much faster with O(1) cache lookup)
val startTime = System.nanoTime()
val removedCount = store.removeBroadcast(broadcastId, tellMaster = false)
val endTime = System.nanoTime()

// Verify correctness
assert(removedCount === numBlocks)
assert((0 until numBlocks)
.forall(i => !store.hasLocalBlock(BroadcastBlockId(broadcastId, s"piece_$i"))))

val durationMs = (endTime - startTime) / 1000000.0
logInfo(s"Removed $numNormalBlocks broadcast blocks in ${durationMs}ms " +
s"(avg: ${durationMs/1000}ms per block)")
}

private def createKryoSerializerWithDiskCorruptedInputStream(): KryoSerializer = {
class TestDiskCorruptedInputStream extends InputStream {
override def read(): Int = throw new IOException("Input/output error")
Expand Down