Skip to content

Commit

Permalink
Save work
Browse files Browse the repository at this point in the history
  • Loading branch information
fanyang01 committed Nov 19, 2024
1 parent 63c7da3 commit 65bb037
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 54 deletions.
18 changes: 9 additions & 9 deletions delta/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ func (c *DeltaController) Flush(ctx *sql.Context, tx *stdsql.Tx, reason FlushRea
}

if stats.DeltaSize > 0 {
if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) {
ctx.GetLogger().WithFields(logrus.Fields{
if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.DebugLevel) {
log.WithFields(logrus.Fields{
"DeltaSize": stats.DeltaSize,
"Insertions": stats.Insertions,
"Deletions": stats.Deletions,
"Reason": reason.String(),
}).Trace("Flushed delta buffer")
}).Debug("Flushed delta buffer")
}
}

Expand Down Expand Up @@ -236,11 +236,11 @@ func (c *DeltaController) updateTable(
stats.DeltaSize += affected
defer tx.ExecContext(ctx, "DROP TABLE IF EXISTS temp.main.delta")

if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) {
if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.DebugLevel) {
log.WithFields(logrus.Fields{
"table": qualifiedTableName,
"rows": affected,
}).Trace("Delta created")
}).Debug("Delta created")
}

// Insert or replace new rows (action = INSERT) into the base table.
Expand All @@ -257,11 +257,11 @@ func (c *DeltaController) updateTable(
}
stats.Insertions += affected

if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) {
if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.DebugLevel) {
log.WithFields(logrus.Fields{
"table": qualifiedTableName,
"rows": affected,
}).Trace("Inserted")
}).Debug("Inserted")
}

// Delete rows that have been deleted.
Expand Down Expand Up @@ -304,11 +304,11 @@ func (c *DeltaController) updateTable(
// fmt.Printf("row:%+v\n", row)
// }

if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) {
if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.DebugLevel) {
log.WithFields(logrus.Fields{
"table": qualifiedTableName,
"rows": affected,
}).Trace("Deleted")
}).Debug("Deleted")
}

return nil
Expand Down
107 changes: 73 additions & 34 deletions pgserver/logrepl/replication.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ type rcvMsg struct {
}

