Skip to content

Commit

Permalink
comments from pr
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastian-alfers committed Jan 22, 2024
1 parent 754593f commit a3b0573
Show file tree
Hide file tree
Showing 12 changed files with 36 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,6 @@ import akka.annotation.InternalApi
}

implicit class TagsCodecRichRow(val row: Row)(implicit codec: TagsCodec) extends AnyRef {
def getTags(column: String = "tags"): Set[String] = codec.getTags(row, column)
def getTags(column: String): Set[String] = codec.getTags(row, column)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ import akka.persistence.r2dbc.internal.InstantFactory
def decode(row: Row, name: String): Instant
def decode(row: Row, index: Int): Instant

// should we name it just `now()`? The type should not be in the name...
def instantNow(): Instant = InstantFactory.now()

def now[T](): T
}

/**
Expand All @@ -38,8 +37,6 @@ import akka.persistence.r2dbc.internal.InstantFactory
override def decode(row: Row, index: Int): Instant = row.get(index, classOf[Instant])

override def encode(timestamp: Instant): Any = timestamp

override def now[T](): T = instantNow().asInstanceOf[T]
}
object PostgresTimestampCodec extends PostgresTimestampCodec

Expand All @@ -55,8 +52,6 @@ import akka.persistence.r2dbc.internal.InstantFactory

override def encode(timestamp: Instant): LocalDateTime = LocalDateTime.ofInstant(timestamp, zone)

override def now[T](): T = LocalDateTime.ofInstant(instantNow(), zone).asInstanceOf[T]

override def decode(row: Row, index: Int): Instant = toInstant(row.get(index, classOf[LocalDateTime]))
}

Expand All @@ -68,6 +63,6 @@ import akka.persistence.r2dbc.internal.InstantFactory
def bindTimestamp(index: Int, timestamp: Instant): Statement = statement.bind(index, codec.encode(timestamp))
}
implicit class TimestampCodecRichRow[T](val row: Row)(implicit codec: TimestampCodec) extends AnyRef {
def getTimestamp(index: String = "db_timestamp"): Instant = codec.decode(row, index)
def getTimestamp(index: String): Instant = codec.decode(row, index)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ private[r2dbc] class PostgresDurableStateDao(
SerializedStateRow(
persistenceId = persistenceId,
revision = row.get[java.lang.Long]("revision", classOf[java.lang.Long]),
dbTimestamp = row.getTimestamp(),
dbTimestamp = row.getTimestamp("db_timestamp"),
readDbTimestamp = Instant.EPOCH, // not needed here
payload = getPayload(row),
serId = row.get[Integer]("state_ser_id", classOf[Integer]),
Expand Down Expand Up @@ -671,7 +671,7 @@ private[r2dbc] class PostgresDurableStateDao(
r2dbcExecutor
.selectOne("select current db timestamp")(
connection => connection.createStatement(currentDbTimestampSql),
row => row.getTimestamp())
row => row.getTimestamp("db_timestamp"))
.map {
case Some(time) => time
case None => throw new IllegalStateException(s"Expected one row for: $currentDbTimestampSql")
Expand Down Expand Up @@ -730,7 +730,7 @@ private[r2dbc] class PostgresDurableStateDao(
SerializedStateRow(
persistenceId = row.get("persistence_id", classOf[String]),
revision = row.get[java.lang.Long]("revision", classOf[java.lang.Long]),
dbTimestamp = row.getTimestamp(),
dbTimestamp = row.getTimestamp("db_timestamp"),
readDbTimestamp = row.getTimestamp("read_db_timestamp"),
// payload = null => lazy loaded for backtracking (ugly, but not worth changing UpdatedDurableState in Akka)
// payload = None => DeletedDurableState (no lazy loading)
Expand All @@ -743,7 +743,7 @@ private[r2dbc] class PostgresDurableStateDao(
SerializedStateRow(
persistenceId = row.get("persistence_id", classOf[String]),
revision = row.get[java.lang.Long]("revision", classOf[java.lang.Long]),
dbTimestamp = row.getTimestamp(),
dbTimestamp = row.getTimestamp("db_timestamp"),
readDbTimestamp = row.getTimestamp("read_db_timestamp"),
payload = getPayload(row),
serId = row.get[Integer]("state_ser_id", classOf[Integer]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti
val result = r2dbcExecutor.updateOneReturning(s"insert [$persistenceId]")(
connection =>
bindInsertStatement(connection.createStatement(insertSql), events.head, useTimestampFromDb, previousSeqNr),
row => row.getTimestamp())
row => row.getTimestamp("db_timestamp"))
if (log.isDebugEnabled())
result.foreach { _ =>
log.debug("Wrote [{}] events for persistenceId [{}]", 1, persistenceId)
Expand All @@ -183,7 +183,7 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti
stmt.add()
bindInsertStatement(stmt, write, useTimestampFromDb, previousSeqNr)
},
row => row.getTimestamp())
row => row.getTimestamp("db_timestamp"))
if (log.isDebugEnabled())
result.foreach { _ =>
log.debug("Wrote [{}] events for persistenceId [{}]", totalEvents, persistenceId)
Expand All @@ -204,7 +204,7 @@ private[r2dbc] class PostgresJournalDao(journalSettings: R2dbcSettings, connecti
else insertEventWithParameterTimestampSql

val stmt = bindInsertStatement(connection.createStatement(insertSql), event, useTimestampFromDb, previousSeqNr)
val result = R2dbcExecutor.updateOneReturningInTx(stmt, row => row.getTimestamp())
val result = R2dbcExecutor.updateOneReturningInTx(stmt, row => row.getTimestamp("db_timestamp"))
if (log.isDebugEnabled())
result.foreach { _ =>
log.debug("Wrote [{}] event for persistenceId [{}]", 1, persistenceId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,13 @@ private[r2dbc] class PostgresQueryDao(settings: R2dbcSettings, connectionFactory
r2dbcExecutor
.selectOne("select current db timestamp")(
connection => connection.createStatement(currentDbTimestampSql),
row => row.getTimestamp())
row => row.getTimestamp("db_timestamp"))
.map {
case Some(time) => time
case None => throw new IllegalStateException(s"Expected one row for: $currentDbTimestampSql")
}
}

//protected def tagsFromDb(row: Row, columnName: String): Set[String] = row.getTags(columnName)

protected def bindEventsBySlicesRangeSql(
stmt: Statement,
entityType: String,
Expand Down Expand Up @@ -224,27 +222,27 @@ private[r2dbc] class PostgresQueryDao(settings: R2dbcSettings, connectionFactory
entityType,
persistenceId = row.get("persistence_id", classOf[String]),
seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]),
dbTimestamp = row.getTimestamp(),
dbTimestamp = row.getTimestamp("db_timestamp"),
readDbTimestamp = row.getTimestamp("read_db_timestamp"),
payload = None, // lazy loaded for backtracking
serId = row.get[Integer]("event_ser_id", classOf[Integer]),
serManifest = "",
writerUuid = "", // not need in this query
tags = row.getTags(),
tags = row.getTags("tags"),
metadata = None)
else
SerializedJournalRow(
slice = row.get[Integer]("slice", classOf[Integer]),
entityType,
persistenceId = row.get("persistence_id", classOf[String]),
seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]),
dbTimestamp = row.getTimestamp(),
dbTimestamp = row.getTimestamp("db_timestamp"),
readDbTimestamp = row.getTimestamp("read_db_timestamp"),
payload = Some(row.getPayload("event_payload")),
serId = row.get[Integer]("event_ser_id", classOf[Integer]),
serManifest = row.get("event_ser_manifest", classOf[String]),
writerUuid = "", // not need in this query
tags = row.getTags(),
tags = row.getTags("tags"),
metadata = readMetadata(row)))

if (log.isDebugEnabled)
Expand All @@ -258,7 +256,7 @@ private[r2dbc] class PostgresQueryDao(settings: R2dbcSettings, connectionFactory
entityType: String,
fromTimestamp: Instant,
toTimestamp: Instant,
limit: Int): _root_.io.r2dbc.spi.Statement = {
limit: Int): Statement = {
stmt
.bind(0, entityType)
.bindTimestamp(1, fromTimestamp)
Expand Down Expand Up @@ -313,7 +311,7 @@ private[r2dbc] class PostgresQueryDao(settings: R2dbcSettings, connectionFactory
.createStatement(selectTimestampOfEventSql)
.bind(0, persistenceId)
.bind(1, seqNr),
row => row.getTimestamp())
row => row.getTimestamp("db_timestamp"))
}

override def loadEvent(
Expand All @@ -338,13 +336,13 @@ private[r2dbc] class PostgresQueryDao(settings: R2dbcSettings, connectionFactory
entityType = row.get("entity_type", classOf[String]),
persistenceId,
seqNr,
dbTimestamp = row.getTimestamp(),
dbTimestamp = row.getTimestamp("db_timestamp"),
readDbTimestamp = row.getTimestamp("read_db_timestamp"),
payload,
serId = row.get[Integer]("event_ser_id", classOf[Integer]),
serManifest = row.get("event_ser_manifest", classOf[String]),
writerUuid = "", // not need in this query
tags = row.getTags(),
tags = row.getTags("tags"),
metadata = readMetadata(row))
})

Expand All @@ -364,13 +362,13 @@ private[r2dbc] class PostgresQueryDao(settings: R2dbcSettings, connectionFactory
entityType = row.get("entity_type", classOf[String]),
persistenceId = row.get("persistence_id", classOf[String]),
seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]),
dbTimestamp = row.getTimestamp(),
dbTimestamp = row.getTimestamp("db_timestamp"),
readDbTimestamp = row.getTimestamp("read_db_timestamp"),
payload = Some(row.getPayload("event_payload")),
serId = row.get[Integer]("event_ser_id", classOf[Integer]),
serManifest = row.get("event_ser_manifest", classOf[String]),
writerUuid = row.get("writer", classOf[String]),
tags = row.getTags(),
tags = row.getTags("tags"),
metadata = readMetadata(row)))

if (log.isDebugEnabled)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,15 @@ private[r2dbc] class PostgresSnapshotDao(settings: R2dbcSettings, connectionFact
// db_timestamp and tags columns were added in 1.2.0
val dbTimestamp =
if (settings.querySettings.startFromSnapshotEnabled)
row.getTimestamp() match {
row.getTimestamp("db_timestamp") match {
case null => Instant.ofEpochMilli(writeTimestamp)
case t => t
}
else
Instant.ofEpochMilli(writeTimestamp)
val tags =
if (settings.querySettings.startFromSnapshotEnabled)
row.getTags()
row.getTags("tags")
else
Set.empty[String]

Expand Down Expand Up @@ -354,7 +354,7 @@ private[r2dbc] class PostgresSnapshotDao(settings: R2dbcSettings, connectionFact
r2dbcExecutor
.selectOne("select current db timestamp")(
connection => connection.createStatement(currentDbTimestampSql),
row => row.getTimestamp())
row => row.getTimestamp("db_timestamp"))
.map {
case Some(time) => time
case None => throw new IllegalStateException(s"Expected one row for: $currentDbTimestampSql")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ private[r2dbc] class SqlServerDurableStateDao(
.bindTimestamp("@fromTimestamp", fromTimestamp)
stmt.bind("@limit", settings.querySettings.bufferSize)
if (behindCurrentTime > Duration.Zero) {
stmt.bind("@now", timestampCodec.now())
stmt.bindTimestamp("@now", timestampCodec.instantNow())
}
toTimestamp.foreach(until => stmt.bindTimestamp("@until", until))
stmt
Expand Down Expand Up @@ -154,7 +154,7 @@ private[r2dbc] class SqlServerDurableStateDao(
}

override protected def bindTimestampNow(stmt: Statement, getAndIncIndex: () => Int): Statement =
stmt.bind(getAndIncIndex(), timestampCodec.now())
stmt.bindTimestamp(getAndIncIndex(), timestampCodec.instantNow())

override protected def persistenceIdsForEntityTypeAfterSql(table: String): String =
sql"SELECT TOP(@limit) persistence_id from $table WHERE persistence_id LIKE @persistenceIdLike AND persistence_id > @after ORDER BY persistence_id"
Expand Down Expand Up @@ -194,6 +194,6 @@ private[r2dbc] class SqlServerDurableStateDao(
.bind("@persistenceId", after)
.bind("@limit", limit)

override def currentDbTimestamp(): Future[Instant] = Future.successful(timestampCodec.now())
override def currentDbTimestamp(): Future[Instant] = Future.successful(timestampCodec.instantNow())

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import akka.actor.typed.ActorSystem
import akka.annotation.InternalApi
import akka.persistence.r2dbc.R2dbcSettings
import akka.persistence.r2dbc.internal.Sql.Interpolation
import akka.persistence.r2dbc.internal.codec.TimestampCodec.TimestampCodecRichStatement
import akka.persistence.r2dbc.internal.postgres.PostgresJournalDao
import io.r2dbc.spi.ConnectionFactory
import io.r2dbc.spi.Statement
Expand Down Expand Up @@ -51,7 +52,7 @@ private[r2dbc] class SqlServerJournalDao(settings: R2dbcSettings, connectionFact
VALUES (@slice, @entityType, @persistenceId, @seqNr, @writer, @adapterManifest, @eventSerId, @eventSerManifest, @eventPayload, @tags, @metaSerId, @metaSerManifest, @metaSerPayload, @dbTimestamp)"""

override protected def bindTimestampNow(stmt: Statement, getAndIncIndex: () => Int): Statement =
stmt.bind(getAndIncIndex(), timestampCodec.now())
stmt.bindTimestamp(getAndIncIndex(), timestampCodec.instantNow())

override def insertDeleteMarkerSql(timestamp: String): String = super.insertDeleteMarkerSql("?")
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import akka.annotation.InternalApi
import akka.persistence.r2dbc.R2dbcSettings
import akka.persistence.r2dbc.internal.InstantFactory
import akka.persistence.r2dbc.internal.Sql.Interpolation
import akka.persistence.r2dbc.internal.codec.TimestampCodec.SqlServerTimestampCodec
import akka.persistence.r2dbc.internal.codec.TimestampCodec.TimestampCodecRichStatement
import akka.persistence.r2dbc.internal.postgres.PostgresQueryDao
import io.r2dbc.spi.ConnectionFactory
Expand Down Expand Up @@ -94,7 +95,7 @@ private[r2dbc] class SqlServerQueryDao(settings: R2dbcSettings, connectionFactor
entityType: String,
fromTimestamp: Instant,
toTimestamp: Instant,
limit: Int): _root_.io.r2dbc.spi.Statement = {
limit: Int): Statement = {
stmt
.bind("@limit", limit)
.bind("@entityType", entityType)
Expand All @@ -112,7 +113,8 @@ private[r2dbc] class SqlServerQueryDao(settings: R2dbcSettings, connectionFactor
def toDbTimestampParamCondition =
if (toDbTimestampParam) "AND db_timestamp <= @until" else ""

def localNow: LocalDateTime = timestampCodec.now[LocalDateTime]()
// we know this is a LocalDateTime, so the cast should be ok
def localNow: LocalDateTime = timestampCodec.encode(timestampCodec.instantNow()).asInstanceOf[LocalDateTime]

def behindCurrentTimeIntervalCondition =
if (behindCurrentTime > Duration.Zero)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,6 @@ private[r2dbc] class SqlServerSnapshotDao(settings: R2dbcSettings, connectionFac
ORDER BY db_timestamp, seq_nr
"""

override def currentDbTimestamp(): Future[Instant] = Future.successful(timestampCodec.now())
override def currentDbTimestamp(): Future[Instant] = Future.successful(timestampCodec.instantNow())

}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class PersistTagsSpec
Row(
pid = row.get("persistence_id", classOf[String]),
seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]),
row.getTags()))
row.getTags("tags")))
.futureValue

rows.foreach { case Row(pid, _, tags) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class PersistTimestampSpec
Row(
pid = row.get("persistence_id", classOf[String]),
seqNr = row.get[java.lang.Long]("seq_nr", classOf[java.lang.Long]),
dbTimestamp = row.getTimestamp(),
dbTimestamp = row.getTimestamp("db_timestamp"),
event)
})
.futureValue
Expand Down

0 comments on commit a3b0573

Please sign in to comment.