Skip to content

Commit

Permalink
feat: Data partitions for snapshot and durable state (#515)
Browse files Browse the repository at this point in the history
* ddl scripts
* create H2 tables
* durable state currentPersistenceIds and friends
* DurableStateStoreAdditionalColumnSpec
* ChangeHandler not supported for > 1 database
* DurableStateUpdateWithChangeEventStoreSpec
* deprecate old tableWithSchema methods
  • Loading branch information
patriknw authored Feb 9, 2024
1 parent 974428d commit 28b95bc
Show file tree
Hide file tree
Showing 21 changed files with 568 additions and 260 deletions.
140 changes: 105 additions & 35 deletions core/src/main/scala/akka/persistence/r2dbc/R2dbcSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ object R2dbcSettings {
numberOfDatabases * (numberOfDataPartitions / numberOfDatabases) == numberOfDataPartitions,
s"data-partition.number-of-databases [$numberOfDatabases] must be a whole number divisor of " +
s"data-partition.number-of-partitions [$numberOfDataPartitions].")
require(
durableStateChangeHandlerClasses.isEmpty || numberOfDatabases == 1,
"Durable State ChangeHandler not supported with more than one data partition database.")

val connectionFactorySettings =
if (numberOfDatabases == 1) {
Expand Down Expand Up @@ -235,52 +238,113 @@ final class R2dbcSettings private (
val numberOfDataPartitions: Int) {
import R2dbcSettings.NumberOfSlices

val numberOfDatabases: Int = _connectionFactorySettings.size

val dataPartitionSliceRanges: immutable.IndexedSeq[Range] = {
val rangeSize = NumberOfSlices / numberOfDataPartitions
(0 until numberOfDataPartitions).map { i =>
(i * rangeSize until i * rangeSize + rangeSize)
}.toVector
}

val connectionFactorSliceRanges: immutable.IndexedSeq[Range] = {
val rangeSize = NumberOfSlices / numberOfDatabases
(0 until numberOfDatabases).map { i =>
(i * rangeSize until i * rangeSize + rangeSize)
}.toVector
}

private val _journalTableWithSchema: String = schema.map(_ + ".").getOrElse("") + journalTable

/**
* The journal table and schema name without data partition suffix.
*/
val journalTableWithSchema: String = schema.map(_ + ".").getOrElse("") + journalTable
@deprecated("Use journalTableWithSchema(slice)", "1.2.2")
val journalTableWithSchema: String = _journalTableWithSchema

/**
* The journal table and schema name with data partition suffix for the given slice. When number-of-partitions is 1
* the table name is without suffix.
*/
def journalTableWithSchema(slice: Int): String = {
def journalTableWithSchema(slice: Int): String =
resolveTableName(_journalTableWithSchema, slice)

private val _snapshotsTableWithSchema: String = schema.map(_ + ".").getOrElse("") + snapshotsTable

/**
* The snapshot table and schema name without data partition suffix.
*/
@deprecated("Use snapshotTableWithSchema(slice)", "1.2.2")
val snapshotsTableWithSchema: String = _snapshotsTableWithSchema

/**
* The snapshot table and schema name with data partition suffix for the given slice. When number-of-partitions is 1
* the table name is without suffix.
*/
def snapshotTableWithSchema(slice: Int): String =
resolveTableName(_snapshotsTableWithSchema, slice)

private val _durableStateTableWithSchema: String = schema.map(_ + ".").getOrElse("") + durableStateTable

/**
* The durable state table and schema name without data partition suffix.
*/
@deprecated("Use durableStateTableWithSchema(slice)", "1.2.2")
val durableStateTableWithSchema: String = schema.map(_ + ".").getOrElse("") + durableStateTable

/**
* The durable state table and schema name with data partition suffix for the given slice. When number-of-partitions
* is 1 the table name is without suffix.
*/
def durableStateTableWithSchema(slice: Int): String =
resolveTableName(_durableStateTableWithSchema, slice)

private def resolveTableName(table: String, slice: Int): String = {
if (numberOfDataPartitions == 1)
journalTableWithSchema
table
else
s"${journalTableWithSchema}_${dataPartition(slice)}"
s"${table}_${dataPartition(slice)}"
}

val snapshotsTableWithSchema: String = schema.map(_ + ".").getOrElse("") + snapshotsTable
val durableStateTableWithSchema: String = schema.map(_ + ".").getOrElse("") + durableStateTable

/**
* INTERNAL API: All journal tables and their the lower slice
*/
@InternalApi private[akka] val allJournalTablesWithSchema: Map[String, Int] = {
(0 until NumberOfSlices).foldLeft(Map.empty[String, Int]) { case (acc, slice) =>
val table = journalTableWithSchema(slice)
if (acc.contains(table)) acc
else acc.updated(table, slice)
}
}
@InternalApi private[akka] lazy val allJournalTablesWithSchema: Map[String, Int] =
resolveAllTableNames(journalTableWithSchema(_))

val numberOfDatabases: Int = _connectionFactorySettings.size
/**
* INTERNAL API: All snapshot tables and their the lower slice
*/
@InternalApi private[akka] lazy val allSnapshotTablesWithSchema: Map[String, Int] =
resolveAllTableNames(snapshotTableWithSchema(_))

val dataPartitionSliceRanges: immutable.IndexedSeq[Range] = {
val rangeSize = NumberOfSlices / numberOfDataPartitions
(0 until numberOfDataPartitions).map { i =>
(i * rangeSize until i * rangeSize + rangeSize)
}.toVector
}
/**
* INTERNAL API
*/
@InternalApi private[akka] val durableStateTableByEntityTypeWithSchema: Map[String, String] =
_durableStateTableByEntityType.map { case (entityType, table) =>
entityType -> (schema.map(_ + ".").getOrElse("") + table)
}

val connectionFactorSliceRanges: immutable.IndexedSeq[Range] = {
val rangeSize = NumberOfSlices / numberOfDatabases
(0 until numberOfDatabases).map { i =>
(i * rangeSize until i * rangeSize + rangeSize)
}.toVector
/**
* INTERNAL API: All durable state tables and their the lower slice
*/
@InternalApi private[akka] lazy val allDurableStateTablesWithSchema: Map[String, Int] = {
val defaultTables = resolveAllTableNames(durableStateTableWithSchema(_))
val entityTypes = _durableStateTableByEntityType.keys
entityTypes.foldLeft(defaultTables) { case (acc, entityType) =>
val entityTypeTables = resolveAllTableNames(slice => getDurableStateTableWithSchema(entityType, slice))
acc ++ entityTypeTables
}
}

private def resolveAllTableNames(tableForSlice: Int => String): Map[String, Int] =
dataPartitionSliceRanges.foldLeft(Map.empty[String, Int]) { case (acc, sliceRange) =>
val table = tableForSlice(sliceRange.min)
if (acc.contains(table)) acc
else acc.updated(table, sliceRange.min)
}

/**
* INTERNAL API
*/
Expand All @@ -298,8 +362,22 @@ final class R2dbcSettings private (
def getDurableStateTable(entityType: String): String =
_durableStateTableByEntityType.getOrElse(entityType, durableStateTable)

/**
* The durable state table and schema name for the `entityType` without data partition suffix.
*/
@deprecated("Use getDurableStateTableWithSchema(entityType, slice)", "1.2.2")
def getDurableStateTableWithSchema(entityType: String): String =
durableStateTableByEntityTypeWithSchema.getOrElse(entityType, durableStateTableWithSchema)
durableStateTableByEntityTypeWithSchema.getOrElse(entityType, _durableStateTableWithSchema)

/**
* The durable state table and schema name for the `entityType` with data partition suffix for the given slice. When
* number-of-partitions is 1 the table name is without suffix.
*/
def getDurableStateTableWithSchema(entityType: String, slice: Int): String =
durableStateTableByEntityTypeWithSchema.get(entityType) match {
case None => durableStateTableWithSchema(slice)
case Some(table) => resolveTableName(table, slice)
}

/**
* INTERNAL API
Expand All @@ -314,14 +392,6 @@ final class R2dbcSettings private (
@InternalApi private[akka] def withUseAppTimestamp(useAppTimestamp: Boolean): R2dbcSettings =
copy(useAppTimestamp = useAppTimestamp)

/**
* INTERNAL API
*/
@InternalApi private[akka] val durableStateTableByEntityTypeWithSchema: Map[String, String] =
_durableStateTableByEntityType.map { case (entityType, table) =>
entityType -> (schema.map(_ + ".").getOrElse("") + table)
}

/**
* INTERNAL API
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ private[r2dbc] trait DurableStateDao extends BySliceQuery.Dao[DurableStateDao.Se

def persistenceIds(afterId: Option[String], limit: Long): Source[String, NotUsed]

def persistenceIds(afterId: Option[String], limit: Long, table: String): Source[String, NotUsed]
def persistenceIds(
afterId: Option[String],
limit: Long,
table: String,
dataPartitionSlice: Int): Source[String, NotUsed]

def persistenceIds(entityType: String, afterId: Option[String], limit: Long): Source[String, NotUsed]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,22 @@ private[r2dbc] object H2Dialect extends Dialect {
}
val snapshotTable = config.getString("snapshot-table")
val snapshotTableWithSchema = schema.map(_ + ".").getOrElse("") + snapshotTable
val allSnapshotTablesWithSchema =
if (numberOfDataPartitions == 1)
Vector(snapshotTableWithSchema)
else
(0 until numberOfDataPartitions).map { dataPartition =>
s"${snapshotTableWithSchema}_$dataPartition"
}
val durableStateTable = config.getString("state-table")
val durableStateTableWithSchema = schema.map(_ + ".").getOrElse("") + durableStateTable
val allDurableStateTablesWithSchema =
if (numberOfDataPartitions == 1)
Vector(durableStateTableWithSchema)
else
(0 until numberOfDataPartitions).map { dataPartition =>
s"${durableStateTableWithSchema}_$dataPartition"
}

implicit val queryAdapter: QueryAdapter = IdentityAdapter

Expand All @@ -134,12 +148,17 @@ private[r2dbc] object H2Dialect extends Dialect {
val sliceIndexWithSchema = table + "_slice_idx"
sql"""CREATE INDEX IF NOT EXISTS $sliceIndexWithSchema ON $table(slice, entity_type, db_timestamp, seq_nr)"""
}
val snapshotSliceIndexWithSchema = snapshotTableWithSchema + "_slice_idx"
val durableStateSliceIndexWithSchema = durableStateTableWithSchema + "_slice_idx"
val snapshotSliceIndexes = allSnapshotTablesWithSchema.map { table =>
val sliceIndexWithSchema = table + "_slice_idx"
sql"""CREATE INDEX IF NOT EXISTS $sliceIndexWithSchema ON $table(slice, entity_type, db_timestamp)"""
}
val durableStateSliceIndexes = allDurableStateTablesWithSchema.map { table =>
val sliceIndexWithSchema = table + "_slice_idx"
sql"""CREATE INDEX IF NOT EXISTS $sliceIndexWithSchema ON $table(slice, entity_type, db_timestamp, revision)"""
}
journalSliceIndexes ++
Seq(
sql"""CREATE INDEX IF NOT EXISTS $snapshotSliceIndexWithSchema ON $snapshotTableWithSchema(slice, entity_type, db_timestamp)""",
sql"""CREATE INDEX IF NOT EXISTS $durableStateSliceIndexWithSchema ON durable_state(slice, entity_type, db_timestamp, revision)""")
snapshotSliceIndexes ++
durableStateSliceIndexes
} else Seq.empty[String]

val createJournalTables = allJournalTablesWithSchema.map { table =>
Expand All @@ -166,11 +185,9 @@ private[r2dbc] object H2Dialect extends Dialect {
PRIMARY KEY(persistence_id, seq_nr)
)"""
}

(createJournalTables ++
Seq(
val createSnapshotTables = allSnapshotTablesWithSchema.map { table =>
sql"""
CREATE TABLE IF NOT EXISTS $snapshotTableWithSchema (
CREATE TABLE IF NOT EXISTS $table (
slice INT NOT NULL,
entity_type VARCHAR(255) NOT NULL,
persistence_id VARCHAR(255) NOT NULL,
Expand All @@ -186,9 +203,11 @@ private[r2dbc] object H2Dialect extends Dialect {
meta_payload BYTEA,

PRIMARY KEY(persistence_id)
)""",
)"""
}
val createDurableStateTables = allDurableStateTablesWithSchema.map { table =>
sql"""
CREATE TABLE IF NOT EXISTS $durableStateTableWithSchema (
CREATE TABLE IF NOT EXISTS $table (
slice INT NOT NULL,
entity_type VARCHAR(255) NOT NULL,
persistence_id VARCHAR(255) NOT NULL,
Expand All @@ -202,7 +221,12 @@ private[r2dbc] object H2Dialect extends Dialect {

PRIMARY KEY(persistence_id, revision)
)
""") ++
"""
}

(createJournalTables ++
createSnapshotTables ++
createDurableStateTables ++
sliceIndexes ++
(if (additionalInit.trim.nonEmpty) Seq(additionalInit) else Seq.empty[String]))
.mkString(";") // r2dbc h2 driver replaces with '\;' as needed for INIT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ private[r2dbc] final class H2SnapshotDao(executorProvider: R2dbcExecutorProvider

override protected lazy val log: Logger = LoggerFactory.getLogger(classOf[H2SnapshotDao])

override protected def createUpsertSql: String = {
override protected def upsertSql(slice: Int): String = {
// db_timestamp and tags columns were added in 1.2.0
if (settings.querySettings.startFromSnapshotEnabled)
sql"""
MERGE INTO $snapshotTable
MERGE INTO ${snapshotTable(slice)}
(slice, entity_type, persistence_id, seq_nr, write_timestamp, snapshot, ser_id, ser_manifest, meta_payload, meta_ser_id, meta_ser_manifest, db_timestamp, tags)
KEY (persistence_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
else
sql"""
MERGE INTO $snapshotTable
MERGE INTO ${snapshotTable(slice)}
(slice, entity_type, persistence_id, seq_nr, write_timestamp, snapshot, ser_id, ser_manifest, meta_payload, meta_ser_id, meta_ser_manifest)
KEY (persistence_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
Expand Down
Loading

0 comments on commit 28b95bc

Please sign in to comment.