diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1e1cb8bf9fd53..da455cacf98ad 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -288,6 +288,65 @@ private[spark] class BlockManager( securityManager.getIOEncryptionKey()) private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + // SPARK-53446 Performance optimization: Cache mappings to avoid O(n) scans in remove operations + private[this] val rddToBlockIds = + new ConcurrentHashMap[Int, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]]() + private[this] val broadcastToBlockIds = + new ConcurrentHashMap[Long, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]]() + private[this] val sessionToBlockIds = + new ConcurrentHashMap[String, ConcurrentHashMap.KeySetView[BlockId, java.lang.Boolean]]() + + /** + * Add a block ID to the appropriate cache mapping based on its type. + */ + private def addToCache(blockId: BlockId): Unit = { + blockId match { + case rddBlockId: RDDBlockId => + rddToBlockIds + .computeIfAbsent(rddBlockId.rddId, _ => ConcurrentHashMap.newKeySet()) + .add(blockId) + case broadcastBlockId: BroadcastBlockId => + broadcastToBlockIds + .computeIfAbsent(broadcastBlockId.broadcastId, _ => ConcurrentHashMap.newKeySet()) + .add(blockId) + case cacheId: CacheId => + sessionToBlockIds + .computeIfAbsent(cacheId.sessionUUID, _ => ConcurrentHashMap.newKeySet()) + .add(blockId) + case _ => // Do nothing for other block types + } + } + + /** + * Remove a block ID from the appropriate cache mapping based on its type. + */ + private def removeFromCache(blockId: BlockId): Unit = { + blockId match { + case rddBlockId: RDDBlockId => + Option(rddToBlockIds.get(rddBlockId.rddId)).foreach { blockSet => + blockSet.remove(blockId) + if (blockSet.isEmpty) { + rddToBlockIds.remove(rddBlockId.rddId) + } + } + case broadcastBlockId: BroadcastBlockId => + Option(broadcastToBlockIds.get(broadcastBlockId.broadcastId)).foreach { blockSet => + blockSet.remove(blockId) + if (blockSet.isEmpty) { + broadcastToBlockIds.remove(broadcastBlockId.broadcastId) + } + } + case cacheId: CacheId => + Option(sessionToBlockIds.get(cacheId.sessionUUID)).foreach { blockSet => + blockSet.remove(blockId) + if (blockSet.isEmpty) { + sessionToBlockIds.remove(cacheId.sessionUUID) + } + } + case _ => // Do nothing for other block types + } + } + var hostLocalDirManager: Option[HostLocalDirManager] = None @inline final private def isDecommissioning() = { @@ -1560,6 +1619,7 @@ private[spark] class BlockManager( exceptionWasThrown = false if (res.isEmpty) { // the block was successfully stored + addToCache(blockId) if (keepReadLock) { blockInfoManager.downgradeLock(blockId) } else { @@ -2028,9 +2088,13 @@ private[spark] class BlockManager( * @return The number of blocks removed. */ def removeRdd(rddId: Int): Int = { - // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. logInfo(log"Removing RDD ${MDC(RDD_ID, rddId)}") - val blocksToRemove = blockInfoManager.entries.flatMap(_._1.asRDDId).filter(_.rddId == rddId) + val blocksToRemove = Option(rddToBlockIds.get(rddId)) match { + case Some(blockSet) => + blockSet.asScala.toSeq + case None => + Seq.empty + } blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } blocksToRemove.size } @@ -2064,8 +2128,11 @@ private[spark] class BlockManager( */ def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { logDebug(s"Removing broadcast $broadcastId") - val blocksToRemove = blockInfoManager.entries.map(_._1).collect { - case bid @ BroadcastBlockId(`broadcastId`, _) => bid + val blocksToRemove = Option(broadcastToBlockIds.get(broadcastId)) match { + case Some(blockSet) => + blockSet.asScala.toSeq + case None => + Seq.empty } blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) } blocksToRemove.size @@ -2078,8 +2145,11 @@ private[spark] class BlockManager( */ def removeCache(sessionUUID: String): Int = { logDebug(s"Removing cache of spark session with UUID: $sessionUUID") - val blocksToRemove = blockInfoManager.entries.map(_._1).collect { - case cid: CacheId if cid.sessionUUID == sessionUUID => cid + val blocksToRemove = Option(sessionToBlockIds.get(sessionUUID)) match { + case Some(blockSet) => + blockSet.asScala.toSeq + case None => + Seq.empty } blocksToRemove.foreach { blockId => removeBlock(blockId) } blocksToRemove.size @@ -2122,6 +2192,7 @@ private[spark] class BlockManager( } blockInfoManager.removeBlock(blockId) + removeFromCache(blockId) hasRemoveBlock = true if (tellMaster) { // Only update storage level from the captured block status before deleting, so that diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 5b86345dd5f9a..cc41fa8a3d0a4 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -2504,6 +2504,98 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe assert(acc.value === 6) } + test("cache optimization: remove operations with cached mappings") { + val store = makeBlockManager(8000) + + // Test removeRdd with cached mappings + val rddId = 123 + val rddBlock1 = RDDBlockId(rddId, 0) + val rddBlock2 = RDDBlockId(rddId, 1) + store.putSingle(rddBlock1, "rdd_data_1", StorageLevel.MEMORY_ONLY) + store.putSingle(rddBlock2, "rdd_data_2", StorageLevel.MEMORY_ONLY) + + assert(store.hasLocalBlock(rddBlock1)) + assert(store.hasLocalBlock(rddBlock2)) + + val removedRddCount = store.removeRdd(rddId) + assert(removedRddCount === 2) + assert(!store.hasLocalBlock(rddBlock1)) + assert(!store.hasLocalBlock(rddBlock2)) + + // Test removeBroadcast with cached mappings + val broadcastId = 456L + val broadcastBlock1 = BroadcastBlockId(broadcastId, "piece0") + val broadcastBlock2 = BroadcastBlockId(broadcastId, "piece1") + store.putSingle(broadcastBlock1, "broadcast_data_1", StorageLevel.MEMORY_ONLY) + store.putSingle(broadcastBlock2, "broadcast_data_2", StorageLevel.MEMORY_ONLY) + + assert(store.hasLocalBlock(broadcastBlock1)) + assert(store.hasLocalBlock(broadcastBlock2)) + + val removedBroadcastCount = store.removeBroadcast(broadcastId, tellMaster = false) + assert(removedBroadcastCount === 2) + assert(!store.hasLocalBlock(broadcastBlock1)) + assert(!store.hasLocalBlock(broadcastBlock2)) + + // Test removeCache with cached mappings + val sessionUUID = "test-session-uuid" + val cacheBlock1 = CacheId(sessionUUID, "hash1") + val cacheBlock2 = CacheId(sessionUUID, "hash2") + store.putSingle(cacheBlock1, "cache_data_1", StorageLevel.MEMORY_ONLY) + store.putSingle(cacheBlock2, "cache_data_2", StorageLevel.MEMORY_ONLY) + + assert(store.hasLocalBlock(cacheBlock1)) + assert(store.hasLocalBlock(cacheBlock2)) + + val removedCacheCount = store.removeCache(sessionUUID) + assert(removedCacheCount === 2) + assert(!store.hasLocalBlock(cacheBlock1)) + assert(!store.hasLocalBlock(cacheBlock2)) + + // Test removing non-existent items (should return 0) + assert(store.removeRdd(999) === 0) + assert(store.removeBroadcast(999L, tellMaster = false) === 0) + assert(store.removeCache("non-existent-uuid") === 0) + } + + test("cache optimization: performance improvement verification") { + val store = makeBlockManager(800000) + + // Create many blocks to test performance difference + val numNormalBlocks = 100000 + val normalBroadcastId = 10000L + val numBlocks = 1000 + val broadcastId = 12345L + + // Add many Broadcast blocks + for (i <- 0 until numNormalBlocks) { + val blockId = BroadcastBlockId(normalBroadcastId, s"piece_$i") + store.putSingle(blockId, s"data_$i", StorageLevel.MEMORY_ONLY) + } + for (i <- 0 until numBlocks) { + val blockId = BroadcastBlockId(broadcastId, s"piece_$i") + store.putSingle(blockId, s"data_$i", StorageLevel.MEMORY_ONLY) + } + + // Verify all blocks exist + assert((0 until numBlocks) + .forall(i => store.hasLocalBlock(BroadcastBlockId(broadcastId, s"piece_$i")))) + + // Time the removal operation (should be much faster with O(1) cache lookup) + val startTime = System.nanoTime() + val removedCount = store.removeBroadcast(broadcastId, tellMaster = false) + val endTime = System.nanoTime() + + // Verify correctness + assert(removedCount === numBlocks) + assert((0 until numBlocks) + .forall(i => !store.hasLocalBlock(BroadcastBlockId(broadcastId, s"piece_$i")))) + + val durationMs = (endTime - startTime) / 1000000.0 + logInfo(s"Removed $numNormalBlocks broadcast blocks in ${durationMs}ms " + + s"(avg: ${durationMs/1000}ms per block)") + } + private def createKryoSerializerWithDiskCorruptedInputStream(): KryoSerializer = { class TestDiskCorruptedInputStream extends InputStream { override def read(): Int = throw new IOException("Input/output error")