diff --git a/catalog/internal_tables.go b/catalog/internal_tables.go index e258f409..0dabd5dd 100644 --- a/catalog/internal_tables.go +++ b/catalog/internal_tables.go @@ -99,6 +99,7 @@ var InternalTables = struct { PersistentVariable InternalTable BinlogPosition InternalTable PgReplicationLSN InternalTable + PgSubscription InternalTable GlobalStatus InternalTable // TODO(sean): This is a temporary work around for clients that query the 'pg_catalog.pg_stat_replication'. // Once we add 'pg_catalog' and support views for PG, replace this by a view. @@ -129,6 +130,13 @@ var InternalTables = struct { ValueColumns: []string{"lsn"}, DDL: "slot_name TEXT PRIMARY KEY, lsn TEXT", }, + PgSubscription: InternalTable{ + Schema: "__sys__", + Name: "pg_subscription", + KeyColumns: []string{"name"}, + ValueColumns: []string{"connection", "publication"}, + DDL: "name TEXT PRIMARY KEY, connection TEXT, publication TEXT", + }, GlobalStatus: InternalTable{ Schema: "performance_schema", Name: "global_status", diff --git a/main.go b/main.go index 3e67152b..bdd071d4 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,7 @@ import ( "github.com/apecloud/myduckserver/catalog" "github.com/apecloud/myduckserver/myfunc" "github.com/apecloud/myduckserver/pgserver" + "github.com/apecloud/myduckserver/pgserver/logrepl" "github.com/apecloud/myduckserver/plugin" "github.com/apecloud/myduckserver/replica" "github.com/apecloud/myduckserver/transpiler" @@ -171,6 +172,19 @@ func main() { if err != nil { logrus.WithError(err).Fatalln("Failed to create Postgres-protocol server") } + + // Check if there is a replication subscription and start replication if there is. + _, conn, pub, ok, err := logrepl.FindReplication(pool.DB) + if err != nil { + logrus.WithError(err).Warnln("Failed to find replication") + } else if ok { + replicator, err := logrepl.NewLogicalReplicator(conn) + if err != nil { + logrus.WithError(err).Fatalln("Failed to create logical replicator") + } + replicator.StartReplication(pgServer.NewInternalCtx(), pub) + } + go pgServer.Start() } diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index c4806016..0561e8e7 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -59,6 +59,7 @@ type ConnectionHandler struct { // COPY DATA messages from the client to import data into tables. copyFromStdinState *copyFromStdinState + server *Server logger *logrus.Entry } @@ -87,7 +88,7 @@ func init() { } // NewConnectionHandler returns a new ConnectionHandler for the connection provided -func NewConnectionHandler(conn net.Conn, handler mysql.Handler, engine *gms.Engine, sm *server.SessionManager, connID uint32) *ConnectionHandler { +func NewConnectionHandler(conn net.Conn, handler mysql.Handler, engine *gms.Engine, sm *server.SessionManager, connID uint32, server *Server) *ConnectionHandler { mysqlConn := &mysql.Conn{ Conn: conn, PrepareData: make(map[uint32]*mysql.PrepareData), @@ -117,6 +118,8 @@ func NewConnectionHandler(conn net.Conn, handler mysql.Handler, engine *gms.Engi duckHandler: duckHandler, backend: pgproto3.NewBackend(conn, conn), pgTypeMap: pgtype.NewMap(), + + server: server, logger: logrus.WithFields(logrus.Fields{ "connectionID": connID, "protocol": "pg", diff --git a/pgserver/listener.go b/pgserver/listener.go index 3132b859..970e27d2 100644 --- a/pgserver/listener.go +++ b/pgserver/listener.go @@ -42,8 +42,6 @@ type Listener struct { connID *atomic.Uint32 } -var _ server.ProtocolListener = (*Listener)(nil) - type ListenerOpt func(*Listener) func WithCertificate(cert tls.Certificate) ListenerOpt { @@ -89,7 +87,7 @@ func NewListenerWithOpts(listenerCfg mysql.ListenerConfig, opts ...ListenerOpt) } // Accept handles incoming connections. -func (l *Listener) Accept() { +func (l *Listener) Accept(server *Server) { for { conn, err := l.listener.Accept() if err != nil { @@ -106,7 +104,7 @@ func (l *Listener) Accept() { conn = netutil.NewConnWithTimeouts(conn, l.cfg.ConnReadTimeout, l.cfg.ConnWriteTimeout) } - connectionHandler := NewConnectionHandler(conn, l.cfg.Handler, l.engine, l.sm, l.connID.Add(1)) + connectionHandler := NewConnectionHandler(conn, l.cfg.Handler, l.engine, l.sm, l.connID.Add(1), server) go connectionHandler.HandleConnection() } } diff --git a/pgserver/logrepl/subscription.go b/pgserver/logrepl/subscription.go new file mode 100644 index 00000000..34b1f876 --- /dev/null +++ b/pgserver/logrepl/subscription.go @@ -0,0 +1,35 @@ +package logrepl + +import ( + "context" + stdsql "database/sql" + + "github.com/apecloud/myduckserver/adapter" + "github.com/apecloud/myduckserver/catalog" + "github.com/dolthub/go-mysql-server/sql" +) + +func WriteSubscription(ctx *sql.Context, name, conn, pub string) error { + _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgSubscription.UpsertStmt(), name, conn, pub) + return err +} + +func FindReplication(db *stdsql.DB) (name, conn, pub string, ok bool, err error) { + var rows *stdsql.Rows + rows, err = db.QueryContext(context.Background(), catalog.InternalTables.PgSubscription.SelectAllStmt()) + if err != nil { + return + } + defer rows.Close() + + if !rows.Next() { + return + } + + if err = rows.Scan(&name, &conn, &pub); err != nil { + return + } + + ok = true + return +} diff --git a/pgserver/server.go b/pgserver/server.go index bdd0cfdb..d2bdffee 100644 --- a/pgserver/server.go +++ b/pgserver/server.go @@ -3,7 +3,6 @@ package pgserver import ( "fmt" - "github.com/apecloud/myduckserver/pgserver/logrepl" "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/mysql" @@ -36,15 +35,7 @@ func NewServer(host string, port int, newCtx func() *sql.Context, options ...Lis } func (s *Server) Start() { - s.Listener.Accept() -} - -func (s *Server) StartReplication(primaryDsn string, slotName string) error { - replicator, err := logrepl.NewLogicalReplicator(primaryDsn) - if err != nil { - return err - } - return replicator.StartReplication(s.NewInternalCtx(), slotName) + s.Listener.Accept(s) } func (s *Server) Close() { diff --git a/pgserver/subscription_handler.go b/pgserver/subscription_handler.go index 22848b31..874350c8 100644 --- a/pgserver/subscription_handler.go +++ b/pgserver/subscription_handler.go @@ -4,11 +4,12 @@ import ( "context" stdsql "database/sql" "fmt" + "regexp" + "strings" + "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/pgserver/logrepl" "github.com/jackc/pglogrepl" - "regexp" - "strings" ) // This file implements the logic for handling CREATE SUBSCRIPTION SQL statements. @@ -116,11 +117,13 @@ func executeCreateSubscriptionSQL(h *ConnectionHandler, subscriptionConfig *Subs return fmt.Errorf("failed to execute snapshot in CREATE SUBSCRIPTION: %w", err) } - err = doCreateSubscription(h, subscriptionConfig) + replicator, err := doCreateSubscription(h, subscriptionConfig) if err != nil { return fmt.Errorf("failed to execute CREATE SUBSCRIPTION: %w", err) } + go replicator.StartReplication(h.server.NewInternalCtx(), subscriptionConfig.PublicationName) + return nil } @@ -158,40 +161,47 @@ func doSnapshot(h *ConnectionHandler, subscriptionConfig *SubscriptionConfig) er return nil } -func doCreateSubscription(h *ConnectionHandler, subscriptionConfig *SubscriptionConfig) error { +func doCreateSubscription(h *ConnectionHandler, subscriptionConfig *SubscriptionConfig) (*logrepl.LogicalReplicator, error) { replicator, err := logrepl.NewLogicalReplicator(subscriptionConfig.ToDNS()) if err != nil { - return fmt.Errorf("failed to create logical replicator: %w", err) + return nil, fmt.Errorf("failed to create logical replicator: %w", err) } err = logrepl.CreatePublicationIfNotExists(subscriptionConfig.ToDNS(), subscriptionConfig.PublicationName) if err != nil { - return fmt.Errorf("failed to create publication: %w", err) + return nil, fmt.Errorf("failed to create publication: %w", err) } err = replicator.CreateReplicationSlotIfNotExists(subscriptionConfig.PublicationName) if err != nil { - return fmt.Errorf("failed to create replication slot: %w", err) + return nil, fmt.Errorf("failed to create replication slot: %w", err) } sqlCtx, err := h.duckHandler.sm.NewContextWithQuery(context.Background(), h.mysqlConn, "") if err != nil { - return fmt.Errorf("failed to create context for query: %w", err) + return nil, fmt.Errorf("failed to create context for query: %w", err) } - err = replicator.WriteWALPosition(sqlCtx, subscriptionConfig.PublicationName, subscriptionConfig.LSN) + // `WriteWALPosition` and `WriteSubscription` execute in a transaction internally, + // so we start a transaction here and commit it after writing the WAL position. + tx, err := adapter.GetCatalogTxn(sqlCtx, nil) if err != nil { - return fmt.Errorf("failed to write WAL position: %w", err) + return nil, fmt.Errorf("failed to get transaction: %w", err) } + defer tx.Rollback() + defer adapter.CloseTxn(sqlCtx) - sqlCtx, err = h.duckHandler.sm.NewContextWithQuery(context.Background(), h.mysqlConn, "") + err = replicator.WriteWALPosition(sqlCtx, subscriptionConfig.PublicationName, subscriptionConfig.LSN) if err != nil { - return fmt.Errorf("failed to create context for query: %w", err) + return nil, fmt.Errorf("failed to write WAL position: %w", err) } - go replicator.StartReplication(sqlCtx, subscriptionConfig.PublicationName) + err = logrepl.WriteSubscription(sqlCtx, subscriptionConfig.SubscriptionName, subscriptionConfig.ToDNS(), subscriptionConfig.PublicationName) + if err != nil { + return nil, fmt.Errorf("failed to write subscription: %w", err) + } - return nil + return replicator, tx.Commit() } // processLSN scans the rows for the LSN value and updates the subscriptionConfig.