Skip to content

Commit

Permalink
[WIP] replica: Refactor transaction handling
Browse files Browse the repository at this point in the history
  • Loading branch information
fanyang01 committed Oct 8, 2024
1 parent 868caff commit 4e13bca
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 177 deletions.
5 changes: 5 additions & 0 deletions adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@ import (

type ConnectionHolder interface {
GetConn(ctx context.Context) (*stdsql.Conn, error)
GetTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error)
GetCatalogConn(ctx context.Context) (*stdsql.Conn, error)
}

func GetConn(ctx *sql.Context) (*stdsql.Conn, error) {
return ctx.Session.(ConnectionHolder).GetConn(ctx)
}

func GetTxn(ctx *sql.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
return ctx.Session.(ConnectionHolder).GetTxn(ctx, options)
}

func QueryContext(ctx *sql.Context, query string, args ...any) (*stdsql.Rows, error) {
conn, err := GetConn(ctx)
if err != nil {
Expand Down
16 changes: 7 additions & 9 deletions backend/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (
"github.com/dolthub/vitess/go/mysql"
)

const PersistentVariableTableName = "main.persistent_variables"

type Session struct {
*memory.Session
db *catalog.DatabaseProvider
Expand All @@ -39,11 +41,6 @@ func NewSession(base *memory.Session, provider *catalog.DatabaseProvider, pool *

// NewSessionBuilder returns a session builder for the given database provider.
func NewSessionBuilder(provider *catalog.DatabaseProvider, pool *ConnectionPool) func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
_, err := pool.Exec("CREATE TABLE IF NOT EXISTS main.persistent_variables (name TEXT PRIMARY KEY, value TEXT, type TEXT)")
if err != nil {
panic(err)
}

return func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
host := ""
user := ""
Expand Down Expand Up @@ -137,7 +134,7 @@ func (sess Session) PersistGlobal(sysVarName string, value interface{}) error {
}
_, err := sess.ExecContext(
context.Background(),
"INSERT OR REPLACE INTO main.persistent_variables (name, value, vtype) VALUES (?, ?, ?)",
catalog.InternalTables.PersistentVariable.UpsertStmt(),
sysVarName, value, fmt.Sprintf("%T", value),
)
return err
Expand All @@ -147,15 +144,15 @@ func (sess Session) PersistGlobal(sysVarName string, value interface{}) error {
func (sess Session) RemovePersistedGlobal(sysVarName string) error {
_, err := sess.ExecContext(
context.Background(),
"DELETE FROM main.persistent_variables WHERE name = ?",
catalog.InternalTables.PersistentVariable.DeleteStmt(),
sysVarName,
)
return err
}

// RemoveAllPersistedGlobals implements sql.PersistableSession.
func (sess Session) RemoveAllPersistedGlobals() error {
_, err := sess.ExecContext(context.Background(), "DELETE FROM main.persistent_variables")
_, err := sess.ExecContext(context.Background(), "DELETE FROM "+catalog.InternalTables.PersistentVariable.Name)
return err
}

Expand All @@ -164,7 +161,8 @@ func (sess Session) GetPersistedValue(k string) (interface{}, error) {
var value, vtype string
err := sess.QueryRow(
context.Background(),
"SELECT value, vtype FROM main.persistent_variables WHERE name = ?", k,
catalog.InternalTables.PersistentVariable.SelectStmt(),
k,
).Scan(&value, &vtype)
switch {
case err == stdsql.ErrNoRows:
Expand Down
95 changes: 42 additions & 53 deletions binlogreplication/binlog_replica_applier.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"sync/atomic"
"time"

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/binlog"
"github.com/apecloud/myduckserver/charset"
"github.com/apecloud/myduckserver/delta"
Expand Down Expand Up @@ -65,6 +66,7 @@ type binlogReplicaApplier struct {
running atomic.Bool
engine *gms.Engine
tableWriterProvider TableWriterProvider
ongoingTxn atomic.Bool // true if we're in a transaction
inTxnStmtID atomic.Uint64 // auto-incrementing ID for statements within a transaction
deltaBufSize atomic.Uint64 // size of the delta buffer
}
Expand Down Expand Up @@ -342,8 +344,6 @@ func (a *binlogReplicaApplier) replicaBinlogEventHandler(ctx *sql.Context) error
// processing it.
func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms.Engine, event mysql.BinlogEvent) error {
var err error
createCommit := false
commitToAllDatabases := false

// We don't support checksum validation, so we MUST strip off any checksum bytes if present, otherwise it gets
// interpreted as part of the payload and corrupts the data. Future checksum sizes, are not guaranteed to be the
Expand Down Expand Up @@ -372,8 +372,7 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms.
// An XID event is generated for a COMMIT of a transaction that modifies one or more tables of an
// XA-capable storage engine. For more details, see: https://mariadb.com/kb/en/xid_event/
ctx.GetLogger().Trace("Received binlog event: XID")
createCommit = true
commitToAllDatabases = true
return a.commit(ctx, engine)

case event.IsQuery():
// A Query event represents a statement executed on the source server that should be executed on the
Expand All @@ -395,13 +394,6 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms.
"sql_mode": fmt.Sprintf("0x%x", mode),
}).Infoln("Received binlog event: Query")

// When executing SQL statements sent from the primary, we can't be sure what database was modified unless we
// look closely at the statement. For example, we could be connected to db01, but executed
// "create table db02.t (...);" – i.e., looking at query.Database is NOT enough to always determine the correct
// database that was modified, so instead, we commit to all databases when we see a Query binlog event to
// avoid issues with correctness, at the cost of being slightly less efficient
commitToAllDatabases = true

if flags&doltvtmysql.QFlagOptionAutoIsNull > 0 {
ctx.GetLogger().Tracef("Setting sql_auto_is_null ON")
ctx.SetSessionVariable(ctx, "sql_auto_is_null", 1)
Expand Down Expand Up @@ -435,11 +427,8 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms.
ctx.SetSessionVariable(ctx, "unique_checks", 1)
}

createCommit = !strings.EqualFold(query.SQL, "begin")
// TODO(fan): Here we
// skip the transaction for now;
// skip the operations on `mysql.time_zone*` tables, which are not supported by go-mysql-server yet.
if createCommit && !(query.Database == "mysql" && strings.HasPrefix(query.SQL, "TRUNCATE TABLE time_zone")) {
// TODO(fan): Here we skip the operations on `mysql.time_zone*` tables, which are not supported by go-mysql-server yet.
if !(query.Database == "mysql" && strings.HasPrefix(query.SQL, "TRUNCATE TABLE time_zone")) {
ctx.SetCurrentDatabase(query.Database)
if err := a.executeQueryWithEngine(ctx, engine, query.SQL); err != nil {
ctx.GetLogger().WithFields(logrus.Fields{
Expand All @@ -450,6 +439,10 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms.
MyBinlogReplicaController.setSqlError(sqlerror.ERUnknownError, msg)
}
}

if strings.EqualFold(query.SQL, "begin") {
a.ongoingTxn.Store(true)
}
a.inTxnStmtID.Add(1)

case event.IsRotate():
Expand Down Expand Up @@ -576,39 +569,31 @@ func (a *binlogReplicaApplier) processBinlogEvent(ctx *sql.Context, engine *gms.
}
}

if createCommit {
// TODO(fan): Skip the transaction commit for now
_ = commitToAllDatabases
// var databasesToCommit []string
// if commitToAllDatabases {
// databasesToCommit = getAllUserDatabaseNames(ctx, engine)
// for _, database := range databasesToCommit {
// executeQueryWithEngine(ctx, engine, "use `"+database+"`;")
// executeQueryWithEngine(ctx, engine, "commit;")
// }
// }

// Record the last GTID processed after the commit
a.currentPosition.GTIDSet = a.currentPosition.GTIDSet.AddGTID(a.currentGtid)
err := sql.SystemVariables.AssignValues(map[string]interface{}{"gtid_executed": a.currentPosition.GTIDSet.String()})
if err != nil {
ctx.GetLogger().Errorf("unable to set @@GLOBAL.gtid_executed: %s", err.Error())
}
err = positionStore.Save(ctx, engine, a.currentPosition)
if err != nil {
return fmt.Errorf("unable to store GTID executed metadata to disk: %s", err.Error())
}
return nil
}

// Reset the statement ID after a commit
a.inTxnStmtID.Store(0)
func (a *binlogReplicaApplier) commit(ctx *sql.Context, engine *gms.Engine) error {
a.executeQueryWithEngine(ctx, engine, "commit")

// Flush the delta buffer if it's grown too large
// TODO(fan): Make the threshold configurable
if a.deltaBufSize.Load() > (64 << 20) { // 64MB
return a.flushDeltaBuffer(ctx, delta.MemoryLimitFlushReason)
}
// Record the last GTID processed after the commit
a.currentPosition.GTIDSet = a.currentPosition.GTIDSet.AddGTID(a.currentGtid)
err := sql.SystemVariables.AssignValues(map[string]interface{}{"gtid_executed": a.currentPosition.GTIDSet.String()})
if err != nil {
ctx.GetLogger().Errorf("unable to set @@GLOBAL.gtid_executed: %s", err.Error())
}
err = positionStore.Save(ctx, engine, a.currentPosition)
if err != nil {
return fmt.Errorf("unable to store GTID executed metadata to disk: %s", err.Error())
}

// Reset the statement ID after a commit
a.inTxnStmtID.Store(0)

// Flush the delta buffer if it's grown too large
// TODO(fan): Make the threshold configurable
if a.deltaBufSize.Load() > (64 << 20) { // 64MB
return a.flushDeltaBuffer(ctx, delta.MemoryLimitFlushReason)
}
return nil
}

Expand Down Expand Up @@ -695,7 +680,7 @@ func (a *binlogReplicaApplier) processRowEvent(ctx *sql.Context, event mysql.Bin

if isRowFormat && len(pkSchema.PkOrdinals) > 0 {
// --binlog-format=ROW & --binlog-row-image=full
return a.appendRowFormatChanges(ctx, engine, tableMap, tableName, schema, eventType, &rows)
return a.appendRowFormatChanges(ctx, tableMap, tableName, schema, eventType, &rows)
} else {
return a.writeChanges(ctx, engine, tableMap, tableName, pkSchema, eventType, &rows, foreignKeyChecksDisabled)
}
Expand Down Expand Up @@ -734,8 +719,13 @@ func (a *binlogReplicaApplier) writeChanges(
dataRows = append(dataRows, dataRow)
}

txn, err := adapter.GetTxn(ctx, nil)
if err != nil {
return err
}
tableWriter, err := a.tableWriterProvider.GetTableWriter(
ctx, engine,
ctx,
txn,
tableMap.Database, tableName,
pkSchema,
len(tableMap.Types), len(rows.Rows),
Expand All @@ -746,7 +736,6 @@ func (a *binlogReplicaApplier) writeChanges(
if err != nil {
return err
}
defer tableWriter.Rollback()

switch event {
case binlog.DeleteRowEvent:
Expand All @@ -767,15 +756,15 @@ func (a *binlogReplicaApplier) writeChanges(
"rows": len(rows.Rows),
}).Infoln("processRowEvent")

return tableWriter.Commit()
return nil
}

func (a *binlogReplicaApplier) appendRowFormatChanges(
ctx *sql.Context, engine *gms.Engine,
ctx *sql.Context,
tableMap *mysql.TableMap, tableName string, schema sql.Schema,
event binlog.RowEventType, rows *mysql.Rows,
) error {
appender, err := a.tableWriterProvider.GetDeltaAppender(ctx, engine, tableMap.Database, tableName, schema)
appender, err := a.tableWriterProvider.GetDeltaAppender(ctx, tableMap.Database, tableName, schema)
if err != nil {
return err
}
Expand Down Expand Up @@ -1129,8 +1118,8 @@ func (a *binlogReplicaApplier) executeQueryWithEngine(ctx *sql.Context, engine *
switch node.(type) {
case *plan.InsertInto, *plan.Update, *plan.DeleteFrom, *plan.LoadData:
flushChangelog, flushReason = true, delta.DMLStmtFlushReason
case *plan.DropDB,
*plan.DropTable, *plan.RenameTable,
case *plan.CreateDB, *plan.DropDB,
*plan.CreateTable, *plan.DropTable, *plan.RenameTable,
*plan.AddColumn, *plan.RenameColumn, *plan.DropColumn, *plan.ModifyColumn,
*plan.CreateIndex, *plan.DropIndex, *plan.AlterIndex,
*plan.AlterDefaultSet, *plan.AlterDefaultDrop:
Expand Down
12 changes: 7 additions & 5 deletions binlogreplication/writer.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package binlogreplication

import (
stdsql "database/sql"

"github.com/apache/arrow/go/v17/arrow/array"
"github.com/apecloud/myduckserver/binlog"
"github.com/apecloud/myduckserver/delta"
sqle "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/sql"
"vitess.io/vitess/go/mysql"
)
Expand All @@ -13,8 +14,6 @@ type TableWriter interface {
Insert(ctx *sql.Context, keyRows []sql.Row) error
Delete(ctx *sql.Context, keyRows []sql.Row) error
Update(ctx *sql.Context, keyRows []sql.Row, valueRows []sql.Row) error
Commit() error
Rollback() error
}

type DeltaAppender interface {
Expand All @@ -31,7 +30,8 @@ type DeltaAppender interface {
type TableWriterProvider interface {
// GetTableWriter returns a TableWriter for writing to the specified |table| in the specified |database|.
GetTableWriter(
ctx *sql.Context, engine *sqle.Engine,
ctx *sql.Context,
txn *stdsql.Tx,
databaseName, tableName string,
schema sql.PrimaryKeySchema,
columnCount, rowCount int,
Expand All @@ -42,11 +42,13 @@ type TableWriterProvider interface {

// GetDeltaAppender returns a DeltaAppender for appending updates to the specified |table| in the specified |database|.
GetDeltaAppender(
ctx *sql.Context, engine *sqle.Engine,
ctx *sql.Context,
databaseName, tableName string,
schema sql.Schema,
) (DeltaAppender, error)

UpdateLogPosition(position string)

// FlushDelta writes the accumulated changes to the database.
FlushDelta(ctx *sql.Context, reason delta.FlushReason) error
}
Loading

0 comments on commit 4e13bca

Please sign in to comment.