Skip to content

Commit

Permalink
KAFKA-18137: Unloading transaction state incorrectly removes loading …
Browse files Browse the repository at this point in the history
…partitions (#18011)

When there is a become follower transition on a transaction coordinator state partition, we intend to unload the state partition. However, we pass the new epoch to the method that does the unloading. In that method, we create a `TransactionPartitionAndLeaderEpoch` object comprising of the topic partition and the epoch that we use as a key to remove the partition from loading. However, we wouldn't ever expect to see this epoch in that map since we only load on the leader. See the code snippet: https://github.com/apache/kafka/blob/d00f0ecf1a1a082c97564f4b807e7a342472b57a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala#L602

We could have a partition load after the unloading occurs, and that partition will be stuck storing stale state on the broker until it restarts. While this may not immediately cause a correctness issue, we should try to properly clean up state.

Check that the epoch is less than the new epoch when removing the partition from loadingPartitions.

Added a test that failed before this change was made.

Reviewers: Artem Livshits <[email protected]>, Jeff Kim <[email protected]>
  • Loading branch information
jolshan authored Dec 3, 2024
1 parent fe88232 commit dbae448
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,9 @@ class TransactionStateManager(brokerId: Int,
*/
def removeTransactionsForTxnTopicPartition(partitionId: Int, coordinatorEpoch: Int): Unit = {
val topicPartition = new TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partitionId)
val partitionAndLeaderEpoch = TransactionPartitionAndLeaderEpoch(partitionId, coordinatorEpoch)

inWriteLock(stateLock) {
loadingPartitions.remove(partitionAndLeaderEpoch)
removeLoadingPartitionWithEpoch(partitionId, coordinatorEpoch)
transactionMetadataCache.remove(partitionId) match {
case Some(txnMetadataCacheEntry) =>
info(s"Unloaded transaction metadata $txnMetadataCacheEntry for $topicPartition on become-follower transition")
Expand All @@ -610,6 +609,18 @@ class TransactionStateManager(brokerId: Int,
}
}

/**
* Remove the loading partition if the epoch is less than the specified epoch. Note: This method must be called under the write state lock.
*/
private def removeLoadingPartitionWithEpoch(partitionId: Int, coordinatorEpoch: Int): Unit = {
loadingPartitions.find(_.txnPartitionId == partitionId).foreach { partitionAndLeaderEpoch =>
if (partitionAndLeaderEpoch.coordinatorEpoch < coordinatorEpoch) {
loadingPartitions.remove(partitionAndLeaderEpoch)
info(s"Cancelling load of currently loading partition $partitionAndLeaderEpoch")
}
}
}

private def validateTransactionTopicPartitionCountIsStable(): Unit = {
val previouslyDeterminedPartitionCount = transactionTopicPartitionCount
val curTransactionTopicPartitionCount = retrieveTransactionTopicPartitionCount()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,65 @@ class TransactionStateManagerTest {
assertEquals(Left(Errors.NOT_COORDINATOR), transactionManager.getTransactionState(txnMetadata1.transactionalId))
}

@Test
def testMakeFollowerLoadingPartition(): Unit = {
// Verify the handling of a call to make a partition a follower while it is in the
// process of being loaded. The partition should not be loaded.

val startOffset = 0L
val endOffset = 1L

val fileRecordsMock = mock[FileRecords](classOf[FileRecords])
val logMock = mock[UnifiedLog](classOf[UnifiedLog])
when(replicaManager.getLog(topicPartition)).thenReturn(Some(logMock))
when(logMock.logStartOffset).thenReturn(startOffset)
when(logMock.read(ArgumentMatchers.eq(startOffset),
maxLength = anyInt(),
isolation = ArgumentMatchers.eq(FetchIsolation.LOG_END),
minOneMessage = ArgumentMatchers.eq(true))
).thenReturn(new FetchDataInfo(new LogOffsetMetadata(startOffset), fileRecordsMock))
when(replicaManager.getLogEndOffset(topicPartition)).thenReturn(Some(endOffset))

txnMetadata1.state = PrepareCommit
txnMetadata1.addPartitions(Set[TopicPartition](
new TopicPartition("topic1", 0),
new TopicPartition("topic1", 1)))
val records = MemoryRecords.withRecords(startOffset, Compression.NONE,
new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit(), TV_2)))

// We create a latch which is awaited while the log is loading. This ensures that the follower transition
// is triggered before the loading returns
val latch = new CountDownLatch(1)

when(fileRecordsMock.sizeInBytes()).thenReturn(records.sizeInBytes)
val bufferCapture: ArgumentCaptor[ByteBuffer] = ArgumentCaptor.forClass(classOf[ByteBuffer])
when(fileRecordsMock.readInto(bufferCapture.capture(), anyInt())).thenAnswer(_ => {
latch.await()
val buffer = bufferCapture.getValue
buffer.put(records.buffer.duplicate)
buffer.flip()
})

val coordinatorEpoch = 0
val partitionAndLeaderEpoch = TransactionPartitionAndLeaderEpoch(partitionId, coordinatorEpoch)

val loadingThread = new Thread(() => {
transactionManager.loadTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch, (_, _, _, _) => ())
})
loadingThread.start()
TestUtils.waitUntilTrue(() => transactionManager.loadingPartitions.contains(partitionAndLeaderEpoch),
"Timed out waiting for loading partition", pause = 10)

transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch + 1)
assertFalse(transactionManager.loadingPartitions.contains(partitionAndLeaderEpoch))

latch.countDown()
loadingThread.join()

// Verify that transaction state was not loaded
assertEquals(Left(Errors.NOT_COORDINATOR), transactionManager.getTransactionState(txnMetadata1.transactionalId))
}

@Test
def testLoadAndRemoveTransactionsForPartition(): Unit = {
// generate transaction log messages for two pids traces:
Expand Down

0 comments on commit dbae448

Please sign in to comment.