Skip to content

Commit

Permalink
feat: Data partitions
Browse files Browse the repository at this point in the history
  • Loading branch information
patriknw committed Jan 30, 2024
1 parent 0ec1de4 commit 91ded34
Show file tree
Hide file tree
Showing 16 changed files with 404 additions and 137 deletions.
13 changes: 13 additions & 0 deletions core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ akka.persistence.r2dbc {
# name of the table to use for events
table = "event_journal"

# Number of tables that the journal will be split into. The selection of data partition is
# made from the slice of the persistenceId. Must be between 1 and 1024 and a whole number
# divisor of 1024 (number of slices).
# For example, 4 data-partitions means that slice range (0 to 255) maps to data partition 0,
# (256 to 511) to data partition 1, (512 to 767) to data partition 3, and (768 to 1023) to
# data partition 3.
# The event_journal tables will have the data partition as suffix, e.g. _0, _1, _2, _3.
# When data-partitions is 1 there will only be one journal table, without suffix.
# This configuration cannot be changed in a rolling update, since the data must be moved
# between the tables if number of data partitions is changed.
table-data-partitions = 1

# the column type to use for event payloads (BYTEA or JSONB)
payload-column-type = "BYTEA"

Expand Down Expand Up @@ -51,6 +63,7 @@ akka.persistence.r2dbc {

# replay filter not needed for this plugin
replay-filter.mode = off

}
}
// #journal-settings
Expand Down
33 changes: 32 additions & 1 deletion core/src/main/scala/akka/persistence/r2dbc/R2dbcSettings.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ import scala.concurrent.duration._
@InternalStableApi
object R2dbcSettings {

// must correspond to akka.persistence.Persistence.numberOfSlices
private val NumberOfSlices = 1024

def apply(config: Config): R2dbcSettings = {
if (config.hasPath("dialect")) {
throw new IllegalArgumentException(
Expand All @@ -45,6 +48,7 @@ object R2dbcSettings {
val schema: Option[String] = Option(config.getString("schema")).filterNot(_.trim.isEmpty)

val journalTable: String = config.getString("journal.table")
val journalTableDataPartitions = config.getInt("journal.table-data-partitions")

def useJsonPayload(prefix: String) = config.getString(s"$prefix.payload-column-type").toUpperCase match {
case "BYTEA" => false
Expand Down Expand Up @@ -116,6 +120,7 @@ object R2dbcSettings {
val settingsFromConfig = new R2dbcSettings(
schema,
journalTable,
journalTableDataPartitions,
journalPayloadCodec,
journalPublishEvents,
snapshotsTable,
Expand Down Expand Up @@ -153,6 +158,7 @@ object R2dbcSettings {
final class R2dbcSettings private (
val schema: Option[String],
val journalTable: String,
val journalTableDataPartitions: Int,
val journalPayloadCodec: PayloadCodec,
val journalPublishEvents: Boolean,
val snapshotsTable: String,
Expand All @@ -172,11 +178,34 @@ final class R2dbcSettings private (
_durableStateAdditionalColumnClasses: Map[String, immutable.IndexedSeq[String]],
_durableStateChangeHandlerClasses: Map[String, String],
_useAppTimestamp: Boolean) {
import R2dbcSettings.NumberOfSlices

require(
0 <= journalTableDataPartitions && journalTableDataPartitions <= NumberOfSlices,
s"journalTableDataPartitions must be between 1 and $NumberOfSlices")
require(
journalTableDataPartitions * NumberOfSlices / journalTableDataPartitions == NumberOfSlices,
s"journalTableDataPartitions [$journalTableDataPartitions] must be a whole number divisor of numberOfSlices [$NumberOfSlices].")

val journalTableWithSchema: String = schema.map(_ + ".").getOrElse("") + journalTable
private val journalTableWithSchema: String = schema.map(_ + ".").getOrElse("") + journalTable
val snapshotsTableWithSchema: String = schema.map(_ + ".").getOrElse("") + snapshotsTable
val durableStateTableWithSchema: String = schema.map(_ + ".").getOrElse("") + durableStateTable

def journalTableWithSchema(slice: Int): String = {
if (journalTableDataPartitions == 1) {
journalTableWithSchema
} else {
val dataPartition = slice / (NumberOfSlices / journalTableDataPartitions)
s"${journalTableWithSchema}_$dataPartition"
}
}

val alljournalTablesWithSchema: Set[String] = {
(0 until NumberOfSlices).map { slice =>
journalTableWithSchema(slice)
}.toSet
}

/**
* One of the supported dialects 'postgres', 'yugabyte', 'sqlserver' or 'h2'
*/
Expand Down Expand Up @@ -234,6 +263,7 @@ final class R2dbcSettings private (
private def copy(
schema: Option[String] = schema,
journalTable: String = journalTable,
journalTableDataPartitions: Int = journalTableDataPartitions,
journalPayloadCodec: PayloadCodec = journalPayloadCodec,
journalPublishEvents: Boolean = journalPublishEvents,
snapshotsTable: String = snapshotsTable,
Expand All @@ -257,6 +287,7 @@ final class R2dbcSettings private (
new R2dbcSettings(
schema,
journalTable,
journalTableDataPartitions: Int,
journalPayloadCodec,
journalPublishEvents,
snapshotsTable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[r2dbc] class H2JournalDao(journalSettings: R2dbcSettings, connectionFact
require(journalSettings.useAppTimestamp)
require(journalSettings.dbTimestampMonotonicIncreasing)

private val insertSql = sql"INSERT INTO $journalTable " +
private def insertSql(slice: Int) = sql"INSERT INTO ${journalTable(slice)} " +
"(slice, entity_type, persistence_id, seq_nr, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, tags, meta_ser_id, meta_ser_manifest, meta_payload, db_timestamp) " +
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"

Expand All @@ -58,15 +58,16 @@ private[r2dbc] class H2JournalDao(journalSettings: R2dbcSettings, connectionFact

// it's always the same persistenceId for all events
val persistenceId = events.head.persistenceId
val slice = persistenceExt.sliceForPersistenceId(persistenceId)

val totalEvents = events.size
val result =
if (totalEvents == 1) {
r2dbcExecutor.updateOne(s"insert [$persistenceId]")(connection =>
bindInsertStatement(connection.createStatement(insertSql), events.head))
bindInsertStatement(connection.createStatement(insertSql(slice)), events.head))
} else {
r2dbcExecutor.updateInBatch(s"batch insert [$persistenceId], [$totalEvents] events")(connection =>
events.foldLeft(connection.createStatement(insertSql)) { (stmt, write) =>
events.foldLeft(connection.createStatement(insertSql(slice))) { (stmt, write) =>
stmt.add()
bindInsertStatement(stmt, write)
})
Expand All @@ -81,8 +82,9 @@ private[r2dbc] class H2JournalDao(journalSettings: R2dbcSettings, connectionFact

override def writeEventInTx(event: SerializedJournalRow, connection: Connection): Future[Instant] = {
val persistenceId = event.persistenceId
val slice = persistenceExt.sliceForPersistenceId(persistenceId)

val stmt = bindInsertStatement(connection.createStatement(insertSql), event)
val stmt = bindInsertStatement(connection.createStatement(insertSql(slice)), event)
val result = R2dbcExecutor.updateOneInTx(stmt)

if (log.isDebugEnabled())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ private[r2dbc] class H2QueryDao(settings: R2dbcSettings, connectionFactory: Conn

sql"""
$selectColumns
FROM $journalTable
FROM ${journalTable(minSlice)}
WHERE entity_type = ?
AND ${sliceCondition(minSlice, maxSlice)}
AND db_timestamp >= ? $toDbTimestampParamCondition $behindCurrentTimeIntervalCondition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti
import JournalDao.SerializedJournalRow
protected def log: Logger = PostgresJournalDao.log

private val persistenceExt = Persistence(system)
protected val persistenceExt: Persistence = Persistence(system)

protected val r2dbcExecutor =
new R2dbcExecutor(
Expand All @@ -79,66 +79,67 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti
journalSettings.logDbCallsExceeding,
journalSettings.connectionFactorySettings.poolSettings.closeCallsExceeding)(ec, system)

protected val journalTable: String = journalSettings.journalTableWithSchema
protected def journalTable(slice: Int): String = journalSettings.journalTableWithSchema(slice)
protected implicit val journalPayloadCodec: PayloadCodec = journalSettings.journalPayloadCodec
protected implicit val tagsCodec: TagsCodec = journalSettings.tagsCodec
protected implicit val timestampCodec: TimestampCodec = journalSettings.timestampCodec
protected implicit val queryAdapter: QueryAdapter = journalSettings.queryAdapter

protected val (insertEventWithParameterTimestampSql, insertEventWithTransactionTimestampSql) = {
val baseSql =
s"INSERT INTO $journalTable " +
"(slice, entity_type, persistence_id, seq_nr, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, tags, meta_ser_id, meta_ser_manifest, meta_payload, db_timestamp) " +
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, "

// The subselect of the db_timestamp of previous seqNr for same pid is to ensure that db_timestamp is
// always increasing for a pid (time not going backwards).
// TODO we could skip the subselect when inserting seqNr 1 as a possible optimization
def timestampSubSelect =
s"(SELECT db_timestamp + '1 microsecond'::interval FROM $journalTable " +
"WHERE persistence_id = ? AND seq_nr = ?)"

val insertEventWithParameterTimestampSql = {
if (journalSettings.dbTimestampMonotonicIncreasing)
sql"$baseSql ?) RETURNING db_timestamp"
else
sql"$baseSql GREATEST(?, $timestampSubSelect)) RETURNING db_timestamp"
}
protected def insertEventWithParameterTimestampSql(slice: Int): String = {
val table = journalTable(slice)
val baseSql = insertEvenBaseSql(table)
if (journalSettings.dbTimestampMonotonicIncreasing)
sql"$baseSql ?) RETURNING db_timestamp"
else
sql"$baseSql GREATEST(?, ${timestampSubSelect(table)})) RETURNING db_timestamp"
}

val insertEventWithTransactionTimestampSql = {
if (journalSettings.dbTimestampMonotonicIncreasing)
sql"$baseSql CURRENT_TIMESTAMP) RETURNING db_timestamp"
else
sql"$baseSql GREATEST(CURRENT_TIMESTAMP, $timestampSubSelect)) RETURNING db_timestamp"
}
private def insertEventWithTransactionTimestampSql(slice: Int) = {
val table = journalTable(slice)
val baseSql = insertEvenBaseSql(table)
if (journalSettings.dbTimestampMonotonicIncreasing)
sql"$baseSql CURRENT_TIMESTAMP) RETURNING db_timestamp"
else
sql"$baseSql GREATEST(CURRENT_TIMESTAMP, ${timestampSubSelect(table)})) RETURNING db_timestamp"
}

(insertEventWithParameterTimestampSql, insertEventWithTransactionTimestampSql)
private def insertEvenBaseSql(table: String) = {
s"INSERT INTO $table " +
"(slice, entity_type, persistence_id, seq_nr, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, tags, meta_ser_id, meta_ser_manifest, meta_payload, db_timestamp) " +
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, "
}

private val selectHighestSequenceNrSql = sql"""
SELECT MAX(seq_nr) from $journalTable
// The subselect of the db_timestamp of previous seqNr for same pid is to ensure that db_timestamp is
// always increasing for a pid (time not going backwards).
// TODO we could skip the subselect when inserting seqNr 1 as a possible optimization
private def timestampSubSelect(table: String) =
s"(SELECT db_timestamp + '1 microsecond'::interval FROM $table " +
"WHERE persistence_id = ? AND seq_nr = ?)"

private def selectHighestSequenceNrSql(slice: Int) = sql"""
SELECT MAX(seq_nr) from ${journalTable(slice)}
WHERE persistence_id = ? AND seq_nr >= ?"""

private val selectLowestSequenceNrSql =
private def selectLowestSequenceNrSql(slice: Int) =
sql"""
SELECT MIN(seq_nr) from $journalTable
SELECT MIN(seq_nr) from ${journalTable(slice)}
WHERE persistence_id = ?"""

private val deleteEventsSql = sql"""
DELETE FROM $journalTable
private def deleteEventsSql(slice: Int) = sql"""
DELETE FROM ${journalTable(slice)}
WHERE persistence_id = ? AND seq_nr >= ? AND seq_nr <= ?"""

protected def insertDeleteMarkerSql(timestamp: String = "CURRENT_TIMESTAMP"): String = sql"""
INSERT INTO $journalTable
protected def insertDeleteMarkerSql(slice: Int, timestamp: String = "CURRENT_TIMESTAMP"): String = sql"""
INSERT INTO ${journalTable(slice)}
(slice, entity_type, persistence_id, seq_nr, db_timestamp, writer, adapter_manifest, event_ser_id, event_ser_manifest, event_payload, deleted)
VALUES (?, ?, ?, ?, $timestamp, ?, ?, ?, ?, ?, ?)"""

private val deleteEventsByPersistenceIdBeforeTimestampSql = sql"""
DELETE FROM $journalTable
private def deleteEventsByPersistenceIdBeforeTimestampSql(slice: Int) = sql"""
DELETE FROM ${journalTable(slice)}
WHERE persistence_id = ? AND db_timestamp < ?"""

private val deleteEventsBySliceBeforeTimestampSql = sql"""
DELETE FROM $journalTable
private def deleteEventsBySliceBeforeTimestampSql(slice: Int) = sql"""
DELETE FROM ${journalTable(slice)}
WHERE slice = ? AND entity_type = ? AND db_timestamp < ?"""

/**
Expand All @@ -156,14 +157,15 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti

// it's always the same persistenceId for all events
val persistenceId = events.head.persistenceId
val slice = persistenceExt.sliceForPersistenceId(persistenceId)
val previousSeqNr = events.head.seqNr - 1

// The MigrationTool defines the dbTimestamp to preserve the original event timestamp
val useTimestampFromDb = events.head.dbTimestamp == Instant.EPOCH

val insertSql =
if (useTimestampFromDb) insertEventWithTransactionTimestampSql
else insertEventWithParameterTimestampSql
if (useTimestampFromDb) insertEventWithTransactionTimestampSql(slice)
else insertEventWithParameterTimestampSql(slice)

val totalEvents = events.size
if (totalEvents == 1) {
Expand Down Expand Up @@ -194,14 +196,15 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti

override def writeEventInTx(event: SerializedJournalRow, connection: Connection): Future[Instant] = {
val persistenceId = event.persistenceId
val slice = persistenceExt.sliceForPersistenceId(persistenceId)
val previousSeqNr = event.seqNr - 1

// The MigrationTool defines the dbTimestamp to preserve the original event timestamp
val useTimestampFromDb = event.dbTimestamp == Instant.EPOCH

val insertSql =
if (useTimestampFromDb) insertEventWithTransactionTimestampSql
else insertEventWithParameterTimestampSql
if (useTimestampFromDb) insertEventWithTransactionTimestampSql(slice)
else insertEventWithParameterTimestampSql(slice)

val stmt = bindInsertStatement(connection.createStatement(insertSql), event, useTimestampFromDb, previousSeqNr)
val result = R2dbcExecutor.updateOneReturningInTx(stmt, row => row.getTimestamp("db_timestamp"))
Expand Down Expand Up @@ -267,11 +270,12 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti
}

override def readHighestSequenceNr(persistenceId: String, fromSequenceNr: Long): Future[Long] = {
val slice = persistenceExt.sliceForPersistenceId(persistenceId)
val result = r2dbcExecutor
.select(s"select highest seqNr [$persistenceId]")(
connection =>
connection
.createStatement(selectHighestSequenceNrSql)
.createStatement(selectHighestSequenceNrSql(slice))
.bind(0, persistenceId)
.bind(1, fromSequenceNr),
row => {
Expand All @@ -287,11 +291,12 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti
}

override def readLowestSequenceNr(persistenceId: String): Future[Long] = {
val slice = persistenceExt.sliceForPersistenceId(persistenceId)
val result = r2dbcExecutor
.select(s"select lowest seqNr [$persistenceId]")(
connection =>
connection
.createStatement(selectLowestSequenceNrSql)
.createStatement(selectLowestSequenceNrSql(slice))
.bind(0, persistenceId),
row => {
val seqNr = row.get(0, classOf[java.lang.Long])
Expand Down Expand Up @@ -321,14 +326,12 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti
}
protected def bindTimestampNow(stmt: Statement, getAndIncIndex: () => Int): Statement = stmt
override def deleteEventsTo(persistenceId: String, toSequenceNr: Long, resetSequenceNumber: Boolean): Future[Unit] = {
val slice = persistenceExt.sliceForPersistenceId(persistenceId)

def insertDeleteMarkerStmt(deleteMarkerSeqNr: Long, connection: Connection): Statement = {

val idx = Iterator.range(0, Int.MaxValue)

val entityType = PersistenceId.extractEntityType(persistenceId)
val slice = persistenceExt.sliceForPersistenceId(persistenceId)
val stmt = connection.createStatement(insertDeleteMarkerSql())
val stmt = connection.createStatement(insertDeleteMarkerSql(slice))
stmt
.bind(idx.next(), slice)
.bind(idx.next(), entityType)
Expand All @@ -349,14 +352,14 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti
r2dbcExecutor
.update(s"delete [$persistenceId] and insert marker") { connection =>
Vector(
connection.createStatement(deleteEventsSql).bind(0, persistenceId).bind(1, from).bind(2, to),
connection.createStatement(deleteEventsSql(slice)).bind(0, persistenceId).bind(1, from).bind(2, to),
insertDeleteMarkerStmt(to, connection))
}
.map(_.head)
} else {
r2dbcExecutor
.updateOne(s"delete [$persistenceId]") { connection =>
connection.createStatement(deleteEventsSql).bind(0, persistenceId).bind(1, from).bind(2, to)
connection.createStatement(deleteEventsSql(slice)).bind(0, persistenceId).bind(1, from).bind(2, to)
}
}).map(deletedRows =>
if (log.isDebugEnabled) {
Expand Down Expand Up @@ -388,10 +391,11 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti
}

override def deleteEventsBefore(persistenceId: String, timestamp: Instant): Future[Unit] = {
val slice = persistenceExt.sliceForPersistenceId(persistenceId)
r2dbcExecutor
.updateOne(s"delete [$persistenceId]") { connection =>
connection
.createStatement(deleteEventsByPersistenceIdBeforeTimestampSql)
.createStatement(deleteEventsByPersistenceIdBeforeTimestampSql(slice))
.bind(0, persistenceId)
.bindTimestamp(1, timestamp)
}
Expand All @@ -404,7 +408,7 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti
r2dbcExecutor
.updateOne(s"delete [$entityType]") { connection =>
connection
.createStatement(deleteEventsBySliceBeforeTimestampSql)
.createStatement(deleteEventsBySliceBeforeTimestampSql(slice))
.bind(0, slice)
.bind(1, entityType)
.bindTimestamp(2, timestamp)
Expand Down
Loading

0 comments on commit 91ded34

Please sign in to comment.