From 75a9cd76307e596c13d1dda3e4f3b51630c8594c Mon Sep 17 00:00:00 2001 From: TianyuZhang1214 Date: Thu, 5 Dec 2024 19:06:43 +0800 Subject: [PATCH] to #246 feat: refactor the subscription. --- binlogreplication/binlog_replica_applier.go | 2 +- catalog/internal_tables.go | 62 ++++++-- main.go | 8 +- pgserver/connection_handler.go | 16 +- pgserver/logrepl/replication.go | 25 +-- pgserver/logrepl/subscription.go | 168 ++++++++++---------- pgserver/pg_catalog_handler.go | 13 +- pgserver/subscription_handler.go | 101 ++++-------- 8 files changed, 188 insertions(+), 207 deletions(-) diff --git a/binlogreplication/binlog_replica_applier.go b/binlogreplication/binlog_replica_applier.go index 3feca741..55532052 100644 --- a/binlogreplication/binlog_replica_applier.go +++ b/binlogreplication/binlog_replica_applier.go @@ -1432,7 +1432,7 @@ func convertSqlTypesValue(ctx *sql.Context, engine *gms.Engine, value sqltypes.V default: convertedValue, _, err = column.Type.Convert(value.ToString()) - // logrus.WithField("column", column.Name).WithField("type", column.Type).Infof( + // logrus.WithField("column", column.Subscription).WithField("type", column.Type).Infof( // "Converting value[%s %v %s] to %v %T", // value.Type(), value.Raw(), value.ToString(), convertedValue, convertedValue, // ) diff --git a/catalog/internal_tables.go b/catalog/internal_tables.go index 8220b2e8..b26badd3 100644 --- a/catalog/internal_tables.go +++ b/catalog/internal_tables.go @@ -15,6 +15,29 @@ func (it *InternalTable) QualifiedName() string { return it.Schema + "." + it.Name } +func (it *InternalTable) UpdateStmt(keyColumns []string, valueColumns []string) string { + var b strings.Builder + b.Grow(128) + b.WriteString("UPDATE ") + b.WriteString(it.QualifiedName()) + b.WriteString(" SET " + valueColumns[0] + " = ?") + + for _, valueColumn := range valueColumns[1:] { + b.WriteString(", ") + b.WriteString(valueColumn) + b.WriteString(" = ?") + } + + b.WriteString(" WHERE " + keyColumns[0] + " = ?") + for _, keyColumn := range keyColumns[1:] { + b.WriteString(", ") + b.WriteString(keyColumn) + b.WriteString(" = ?") + } + + return b.String() +} + func (it *InternalTable) UpsertStmt() string { var b strings.Builder b.Grow(128) @@ -84,6 +107,30 @@ func (it *InternalTable) SelectStmt() string { return b.String() } +func (it *InternalTable) SelectColumnsStmt(valueColumns []string) string { + var b strings.Builder + b.Grow(128) + b.WriteString("SELECT ") + b.WriteString(valueColumns[0]) + for _, c := range valueColumns[1:] { + b.WriteString(", ") + b.WriteString(c) + } + b.WriteString(" FROM ") + b.WriteString(it.Schema) + b.WriteByte('.') + b.WriteString(it.Name) + b.WriteString(" WHERE ") + b.WriteString(it.KeyColumns[0]) + b.WriteString(" = ?") + for _, c := range it.KeyColumns[1:] { + b.WriteString(" AND ") + b.WriteString(c) + b.WriteString(" = ?") + } + return b.String() +} + func (it *InternalTable) SelectAllStmt() string { var b strings.Builder b.Grow(128) @@ -108,7 +155,6 @@ func (it *InternalTable) CountAllStmt() string { var InternalTables = struct { PersistentVariable InternalTable BinlogPosition InternalTable - PgReplicationLSN InternalTable PgSubscription InternalTable GlobalStatus InternalTable // TODO(sean): This is a temporary work around for clients that query the 'pg_catalog.pg_stat_replication'. @@ -130,19 +176,12 @@ var InternalTables = struct { ValueColumns: []string{"position"}, DDL: "channel TEXT PRIMARY KEY, position TEXT", }, - PgReplicationLSN: InternalTable{ - Schema: "__sys__", - Name: "pg_replication_lsn", - KeyColumns: []string{"slot_name"}, - ValueColumns: []string{"lsn"}, - DDL: "slot_name TEXT PRIMARY KEY, lsn TEXT", - }, PgSubscription: InternalTable{ Schema: "__sys__", Name: "pg_subscription", - KeyColumns: []string{"name"}, - ValueColumns: []string{"connection", "publication"}, - DDL: "name TEXT PRIMARY KEY, connection TEXT, publication TEXT", + KeyColumns: []string{"subname"}, + ValueColumns: []string{"subconninfo", "subpublication", "subenabled", "subskiplsn"}, + DDL: "subname TEXT PRIMARY KEY, subconninfo TEXT, subpublication TEXT, subenabled BOOLEAN, subskiplsn TEXT", }, GlobalStatus: InternalTable{ Schema: "performance_schema", @@ -237,7 +276,6 @@ var InternalTables = struct { var internalTables = []InternalTable{ InternalTables.PersistentVariable, InternalTables.BinlogPosition, - InternalTables.PgReplicationLSN, InternalTables.PgSubscription, InternalTables.GlobalStatus, InternalTables.PGStatReplication, diff --git a/main.go b/main.go index 745aa058..17a52295 100644 --- a/main.go +++ b/main.go @@ -182,13 +182,9 @@ func main() { } // Check if there is a replication subscription and start replication if there is. - subscriptions, err := logrepl.GetAllSubscriptions(pgServer.NewInternalCtx()) + err = logrepl.UpdateSubscriptions(pgServer.NewInternalCtx()) if err != nil { - logrus.WithError(err).Warnln("Failed to find replication") - } else { - for _, subscription := range subscriptions { - go subscription.Replicator.StartReplication(pgServer.NewInternalCtx(), subscription.Publication) - } + logrus.WithError(err).Warnln("Failed to update subscriptions") } // Load the configuration for the Postgres server. diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index c3c330f4..2051e03f 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -1059,7 +1059,7 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error) statement = strings.ToLower(statement) // Command: \l if statement == "select d.datname as \"name\",\n pg_catalog.pg_get_userbyid(d.datdba) as \"owner\",\n pg_catalog.pg_encoding_to_char(d.encoding) as \"encoding\",\n d.datcollate as \"collate\",\n d.datctype as \"ctype\",\n d.daticulocale as \"icu locale\",\n case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as \"locale provider\",\n pg_catalog.array_to_string(d.datacl, e'\\n') as \"access privileges\"\nfrom pg_catalog.pg_database d\norder by 1;" { - query, err := h.convertQuery(`select d.datname as "Name", 'postgres' as "Owner", 'UTF8' as "Encoding", 'en_US.UTF-8' as "Collate", 'en_US.UTF-8' as "Ctype", 'en-US' as "ICU Locale", case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as "locale provider", '' as "access privileges" from pg_catalog.pg_database d order by 1;`) + query, err := h.convertQuery(`select d.datname as "Subscription", 'postgres' as "Owner", 'UTF8' as "Encoding", 'en_US.UTF-8' as "Collate", 'en_US.UTF-8' as "Ctype", 'en-US' as "ICU Locale", case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as "locale provider", '' as "access privileges" from pg_catalog.pg_database d order by 1;`) if err != nil { return false, err } @@ -1067,7 +1067,7 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error) } // Command: \l on psql 16 if statement == "select\n d.datname as \"name\",\n pg_catalog.pg_get_userbyid(d.datdba) as \"owner\",\n pg_catalog.pg_encoding_to_char(d.encoding) as \"encoding\",\n case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as \"locale provider\",\n d.datcollate as \"collate\",\n d.datctype as \"ctype\",\n d.daticulocale as \"icu locale\",\n null as \"icu rules\",\n pg_catalog.array_to_string(d.datacl, e'\\n') as \"access privileges\"\nfrom pg_catalog.pg_database d\norder by 1;" { - query, err := h.convertQuery(`select d.datname as "Name", 'postgres' as "Owner", 'UTF8' as "Encoding", 'en_US.UTF-8' as "Collate", 'en_US.UTF-8' as "Ctype", 'en-US' as "ICU Locale", case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as "locale provider", '' as "access privileges" from pg_catalog.pg_database d order by 1;`) + query, err := h.convertQuery(`select d.datname as "Subscription", 'postgres' as "Owner", 'UTF8' as "Encoding", 'en_US.UTF-8' as "Collate", 'en_US.UTF-8' as "Ctype", 'en-US' as "ICU Locale", case d.datlocprovider when 'c' then 'libc' when 'i' then 'icu' end as "locale provider", '' as "access privileges" from pg_catalog.pg_database d order by 1;`) if err != nil { return false, err } @@ -1076,21 +1076,21 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error) // Command: \dt if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { return true, h.query(ConvertedQuery{ - String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", 'table' AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' ORDER BY 2;`, + String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Subscription", 'table' AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' ORDER BY 2;`, StatementTag: "SELECT", }) } // Command: \d if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','v','m','s','f','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { return true, h.query(ConvertedQuery{ - String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", IF(TABLE_TYPE = 'VIEW', 'view', 'table') AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW' ORDER BY 2;`, + String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Subscription", IF(TABLE_TYPE = 'VIEW', 'view', 'table') AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW' ORDER BY 2;`, StatementTag: "SELECT", }) } // Alternate \d for psql 14 if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 's' then 'special' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\n left join pg_catalog.pg_am am on am.oid = c.relam\nwhere c.relkind in ('r','p','v','m','s','f','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { return true, h.query(ConvertedQuery{ - String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", IF(TABLE_TYPE = 'VIEW', 'view', 'table') AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW' ORDER BY 2;`, + String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Subscription", IF(TABLE_TYPE = 'VIEW', 'view', 'table') AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'BASE TABLE' OR TABLE_TYPE = 'VIEW' ORDER BY 2;`, StatementTag: "SELECT", }) } @@ -1103,21 +1103,21 @@ func (h *ConnectionHandler) handledPSQLCommands(statement string) (bool, error) // Command: \dn if statement == "select n.nspname as \"name\",\n pg_catalog.pg_get_userbyid(n.nspowner) as \"owner\"\nfrom pg_catalog.pg_namespace n\nwhere n.nspname !~ '^pg_' and n.nspname <> 'information_schema'\norder by 1;" { return true, h.query(ConvertedQuery{ - String: `SELECT 'public' AS "Name", 'pg_database_owner' AS "Owner";`, + String: `SELECT 'public' AS "Subscription", 'pg_database_owner' AS "Owner";`, StatementTag: "SELECT", }) } // Command: \df if statement == "select n.nspname as \"schema\",\n p.proname as \"name\",\n pg_catalog.pg_get_function_result(p.oid) as \"result data type\",\n pg_catalog.pg_get_function_arguments(p.oid) as \"argument data types\",\n case p.prokind\n when 'a' then 'agg'\n when 'w' then 'window'\n when 'p' then 'proc'\n else 'func'\n end as \"type\"\nfrom pg_catalog.pg_proc p\n left join pg_catalog.pg_namespace n on n.oid = p.pronamespace\nwhere pg_catalog.pg_function_is_visible(p.oid)\n and n.nspname <> 'pg_catalog'\n and n.nspname <> 'information_schema'\norder by 1, 2, 4;" { return true, h.query(ConvertedQuery{ - String: `SELECT '' AS "Schema", '' AS "Name", '' AS "Result data type", '' AS "Argument data types", '' AS "Type" LIMIT 0;`, + String: `SELECT '' AS "Schema", '' AS "Subscription", '' AS "Result data type", '' AS "Argument data types", '' AS "Type" LIMIT 0;`, StatementTag: "SELECT", }) } // Command: \dv if statement == "select n.nspname as \"schema\",\n c.relname as \"name\",\n case c.relkind when 'r' then 'table' when 'v' then 'view' when 'm' then 'materialized view' when 'i' then 'index' when 's' then 'sequence' when 't' then 'toast table' when 'f' then 'foreign table' when 'p' then 'partitioned table' when 'i' then 'partitioned index' end as \"type\",\n pg_catalog.pg_get_userbyid(c.relowner) as \"owner\"\nfrom pg_catalog.pg_class c\n left join pg_catalog.pg_namespace n on n.oid = c.relnamespace\nwhere c.relkind in ('v','')\n and n.nspname <> 'pg_catalog'\n and n.nspname !~ '^pg_toast'\n and n.nspname <> 'information_schema'\n and pg_catalog.pg_table_is_visible(c.oid)\norder by 1,2;" { return true, h.query(ConvertedQuery{ - String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Name", 'view' AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'VIEW' ORDER BY 2;`, + String: `SELECT table_schema AS "Schema", TABLE_NAME AS "Subscription", 'view' AS "Type", 'postgres' AS "Owner" FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA <> 'pg_catalog' AND TABLE_SCHEMA <> 'information_schema' AND TABLE_TYPE = 'VIEW' ORDER BY 2;`, StatementTag: "SELECT", }) } diff --git a/pgserver/logrepl/replication.go b/pgserver/logrepl/replication.go index afc1d1bd..3c2c1c93 100644 --- a/pgserver/logrepl/replication.go +++ b/pgserver/logrepl/replication.go @@ -26,7 +26,6 @@ import ( "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/binlog" - "github.com/apecloud/myduckserver/catalog" "github.com/apecloud/myduckserver/delta" "github.com/apecloud/myduckserver/pgtypes" "github.com/dolthub/go-mysql-server/sql" @@ -222,7 +221,7 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin standbyMessageTimeout := 10 * time.Second nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout) - lastWrittenLsn, err := r.readWALPosition(sqlCtx, slotName) + lastWrittenLsn, err := SelectSubscriptionLsn(sqlCtx, slotName) if err != nil { return err } @@ -881,26 +880,6 @@ func (r *LogicalReplicator) processMessage( return false, nil } -// readWALPosition reads the recorded WAL position from the WAL position table -func (r *LogicalReplicator) readWALPosition(ctx *sql.Context, slotName string) (pglogrepl.LSN, error) { - var lsn string - if err := adapter.QueryRowCatalog(ctx, catalog.InternalTables.PgReplicationLSN.SelectStmt(), slotName).Scan(&lsn); err != nil { - if errors.Is(err, stdsql.ErrNoRows) { - // if the LSN doesn't exist, consider this a cold start and return 0 - return pglogrepl.LSN(0), nil - } - return 0, err - } - - return pglogrepl.ParseLSN(lsn) -} - -// WriteWALPosition writes the recorded WAL position to the WAL position table -func (r *LogicalReplicator) WriteWALPosition(ctx *sql.Context, slotName string, lsn pglogrepl.LSN) error { - _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgReplicationLSN.UpsertStmt(), slotName, lsn.String()) - return err -} - // whereClause returns a WHERE clause string with the contents of the builder if it's non-empty, or the empty // string otherwise func whereClause(str strings.Builder) string { @@ -1000,7 +979,7 @@ func (r *LogicalReplicator) commitOngoingTxn(state *replicationState, flushReaso } r.logger.Debugf("Writing LSN %s\n", state.lastCommitLSN) - if err = r.WriteWALPosition(state.replicaCtx, state.slotName, state.lastCommitLSN); err != nil { + if err = UpdateSubscriptionLsn(state.replicaCtx, state.slotName, state.lastCommitLSN.String()); err != nil { return err } diff --git a/pgserver/logrepl/subscription.go b/pgserver/logrepl/subscription.go index 2e268cd9..63367bff 100644 --- a/pgserver/logrepl/subscription.go +++ b/pgserver/logrepl/subscription.go @@ -1,39 +1,31 @@ package logrepl import ( + stdsql "database/sql" + "errors" "fmt" "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/catalog" "github.com/dolthub/go-mysql-server/sql" + "github.com/jackc/pglogrepl" "sync" ) type Subscription struct { - Name string - Conn string - Publication string - Replicator *LogicalReplicator + Subscription string + Conn string + Publication string + Lsn pglogrepl.LSN + Enabled bool + Replicator *LogicalReplicator } -var subscriptionMap = sync.Map{} -var createMutex sync.Mutex -var deleteMutex sync.Mutex - -func GetAllSubscriptions(ctx *sql.Context) ([]*Subscription, error) { - if err := loadAllSubscriptions(ctx); err != nil { - return nil, err - } - - var subscriptions []*Subscription - subscriptionMap.Range(func(key, value interface{}) bool { - if sub, ok := value.(*Subscription); ok { - subscriptions = append(subscriptions, sub) - } - return true - }) +var keyColumns = []string{"subname"} +var statusValueColumns = []string{"substatus"} +var lsnValueColumns = []string{"subskiplsn"} - return subscriptions, nil -} +var subscriptionMap = sync.Map{} +var mu sync.Mutex func GetSubscription(ctx *sql.Context, name string) (*Subscription, error) { if value, ok := subscriptionMap.Load(name); ok { @@ -43,7 +35,7 @@ func GetSubscription(ctx *sql.Context, name string) (*Subscription, error) { } // Attempt to reload all subscriptions if not found - if err := loadAllSubscriptions(ctx); err != nil { + if err := UpdateSubscriptions(ctx); err != nil { return nil, err } @@ -56,92 +48,106 @@ func GetSubscription(ctx *sql.Context, name string) (*Subscription, error) { return nil, nil } -func CreateSubscription(ctx *sql.Context, name string, conn string, publication string) (*Subscription, error) { - createMutex.Lock() - defer createMutex.Unlock() - - subscription, err := GetSubscription(ctx, name) - if err != nil { - return nil, err - } - - if subscription != nil { - return nil, fmt.Errorf("subscription %s already exists", name) - } - - replicator, err := NewLogicalReplicator(conn) - if err != nil { - return nil, fmt.Errorf("failed to create logical replicator: %v", err) - } - - subscription = &Subscription{ - Name: name, - Conn: conn, - Publication: publication, - Replicator: replicator, - } - - subscriptionMap.Store(name, subscription) - - return subscription, nil -} - -func DeleteSubscription(name string) (*Subscription, error) { - deleteMutex.Lock() - defer deleteMutex.Unlock() - - val, loaded := subscriptionMap.LoadAndDelete(name) - if !loaded { - return nil, fmt.Errorf("subscription %s does not exist", name) - } - - return val.(*Subscription), nil -} - -func loadAllSubscriptions(ctx *sql.Context) error { +func UpdateSubscriptions(ctx *sql.Context) error { + mu.Lock() + defer mu.Unlock() rows, err := adapter.QueryCatalog(ctx, catalog.InternalTables.PgSubscription.SelectAllStmt()) if err != nil { return err } defer rows.Close() + var tempMap = make(map[string]*Subscription) for rows.Next() { var name, conn, pub string - if err := rows.Scan(&name, &conn, &pub); err != nil { + var enabled bool + if err := rows.Scan(&name, &conn, &pub, &enabled); err != nil { return err } - if _, loaded := subscriptionMap.LoadOrStore(name, &Subscription{ - Name: name, - Conn: conn, - Publication: pub, - Replicator: nil, - }); !loaded { - replicator, err := NewLogicalReplicator(conn) + tempMap[name] = &Subscription{ + Subscription: name, + Conn: conn, + Publication: pub, + Enabled: enabled, + Replicator: nil, + } + } + + for tempName, tempSub := range tempMap { + if _, loaded := subscriptionMap.LoadOrStore(tempName, tempSub); !loaded { + replicator, err := NewLogicalReplicator(tempSub.Conn) if err != nil { - return err + return fmt.Errorf("failed to create logical replicator: %v", err) } - if sub, ok := subscriptionMap.Load(name); ok { + + if sub, ok := subscriptionMap.Load(tempName); ok { if subscription, ok := sub.(*Subscription); ok { subscription.Replicator = replicator } } + + err = replicator.CreateReplicationSlotIfNotExists(tempSub.Publication) + if err != nil { + return fmt.Errorf("failed to create replication slot: %v", err) + } + } else { + if sub, ok := subscriptionMap.Load(tempSub); ok { + if subscription, ok := sub.(*Subscription); ok { + if tempSub.Enabled != subscription.Enabled { + subscription.Enabled = tempSub.Enabled + if subscription.Enabled { + go subscription.Replicator.StartReplication(ctx, subscription.Publication) + } else { + subscription.Replicator.Stop() + } + } + } + } } } + subscriptionMap.Range(func(key, value interface{}) bool { + name, _ := key.(string) + subscription, _ := value.(*Subscription) + if _, ok := tempMap[name]; !ok { + subscription.Replicator.Stop() + subscriptionMap.Delete(name) + } + return true + }) + return nil } -func WriteSubscriptionIntoTable(ctx *sql.Context, name, conn, pub string) error { - _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgSubscription.UpsertStmt(), name, conn, pub) +func CreateSubscription(ctx *sql.Context, name, conn, pub, lsn string, enabled bool) error { + _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgSubscription.UpsertStmt(), name, conn, pub, lsn, enabled) return err } -func DeleteSubscriptionFromTable(ctx *sql.Context, name string) error { +func UpdateSubscriptionStatus(ctx *sql.Context, name string, enabled bool) error { + _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgSubscription.UpdateStmt(keyColumns, statusValueColumns), name, enabled) + return err +} + +func DeleteSubscription(ctx *sql.Context, name string) error { _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgSubscription.DeleteStmt(), name) return err } -func DeleteAllSubscriptions(ctx *sql.Context) error { - _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgSubscription.DeleteAllStmt()) +func UpdateSubscriptionLsn(ctx *sql.Context, name, lsn string) error { + _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgSubscription.UpdateStmt(keyColumns, lsnValueColumns), name, lsn) return err } + +func SelectSubscriptionLsn(ctx *sql.Context, subscription string) (pglogrepl.LSN, error) { + var lsn string + if err := adapter.QueryRowCatalog(ctx, catalog.InternalTables.PgSubscription.SelectColumnsStmt(lsnValueColumns), subscription).Scan(&lsn); err != nil { + if errors.Is(err, stdsql.ErrNoRows) { + // if the LSN doesn't exist, consider this a cold start and return 0 + return pglogrepl.LSN(0), nil + } + return 0, err + } + + return pglogrepl.ParseLSN(lsn) +} diff --git a/pgserver/pg_catalog_handler.go b/pgserver/pg_catalog_handler.go index 71cdd83c..f3d0dcb6 100644 --- a/pgserver/pg_catalog_handler.go +++ b/pgserver/pg_catalog_handler.go @@ -11,7 +11,7 @@ import ( "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/catalog" duckConfig "github.com/apecloud/myduckserver/configuration" - tree "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" "github.com/dolthub/go-mysql-server/sql" "github.com/jackc/pgx/v5/pgproto3" ) @@ -36,7 +36,7 @@ func (h *ConnectionHandler) isInRecovery() (string, error) { return "f", err } var count int - if err := adapter.QueryRow(ctx, catalog.InternalTables.PgReplicationLSN.CountAllStmt()).Scan(&count); err != nil { + if err := adapter.QueryRow(ctx, catalog.InternalTables.PgSubscription.CountAllStmt()).Scan(&count); err != nil { return "f", err } @@ -54,9 +54,12 @@ func (h *ConnectionHandler) readOneWALPositionStr() (string, error) { if err != nil { return "0/0", err } - var slotName string - var lsn string - if err := adapter.QueryRow(ctx, catalog.InternalTables.PgReplicationLSN.SelectAllStmt()).Scan(&slotName, &lsn); err != nil { + + // TODO(neo.zty): needs to be fixed + var subscription, conn, publication, lsn string + var enabled bool + + if err := adapter.QueryRow(ctx, catalog.InternalTables.PgSubscription.SelectAllStmt()).Scan(&subscription, &conn, &publication, &lsn, &enabled); err != nil { if errors.Is(err, stdsql.ErrNoRows) { // if no lsn is stored, return 0 return "0/0", nil diff --git a/pgserver/subscription_handler.go b/pgserver/subscription_handler.go index 722f7146..5e8e57d9 100644 --- a/pgserver/subscription_handler.go +++ b/pgserver/subscription_handler.go @@ -181,23 +181,12 @@ func (h *ConnectionHandler) executeAlterEnable(subscriptionConfig *SubscriptionC return fmt.Errorf("failed to create context for query: %w", err) } - subscription, err := logrepl.GetSubscription(sqlCtx, subscriptionConfig.SubscriptionName) + err = logrepl.UpdateSubscriptionStatus(sqlCtx, subscriptionConfig.SubscriptionName, true) if err != nil { - return fmt.Errorf("failed to get subscription: %w", err) - } - - if subscription == nil { - return fmt.Errorf("subscription not found: %s", subscriptionConfig.SubscriptionName) - } - - if subscription.Replicator.Running() { - return fmt.Errorf("subscription is already enabled: %s", subscriptionConfig.SubscriptionName) + return fmt.Errorf("failed to delete subscription: %w", err) } - go subscription.Replicator.StartReplication(sqlCtx, subscription.Publication) - - return nil - + return commitAndUpdate(sqlCtx) } func (h *ConnectionHandler) executeAlterDisable(subscriptionConfig *SubscriptionConfig) error { @@ -206,18 +195,12 @@ func (h *ConnectionHandler) executeAlterDisable(subscriptionConfig *Subscription return fmt.Errorf("failed to create context for query: %w", err) } - subscription, err := logrepl.GetSubscription(sqlCtx, subscriptionConfig.SubscriptionName) - + err = logrepl.UpdateSubscriptionStatus(sqlCtx, subscriptionConfig.SubscriptionName, false) if err != nil { - return fmt.Errorf("failed to get subscription: %w", err) - } - - if subscription == nil { - return fmt.Errorf("subscription not found: %s", subscriptionConfig.SubscriptionName) + return fmt.Errorf("failed to delete subscription: %w", err) } - subscription.Replicator.Stop() - return nil + return commitAndUpdate(sqlCtx) } func (h *ConnectionHandler) executeDrop(subscriptionConfig *SubscriptionConfig) error { @@ -226,26 +209,12 @@ func (h *ConnectionHandler) executeDrop(subscriptionConfig *SubscriptionConfig) return fmt.Errorf("failed to create context for query: %w", err) } - subscription, err := logrepl.DeleteSubscription(subscriptionConfig.SubscriptionName) + err = logrepl.DeleteSubscription(sqlCtx, subscriptionConfig.SubscriptionName) if err != nil { return fmt.Errorf("failed to delete subscription: %w", err) } - subscription.Replicator.Stop() - - if err := logrepl.DeleteSubscriptionFromTable(sqlCtx, subscriptionConfig.SubscriptionName); err != nil { - return fmt.Errorf("failed to delete subscription from table: %w", err) - } - - tx := adapter.TryGetTxn(sqlCtx) - if tx != nil { - if err := tx.Commit(); err != nil { - return fmt.Errorf("failed to commit transaction: %w", err) - } - adapter.CloseTxn(sqlCtx) - } - - return nil + return commitAndUpdate(sqlCtx) } func (h *ConnectionHandler) executeCreate(subscriptionConfig *SubscriptionConfig) error { @@ -259,13 +228,11 @@ func (h *ConnectionHandler) executeCreate(subscriptionConfig *SubscriptionConfig return fmt.Errorf("failed to create snapshot for CREATE SUBSCRIPTION: %w", err) } - replicator, err := h.doCreateSubscription(sqlCtx, subscriptionConfig, lsn) + err = h.doCreateSubscription(sqlCtx, subscriptionConfig, lsn) if err != nil { return fmt.Errorf("failed to execute CREATE SUBSCRIPTION: %w", err) } - go replicator.StartReplication(sqlCtx, subscriptionConfig.PublicationName) - return nil } @@ -366,48 +333,40 @@ func (h *ConnectionHandler) doSnapshot(sqlCtx *sql.Context, subscriptionConfig * return lsn, txn.Commit() } -func (h *ConnectionHandler) doCreateSubscription(sqlCtx *sql.Context, subscriptionConfig *SubscriptionConfig, lsn pglogrepl.LSN) (*logrepl.LogicalReplicator, error) { - - subscription, err := logrepl.CreateSubscription( - sqlCtx, - subscriptionConfig.SubscriptionName, - subscriptionConfig.ToDNS(), - subscriptionConfig.PublicationName) - - if err != nil { - return nil, fmt.Errorf("failed to create subscription: %w", err) - } - - replicator := subscription.Replicator - - err = logrepl.CreatePublicationIfNotExists(subscriptionConfig.ToDNS(), subscriptionConfig.PublicationName) - if err != nil { - return nil, fmt.Errorf("failed to create publication: %w", err) - } - - err = replicator.CreateReplicationSlotIfNotExists(subscriptionConfig.PublicationName) +func (h *ConnectionHandler) doCreateSubscription(sqlCtx *sql.Context, subscriptionConfig *SubscriptionConfig, lsn pglogrepl.LSN) error { + err := logrepl.CreatePublicationIfNotExists(subscriptionConfig.ToDNS(), subscriptionConfig.PublicationName) if err != nil { - return nil, fmt.Errorf("failed to create replication slot: %w", err) + return fmt.Errorf("failed to create publication: %w", err) } - // `WriteWALPosition` and `WriteSubscriptionIntoTable` execute in a transaction internally, - // so we start a transaction here and commit it after writing the WAL position. tx, err := adapter.GetCatalogTxn(sqlCtx, nil) if err != nil { - return nil, fmt.Errorf("failed to get transaction: %w", err) + return fmt.Errorf("failed to get transaction: %w", err) } defer tx.Rollback() defer adapter.CloseTxn(sqlCtx) - err = replicator.WriteWALPosition(sqlCtx, subscriptionConfig.PublicationName, lsn) + err = logrepl.CreateSubscription(sqlCtx, subscriptionConfig.SubscriptionName, subscriptionConfig.ToDNS(), subscriptionConfig.PublicationName, lsn.String(), true) if err != nil { - return nil, fmt.Errorf("failed to write WAL position: %w", err) + return fmt.Errorf("failed to write subscription: %w", err) } - err = logrepl.WriteSubscriptionIntoTable(sqlCtx, subscriptionConfig.SubscriptionName, subscriptionConfig.ToDNS(), subscriptionConfig.PublicationName) + return commitAndUpdate(sqlCtx) +} + +func commitAndUpdate(sqlCtx *sql.Context) error { + tx := adapter.TryGetTxn(sqlCtx) + if tx != nil { + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + adapter.CloseTxn(sqlCtx) + } + + err := logrepl.UpdateSubscriptions(sqlCtx) if err != nil { - return nil, fmt.Errorf("failed to write subscription: %w", err) + return fmt.Errorf("failed to update subscriptions: %w", err) } - return replicator, tx.Commit() + return nil }