Skip to content

Commit

Permalink
fix: Don't cache failed PreparedStatement (#1056)
Browse files Browse the repository at this point in the history
* the lazy val introduced in #816 will keep the failed
  future if prepare failed
  • Loading branch information
patriknw authored Sep 28, 2023
1 parent 6ed1e6d commit a012ff1
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 140 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (C) 2016-2023 Lightbend Inc. <https://www.lightbend.com>
*/

package akka.persistence.cassandra

import scala.concurrent.Future

import akka.annotation.InternalApi
import akka.dispatch.ExecutionContexts
import akka.util.OptionVal
import com.datastax.oss.driver.api.core.cql.PreparedStatement

/**
* INTERNAL API
*/
@InternalApi private[cassandra] class CachedPreparedStatement(init: () => Future[PreparedStatement]) {
@volatile private var preparedStatement: OptionVal[Future[PreparedStatement]] = OptionVal.None

def get(): Future[PreparedStatement] =
preparedStatement match {
case OptionVal.Some(ps) => ps
case _ =>
// ok to init multiple times in case of concurrent access
val ps = init()
ps.foreach { p =>
// only cache successful futures, ok to overwrite
preparedStatement = OptionVal.Some(Future.successful(p))
}(ExecutionContexts.parasitic)
ps
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
package akka.persistence.cassandra.cleanup

import java.lang.{ Integer => JInt, Long => JLong }

import scala.collection.immutable
import scala.concurrent.Future
import scala.util.Failure
import scala.util.Success

import akka.{ Done, NotUsed }
import akka.actor.{ ActorRef, ActorSystem, ClassicActorSystemProvider }
import akka.annotation.ApiMayChange
import akka.event.Logging
import akka.pattern.ask
import akka.persistence.JournalProtocol.DeleteMessagesTo
import akka.persistence.cassandra.CachedPreparedStatement
import akka.persistence.{ Persistence, SnapshotMetadata }
import akka.persistence.cassandra.PluginSettings
import akka.persistence.cassandra.journal.CassandraJournal
Expand Down Expand Up @@ -67,8 +70,10 @@ final class Cleanup(systemProvider: ClassicActorSystemProvider, settings: Cleanu

private lazy val pluginSettings = PluginSettings(system, system.settings.config.getConfig(pluginLocation))
private lazy val statements = new CassandraSnapshotStatements(pluginSettings.snapshotSettings)
private lazy val selectLatestSnapshotsPs = session.prepare(statements.selectLatestSnapshotMeta)
private lazy val selectAllSnapshotMetaPs = session.prepare(statements.selectAllSnapshotMeta)
private val selectLatestSnapshotsPs =
new CachedPreparedStatement(() => session.prepare(statements.selectLatestSnapshotMeta))
private val selectAllSnapshotMetaPs =
new CachedPreparedStatement(() => session.prepare(statements.selectAllSnapshotMeta))

if (dryRun) {
log.info("Cleanup running in dry run mode. No operations will be executed against the database, only logged")
Expand Down Expand Up @@ -139,6 +144,7 @@ final class Cleanup(systemProvider: ClassicActorSystemProvider, settings: Cleanu
require(snapshotsToKeep >= 1, "must keep at least one snapshot")
require(keepAfterUnixTimestamp >= 0, "keepAfter must be greater than 0")
selectAllSnapshotMetaPs
.get()
.flatMap { ps =>
val allRows: Source[Row, NotUsed] = session.select(ps.bind(persistenceId))
allRows.zipWithIndex
Expand Down Expand Up @@ -169,7 +175,7 @@ final class Cleanup(systemProvider: ClassicActorSystemProvider, settings: Cleanu
*/
def deleteBeforeSnapshot(persistenceId: String, maxSnapshotsToKeep: Int): Future[Option[SnapshotMetadata]] = {
require(maxSnapshotsToKeep >= 1, "Must keep at least one snapshot")
val snapshots: Future[immutable.Seq[Row]] = selectLatestSnapshotsPs.flatMap { ps =>
val snapshots: Future[immutable.Seq[Row]] = selectLatestSnapshotsPs.get().flatMap { ps =>
session.select(ps.bind(persistenceId, maxSnapshotsToKeep: JInt)).runWith(Sink.seq)
}
snapshots.flatMap(rows => issueSnapshotDelete(persistenceId, maxSnapshotsToKeep, rows))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ import akka.event.LoggingAdapter
import akka.persistence.cassandra.PluginSettings
import akka.persistence.cassandra.journal.CassandraJournal.{ Serialized, TagPidSequenceNr }
import com.datastax.oss.driver.api.core.cql.{ PreparedStatement, Row, Statement }

import akka.util.ccompat.JavaConverters._
import scala.concurrent.{ ExecutionContext, Future }
import java.lang.{ Long => JLong }

import akka.annotation.InternalApi
import akka.persistence.cassandra.CachedPreparedStatement
import akka.stream.alpakka.cassandra.scaladsl.CassandraSession

/** INTERNAL API */
Expand All @@ -27,11 +27,14 @@ import akka.stream.alpakka.cassandra.scaladsl.CassandraSession

private def journalSettings = settings.journalSettings
private lazy val journalStatements = new CassandraJournalStatements(settings)
lazy val psUpdateMessage: Future[PreparedStatement] = session.prepare(journalStatements.updateMessagePayloadAndTags)
lazy val psSelectTagPidSequenceNr: Future[PreparedStatement] =
session.prepare(journalStatements.selectTagPidSequenceNr)
lazy val psUpdateTagView: Future[PreparedStatement] = session.prepare(journalStatements.updateMessagePayloadInTagView)
lazy val psSelectMessages: Future[PreparedStatement] = session.prepare(journalStatements.selectMessages)
val psUpdateMessage: CachedPreparedStatement =
new CachedPreparedStatement(() => session.prepare(journalStatements.updateMessagePayloadAndTags))
val psSelectTagPidSequenceNr: CachedPreparedStatement =
new CachedPreparedStatement(() => session.prepare(journalStatements.selectTagPidSequenceNr))
val psUpdateTagView: CachedPreparedStatement =
new CachedPreparedStatement(() => session.prepare(journalStatements.updateMessagePayloadInTagView))
val psSelectMessages: CachedPreparedStatement =
new CachedPreparedStatement(() => session.prepare(journalStatements.selectMessages))

/**
* Update the given event in the messages table and the tag_views table.
Expand All @@ -41,7 +44,7 @@ import akka.stream.alpakka.cassandra.scaladsl.CassandraSession
def updateEvent(event: Serialized): Future[Done] =
for {
(partitionNr, existingTags) <- findEvent(event)
psUM <- psUpdateMessage
psUM <- psUpdateMessage.get()
e = event.copy(tags = existingTags) // do not allow updating of tags
_ <- session.executeWrite(prepareUpdate(psUM, e, partitionNr))
_ <- Future.traverse(existingTags) { tag =>
Expand All @@ -52,7 +55,7 @@ import akka.stream.alpakka.cassandra.scaladsl.CassandraSession
private def findEvent(s: Serialized): Future[(Long, Set[String])] = {
val firstPartition = partitionNr(s.sequenceNr, journalSettings.targetPartitionSize)
for {
ps <- psSelectMessages
ps <- psSelectMessages.get()
row <- findEvent(ps, s.persistenceId, s.sequenceNr, firstPartition)
} yield (row.getLong("partition_nr"), row.getSet[String]("tags", classOf[String]).asScala.toSet)
}
Expand All @@ -78,6 +81,7 @@ import akka.stream.alpakka.cassandra.scaladsl.CassandraSession

private def updateEventInTagViews(event: Serialized, tag: String): Future[Done] =
psSelectTagPidSequenceNr
.get()
.flatMap { ps =>
val bound = ps
.bind()
Expand All @@ -98,7 +102,7 @@ import akka.stream.alpakka.cassandra.scaladsl.CassandraSession
}

private def updateEventInTagViews(event: Serialized, tag: String, tagPidSequenceNr: TagPidSequenceNr): Future[Done] =
psUpdateTagView.flatMap { ps =>
psUpdateTagView.get().flatMap { ps =>
// primary key
val bound = ps
.bind()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,42 +107,43 @@ import akka.stream.scaladsl.Source
private val tagRecovery: Option[CassandraTagRecovery] =
tagWrites.map(ref => new CassandraTagRecovery(context.system, session, settings, taggedPreparedStatements, ref))

private lazy val preparedWriteMessage =
session.prepare(statements.journalStatements.writeMessage(withMeta = false))
private lazy val preparedSelectDeletedTo: Option[Future[PreparedStatement]] = {
private val preparedWriteMessage: CachedPreparedStatement =
new CachedPreparedStatement(() => session.prepare(statements.journalStatements.writeMessage(withMeta = false)))
private val preparedSelectDeletedTo: Option[CachedPreparedStatement] = {
if (settings.journalSettings.supportDeletes)
Some(session.prepare(statements.journalStatements.selectDeletedTo))
Some(new CachedPreparedStatement(() => session.prepare(statements.journalStatements.selectDeletedTo)))
else
None
}
private lazy val preparedSelectHighestSequenceNr: Future[PreparedStatement] =
session.prepare(statements.journalStatements.selectHighestSequenceNr)
private val preparedSelectHighestSequenceNr =
new CachedPreparedStatement(() => session.prepare(statements.journalStatements.selectHighestSequenceNr))

private def deletesNotSupportedException: Future[PreparedStatement] =
private lazy val deletesNotSupportedException: Future[PreparedStatement] =
Future.failed(new IllegalArgumentException(s"Deletes not supported because config support-deletes=off"))

private lazy val preparedInsertDeletedTo: Future[PreparedStatement] = {
private val preparedInsertDeletedTo: CachedPreparedStatement = {
if (settings.journalSettings.supportDeletes)
session.prepare(statements.journalStatements.insertDeletedTo)
new CachedPreparedStatement(() => session.prepare(statements.journalStatements.insertDeletedTo))
else
deletesNotSupportedException
new CachedPreparedStatement(() => deletesNotSupportedException)
}
private lazy val preparedDeleteMessages: Future[PreparedStatement] = {
private val preparedDeleteMessages: CachedPreparedStatement = {
if (settings.journalSettings.supportDeletes) {
session.serverMetaData.flatMap { meta =>
session.prepare(statements.journalStatements.deleteMessages(meta.isVersion2 || settings.cosmosDb))
}
new CachedPreparedStatement(() =>
session.serverMetaData.flatMap { meta =>
session.prepare(statements.journalStatements.deleteMessages(meta.isVersion2 || settings.cosmosDb))
})
} else
deletesNotSupportedException
new CachedPreparedStatement(() => deletesNotSupportedException)
}
private lazy val preparedInsertIntoAllPersistenceIds: Future[PreparedStatement] = {
session.prepare(statements.journalStatements.insertIntoAllPersistenceIds)
private val preparedInsertIntoAllPersistenceIds: CachedPreparedStatement = {
new CachedPreparedStatement(() => session.prepare(statements.journalStatements.insertIntoAllPersistenceIds))
}

private lazy val preparedWriteMessageWithMeta =
session.prepare(statements.journalStatements.writeMessage(withMeta = true))
private lazy val preparedSelectMessages =
session.prepare(statements.journalStatements.selectMessages)
private val preparedWriteMessageWithMeta =
new CachedPreparedStatement(() => session.prepare(statements.journalStatements.writeMessage(withMeta = true)))
private val preparedSelectMessages =
new CachedPreparedStatement(() => session.prepare(statements.journalStatements.selectMessages))

private lazy val queries: CassandraReadJournal =
PersistenceQuery(context.system.asInstanceOf[ExtendedActorSystem])
Expand Down Expand Up @@ -192,16 +193,16 @@ import akka.stream.scaladsl.Source

case CassandraJournal.Init =>
// try initialize early, to be prepared for first real request
preparedWriteMessage
preparedWriteMessageWithMeta
preparedSelectMessages
preparedSelectHighestSequenceNr
preparedWriteMessage.get()
preparedWriteMessageWithMeta.get()
preparedSelectMessages.get()
preparedSelectHighestSequenceNr.get()
if (settings.journalSettings.supportAllPersistenceIds)
preparedInsertIntoAllPersistenceIds
preparedInsertIntoAllPersistenceIds.get()
if (settings.journalSettings.supportDeletes) {
preparedDeleteMessages
preparedSelectDeletedTo
preparedInsertDeletedTo
preparedDeleteMessages.get()
preparedSelectDeletedTo.foreach(_.get())
preparedInsertDeletedTo.get()
}
queries.initialize()

Expand Down Expand Up @@ -358,7 +359,7 @@ import akka.stream.scaladsl.Source
require(atomicWrites.head.payload.nonEmpty)
val allPersistenceId =
if (settings.journalSettings.supportAllPersistenceIds && atomicWrites.head.payload.head.sequenceNr == 1L)
preparedInsertIntoAllPersistenceIds.map(_.bind(atomicWrites.head.persistenceId)).flatMap(execute(_))
preparedInsertIntoAllPersistenceIds.get().map(_.bind(atomicWrites.head.persistenceId)).flatMap(execute(_))
else
FutureUnit

Expand Down Expand Up @@ -398,7 +399,7 @@ import akka.stream.scaladsl.Source
if (m.meta.isDefined) preparedWriteMessageWithMeta
else preparedWriteMessage

stmt.map { stmt =>
stmt.get().map { stmt =>
val bs = stmt
.bind()
.setString("persistence_id", persistenceId)
Expand Down Expand Up @@ -533,7 +534,7 @@ import akka.stream.scaladsl.Source
val deleteResult =
Future.sequence((lowestPartition to highestPartition).map { partitionNr =>
val boundDeleteMessages =
preparedDeleteMessages.map(_.bind(persistenceId, partitionNr: JLong, toSeqNr: JLong))
preparedDeleteMessages.get().map(_.bind(persistenceId, partitionNr: JLong, toSeqNr: JLong))
boundDeleteMessages.flatMap(execute(_))
})
deleteResult.failed.foreach { e =>
Expand All @@ -557,7 +558,7 @@ import akka.stream.scaladsl.Source
toSeqNr: TagPidSequenceNr): Future[Done] = {
def asyncDeleteMessages(partitionNr: TagPidSequenceNr, messageIds: Seq[MessageId]): Future[Unit] = {
val boundStatements = messageIds.map(mid =>
preparedDeleteMessages.map(_.bind(mid.persistenceId, partitionNr: JLong, mid.sequenceNr: JLong)))
preparedDeleteMessages.get().map(_.bind(mid.persistenceId, partitionNr: JLong, mid.sequenceNr: JLong)))
Future.sequence(boundStatements).flatMap { stmts =>
executeBatch(batch => stmts.foldLeft(batch) { case (acc, next) => acc.add(next) })
}
Expand Down Expand Up @@ -601,7 +602,7 @@ import akka.stream.scaladsl.Source
FutureUnit
} else {
val boundInsertDeletedTo =
preparedInsertDeletedTo.map(_.bind(persistenceId, toSeqNr: JLong))
preparedInsertDeletedTo.get().map(_.bind(persistenceId, toSeqNr: JLong))
boundInsertDeletedTo.flatMap(execute)
}
logicalDelete.flatMap(_ => physicalDelete(lowestPartition, highestPartition, toSeqNr))
Expand Down Expand Up @@ -636,7 +637,8 @@ import akka.stream.scaladsl.Source
}

private def partitionInfo(persistenceId: String, partitionNr: Long, maxSequenceNr: Long): Future[PartitionInfo] = {
val boundSelectHighestSequenceNr = preparedSelectHighestSequenceNr.map(_.bind(persistenceId, partitionNr: JLong))
val boundSelectHighestSequenceNr =
preparedSelectHighestSequenceNr.get().map(_.bind(persistenceId, partitionNr: JLong))
boundSelectHighestSequenceNr
.flatMap(selectOne)
.map(
Expand All @@ -650,7 +652,7 @@ import akka.stream.scaladsl.Source
private def asyncHighestDeletedSequenceNumber(persistenceId: String): Future[Long] = {
preparedSelectDeletedTo match {
case Some(pstmt) =>
val boundSelectDeletedTo = pstmt.map(_.bind(persistenceId))
val boundSelectDeletedTo = pstmt.get().map(_.bind(persistenceId))
boundSelectDeletedTo.flatMap(selectOne).map(rowOption => rowOption.map(_.getLong("deleted_to")).getOrElse(0))
case None =>
Future.successful(0L)
Expand All @@ -663,11 +665,13 @@ import akka.stream.scaladsl.Source
partitionSize: Long): Future[Long] = {
def find(currentPnr: Long, currentSnr: Long, foundEmptyPartition: Boolean): Future[Long] = {
// if every message has been deleted and thus no sequence_nr the driver gives us back 0 for "null" :(
val boundSelectHighestSequenceNr = preparedSelectHighestSequenceNr.map(ps => {
val bound = ps.bind(persistenceId, currentPnr: JLong)
bound
val boundSelectHighestSequenceNr = preparedSelectHighestSequenceNr
.get()
.map(ps => {
val bound = ps.bind(persistenceId, currentPnr: JLong)
bound

})
})
boundSelectHighestSequenceNr
.flatMap(selectOne)
.map { rowOption =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ import akka.stream.alpakka.cassandra.scaladsl.CassandraSession
// The result set size will be the number of distinct tags that this pid has used, expecting
// that to be small (<10) so call to all should be safe
def lookupTagProgress(persistenceId: String)(implicit ec: ExecutionContext): Future[Map[Tag, TagProgress]] =
SelectTagProgressForPersistenceId
selectTagProgressForPersistenceId
.get()
.map(_.bind(persistenceId).setExecutionProfileName(settings.journalSettings.readProfile))
.flatMap(stmt => {
session.select(stmt).runWith(Sink.seq)
Expand All @@ -72,7 +73,8 @@ import akka.stream.alpakka.cassandra.scaladsl.CassandraSession
// or min tag scanning sequence number, and fix any tags. This recovers any tag writes that
// happened before the latest snapshot
def tagScanningStartingSequenceNr(persistenceId: String): Future[SequenceNr] =
SelectTagScanningForPersistenceId
selectTagScanningForPersistenceId
.get()
.map(_.bind(persistenceId).setExecutionProfileName(settings.journalSettings.readProfile))
.flatMap(session.selectOne)
.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ import scala.util.Try
val batch = new BatchStatementBuilder(BatchType.UNLOGGED)
batch.setExecutionProfileName(writeProfile)
val tagWritePSs = for {
withMeta <- taggedPreparedStatements.WriteTagViewWithMeta
withoutMeta <- taggedPreparedStatements.WriteTagViewWithoutMeta
withMeta <- taggedPreparedStatements.writeTagViewWithMeta.get()
withoutMeta <- taggedPreparedStatements.writeTagViewWithoutMeta.get()
} yield (withMeta, withoutMeta)

tagWritePSs
Expand Down Expand Up @@ -104,7 +104,8 @@ import scala.util.Try

def writeProgress(tag: Tag, persistenceId: String, seqNr: Long, tagPidSequenceNr: Long, offset: UUID)(
implicit ec: ExecutionContext): Future[Done] = {
WriteTagProgress
writeTagProgress
.get()
.map(
ps =>
ps.bind(persistenceId, tag, seqNr: JLong, tagPidSequenceNr: JLong, offset)
Expand Down Expand Up @@ -385,7 +386,7 @@ import scala.util.Try
updates.take(maxPrint).mkString(",") + s" ...and ${updates.size - 20} more")
}

tagWriterSession.taggedPreparedStatements.WriteTagScanning.foreach { ps =>
tagWriterSession.taggedPreparedStatements.writeTagScanning.get().foreach { ps =>
val startTime = System.nanoTime()

def writeTagScanningBatch(group: Seq[(String, Long)]): Future[Done] = {
Expand Down
Loading

0 comments on commit a012ff1

Please sign in to comment.