Skip to content

Commit

Permalink
feat(pg): restart the replication (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
fanyang01 authored Nov 28, 2024
1 parent 42200b9 commit 0cd1d5b
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 29 deletions.
8 changes: 8 additions & 0 deletions catalog/internal_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ var InternalTables = struct {
PersistentVariable InternalTable
BinlogPosition InternalTable
PgReplicationLSN InternalTable
PgSubscription InternalTable
GlobalStatus InternalTable
// TODO(sean): This is a temporary work around for clients that query the 'pg_catalog.pg_stat_replication'.
// Once we add 'pg_catalog' and support views for PG, replace this by a view.
Expand Down Expand Up @@ -129,6 +130,13 @@ var InternalTables = struct {
ValueColumns: []string{"lsn"},
DDL: "slot_name TEXT PRIMARY KEY, lsn TEXT",
},
PgSubscription: InternalTable{
Schema: "__sys__",
Name: "pg_subscription",
KeyColumns: []string{"name"},
ValueColumns: []string{"connection", "publication"},
DDL: "name TEXT PRIMARY KEY, connection TEXT, publication TEXT",
},
GlobalStatus: InternalTable{
Schema: "performance_schema",
Name: "global_status",
Expand Down
14 changes: 14 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/apecloud/myduckserver/catalog"
"github.com/apecloud/myduckserver/myfunc"
"github.com/apecloud/myduckserver/pgserver"
"github.com/apecloud/myduckserver/pgserver/logrepl"
"github.com/apecloud/myduckserver/plugin"
"github.com/apecloud/myduckserver/replica"
"github.com/apecloud/myduckserver/transpiler"
Expand Down Expand Up @@ -171,6 +172,19 @@ func main() {
if err != nil {
logrus.WithError(err).Fatalln("Failed to create Postgres-protocol server")
}

// Check if there is a replication subscription and start replication if there is.
_, conn, pub, ok, err := logrepl.FindReplication(pool.DB)
if err != nil {
logrus.WithError(err).Warnln("Failed to find replication")
} else if ok {
replicator, err := logrepl.NewLogicalReplicator(conn)
if err != nil {
logrus.WithError(err).Fatalln("Failed to create logical replicator")
}
replicator.StartReplication(pgServer.NewInternalCtx(), pub)
}

go pgServer.Start()
}

Expand Down
5 changes: 4 additions & 1 deletion pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type ConnectionHandler struct {
// COPY DATA messages from the client to import data into tables.
copyFromStdinState *copyFromStdinState

server *Server
logger *logrus.Entry
}

Expand Down Expand Up @@ -87,7 +88,7 @@ func init() {
}

// NewConnectionHandler returns a new ConnectionHandler for the connection provided
func NewConnectionHandler(conn net.Conn, handler mysql.Handler, engine *gms.Engine, sm *server.SessionManager, connID uint32) *ConnectionHandler {
func NewConnectionHandler(conn net.Conn, handler mysql.Handler, engine *gms.Engine, sm *server.SessionManager, connID uint32, server *Server) *ConnectionHandler {
mysqlConn := &mysql.Conn{
Conn: conn,
PrepareData: make(map[uint32]*mysql.PrepareData),
Expand Down Expand Up @@ -117,6 +118,8 @@ func NewConnectionHandler(conn net.Conn, handler mysql.Handler, engine *gms.Engi
duckHandler: duckHandler,
backend: pgproto3.NewBackend(conn, conn),
pgTypeMap: pgtype.NewMap(),

server: server,
logger: logrus.WithFields(logrus.Fields{
"connectionID": connID,
"protocol": "pg",
Expand Down
6 changes: 2 additions & 4 deletions pgserver/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ type Listener struct {
connID *atomic.Uint32
}

var _ server.ProtocolListener = (*Listener)(nil)

type ListenerOpt func(*Listener)

func WithCertificate(cert tls.Certificate) ListenerOpt {
Expand Down Expand Up @@ -89,7 +87,7 @@ func NewListenerWithOpts(listenerCfg mysql.ListenerConfig, opts ...ListenerOpt)
}

// Accept handles incoming connections.
func (l *Listener) Accept() {
func (l *Listener) Accept(server *Server) {
for {
conn, err := l.listener.Accept()
if err != nil {
Expand All @@ -106,7 +104,7 @@ func (l *Listener) Accept() {
conn = netutil.NewConnWithTimeouts(conn, l.cfg.ConnReadTimeout, l.cfg.ConnWriteTimeout)
}

connectionHandler := NewConnectionHandler(conn, l.cfg.Handler, l.engine, l.sm, l.connID.Add(1))
connectionHandler := NewConnectionHandler(conn, l.cfg.Handler, l.engine, l.sm, l.connID.Add(1), server)
go connectionHandler.HandleConnection()
}
}
Expand Down
35 changes: 35 additions & 0 deletions pgserver/logrepl/subscription.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package logrepl

import (
"context"
stdsql "database/sql"

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/catalog"
"github.com/dolthub/go-mysql-server/sql"
)

func WriteSubscription(ctx *sql.Context, name, conn, pub string) error {
_, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgSubscription.UpsertStmt(), name, conn, pub)
return err
}

func FindReplication(db *stdsql.DB) (name, conn, pub string, ok bool, err error) {
var rows *stdsql.Rows
rows, err = db.QueryContext(context.Background(), catalog.InternalTables.PgSubscription.SelectAllStmt())
if err != nil {
return
}
defer rows.Close()

if !rows.Next() {
return
}

if err = rows.Scan(&name, &conn, &pub); err != nil {
return
}

ok = true
return
}
11 changes: 1 addition & 10 deletions pgserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package pgserver
import (
"fmt"

"github.com/apecloud/myduckserver/pgserver/logrepl"
"github.com/dolthub/go-mysql-server/server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/mysql"
Expand Down Expand Up @@ -36,15 +35,7 @@ func NewServer(host string, port int, newCtx func() *sql.Context, options ...Lis
}

func (s *Server) Start() {
s.Listener.Accept()
}

func (s *Server) StartReplication(primaryDsn string, slotName string) error {
replicator, err := logrepl.NewLogicalReplicator(primaryDsn)
if err != nil {
return err
}
return replicator.StartReplication(s.NewInternalCtx(), slotName)
s.Listener.Accept(s)
}

func (s *Server) Close() {
Expand Down
38 changes: 24 additions & 14 deletions pgserver/subscription_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"context"
stdsql "database/sql"
"fmt"
"regexp"
"strings"

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/pgserver/logrepl"
"github.com/jackc/pglogrepl"
"regexp"
"strings"
)

// This file implements the logic for handling CREATE SUBSCRIPTION SQL statements.
Expand Down Expand Up @@ -116,11 +117,13 @@ func executeCreateSubscriptionSQL(h *ConnectionHandler, subscriptionConfig *Subs
return fmt.Errorf("failed to execute snapshot in CREATE SUBSCRIPTION: %w", err)
}

err = doCreateSubscription(h, subscriptionConfig)
replicator, err := doCreateSubscription(h, subscriptionConfig)
if err != nil {
return fmt.Errorf("failed to execute CREATE SUBSCRIPTION: %w", err)
}

go replicator.StartReplication(h.server.NewInternalCtx(), subscriptionConfig.PublicationName)

return nil
}

Expand Down Expand Up @@ -158,40 +161,47 @@ func doSnapshot(h *ConnectionHandler, subscriptionConfig *SubscriptionConfig) er
return nil
}

func doCreateSubscription(h *ConnectionHandler, subscriptionConfig *SubscriptionConfig) error {
func doCreateSubscription(h *ConnectionHandler, subscriptionConfig *SubscriptionConfig) (*logrepl.LogicalReplicator, error) {
replicator, err := logrepl.NewLogicalReplicator(subscriptionConfig.ToDNS())
if err != nil {
return fmt.Errorf("failed to create logical replicator: %w", err)
return nil, fmt.Errorf("failed to create logical replicator: %w", err)
}

err = logrepl.CreatePublicationIfNotExists(subscriptionConfig.ToDNS(), subscriptionConfig.PublicationName)
if err != nil {
return fmt.Errorf("failed to create publication: %w", err)
return nil, fmt.Errorf("failed to create publication: %w", err)
}

err = replicator.CreateReplicationSlotIfNotExists(subscriptionConfig.PublicationName)
if err != nil {
return fmt.Errorf("failed to create replication slot: %w", err)
return nil, fmt.Errorf("failed to create replication slot: %w", err)
}

sqlCtx, err := h.duckHandler.sm.NewContextWithQuery(context.Background(), h.mysqlConn, "")
if err != nil {
return fmt.Errorf("failed to create context for query: %w", err)
return nil, fmt.Errorf("failed to create context for query: %w", err)
}

err = replicator.WriteWALPosition(sqlCtx, subscriptionConfig.PublicationName, subscriptionConfig.LSN)
// `WriteWALPosition` and `WriteSubscription` execute in a transaction internally,
// so we start a transaction here and commit it after writing the WAL position.
tx, err := adapter.GetCatalogTxn(sqlCtx, nil)
if err != nil {
return fmt.Errorf("failed to write WAL position: %w", err)
return nil, fmt.Errorf("failed to get transaction: %w", err)
}
defer tx.Rollback()
defer adapter.CloseTxn(sqlCtx)

sqlCtx, err = h.duckHandler.sm.NewContextWithQuery(context.Background(), h.mysqlConn, "")
err = replicator.WriteWALPosition(sqlCtx, subscriptionConfig.PublicationName, subscriptionConfig.LSN)
if err != nil {
return fmt.Errorf("failed to create context for query: %w", err)
return nil, fmt.Errorf("failed to write WAL position: %w", err)
}

go replicator.StartReplication(sqlCtx, subscriptionConfig.PublicationName)
err = logrepl.WriteSubscription(sqlCtx, subscriptionConfig.SubscriptionName, subscriptionConfig.ToDNS(), subscriptionConfig.PublicationName)
if err != nil {
return nil, fmt.Errorf("failed to write subscription: %w", err)
}

return nil
return replicator, tx.Commit()
}

// processLSN scans the rows for the LSN value and updates the subscriptionConfig.
Expand Down

0 comments on commit 0cd1d5b

Please sign in to comment.