type LogicalReplicator struct {
primaryDns string
primaryDns string
flushInterval time.Duration

running bool
messageReceived bool
Expand All @@ -62,8 +63,9 @@ type LogicalReplicator struct {
// connection to the primary is established when StartReplication is called.
func NewLogicalReplicator(primaryDns string) (*LogicalReplicator, error) {
return &LogicalReplicator{
primaryDns: primaryDns,
mu: &sync.Mutex{},
primaryDns: primaryDns,
flushInterval: 200 * time.Millisecond,
mu: &sync.Mutex{},
logger: logrus.WithFields(logrus.Fields{
"component": "replicator",
"protocol": "pg",
Expand Down Expand Up @@ -109,11 +111,10 @@ func (r *LogicalReplicator) CaughtUp(threshold int) (bool, error) {
}
defer conn.Close(context.Background())

result, err := conn.Query(context.Background(), "SELECT pg_wal_lsn_diff(write_lsn, sent_lsn) AS replication_lag FROM pg_stat_replication")
result, err := conn.Query(context.Background(), "SELECT pg_wal_lsn_diff(sent_lsn, flush_lsn) AS replication_lag FROM pg_stat_replication")
if err != nil {
return false, err
}

defer result.Close()

for result.Next() {
Expand All @@ -123,12 +124,10 @@ func (r *LogicalReplicator) CaughtUp(threshold int) (bool, error) {
}

row := rows[0]
r.logger.Debugf("Current replication lag: %+v", row)
lag, ok := row.(pgtype.Numeric)
if ok && lag.Valid {
r.logger.Debugf("Current replication lag: %v", row)
return int(math.Abs(float64(lag.Int.Int64()))) < threshold, nil
} else {
r.logger.Debugf("Replication lag unknown: %v", row)
}
}

Expand Down Expand Up @@ -193,9 +192,33 @@ type replicationState struct {
inTxnStmtID atomic.Uint64 // statement ID within transaction
}

func (state *replicationState) reset(ctx *sql.Context, slotName string, lsn pglogrepl.LSN) {
if state.deltas != nil {
state.deltas.Close()
}
if state.relations != nil {
clear(state.relations)
clear(state.schemas)
clear(state.keys)
}
*state = replicationState{
replicaCtx: ctx,
slotName: slotName,
lastWrittenLSN: lsn,
lastCommitLSN: lsn,
typeMap: pgtype.NewMap(),
relations: map[uint32]*pglogrepl.RelationMessageV2{},
schemas: map[uint32]sql.Schema{},
keys: map[uint32][]uint16{},
deltas: delta.NewController(),
lastCommitTime: time.Now(),
}
}

// StartReplication starts the replication process for the given slot name. This function blocks until replication is
// stopped via the Stop method, or an error occurs.
func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName string) error {
sqlCtx.SetLogger(r.logger)
standbyMessageTimeout := 10 * time.Second
nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout)

Expand All @@ -204,17 +227,8 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin
return err
}

state := &replicationState{
replicaCtx: sqlCtx,
slotName: slotName,
lastCommitLSN: lastWrittenLsn,
lastWrittenLSN: lastWrittenLsn,
typeMap: pgtype.NewMap(),
relations: map[uint32]*pglogrepl.RelationMessageV2{},
schemas: map[uint32]sql.Schema{},
keys: map[uint32][]uint16{},
deltas: delta.NewController(),
}
state := &replicationState{}
state.reset(sqlCtx, slotName, lastWrittenLsn)

// Switch to the `public` schema.
if _, err := adapter.Exec(sqlCtx, "USE public"); err != nil {
Expand All @@ -234,13 +248,16 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin
connErrCnt := 0
handleErrWithRetry := func(err error, incrementErrorCount bool) error {
if err != nil {
r.logger.Warnf("Handle error: %v", err)
if incrementErrorCount {
connErrCnt++
}
if connErrCnt < maxConsecutiveFailures {
r.logger.Debugf("Error: %v. Retrying", err)
r.logger.Warnf("Retrying (%d/%d) on error %v", connErrCnt, maxConsecutiveFailures, err)
if primaryConn != nil {
_ = primaryConn.Close(context.Background())
if err := primaryConn.Close(context.Background()); err != nil {
r.logger.Warnf("Failed to close connection: %v", err)
}
}
primaryConn = nil
return nil
Expand Down Expand Up @@ -276,7 +293,7 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin
r.stop = make(chan struct{})
r.mu.Unlock()

ticker := time.NewTicker(200 * time.Millisecond)
ticker := time.NewTicker(r.flushInterval)
defer ticker.Stop()

for {
Expand All @@ -298,6 +315,12 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin
time.Sleep(3 * time.Second)
return handleErrWithRetry(err, true)
}

// Reset the state on reconnection.
if err := r.rollback(sqlCtx); err != nil {
return err
}
state.reset(sqlCtx, slotName, state.lastWrittenLSN)
}

if time.Now().After(nextStandbyMessageDeadline) && state.lastReceivedLSN > 0 {
Expand Down Expand Up @@ -330,7 +353,10 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin
cancel()
case <-ticker.C:
cancel()
return r.commitOngoingTxnIfClean(state, delta.TimeTickFlushReason)
if time.Since(state.lastCommitTime) > r.flushInterval {
return r.commitOngoingTxnIfClean(state, delta.TimeTickFlushReason)
}
return nil
}

if msgAndErr.err != nil {
Expand Down Expand Up @@ -363,12 +389,11 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin
return fmt.Errorf("ParsePrimaryKeepaliveMessage failed: %w", err)
}

r.logger.Traceln("Primary Keepalive Message =>", "ServerWALEnd:", pkm.ServerWALEnd, "ServerTime:", pkm.ServerTime, "ReplyRequested:", pkm.ReplyRequested)
r.logger.Debugln("Primary Keepalive Message =>", "ServerWALEnd:", pkm.ServerWALEnd, "ServerTime:", pkm.ServerTime, "ReplyRequested:", pkm.ReplyRequested)
state.lastReceivedLSN = pkm.ServerWALEnd

if pkm.ReplyRequested {
// Send our reply the next time through the loop
nextStandbyMessageDeadline = time.Time{}
return sendStandbyStatusUpdate(state)
}
case pglogrepl.XLogDataByteID:
xld, err := pglogrepl.ParseXLogData(msg.Data[1:])
Expand All @@ -382,7 +407,8 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin
return handleErrWithRetry(err, true)
}

return sendStandbyStatusUpdate(state)
// return sendStandbyStatusUpdate(state)
return nil
default:
r.logger.Debugf("Received unexpected message: %T\n", rawMsg)
}
Expand All @@ -400,6 +426,20 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin
}
}

func (r *LogicalReplicator) rollback(ctx *sql.Context) error {
defer adapter.CloseTxn(ctx)
txn := adapter.TryGetTxn(ctx)
if txn == nil {
return nil
}
err := txn.Rollback()
if err != nil && !strings.Contains(err.Error(), "no transaction is active") {
r.logger.Debugf("Failed to roll back transaction: %v", err)
return err
}
return nil
}

func (r *LogicalReplicator) shutdown(ctx *sql.Context, state *replicationState) {
r.mu.Lock()
defer r.mu.Unlock()
Expand All @@ -408,10 +448,7 @@ func (r *LogicalReplicator) shutdown(ctx *sql.Context, state *replicationState)
r.commitOngoingTxnIfClean(state, delta.OnCloseFlushReason)

// Rollback any open transaction
_, err := adapter.ExecCatalog(ctx, "ROLLBACK")
if err != nil && !strings.Contains(err.Error(), "no transaction is active") {
r.logger.Debugf("Failed to roll back transaction: %v", err)
}
r.rollback(ctx)

r.running = false
close(r.stop)
Expand Down Expand Up @@ -468,7 +505,7 @@ func (r *LogicalReplicator) beginReplication(slotName string, lastFlushLsn pglog
if err != nil {
return nil, err
}
r.logger.Infoln("Logical replication started on slot", slotName)
r.logger.Infoln("Logical replication started on slot", slotName, "at WAL location", lastFlushLsn+1)

return conn, nil
}
Expand Down Expand Up @@ -586,8 +623,10 @@ func (r *LogicalReplicator) processMessage(

switch logicalMsg := logicalMsg.(type) {
case *pglogrepl.RelationMessageV2:
// When a Relation message is received, commit any buffered ongoing batch transactions.
if state.dirtyTxn.Load() {
_, exists := state.relations[logicalMsg.RelationID]
if exists {
// This means schema changes have occurred, so we need to
// commit any buffered ongoing batch transactions.
err := r.commitOngoingTxn(state, delta.DDLStmtFlushReason)
if err != nil {
return false, err
Expand Down
23 changes: 12 additions & 11 deletions pgserver/logrepl/replication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/cockroachdb/errors"
"github.com/dolthub/go-mysql-server/sql"
"github.com/jackc/pgx/v5"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -528,7 +529,7 @@ var replicationTests = []ReplicationTest{
}

func TestReplication(t *testing.T) {
// logrus.SetLevel(logrus.DebugLevel)
logrus.SetLevel(logrus.DebugLevel)
RunReplicationScripts(t, replicationTests)
}

Expand Down Expand Up @@ -568,14 +569,14 @@ func RunReplicationScripts(t *testing.T, scripts []ReplicationTest) {
require.NoError(t, logrepl.CreatePublication(primaryDns, slotName))
time.Sleep(500 * time.Millisecond)

// for i, script := range scripts {
// if i == 9 {
// RunReplicationScript(t, dsn, script)
// }
// }
for _, script := range scripts {
RunReplicationScript(t, dsn, script)
for i, script := range scripts {
if i == 0 {
RunReplicationScript(t, dsn, script)
}
}
// for _, script := range scripts {
// RunReplicationScript(t, dsn, script)
// }
}

const slotName = "myduck_slot"
Expand Down Expand Up @@ -605,7 +606,7 @@ func RunReplicationScript(t *testing.T, dsn string, script ReplicationTest) {
})
}

func newReplicator(t *testing.T, server *pgserver.Server, primaryDns string) *logrepl.LogicalReplicator {
func newReplicator(t *testing.T, primaryDns string) *logrepl.LogicalReplicator {
r, err := logrepl.NewLogicalReplicator(primaryDns)
require.NoError(t, err)
return r
Expand All @@ -620,7 +621,7 @@ func runReplicationScript(
replicaConn *pgx.Conn,
primaryDns string,
) {
r := newReplicator(t, server, primaryDns)
r := newReplicator(t, primaryDns)
defer r.Stop()

if script.Skip {
Expand Down Expand Up @@ -808,7 +809,7 @@ func waitForCaughtUp(r *logrepl.LogicalReplicator) error {
if time.Since(start) >= 5*time.Second {
return errors.New("Replication did not catch up")
}
time.Sleep(50 * time.Millisecond)
time.Sleep(1 * time.Second)
}

return nil
Expand Down

0 comments on commit 65bb037

Please sign in to comment.