From 65bb037e660357da68215c119972c03c48866c0b Mon Sep 17 00:00:00 2001 From: Fan Yang Date: Tue, 19 Nov 2024 21:27:17 +0800 Subject: [PATCH] Save work --- delta/controller.go | 18 ++--- pgserver/logrepl/replication.go | 107 ++++++++++++++++++--------- pgserver/logrepl/replication_test.go | 23 +++--- 3 files changed, 94 insertions(+), 54 deletions(-) diff --git a/delta/controller.go b/delta/controller.go index 651dd9dd..073463f6 100644 --- a/delta/controller.go +++ b/delta/controller.go @@ -118,13 +118,13 @@ func (c *DeltaController) Flush(ctx *sql.Context, tx *stdsql.Tx, reason FlushRea } if stats.DeltaSize > 0 { - if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { - ctx.GetLogger().WithFields(logrus.Fields{ + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.DebugLevel) { + log.WithFields(logrus.Fields{ "DeltaSize": stats.DeltaSize, "Insertions": stats.Insertions, "Deletions": stats.Deletions, "Reason": reason.String(), - }).Trace("Flushed delta buffer") + }).Debug("Flushed delta buffer") } } @@ -236,11 +236,11 @@ func (c *DeltaController) updateTable( stats.DeltaSize += affected defer tx.ExecContext(ctx, "DROP TABLE IF EXISTS temp.main.delta") - if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.DebugLevel) { log.WithFields(logrus.Fields{ "table": qualifiedTableName, "rows": affected, - }).Trace("Delta created") + }).Debug("Delta created") } // Insert or replace new rows (action = INSERT) into the base table. @@ -257,11 +257,11 @@ func (c *DeltaController) updateTable( } stats.Insertions += affected - if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.DebugLevel) { log.WithFields(logrus.Fields{ "table": qualifiedTableName, "rows": affected, - }).Trace("Inserted") + }).Debug("Inserted") } // Delete rows that have been deleted. @@ -304,11 +304,11 @@ func (c *DeltaController) updateTable( // fmt.Printf("row:%+v\n", row) // } - if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.DebugLevel) { log.WithFields(logrus.Fields{ "table": qualifiedTableName, "rows": affected, - }).Trace("Deleted") + }).Debug("Deleted") } return nil diff --git a/pgserver/logrepl/replication.go b/pgserver/logrepl/replication.go index 5442f047..9d611116 100644 --- a/pgserver/logrepl/replication.go +++ b/pgserver/logrepl/replication.go @@ -47,7 +47,8 @@ type rcvMsg struct { } type LogicalReplicator struct { - primaryDns string + primaryDns string + flushInterval time.Duration running bool messageReceived bool @@ -62,8 +63,9 @@ type LogicalReplicator struct { // connection to the primary is established when StartReplication is called. func NewLogicalReplicator(primaryDns string) (*LogicalReplicator, error) { return &LogicalReplicator{ - primaryDns: primaryDns, - mu: &sync.Mutex{}, + primaryDns: primaryDns, + flushInterval: 200 * time.Millisecond, + mu: &sync.Mutex{}, logger: logrus.WithFields(logrus.Fields{ "component": "replicator", "protocol": "pg", @@ -109,11 +111,10 @@ func (r *LogicalReplicator) CaughtUp(threshold int) (bool, error) { } defer conn.Close(context.Background()) - result, err := conn.Query(context.Background(), "SELECT pg_wal_lsn_diff(write_lsn, sent_lsn) AS replication_lag FROM pg_stat_replication") + result, err := conn.Query(context.Background(), "SELECT pg_wal_lsn_diff(sent_lsn, flush_lsn) AS replication_lag FROM pg_stat_replication") if err != nil { return false, err } - defer result.Close() for result.Next() { @@ -123,12 +124,10 @@ func (r *LogicalReplicator) CaughtUp(threshold int) (bool, error) { } row := rows[0] + r.logger.Debugf("Current replication lag: %+v", row) lag, ok := row.(pgtype.Numeric) if ok && lag.Valid { - r.logger.Debugf("Current replication lag: %v", row) return int(math.Abs(float64(lag.Int.Int64()))) < threshold, nil - } else { - r.logger.Debugf("Replication lag unknown: %v", row) } } @@ -193,9 +192,33 @@ type replicationState struct { inTxnStmtID atomic.Uint64 // statement ID within transaction } +func (state *replicationState) reset(ctx *sql.Context, slotName string, lsn pglogrepl.LSN) { + if state.deltas != nil { + state.deltas.Close() + } + if state.relations != nil { + clear(state.relations) + clear(state.schemas) + clear(state.keys) + } + *state = replicationState{ + replicaCtx: ctx, + slotName: slotName, + lastWrittenLSN: lsn, + lastCommitLSN: lsn, + typeMap: pgtype.NewMap(), + relations: map[uint32]*pglogrepl.RelationMessageV2{}, + schemas: map[uint32]sql.Schema{}, + keys: map[uint32][]uint16{}, + deltas: delta.NewController(), + lastCommitTime: time.Now(), + } +} + // StartReplication starts the replication process for the given slot name. This function blocks until replication is // stopped via the Stop method, or an error occurs. func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName string) error { + sqlCtx.SetLogger(r.logger) standbyMessageTimeout := 10 * time.Second nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout) @@ -204,17 +227,8 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin return err } - state := &replicationState{ - replicaCtx: sqlCtx, - slotName: slotName, - lastCommitLSN: lastWrittenLsn, - lastWrittenLSN: lastWrittenLsn, - typeMap: pgtype.NewMap(), - relations: map[uint32]*pglogrepl.RelationMessageV2{}, - schemas: map[uint32]sql.Schema{}, - keys: map[uint32][]uint16{}, - deltas: delta.NewController(), - } + state := &replicationState{} + state.reset(sqlCtx, slotName, lastWrittenLsn) // Switch to the `public` schema. if _, err := adapter.Exec(sqlCtx, "USE public"); err != nil { @@ -234,13 +248,16 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin connErrCnt := 0 handleErrWithRetry := func(err error, incrementErrorCount bool) error { if err != nil { + r.logger.Warnf("Handle error: %v", err) if incrementErrorCount { connErrCnt++ } if connErrCnt < maxConsecutiveFailures { - r.logger.Debugf("Error: %v. Retrying", err) + r.logger.Warnf("Retrying (%d/%d) on error %v", connErrCnt, maxConsecutiveFailures, err) if primaryConn != nil { - _ = primaryConn.Close(context.Background()) + if err := primaryConn.Close(context.Background()); err != nil { + r.logger.Warnf("Failed to close connection: %v", err) + } } primaryConn = nil return nil @@ -276,7 +293,7 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin r.stop = make(chan struct{}) r.mu.Unlock() - ticker := time.NewTicker(200 * time.Millisecond) + ticker := time.NewTicker(r.flushInterval) defer ticker.Stop() for { @@ -298,6 +315,12 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin time.Sleep(3 * time.Second) return handleErrWithRetry(err, true) } + + // Reset the state on reconnection. + if err := r.rollback(sqlCtx); err != nil { + return err + } + state.reset(sqlCtx, slotName, state.lastWrittenLSN) } if time.Now().After(nextStandbyMessageDeadline) && state.lastReceivedLSN > 0 { @@ -330,7 +353,10 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin cancel() case <-ticker.C: cancel() - return r.commitOngoingTxnIfClean(state, delta.TimeTickFlushReason) + if time.Since(state.lastCommitTime) > r.flushInterval { + return r.commitOngoingTxnIfClean(state, delta.TimeTickFlushReason) + } + return nil } if msgAndErr.err != nil { @@ -363,12 +389,11 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin return fmt.Errorf("ParsePrimaryKeepaliveMessage failed: %w", err) } - r.logger.Traceln("Primary Keepalive Message =>", "ServerWALEnd:", pkm.ServerWALEnd, "ServerTime:", pkm.ServerTime, "ReplyRequested:", pkm.ReplyRequested) + r.logger.Debugln("Primary Keepalive Message =>", "ServerWALEnd:", pkm.ServerWALEnd, "ServerTime:", pkm.ServerTime, "ReplyRequested:", pkm.ReplyRequested) state.lastReceivedLSN = pkm.ServerWALEnd if pkm.ReplyRequested { - // Send our reply the next time through the loop - nextStandbyMessageDeadline = time.Time{} + return sendStandbyStatusUpdate(state) } case pglogrepl.XLogDataByteID: xld, err := pglogrepl.ParseXLogData(msg.Data[1:]) @@ -382,7 +407,8 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin return handleErrWithRetry(err, true) } - return sendStandbyStatusUpdate(state) + // return sendStandbyStatusUpdate(state) + return nil default: r.logger.Debugf("Received unexpected message: %T\n", rawMsg) } @@ -400,6 +426,20 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin } } +func (r *LogicalReplicator) rollback(ctx *sql.Context) error { + defer adapter.CloseTxn(ctx) + txn := adapter.TryGetTxn(ctx) + if txn == nil { + return nil + } + err := txn.Rollback() + if err != nil && !strings.Contains(err.Error(), "no transaction is active") { + r.logger.Debugf("Failed to roll back transaction: %v", err) + return err + } + return nil +} + func (r *LogicalReplicator) shutdown(ctx *sql.Context, state *replicationState) { r.mu.Lock() defer r.mu.Unlock() @@ -408,10 +448,7 @@ func (r *LogicalReplicator) shutdown(ctx *sql.Context, state *replicationState) r.commitOngoingTxnIfClean(state, delta.OnCloseFlushReason) // Rollback any open transaction - _, err := adapter.ExecCatalog(ctx, "ROLLBACK") - if err != nil && !strings.Contains(err.Error(), "no transaction is active") { - r.logger.Debugf("Failed to roll back transaction: %v", err) - } + r.rollback(ctx) r.running = false close(r.stop) @@ -468,7 +505,7 @@ func (r *LogicalReplicator) beginReplication(slotName string, lastFlushLsn pglog if err != nil { return nil, err } - r.logger.Infoln("Logical replication started on slot", slotName) + r.logger.Infoln("Logical replication started on slot", slotName, "at WAL location", lastFlushLsn+1) return conn, nil } @@ -586,8 +623,10 @@ func (r *LogicalReplicator) processMessage( switch logicalMsg := logicalMsg.(type) { case *pglogrepl.RelationMessageV2: - // When a Relation message is received, commit any buffered ongoing batch transactions. - if state.dirtyTxn.Load() { + _, exists := state.relations[logicalMsg.RelationID] + if exists { + // This means schema changes have occurred, so we need to + // commit any buffered ongoing batch transactions. err := r.commitOngoingTxn(state, delta.DDLStmtFlushReason) if err != nil { return false, err diff --git a/pgserver/logrepl/replication_test.go b/pgserver/logrepl/replication_test.go index c9ad08b9..8a41e5f7 100644 --- a/pgserver/logrepl/replication_test.go +++ b/pgserver/logrepl/replication_test.go @@ -30,6 +30,7 @@ import ( "github.com/cockroachdb/errors" "github.com/dolthub/go-mysql-server/sql" "github.com/jackc/pgx/v5" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -528,7 +529,7 @@ var replicationTests = []ReplicationTest{ } func TestReplication(t *testing.T) { - // logrus.SetLevel(logrus.DebugLevel) + logrus.SetLevel(logrus.DebugLevel) RunReplicationScripts(t, replicationTests) } @@ -568,14 +569,14 @@ func RunReplicationScripts(t *testing.T, scripts []ReplicationTest) { require.NoError(t, logrepl.CreatePublication(primaryDns, slotName)) time.Sleep(500 * time.Millisecond) - // for i, script := range scripts { - // if i == 9 { - // RunReplicationScript(t, dsn, script) - // } - // } - for _, script := range scripts { - RunReplicationScript(t, dsn, script) + for i, script := range scripts { + if i == 0 { + RunReplicationScript(t, dsn, script) + } } + // for _, script := range scripts { + // RunReplicationScript(t, dsn, script) + // } } const slotName = "myduck_slot" @@ -605,7 +606,7 @@ func RunReplicationScript(t *testing.T, dsn string, script ReplicationTest) { }) } -func newReplicator(t *testing.T, server *pgserver.Server, primaryDns string) *logrepl.LogicalReplicator { +func newReplicator(t *testing.T, primaryDns string) *logrepl.LogicalReplicator { r, err := logrepl.NewLogicalReplicator(primaryDns) require.NoError(t, err) return r @@ -620,7 +621,7 @@ func runReplicationScript( replicaConn *pgx.Conn, primaryDns string, ) { - r := newReplicator(t, server, primaryDns) + r := newReplicator(t, primaryDns) defer r.Stop() if script.Skip { @@ -808,7 +809,7 @@ func waitForCaughtUp(r *logrepl.LogicalReplicator) error { if time.Since(start) >= 5*time.Second { return errors.New("Replication did not catch up") } - time.Sleep(50 * time.Millisecond) + time.Sleep(1 * time.Second) } return nil