diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 94db9ff7..f1b817dd 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -53,15 +53,3 @@ jobs: go test -v -cover --timeout 600s . | tee query.log cat query.log | grep -e "^--- " | sed 's/--- //g' | awk 'BEGIN {count=1} {printf "%d. %s\n", count++, $0}' cat query.log | grep -q "FAIL" && exit 1 || exit 0 - - - name: Test Binlog Replication With GTID Enabled - run: | - GTID_ENABLED=true go test -v -p 1 --timeout 600s ./binlogreplication | tee replication.log - cat replication.log | grep -e "^--- " | sed 's/--- //g' | awk 'BEGIN {count=1} {printf "%d. %s\n", count++, $0}' - cat replication.log | grep -q "FAIL" && exit 1 || exit 0 - - - name: Test Binlog Replication With GTID Disabled - run: | - GTID_ENABLED=false go test -v -p 1 --timeout 600s ./binlogreplication | tee replication.log - cat replication.log | grep -e "^--- " | sed 's/--- //g' | awk 'BEGIN {count=1} {printf "%d. %s\n", count++, $0}' - cat replication.log | grep -q "FAIL" && exit 1 || exit 0 diff --git a/.github/workflows/mysql-replication.yml b/.github/workflows/mysql-replication.yml index f1c3a0b9..3488229d 100644 --- a/.github/workflows/mysql-replication.yml +++ b/.github/workflows/mysql-replication.yml @@ -7,9 +7,12 @@ on: branches: [ "main" ] jobs: - build: runs-on: ubuntu-latest + strategy: + matrix: + GTID_ENABLED: [true, false] + steps: - uses: actions/checkout@v4 @@ -26,9 +29,7 @@ jobs: - name: Install dependencies run: | go get . - pip3 install "sqlglot[rs]" - curl -LJO https://github.com/duckdb/duckdb/releases/download/v1.1.3/duckdb_cli-linux-amd64.zip unzip duckdb_cli-linux-amd64.zip chmod +x duckdb @@ -39,14 +40,8 @@ jobs: - name: Build run: go build -v - - name: Test Binlog Replication With GTID Enabled - run: | - GTID_ENABLED=true go test -v -p 1 --timeout 600s ./binlogreplication | tee replication.log - cat replication.log | grep -e "^--- " | sed 's/--- //g' | awk 'BEGIN {count=1} {printf "%d. %s\n", count++, $0}' - cat replication.log | grep -q "FAIL" && exit 1 || exit 0 - - - name: Test Binlog Replication With GTID Disabled + - name: Test Binlog Replication With GTID ${{ matrix.GTID_ENABLED }} run: | - GTID_ENABLED=false go test -v -p 1 --timeout 600s ./binlogreplication | tee replication.log + GTID_ENABLED=${{ matrix.GTID_ENABLED }} go test -v -p 1 --timeout 600s ./binlogreplication | tee replication.log cat replication.log | grep -e "^--- " | sed 's/--- //g' | awk 'BEGIN {count=1} {printf "%d. %s\n", count++, $0}' cat replication.log | grep -q "FAIL" && exit 1 || exit 0 diff --git a/delta/controller.go b/delta/controller.go index 84a2d79a..651dd9dd 100644 --- a/delta/controller.go +++ b/delta/controller.go @@ -11,7 +11,6 @@ import ( "unsafe" "github.com/apache/arrow-go/v18/arrow/ipc" - "github.com/apecloud/myduckserver/backend" "github.com/apecloud/myduckserver/binlog" "github.com/apecloud/myduckserver/catalog" "github.com/dolthub/go-mysql-server/sql" @@ -28,12 +27,10 @@ type FlushStats struct { type DeltaController struct { mutex sync.Mutex tables map[tableIdentifier]*DeltaAppender - pool *backend.ConnectionPool } -func NewController(pool *backend.ConnectionPool) *DeltaController { +func NewController() *DeltaController { return &DeltaController{ - pool: pool, tables: make(map[tableIdentifier]*DeltaAppender), } } @@ -142,6 +139,10 @@ func (c *DeltaController) updateTable( buf *bytes.Buffer, stats *FlushStats, ) error { + if tx == nil { + return fmt.Errorf("no active transaction") + } + buf.Reset() schema := appender.BaseSchema() // schema of the base table @@ -284,6 +285,25 @@ func (c *DeltaController) updateTable( } stats.Deletions += affected + // For debugging: + // + // rows, err := tx.QueryContext(ctx, "SELECT * FROM "+qualifiedTableName) + // if err != nil { + // return err + // } + // defer rows.Close() + // row := make([]any, len(schema)) + // pointers := make([]any, len(row)) + // for i := range row { + // pointers[i] = &row[i] + // } + // for rows.Next() { + // if err := rows.Scan(pointers...); err != nil { + // return err + // } + // fmt.Printf("row:%+v\n", row) + // } + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { log.WithFields(logrus.Fields{ "table": qualifiedTableName, diff --git a/delta/delta.go b/delta/delta.go index 05442c51..b914773a 100644 --- a/delta/delta.go +++ b/delta/delta.go @@ -27,7 +27,7 @@ type DeltaAppender struct { // https://mariadb.com/kb/en/gtid/ // https://dev.mysql.com/doc/refman/9.0/en/replication-gtids-concepts.html func newDeltaAppender(schema sql.Schema) (*DeltaAppender, error) { - augmented := make(sql.Schema, 0, len(schema)+5) + augmented := make(sql.Schema, 0, len(schema)+6) augmented = append(augmented, &sql.Column{ Name: "action", // delete = 0, update = 1, insert = 2 Type: types.Int8, @@ -73,7 +73,7 @@ func (a *DeltaAppender) Schema() sql.Schema { } func (a *DeltaAppender) BaseSchema() sql.Schema { - return a.schema[5:] + return a.schema[6:] } func (a *DeltaAppender) Action() *array.Int8Builder { diff --git a/myarrow/schema.go b/myarrow/schema.go index 090e5bce..b4d321bf 100644 --- a/myarrow/schema.go +++ b/myarrow/schema.go @@ -2,6 +2,7 @@ package myarrow import ( "github.com/apache/arrow-go/v18/arrow" + "github.com/apecloud/myduckserver/pgtypes" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/vt/proto/query" ) @@ -38,6 +39,9 @@ func ToArrowType(t sql.Type) (arrow.DataType, error) { } func toArrowType(t sql.Type) arrow.DataType { + if pgType, ok := t.(pgtypes.PostgresType); ok { + return pgtypes.PostgresTypeToArrowType(pgType.PG.OID) + } switch t.Type() { case query.Type_UINT8: return arrow.PrimitiveTypes.Uint8 diff --git a/pgserver/connection_data.go b/pgserver/connection_data.go index 03e1e8ba..18b6d98c 100644 --- a/pgserver/connection_data.go +++ b/pgserver/connection_data.go @@ -81,6 +81,7 @@ type PortalData struct { IsEmptyQuery bool Fields []pgproto3.FieldDescription Stmt *duckdb.Stmt + Vars []any } type PreparedStatementData struct { diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index 3503f6f4..3e764ffb 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -24,6 +24,7 @@ import ( "io" "net" "os" + "runtime/debug" "strings" "unicode" @@ -106,7 +107,7 @@ func (h *ConnectionHandler) HandleConnection() { if HandlePanics { defer func() { if r := recover(); r != nil { - fmt.Printf("Listener recovered panic: %v", r) + fmt.Printf("Listener recovered panic: %v\n%s\n", r, string(debug.Stack())) var eomErr error if returnErr != nil { @@ -120,7 +121,7 @@ func (h *ConnectionHandler) HandleConnection() { // Sending eom can panic, which means we must recover again defer func() { if r := recover(); r != nil { - fmt.Printf("Listener recovered panic: %v", r) + fmt.Printf("Listener recovered panic: %v\n%s\n", r, string(debug.Stack())) } }() h.endOfMessages(eomErr) @@ -288,7 +289,7 @@ func (h *ConnectionHandler) receiveMessage() (bool, error) { if HandlePanics { defer func() { if r := recover(); r != nil { - fmt.Printf("Listener recovered panic: %v", r) + fmt.Printf("Listener recovered panic: %v\n%s\n", r, string(debug.Stack())) var eomErr error if rErr, ok := r.(error); ok { @@ -553,6 +554,7 @@ func (h *ConnectionHandler) handleBind(message *pgproto3.Bind) error { Query: preparedData.Query, Fields: fields, Stmt: preparedData.Stmt, + Vars: bindVars, } return h.send(&pgproto3.BindComplete{}) } @@ -668,7 +670,7 @@ func (h *ConnectionHandler) handleCopyDataHelper(message *pgproto3.CopyData) (st } fallthrough case tree.CopyFormatCSV: - dataLoader, err = NewCsvDataLoader(sqlCtx, h.duckHandler, &schemaName, insertableTable, copyFrom.Columns, ©From.Options) + dataLoader, err = NewCsvDataLoader(sqlCtx, h.duckHandler, schemaName, insertableTable, copyFrom.Columns, ©From.Options) case tree.CopyFormatBinary: err = fmt.Errorf("BINARY format is not supported for COPY FROM") default: @@ -787,11 +789,11 @@ func (h *ConnectionHandler) deletePortal(name string) { } // convertBindParameters handles the conversion from bind parameters to variable values. -func (h *ConnectionHandler) convertBindParameters(types []uint32, formatCodes []int16, values [][]byte) ([]string, error) { +func (h *ConnectionHandler) convertBindParameters(types []uint32, formatCodes []int16, values [][]byte) ([]any, error) { if len(types) != len(values) { return nil, fmt.Errorf("number of values does not match number of parameters") } - bindings := make([]string, len(values)) + bindings := make([]pgtype.Text, len(values)) for i := range values { typ := types[i] // We'll rely on a library to decode each format, which will deal with text and binary representations for us @@ -799,7 +801,12 @@ func (h *ConnectionHandler) convertBindParameters(types []uint32, formatCodes [] return nil, err } } - return bindings, nil + + vars := make([]any, len(bindings)) + for i, b := range bindings { + vars[i] = b.String + } + return vars, nil } // query runs the given query and sends a CommandComplete message to the client diff --git a/pgserver/dataloader.go b/pgserver/dataloader.go index b4390eff..fef7750b 100644 --- a/pgserver/dataloader.go +++ b/pgserver/dataloader.go @@ -3,6 +3,7 @@ package pgserver import ( "bufio" "context" + "errors" "fmt" "io" "os" @@ -49,19 +50,20 @@ var ErrCopyAborted = fmt.Errorf("COPY operation aborted") type CsvDataLoader struct { ctx *sql.Context cancel context.CancelFunc - schema *string + schema string table sql.InsertableTable columns tree.NameList options *tree.CopyOptions pipePath string pipe *os.File + errPipe *os.File // for error handling rowCount chan int64 err atomic.Pointer[error] } var _ DataLoader = (*CsvDataLoader)(nil) -func NewCsvDataLoader(sqlCtx *sql.Context, handler *DuckHandler, schema *string, table sql.InsertableTable, columns tree.NameList, options *tree.CopyOptions) (DataLoader, error) { +func NewCsvDataLoader(sqlCtx *sql.Context, handler *DuckHandler, schema string, table sql.InsertableTable, columns tree.NameList, options *tree.CopyOptions) (DataLoader, error) { duckBuilder := handler.e.Analyzer.ExecBuilder.(*backend.DuckBuilder) dataDir := duckBuilder.Provider().DataDir() @@ -95,7 +97,9 @@ func NewCsvDataLoader(sqlCtx *sql.Context, handler *DuckHandler, schema *string, // Execute the DuckDB COPY statement in a goroutine. sql := loader.buildSQL() loader.ctx.GetLogger().Trace(sql) - go loader.executeCopy(sql) + go loader.executeCopy(sql, pipePath) + + // TODO(fan): If the reader fails to open the pipe, the writer will block forever. // Open the pipe for writing. // This operation will block until the reader opens the pipe for reading. @@ -103,6 +107,12 @@ func NewCsvDataLoader(sqlCtx *sql.Context, handler *DuckHandler, schema *string, if err != nil { return nil, err } + + // If the COPY operation failed to start, close the pipe and return the error. + if loader.errPipe != nil { + return nil, errors.Join(*loader.err.Load(), pipe.Close(), loader.errPipe.Close()) + } + loader.pipe = pipe return loader, nil @@ -114,8 +124,8 @@ func (loader *CsvDataLoader) buildSQL() string { b.Grow(256) b.WriteString("COPY ") - if loader.schema != nil { - b.WriteString(*loader.schema) + if loader.schema != "" { + b.WriteString(loader.schema) b.WriteString(".") } b.WriteString(loader.table.Name()) @@ -161,12 +171,14 @@ func (loader *CsvDataLoader) buildSQL() string { return b.String() } -func (loader *CsvDataLoader) executeCopy(sql string) { +func (loader *CsvDataLoader) executeCopy(sql string, pipePath string) { defer close(loader.rowCount) result, err := adapter.Exec(loader.ctx, sql) if err != nil { loader.ctx.GetLogger().Error(err) loader.err.Store(&err) + // Open the pipe once to unblock the writer + loader.errPipe, _ = os.OpenFile(pipePath, os.O_RDONLY, 0600) return } diff --git a/pgserver/duck_handler.go b/pgserver/duck_handler.go index d88cd543..a4032262 100644 --- a/pgserver/duck_handler.go +++ b/pgserver/duck_handler.go @@ -29,6 +29,7 @@ import ( "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/backend" + "github.com/apecloud/myduckserver/pgtypes" "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" sqle "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/server" @@ -80,7 +81,7 @@ type DuckHandler struct { var _ Handler = &DuckHandler{} // ComBind implements the Handler interface. -func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, prepared PreparedStatementData, bindVars []string) ([]pgproto3.FieldDescription, error) { +func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, prepared PreparedStatementData, bindVars []any) ([]pgproto3.FieldDescription, error) { vars := make([]driver.NamedValue, len(bindVars)) for i, v := range bindVars { vars[i] = driver.NamedValue{ @@ -100,7 +101,7 @@ func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, prepared Prepa // ComExecuteBound implements the Handler interface. func (h *DuckHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, portal PortalData, callback func(*Result) error) error { - err := h.doQuery(ctx, conn, portal.Query.String, portal.Query.AST, portal.Stmt, h.executeBoundPlan, callback) + err := h.doQuery(ctx, conn, portal.Query.String, portal.Query.AST, portal.Stmt, portal.Vars, h.executeBoundPlan, callback) if err != nil { err = sql.CastSQLError(err) } @@ -161,7 +162,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query paramOIDs := make([]uint32, len(paramTypes)) for i, t := range paramTypes { - paramOIDs[i] = duckdbTypeToPostgresOID[t] + paramOIDs[i] = pgtypes.DuckdbTypeToPostgresOID[t] } var ( @@ -188,7 +189,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query break } defer rows.Close() - schema, err := inferSchema(rows) + schema, err := pgtypes.InferSchema(rows) if err != nil { break } @@ -213,7 +214,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query // ComQuery implements the Handler interface. func (h *DuckHandler) ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, callback func(*Result) error) error { - err := h.doQuery(ctx, c, query, parsed, nil, h.executeQuery, callback) + err := h.doQuery(ctx, c, query, parsed, nil, nil, h.executeQuery, callback) if err != nil { err = sql.CastSQLError(err) } @@ -290,7 +291,7 @@ func (h *DuckHandler) getStatementTag(mysqlConn *mysql.Conn, query string) (stri var queryLoggingRegex = regexp.MustCompile(`[\r\n\t ]+`) -func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, stmt *duckdb.Stmt, queryExec QueryExecutor, callback func(*Result) error) error { +func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, stmt *duckdb.Stmt, vars []any, queryExec QueryExecutor, callback func(*Result) error) error { sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query) if err != nil { return err @@ -326,7 +327,7 @@ func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, } }() - schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, stmt) + schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, stmt, vars) if err != nil { if printErrorStackTraces { fmt.Printf("error running query: %+v\n", err) @@ -378,11 +379,11 @@ func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, // QueryExecutor is a function that executes a query and returns the result as a schema and iterator. Either of // |parsed| or |analyzed| can be nil depending on the use case -type QueryExecutor func(ctx *sql.Context, query string, parsed tree.Statement, stmt *duckdb.Stmt) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) +type QueryExecutor func(ctx *sql.Context, query string, parsed tree.Statement, stmt *duckdb.Stmt, vars []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) // executeQuery is a QueryExecutor that calls QueryWithBindings on the given engine using the given query and parsed // statement, which may be nil. -func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.Statement, stmt *duckdb.Stmt) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { +func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.Statement, _ *duckdb.Stmt, _ []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { // return h.e.QueryWithBindings(ctx, query, parsed, nil, nil) sql.IncrementStatusVariable(ctx, "Questions", 1) @@ -421,7 +422,7 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S if err != nil { break } - schema, err = inferSchema(rows) + schema, err = pgtypes.InferSchema(rows) if err != nil { rows.Close() break @@ -432,44 +433,100 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S break } } + if err != nil { + return nil, nil, nil, err + } return schema, iter, nil, nil } // executeBoundPlan is a QueryExecutor that calls QueryWithBindings on the given engine using the given query and parsed // statement, which may be nil. -func (h *DuckHandler) executeBoundPlan(ctx *sql.Context, query string, _ tree.Statement, stmt *duckdb.Stmt) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { +func (h *DuckHandler) executeBoundPlan(ctx *sql.Context, query string, _ tree.Statement, stmt *duckdb.Stmt, vars []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { // return h.e.PrepQueryPlanForExecution(ctx, query, plan, nil) + // TODO(fan): Currently, the result of executing the bound query is occasionally incorrect. + // For example, for the "concurrent writes" test in the "TestReplication" test case, + // this approach returns [[2 x] [4 i]] instead of [[2 three] [4 five]]. + // However, `x` and `i` never appear in the data. + // The reason is not clear and needs further investigation. + // Therefore, we fall back to the unbound query execution for now. + // + // var ( + // schema sql.Schema + // iter sql.RowIter + // rows driver.Rows + // result driver.Result + // err error + // ) + // switch stmt.StatementType() { + // case duckdb.DUCKDB_STATEMENT_TYPE_SELECT, + // duckdb.DUCKDB_STATEMENT_TYPE_RELATION, + // duckdb.DUCKDB_STATEMENT_TYPE_CALL, + // duckdb.DUCKDB_STATEMENT_TYPE_PRAGMA, + // duckdb.DUCKDB_STATEMENT_TYPE_EXPLAIN: + // rows, err = stmt.QueryBound(ctx) + // if err != nil { + // break + // } + // schema, err = pgtypes.InferDriverSchema(rows) + // if err != nil { + // rows.Close() + // break + // } + // iter, err = NewDriverRowIter(rows, schema) + // if err != nil { + // rows.Close() + // break + // } + // default: + // result, err = stmt.ExecBound(ctx) + // if err != nil { + // break + // } + // affected, _ := result.RowsAffected() + // insertId, _ := result.LastInsertId() + // schema = types.OkResultSchema + // iter = sql.RowsToRowIter(sql.NewRow(types.OkResult{ + // RowsAffected: uint64(affected), + // InsertID: uint64(insertId), + // })) + // } + // if err != nil { + // return nil, nil, nil, err + // } + var ( - schema sql.Schema - iter sql.RowIter - rows driver.Rows - result driver.Result - err error + stmtType = stmt.StatementType() + schema sql.Schema + iter sql.RowIter + rows *stdsql.Rows + result stdsql.Result + err error ) - switch stmt.StatementType() { + + switch stmtType { case duckdb.DUCKDB_STATEMENT_TYPE_SELECT, duckdb.DUCKDB_STATEMENT_TYPE_RELATION, duckdb.DUCKDB_STATEMENT_TYPE_CALL, duckdb.DUCKDB_STATEMENT_TYPE_PRAGMA, duckdb.DUCKDB_STATEMENT_TYPE_EXPLAIN: - rows, err = stmt.QueryBound(ctx) + rows, err = adapter.QueryCatalog(ctx, query, vars...) if err != nil { break } - schema, err = inferDriverSchema(rows) + schema, err = pgtypes.InferSchema(rows) if err != nil { rows.Close() break } - iter, err = NewDriverRowIter(rows, schema) + iter, err = backend.NewSQLRowIter(rows, schema) if err != nil { rows.Close() break } default: - result, err = stmt.ExecBound(ctx) + result, err = adapter.ExecCatalog(ctx, query, vars...) if err != nil { break } @@ -481,6 +538,9 @@ func (h *DuckHandler) executeBoundPlan(ctx *sql.Context, query string, _ tree.St InsertID: uint64(insertId), })) } + if err != nil { + return nil, nil, nil, err + } return schema, iter, nil, nil } @@ -509,7 +569,7 @@ func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema) []pgproto3.FieldD var size int16 var format int16 var err error - if pgType, ok := c.Type.(PostgresType); ok { + if pgType, ok := c.Type.(pgtypes.PostgresType); ok { oid = pgType.PG.OID format = pgType.PG.Codec.PreferredFormat() size = int16(pgType.Size) @@ -747,7 +807,7 @@ func rowToBytes(ctx *sql.Context, s sql.Schema, row sql.Row) ([][]byte, error) { } // TODO(fan): Preallocate the buffer - if pgType, ok := s[i].Type.(PostgresType); ok { + if pgType, ok := s[i].Type.(pgtypes.PostgresType); ok { bytes, err := pgType.Encode(v, []byte{}) if err != nil { return nil, err diff --git a/pgserver/handler.go b/pgserver/handler.go index 887dfc8f..5f60fc22 100644 --- a/pgserver/handler.go +++ b/pgserver/handler.go @@ -26,7 +26,7 @@ import ( type Handler interface { // ComBind is called when a connection receives a request to bind a prepared statement to a set of values. - ComBind(ctx context.Context, c *mysql.Conn, prepared PreparedStatementData, bindVars []string) ([]pgproto3.FieldDescription, error) + ComBind(ctx context.Context, c *mysql.Conn, prepared PreparedStatementData, bindVars []any) ([]pgproto3.FieldDescription, error) // ComExecuteBound is called when a connection receives a request to execute a prepared statement that has already bound to a set of values. ComExecuteBound(ctx context.Context, conn *mysql.Conn, portal PortalData, callback func(*Result) error) error // ComPrepareParsed is called when a connection receives a prepared statement query that has already been parsed. diff --git a/pgserver/iter.go b/pgserver/iter.go index d4b6fb00..761e6e45 100644 --- a/pgserver/iter.go +++ b/pgserver/iter.go @@ -31,7 +31,7 @@ func NewDriverRowIter(rows driver.Rows, schema sql.Schema) (*DriverRowIter, erro } sb.WriteString(col.Type.String()) } - logrus.Debugf("New DriverRowIter: columns=%v, schema=[%s]", columns, sb.String()) + logrus.Debugf("New DriverRowIter: columns=%v, schema=[%s]\n", columns, sb.String()) return &DriverRowIter{rows, columns, schema, buf, row}, nil } diff --git a/pgserver/logrepl/decode.go b/pgserver/logrepl/decode.go new file mode 100644 index 00000000..2bd2ae5a --- /dev/null +++ b/pgserver/logrepl/decode.go @@ -0,0 +1,278 @@ +package logrepl + +import ( + "fmt" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/jackc/pglogrepl" + "github.com/jackc/pgx/v5/pgtype" +) + +// decodeToArrow decodes Postgres text format data and appends directly to Arrow builder +func decodeToArrow(typeMap *pgtype.Map, columnType *pglogrepl.RelationMessageColumn, data []byte, format int16, builder array.Builder) (int, error) { + if data == nil { + builder.AppendNull() + return 0, nil + } + + dt, ok := typeMap.TypeForOID(columnType.DataType) + if !ok { + // Unknown type, store as string + if b, ok := builder.(*array.StringBuilder); ok { + b.Append(string(data)) + return len(data), nil + } + return 0, fmt.Errorf("column %s: unsupported type conversion for OID %d to %T", columnType.Name, columnType.DataType, builder) + } + + oid := dt.OID + switch oid { + case pgtype.BoolOID: + if b, ok := builder.(*array.BooleanBuilder); ok { + var v bool + var codec pgtype.BoolCodec + if err := codec.PlanScan(typeMap, oid, format, &v).Scan(data, &v); err != nil { + return 0, err + } + b.Append(v) + return 1, nil + } + + case pgtype.Int2OID: + if b, ok := builder.(*array.Int16Builder); ok { + var v int16 + var codec pgtype.Int2Codec + if err := codec.PlanScan(typeMap, oid, format, &v).Scan(data, &v); err != nil { + return 0, err + } + b.Append(v) + return 2, nil + } + + case pgtype.Int4OID: + if b, ok := builder.(*array.Int32Builder); ok { + var v int32 + var codec pgtype.Int4Codec + if err := codec.PlanScan(typeMap, oid, format, &v).Scan(data, &v); err != nil { + return 0, err + } + b.Append(v) + return 4, nil + } + + case pgtype.Int8OID: + if b, ok := builder.(*array.Int64Builder); ok { + var v int64 + var codec pgtype.Int8Codec + if err := codec.PlanScan(typeMap, oid, format, &v).Scan(data, &v); err != nil { + return 0, err + } + b.Append(v) + return 8, nil + } + + case pgtype.Float4OID: + if b, ok := builder.(*array.Float32Builder); ok { + var v float32 + var codec pgtype.Float4Codec + if err := codec.PlanScan(typeMap, oid, format, &v).Scan(data, &v); err != nil { + return 0, err + } + b.Append(v) + return 4, nil + } + + case pgtype.Float8OID: + if b, ok := builder.(*array.Float64Builder); ok { + var v float64 + var codec pgtype.Float8Codec + if err := codec.PlanScan(typeMap, oid, format, &v).Scan(data, &v); err != nil { + return 0, err + } + b.Append(v) + return 8, nil + } + + case pgtype.TimestampOID: + if b, ok := builder.(*array.TimestampBuilder); ok { + var v pgtype.Timestamp + codec := pgtype.TimestampCodec{ScanLocation: time.UTC} + if err := codec.PlanScan(typeMap, oid, format, &v).Scan(data, &v); err != nil { + return 0, err + } + b.AppendTime(v.Time) + return 8, nil + } + + case pgtype.TimestamptzOID: + if b, ok := builder.(*array.TimestampBuilder); ok { + var v pgtype.Timestamptz + codec := pgtype.TimestamptzCodec{ScanLocation: time.UTC} + if err := codec.PlanScan(typeMap, oid, format, &v).Scan(data, &v); err != nil { + return 0, err + } + b.AppendTime(v.Time) + return 8, nil + } + + case pgtype.DateOID: + if b, ok := builder.(*array.Date32Builder); ok { + var v pgtype.Date + var codec pgtype.DateCodec + if err := codec.PlanScan(typeMap, oid, format, &v).Scan(data, &v); err != nil { + return 0, err + } + b.Append(arrow.Date32FromTime(v.Time)) + return 4, nil + } + + case pgtype.NumericOID: + // TODO(fan): write small decimal as Decimal128 + if b, ok := builder.(*array.StringBuilder); ok { + var v pgtype.Text + var codec pgtype.NumericCodec + if err := codec.PlanScan(typeMap, oid, format, &v).Scan(data, &v); err != nil { + return 0, err + } + b.AppendString(v.String) + return len(data), nil + } + + case pgtype.TextOID, pgtype.VarcharOID, pgtype.BPCharOID, pgtype.NameOID: + var buf [32]byte // Stack-allocated buffer for small string + v := pgtype.PreallocBytes(buf[:]) + var codec pgtype.TextCodec + if err := codec.PlanScan(typeMap, oid, format, &v).Scan(data, &v); err != nil { + return 0, err + } + switch b := builder.(type) { + case *array.StringBuilder: + b.BinaryBuilder.Append(v) + return len(v), nil + case *array.BinaryBuilder: + b.Append(v) + return len(v), nil + } + + case pgtype.ByteaOID: + if b, ok := builder.(*array.BinaryBuilder); ok { + var buf [32]byte // Stack-allocated buffer for small byte array + v := pgtype.PreallocBytes(buf[:]) + var codec pgtype.ByteaCodec + if err := codec.PlanScan(typeMap, oid, format, &v).Scan(data, &v); err != nil { + return 0, err + } + b.Append(v) + return len(v), nil + } + + case pgtype.UUIDOID: + var v pgtype.UUID + var codec pgtype.UUIDCodec + if err := codec.PlanScan(typeMap, oid, format, &v).Scan(data, &v); err != nil { + return 0, err + } + switch b := builder.(type) { + case *array.FixedSizeBinaryBuilder: + b.Append(v.Bytes[:]) + return 16, nil + case *array.StringBuilder: + var buf [36]byte + codec.PlanEncode(typeMap, oid, pgtype.TextFormatCode, &v).Encode(&v, buf[:0]) + b.BinaryBuilder.Append(buf[:]) + return 36, nil + } + } + // TODO(fan): add support for other types + + // Fallback + v, err := dt.Codec.DecodeValue(typeMap, oid, format, data) + if err != nil { + return 0, err + } + return writeValue(builder, v) +} + +// Keep writeValue as a fallback for handling Go values from pgtype codec +func writeValue(builder array.Builder, val any) (int, error) { + switch b := builder.(type) { + case *array.BooleanBuilder: + if v, ok := val.(bool); ok { + b.Append(v) + return 1, nil + } + case *array.Int8Builder: + if v, ok := val.(int8); ok { + b.Append(v) + return 1, nil + } + case *array.Int16Builder: + if v, ok := val.(int16); ok { + b.Append(v) + return 2, nil + } + case *array.Int32Builder: + if v, ok := val.(int32); ok { + b.Append(v) + return 4, nil + } + case *array.Int64Builder: + if v, ok := val.(int64); ok { + b.Append(v) + return 8, nil + } + case *array.Uint8Builder: + if v, ok := val.(uint8); ok { + b.Append(v) + return 1, nil + } + case *array.Uint16Builder: + if v, ok := val.(uint16); ok { + b.Append(v) + return 2, nil + } + case *array.Uint32Builder: + if v, ok := val.(uint32); ok { + b.Append(v) + return 4, nil + } + case *array.Uint64Builder: + if v, ok := val.(uint64); ok { + b.Append(v) + return 8, nil + } + case *array.Float32Builder: + if v, ok := val.(float32); ok { + b.Append(v) + return 4, nil + } + case *array.Float64Builder: + if v, ok := val.(float64); ok { + b.Append(v) + return 8, nil + } + case *array.StringBuilder: + if v, ok := val.(string); ok { + b.Append(v) + return len(v), nil + } + case *array.BinaryBuilder: + if v, ok := val.([]byte); ok { + b.Append(v) + return len(v), nil + } + case *array.TimestampBuilder: + if v, ok := val.(pgtype.Timestamp); ok { + b.AppendTime(v.Time) + return 8, nil + } + case *array.DurationBuilder: + if v, ok := val.(time.Duration); ok { + b.Append(arrow.Duration(v)) + return 8, nil + } + } + return 0, fmt.Errorf("unsupported type conversion: %T -> %T", val, builder) +} diff --git a/pgserver/logrepl/replication.go b/pgserver/logrepl/replication.go index 63f3ea4f..1f78e93f 100644 --- a/pgserver/logrepl/replication.go +++ b/pgserver/logrepl/replication.go @@ -19,20 +19,24 @@ import ( stdsql "database/sql" "errors" "fmt" - "log" "math" "strings" "sync" + "sync/atomic" "time" "github.com/apecloud/myduckserver/adapter" + "github.com/apecloud/myduckserver/binlog" "github.com/apecloud/myduckserver/catalog" + "github.com/apecloud/myduckserver/delta" + "github.com/apecloud/myduckserver/pgtypes" "github.com/dolthub/go-mysql-server/sql" "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" + "github.com/sirupsen/logrus" ) const outputPlugin = "pgoutput" @@ -49,6 +53,8 @@ type LogicalReplicator struct { messageReceived bool stop chan struct{} mu *sync.Mutex + + logger *logrus.Entry } // NewLogicalReplicator creates a new logical replicator instance which connects to the primary and replication @@ -58,6 +64,10 @@ func NewLogicalReplicator(primaryDns string) (*LogicalReplicator, error) { return &LogicalReplicator{ primaryDns: primaryDns, mu: &sync.Mutex{}, + logger: logrus.WithFields(logrus.Fields{ + "component": "replicator", + "protocol": "pg", + }), }, nil } @@ -92,7 +102,7 @@ func (r *LogicalReplicator) CaughtUp(threshold int) (bool, error) { } r.mu.Unlock() - log.Printf("Checking replication lag with threshold %d\n", threshold) + r.logger.Debugf("Checking replication lag with threshold %d\n", threshold) conn, err := pgx.Connect(context.Background(), r.PrimaryDns()) if err != nil { return false, err @@ -115,10 +125,10 @@ func (r *LogicalReplicator) CaughtUp(threshold int) (bool, error) { row := rows[0] lag, ok := row.(pgtype.Numeric) if ok && lag.Valid { - log.Printf("Current replication lag: %v", row) + r.logger.Debugf("Current replication lag: %v", row) return int(math.Abs(float64(lag.Int.Int64()))) < threshold, nil } else { - log.Printf("Replication lag unknown: %v", row) + r.logger.Debugf("Replication lag unknown: %v", row) } } @@ -148,10 +158,14 @@ type replicationState struct { // SendStandbyStatusUpdate after every message we get. lastReceivedLSN pglogrepl.LSN - // currentTransactionLSN is the LSN of the current transaction we are processing. This becomes the lastWrittenLSN - // when we get a CommitMessage + // currentTransactionLSN is the LSN of the current transaction we are processing. + // This becomes the lastCommitLSN when we get a CommitMessage. currentTransactionLSN pglogrepl.LSN + // lastCommitLSN is the LSN of the last commit message we received. + // This becomes the lastWrittenLSN when we commit the transaction to the database. + lastCommitLSN pglogrepl.LSN + // inStream tracks the state of the replication stream. When we receive a StreamStartMessage, we set inStream to // true, and then back to false when we receive a StreamStopMessage. inStream bool @@ -164,8 +178,19 @@ type replicationState struct { // the final LSN of the transaction, as recorded in the Begin message. So for every Begin, we decide whether to // process or ignore all messages until a corresponding Commit message. processMessages bool - relations map[uint32]*pglogrepl.RelationMessageV2 - typeMap *pgtype.Map + + typeMap *pgtype.Map + relations map[uint32]*pglogrepl.RelationMessageV2 + schemas map[uint32]sql.Schema + keys map[uint32][]uint16 // relationID -> slice of key column indices + deltas *delta.DeltaController + + deltaBufSize atomic.Uint64 // size of the delta buffer in bytes + lastCommitTime time.Time // time of last commit + ongoingBatchTxn atomic.Bool // true if we're in a batched transaction + dirtyTxn atomic.Bool // true if we have uncommitted changes + dirtyStream atomic.Bool // true if the binlog stream does not end with a commit + inTxnStmtID atomic.Uint64 // statement ID within transaction } // StartReplication starts the replication process for the given slot name. This function blocks until replication is @@ -183,8 +208,11 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin replicaCtx: sqlCtx, slotName: slotName, lastWrittenLSN: lastWrittenLsn, - relations: map[uint32]*pglogrepl.RelationMessageV2{}, typeMap: pgtype.NewMap(), + relations: map[uint32]*pglogrepl.RelationMessageV2{}, + schemas: map[uint32]sql.Schema{}, + keys: map[uint32][]uint16{}, + deltas: delta.NewController(), } // Switch to the `public` schema. @@ -199,7 +227,7 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin _ = primaryConn.Close(context.Background()) } // We always shut down here and only here, so we do the cleanup on thread exit in exactly one place - r.shutdown(sqlCtx) + r.shutdown(sqlCtx, state) }() connErrCnt := 0 @@ -209,7 +237,7 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin connErrCnt++ } if connErrCnt < maxConsecutiveFailures { - log.Printf("Error: %v. Retrying", err) + r.logger.Debugf("Error: %v. Retrying", err) if primaryConn != nil { _ = primaryConn.Close(context.Background()) } @@ -235,18 +263,21 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin return handleErrWithRetry(err, false) } - log.Printf("Sent Standby status message with WALWritePosition = %s, WALApplyPosition = %s\n", state.lastWrittenLSN+1, state.lastReceivedLSN+1) + r.logger.Debugf("Sent Standby status message with WALWritePosition = %s, WALApplyPosition = %s\n", state.lastWrittenLSN+1, state.lastReceivedLSN+1) nextStandbyMessageDeadline = time.Now().Add(standbyMessageTimeout) return nil } - log.Printf("Starting replicator: primaryDsn=%s, slotName=%s", r.PrimaryDns(), slotName) + r.logger.Debugf("Starting replicator: primaryDsn=%s, slotName=%s", r.PrimaryDns(), slotName) r.mu.Lock() r.running = true r.messageReceived = false r.stop = make(chan struct{}) r.mu.Unlock() + ticker := time.NewTicker(200 * time.Millisecond) + defer ticker.Stop() + for { err := func() error { // Shutdown if requested @@ -296,6 +327,9 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin return nil case msgAndErr = <-receiveMsgChan: cancel() + case <-ticker.C: + cancel() + return r.commitOngoingTxnIfClean(state, delta.TimeTickFlushReason) } if msgAndErr.err != nil { @@ -317,7 +351,7 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin msg, ok := rawMsg.(*pgproto3.CopyData) if !ok { - log.Printf("Received unexpected message: %T\n", rawMsg) + r.logger.Debugf("Received unexpected message: %T\n", rawMsg) return nil } @@ -325,10 +359,10 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin case pglogrepl.PrimaryKeepaliveMessageByteID: pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(msg.Data[1:]) if err != nil { - log.Fatalln("ParsePrimaryKeepaliveMessage failed:", err) + return fmt.Errorf("ParsePrimaryKeepaliveMessage failed: %w", err) } - log.Println("Primary Keepalive Message =>", "ServerWALEnd:", pkm.ServerWALEnd, "ServerTime:", pkm.ServerTime, "ReplyRequested:", pkm.ReplyRequested) + r.logger.Traceln("Primary Keepalive Message =>", "ServerWALEnd:", pkm.ServerWALEnd, "ServerTime:", pkm.ServerTime, "ReplyRequested:", pkm.ReplyRequested) state.lastReceivedLSN = pkm.ServerWALEnd if pkm.ReplyRequested { @@ -349,7 +383,7 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin return sendStandbyStatusUpdate(state) default: - log.Printf("Received unexpected message: %T\n", rawMsg) + r.logger.Debugf("Received unexpected message: %T\n", rawMsg) } return nil @@ -359,21 +393,23 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin if errors.Is(err, errShutdownRequested) { return nil } - log.Println("Error during replication:", err) + r.logger.Errorln("Error during replication:", err) return err } } } -func (r *LogicalReplicator) shutdown(ctx *sql.Context) { +func (r *LogicalReplicator) shutdown(ctx *sql.Context, state *replicationState) { r.mu.Lock() defer r.mu.Unlock() - log.Print("shutting down replicator") + r.logger.Info("shutting down replicator") + + r.commitOngoingTxnIfClean(state, delta.OnCloseFlushReason) // Rollback any open transaction _, err := adapter.ExecCatalog(ctx, "ROLLBACK") if err != nil && !strings.Contains(err.Error(), "no transaction is active") { - log.Printf("Failed to roll back transaction: %v", err) + r.logger.Debugf("Failed to roll back transaction: %v", err) } r.running = false @@ -396,27 +432,16 @@ func (r *LogicalReplicator) Stop() { } r.mu.Unlock() - log.Print("stopping replication...") + r.logger.Info("stopping replication...") r.stop <- struct{}{} // wait for the channel to be closed, acknowledging that the replicator has stopped <-r.stop } -// replicateQuery executes the query provided on the replica connection -func (r *LogicalReplicator) replicateQuery(replicaCtx *sql.Context, query string) error { - log.Printf("replicating query: %s", query) - result, err := adapter.Exec(replicaCtx, query) - if err == nil { - affected, _ := result.RowsAffected() - log.Printf("Affected rows: %d", affected) - } - return err -} - // beginReplication starts a new replication connection to the primary server and returns it. The LSN provided is the // last one we have confirmed that we flushed to disk. func (r *LogicalReplicator) beginReplication(slotName string, lastFlushLsn pglogrepl.LSN) (*pgconn.PgConn, error) { - log.Printf("Connecting to primary for replication: %s", r.ReplicationDns()) + r.logger.Debugf("Connecting to primary for replication: %s", r.ReplicationDns()) conn, err := pgconn.Connect(context.Background(), r.ReplicationDns()) if err != nil { return nil, err @@ -435,14 +460,14 @@ func (r *LogicalReplicator) beginReplication(slotName string, lastFlushLsn pglog // not rewind to previous entries that we've already confirmed to the primary that we flushed. We still pass an LSN // for the edge case where we have flushed an entry to disk, but crashed before the primary received confirmation. // In that edge case, we want to "skip" entries (from the primary's perspective) that we have already flushed to disk. - log.Printf("Starting logical replication on slot %s at WAL location %s", slotName, lastFlushLsn+1) + r.logger.Debugf("Starting logical replication on slot %s at WAL location %s", slotName, lastFlushLsn+1) err = pglogrepl.StartReplication(context.Background(), conn, slotName, lastFlushLsn+1, pglogrepl.StartReplicationOptions{ PluginArgs: pluginArguments, }) if err != nil { return nil, err } - log.Println("Logical replication started on slot", slotName) + r.logger.Infoln("Logical replication started on slot", slotName) return conn, nil } @@ -529,7 +554,7 @@ func (r *LogicalReplicator) CreateReplicationSlotIfNecessary(slotName string) er } } - log.Println("Created replication slot:", slotName) + r.logger.Infoln("Created replication slot:", slotName) } return nil @@ -551,18 +576,48 @@ func (r *LogicalReplicator) processMessage( return false, err } - log.Printf("XLogData (%T) => WALStart %s ServerWALEnd %s ServerTime %s", logicalMsg, xld.WALStart, xld.ServerWALEnd, xld.ServerTime) - state.lastReceivedLSN = xld.ServerWALEnd + r.logger.Debugf("XLogData (%T) => WALStart %s ServerWALEnd %s ServerTime %s", logicalMsg, xld.WALStart, xld.ServerWALEnd, xld.ServerTime) + + // Update the last received LSN + if xld.ServerWALEnd > state.lastReceivedLSN { + state.lastReceivedLSN = xld.ServerWALEnd + } switch logicalMsg := logicalMsg.(type) { case *pglogrepl.RelationMessageV2: + // When a Relation message is received, commit any buffered ongoing batch transactions. + err := r.commitOngoingTxn(state, delta.DDLStmtFlushReason) + if err != nil { + return false, err + } + state.relations[logicalMsg.RelationID] = logicalMsg + + schema := make(sql.Schema, len(logicalMsg.Columns)) + var keys []uint16 + for i, col := range logicalMsg.Columns { + pgType, err := pgtypes.NewPostgresType(col.DataType) + if err != nil { + return false, err + } + schema[i] = &sql.Column{ + Name: col.Name, + Type: pgType, + PrimaryKey: col.Flags == 1, + } + if col.Flags == 1 { + keys = append(keys, uint16(i)) + } + } + state.schemas[logicalMsg.RelationID] = schema + state.keys[logicalMsg.RelationID] = keys + case *pglogrepl.BeginMessage: // Indicates the beginning of a group of changes in a transaction. // This is only sent for committed transactions. We won't get any events from rolled back transactions. if state.lastWrittenLSN > logicalMsg.FinalLSN { - log.Printf("Received stale message, ignoring. Last written LSN: %s Message LSN: %s", state.lastWrittenLSN, logicalMsg.FinalLSN) + r.logger.Debugf("Received stale message, ignoring. Last written LSN: %s Message LSN: %s", state.lastWrittenLSN, logicalMsg.FinalLSN) state.processMessages = false return false, nil } @@ -570,202 +625,127 @@ func (r *LogicalReplicator) processMessage( state.processMessages = true state.currentTransactionLSN = logicalMsg.FinalLSN - log.Printf("BeginMessage: %v", logicalMsg) - err = r.replicateQuery(state.replicaCtx, "BEGIN TRANSACTION") - if err != nil { - return false, err + // Start a new transaction or extend existing batch + extend, reason := r.mayExtendBatchTxn(state) + if !extend { + err := r.commitOngoingTxn(state, reason) + if err != nil { + return false, err + } + _, err = adapter.GetCatalogTxn(state.replicaCtx, nil) + if err != nil { + return false, err + } + state.ongoingBatchTxn.Store(true) } + case *pglogrepl.CommitMessage: - log.Printf("CommitMessage: %v", logicalMsg) + r.logger.Debugf("CommitMessage: %v", logicalMsg) - // Record the LSN before we commit the transaction - log.Printf("Writing LSN %s\n", state.currentTransactionLSN) - err := r.writeWALPosition(state.replicaCtx, state.slotName, state.currentTransactionLSN) - if err != nil { - return false, err - } + state.lastCommitLSN = logicalMsg.CommitLSN - err = r.replicateQuery(state.replicaCtx, "COMMIT") - if err != nil { - return false, err + extend, reason := r.mayExtendBatchTxn(state) + if !extend { + err = r.commitOngoingTxn(state, reason) + if err != nil { + return false, err + } } + state.dirtyStream.Store(false) + state.inTxnStmtID.Store(0) - state.lastWrittenLSN = state.currentTransactionLSN state.processMessages = false return true, nil case *pglogrepl.InsertMessageV2: if !state.processMessages { - log.Printf("Received stale message, ignoring. Last written LSN: %s Message LSN: %s", state.lastWrittenLSN, xld.ServerWALEnd) + r.logger.Debugf("Received stale message, ignoring. Last written LSN: %s Message LSN: %s", state.lastWrittenLSN, xld.ServerWALEnd) return false, nil } - rel, ok := state.relations[logicalMsg.RelationID] - if !ok { - log.Fatalf("unknown relation ID %d", logicalMsg.RelationID) - } - - columnStr := strings.Builder{} - valuesStr := strings.Builder{} - for idx, col := range logicalMsg.Tuple.Columns { - if idx > 0 { - columnStr.WriteString(", ") - valuesStr.WriteString(", ") - } - - colName := rel.Columns[idx].Name - columnStr.WriteString(colName) - - switch col.DataType { - case 'n': // null - valuesStr.WriteString("NULL") - case 't': // text - - // We have to round-trip the data through the encodings to get an accurate text rep back - val, err := decodeTextColumnData(state.typeMap, col.Data, rel.Columns[idx].DataType) - if err != nil { - log.Fatalln("error decoding column data:", err) - } - colData, err := encodeColumnData(state.typeMap, val, rel.Columns[idx].DataType) - if err != nil { - return false, err - } - valuesStr.WriteString(colData) - default: - log.Printf("unknown column data type: %c", col.DataType) - } - } - - err = r.replicateQuery(state.replicaCtx, fmt.Sprintf("INSERT INTO %s.%s (%s) VALUES (%s)", rel.Namespace, rel.RelationName, columnStr.String(), valuesStr.String())) + err = r.append(state, logicalMsg.RelationID, logicalMsg.Tuple.Columns, binlog.InsertRowEvent, false) if err != nil { return false, err } + + state.dirtyTxn.Store(true) + state.dirtyStream.Store(true) + state.inTxnStmtID.Add(1) + case *pglogrepl.UpdateMessageV2: if !state.processMessages { - log.Printf("Received stale message, ignoring. Last written LSN: %s Message LSN: %s", state.lastWrittenLSN, xld.ServerWALEnd) + r.logger.Debugf("Received stale message, ignoring. Last written LSN: %s Message LSN: %s", state.lastWrittenLSN, xld.ServerWALEnd) return false, nil } - // TODO: this won't handle primary key changes correctly - // TODO: this probably doesn't work for unkeyed tables - rel, ok := state.relations[logicalMsg.RelationID] - if !ok { - log.Fatalf("unknown relation ID %d", logicalMsg.RelationID) + // Delete the old tuple + switch logicalMsg.OldTupleType { + case pglogrepl.UpdateMessageTupleTypeKey: + err = r.append(state, logicalMsg.RelationID, logicalMsg.OldTuple.Columns, binlog.DeleteRowEvent, true) + case pglogrepl.UpdateMessageTupleTypeOld: + err = r.append(state, logicalMsg.RelationID, logicalMsg.OldTuple.Columns, binlog.DeleteRowEvent, false) + default: + // No old tuple provided; it means the key columns are unchanged } - - updateStr := strings.Builder{} - whereStr := strings.Builder{} - for idx, col := range logicalMsg.NewTuple.Columns { - colName := rel.Columns[idx].Name - colFlags := rel.Columns[idx].Flags - - var stringVal string - switch col.DataType { - case 'n': // null - stringVal = "NULL" - case 'u': // unchanged toast - case 't': // text - val, err := decodeTextColumnData(state.typeMap, col.Data, rel.Columns[idx].DataType) - if err != nil { - log.Fatalln("error decoding column data:", err) - } - - stringVal, err = encodeColumnData(state.typeMap, val, rel.Columns[idx].DataType) - if err != nil { - return false, err - } - default: - log.Printf("unknown column data type: %c", col.DataType) - } - - // TODO: quote column names? - if colFlags == 0 { - if updateStr.Len() > 0 { - updateStr.WriteString(", ") - } - updateStr.WriteString(fmt.Sprintf("%s = %v", colName, stringVal)) - } else { - if whereStr.Len() > 0 { - updateStr.WriteString(", ") - } - whereStr.WriteString(fmt.Sprintf("%s = %v", colName, stringVal)) - } + if err != nil { + return false, err } - err = r.replicateQuery(state.replicaCtx, fmt.Sprintf("UPDATE %s.%s SET %s%s", rel.Namespace, rel.RelationName, updateStr.String(), whereClause(whereStr))) + // Insert the new tuple + err = r.append(state, logicalMsg.RelationID, logicalMsg.NewTuple.Columns, binlog.InsertRowEvent, false) if err != nil { return false, err } + + state.dirtyTxn.Store(true) + state.dirtyStream.Store(true) + state.inTxnStmtID.Add(1) + case *pglogrepl.DeleteMessageV2: if !state.processMessages { - log.Printf("Received stale message, ignoring. Last written LSN: %s Message LSN: %s", state.lastWrittenLSN, xld.ServerWALEnd) + r.logger.Debugf("Received stale message, ignoring. Last written LSN: %s Message LSN: %s", state.lastWrittenLSN, xld.ServerWALEnd) return false, nil + // Determine which columns to use based on OldTupleType } - // TODO: this probably doesn't work for unkeyed tables - rel, ok := state.relations[logicalMsg.RelationID] - if !ok { - log.Fatalf("unknown relation ID %d", logicalMsg.RelationID) - } - - whereStr := strings.Builder{} - for idx, col := range logicalMsg.OldTuple.Columns { - colName := rel.Columns[idx].Name - colFlags := rel.Columns[idx].Flags - - var stringVal string - switch col.DataType { - case 'n': // null - stringVal = "NULL" - case 'u': // unchanged toast - case 't': // text - val, err := decodeTextColumnData(state.typeMap, col.Data, rel.Columns[idx].DataType) - if err != nil { - log.Fatalln("error decoding column data:", err) - } - - stringVal, err = encodeColumnData(state.typeMap, val, rel.Columns[idx].DataType) - if err != nil { - return false, err - } - default: - log.Printf("unknown column data type: %c", col.DataType) - } - - if colFlags == 0 { - // nothing to do - } else { - if whereStr.Len() > 0 { - whereStr.WriteString(", ") - } - whereStr.WriteString(fmt.Sprintf("%s = %v", colName, stringVal)) - } + switch logicalMsg.OldTupleType { + case pglogrepl.UpdateMessageTupleTypeKey: + err = r.append(state, logicalMsg.RelationID, logicalMsg.OldTuple.Columns, binlog.DeleteRowEvent, true) + case pglogrepl.UpdateMessageTupleTypeOld: + err = r.append(state, logicalMsg.RelationID, logicalMsg.OldTuple.Columns, binlog.DeleteRowEvent, false) + default: + // No old tuple provided; cannot perform delete + err = fmt.Errorf("DeleteMessage without OldTuple") } - err = r.replicateQuery(state.replicaCtx, fmt.Sprintf("DELETE FROM %s.%s WHERE %s", rel.Namespace, rel.RelationName, whereStr.String())) if err != nil { return false, err } + + state.dirtyTxn.Store(true) + state.dirtyStream.Store(true) + state.inTxnStmtID.Add(1) + case *pglogrepl.TruncateMessageV2: - log.Printf("truncate for xid %d\n", logicalMsg.Xid) + r.logger.Debugf("truncate for xid %d\n", logicalMsg.Xid) case *pglogrepl.TypeMessageV2: - log.Printf("typeMessage for xid %d\n", logicalMsg.Xid) + r.logger.Debugf("typeMessage for xid %d\n", logicalMsg.Xid) case *pglogrepl.OriginMessage: - log.Printf("originMessage for xid %s\n", logicalMsg.Name) + r.logger.Debugf("originMessage for xid %s\n", logicalMsg.Name) case *pglogrepl.LogicalDecodingMessageV2: - log.Printf("Logical decoding message: %q, %q, %d", logicalMsg.Prefix, logicalMsg.Content, logicalMsg.Xid) + r.logger.Debugf("Logical decoding message: %q, %q, %d", logicalMsg.Prefix, logicalMsg.Content, logicalMsg.Xid) case *pglogrepl.StreamStartMessageV2: state.inStream = true - log.Printf("Stream start message: xid %d, first segment? %d", logicalMsg.Xid, logicalMsg.FirstSegment) + r.logger.Debugf("Stream start message: xid %d, first segment? %d", logicalMsg.Xid, logicalMsg.FirstSegment) case *pglogrepl.StreamStopMessageV2: state.inStream = false - log.Printf("Stream stop message") + r.logger.Debugf("Stream stop message") case *pglogrepl.StreamCommitMessageV2: - log.Printf("Stream commit message: xid %d", logicalMsg.Xid) + r.logger.Debugf("Stream commit message: xid %d", logicalMsg.Xid) case *pglogrepl.StreamAbortMessageV2: - log.Printf("Stream abort message: xid %d", logicalMsg.Xid) + r.logger.Debugf("Stream abort message: xid %d", logicalMsg.Xid) default: - log.Printf("Unknown message type in pgoutput stream: %T", logicalMsg) + r.logger.Debugf("Unknown message type in pgoutput stream: %T", logicalMsg) } return false, nil @@ -787,7 +767,7 @@ func (r *LogicalReplicator) readWALPosition(ctx *sql.Context, slotName string) ( // writeWALPosition writes the recorded WAL position to the WAL position table func (r *LogicalReplicator) writeWALPosition(ctx *sql.Context, slotName string, lsn pglogrepl.LSN) error { - _, err := adapter.Exec(ctx, catalog.InternalTables.PgReplicationLSN.UpsertStmt(), slotName, lsn.String()) + _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgReplicationLSN.UpsertStmt(), slotName, lsn.String()) return err } @@ -843,3 +823,139 @@ func encodeColumnData(mi *pgtype.Map, data interface{}, dataType uint32) (string return value, nil } } + +// mayExtendBatchTxn checks if we should extend the current batch transaction +func (r *LogicalReplicator) mayExtendBatchTxn(state *replicationState) (bool, delta.FlushReason) { + extend, reason := false, delta.UnknownFlushReason + if state.ongoingBatchTxn.Load() { + extend = true + switch { + case time.Since(state.lastCommitTime) >= 200*time.Millisecond: + extend, reason = false, delta.TimeTickFlushReason + case state.deltaBufSize.Load() >= (128 << 20): // 128MB + extend, reason = false, delta.MemoryLimitFlushReason + } + } + return extend, reason +} + +func (r *LogicalReplicator) commitOngoingTxnIfClean(state *replicationState, reason delta.FlushReason) error { + if state.dirtyTxn.Load() && !state.dirtyStream.Load() { + return r.commitOngoingTxn(state, reason) + } + return nil +} + +// commitOngoingTxn commits the current transaction +func (r *LogicalReplicator) commitOngoingTxn(state *replicationState, flushReason delta.FlushReason) error { + tx := adapter.TryGetTxn(state.replicaCtx) + if tx == nil { + return nil + } + + defer tx.Rollback() + defer adapter.CloseTxn(state.replicaCtx) + + // Flush the delta buffer if too large + err := r.flushDeltaBuffer(state, tx, flushReason) + if err != nil { + return err + } + + r.logger.Debugf("Writing LSN %s\n", state.currentTransactionLSN) + if err = r.writeWALPosition(state.replicaCtx, state.slotName, state.currentTransactionLSN); err != nil { + return err + } + + // Commit the transaction + if err := tx.Commit(); err != nil { + return err + } + + // Reset transaction state + state.ongoingBatchTxn.Store(false) + state.dirtyTxn.Store(false) + state.dirtyStream.Store(false) + state.inTxnStmtID.Store(0) + state.lastCommitTime = time.Now() + + state.lastWrittenLSN = state.lastCommitLSN + + return nil +} + +// flushDeltaBuffer flushes the accumulated changes in the delta buffer +func (r *LogicalReplicator) flushDeltaBuffer(state *replicationState, tx *stdsql.Tx, reason delta.FlushReason) error { + defer state.deltaBufSize.Store(0) + + _, err := state.deltas.Flush(state.replicaCtx, tx, reason) + return err +} + +func (r *LogicalReplicator) append(state *replicationState, relationID uint32, tuple []*pglogrepl.TupleDataColumn, eventType binlog.RowEventType, onlyKeys bool) error { + rel, ok := state.relations[relationID] + if !ok { + return fmt.Errorf("unknown relation ID %d", relationID) + } + appender, err := state.deltas.GetDeltaAppender(rel.Namespace, rel.RelationName, state.schemas[relationID]) + if err != nil { + return err + } + + fields := appender.Fields() + actions := appender.Action() + txnTags := appender.TxnTag() + txnServers := appender.TxnServer() + txnGroups := appender.TxnGroup() + txnSeqNumbers := appender.TxnSeqNumber() + txnStmtOrdinals := appender.TxnStmtOrdinal() + + actions.Append(int8(eventType)) + txnTags.AppendNull() + txnServers.Append([]byte("")) + txnGroups.AppendNull() + txnSeqNumbers.Append(uint64(state.currentTransactionLSN)) + txnStmtOrdinals.Append(state.inTxnStmtID.Load()) + + size := 0 + idx := 0 + + for i, metadata := range rel.Columns { + builder := fields[i] + var col *pglogrepl.TupleDataColumn + if onlyKeys { + if metadata.Flags != 1 { // not a key column + builder.AppendNull() + continue + } + col = tuple[idx] + idx++ + } else { + col = tuple[i] + } + switch col.DataType { + case pglogrepl.TupleDataTypeNull: + builder.AppendNull() + case pglogrepl.TupleDataTypeText, pglogrepl.TupleDataTypeBinary: + length, err := decodeToArrow(state.typeMap, metadata, col.Data, tupleDataFormat(col.DataType), builder) + if err != nil { + return err + } + size += length + default: + return fmt.Errorf("unsupported replication data format %d", col.DataType) + } + } + + state.deltaBufSize.Add(uint64(size)) + return nil +} + +func tupleDataFormat(dataType uint8) int16 { + switch dataType { + case pglogrepl.TupleDataTypeBinary: + return pgtype.BinaryFormatCode + default: + return pgtype.TextFormatCode + } +} diff --git a/pgserver/logrepl/replication_test.go b/pgserver/logrepl/replication_test.go index 1a41e096..9261a7af 100644 --- a/pgserver/logrepl/replication_test.go +++ b/pgserver/logrepl/replication_test.go @@ -504,7 +504,7 @@ var replicationTests = []ReplicationTest{ } func TestReplication(t *testing.T) { - // logrus.SetLevel(logrus.TraceLevel) + // logrus.SetLevel(logrus.DebugLevel) RunReplicationScripts(t, replicationTests) } @@ -545,7 +545,7 @@ func RunReplicationScripts(t *testing.T, scripts []ReplicationTest) { time.Sleep(500 * time.Millisecond) // for i, script := range scripts { - // if i == 4 { + // if i == 0 { // RunReplicationScript(t, dsn, script) // } // } @@ -555,7 +555,6 @@ func RunReplicationScripts(t *testing.T, scripts []ReplicationTest) { } const slotName = "myduck_slot" -const localPostgresPort = 5432 // RunReplicationScript runs the given ReplicationTest. func RunReplicationScript(t *testing.T, dsn string, script ReplicationTest) { diff --git a/pgserver/type_mapping.go b/pgtypes/pgtypes.go similarity index 71% rename from pgserver/type_mapping.go rename to pgtypes/pgtypes.go index 4434da82..250298e5 100644 --- a/pgserver/type_mapping.go +++ b/pgtypes/pgtypes.go @@ -1,4 +1,4 @@ -package pgserver +package pgtypes import ( stdsql "database/sql" @@ -6,16 +6,18 @@ import ( "fmt" "reflect" - "github.com/dolthub/go-mysql-server/sql" + "github.com/apache/arrow-go/v18/arrow" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/jackc/pgx/v5/pgtype" "github.com/marcboeker/go-duckdb" + + "github.com/dolthub/go-mysql-server/sql" ) -var defaultTypeMap = pgtype.NewMap() +var DefaultTypeMap = pgtype.NewMap() -var duckdbTypeStrToPostgresTypeStr = map[string]string{ +var DuckdbTypeStrToPostgresTypeStr = map[string]string{ "INVALID": "unknown", "BOOLEAN": "bool", "TINYINT": "int2", @@ -49,7 +51,7 @@ var duckdbTypeStrToPostgresTypeStr = map[string]string{ "VARINT": "numeric", // Variable integer, mapped to numeric } -var duckdbTypeToPostgresOID = map[duckdb.Type]uint32{ +var DuckdbTypeToPostgresOID = map[duckdb.Type]uint32{ duckdb.TYPE_INVALID: pgtype.UnknownOID, duckdb.TYPE_BOOLEAN: pgtype.BoolOID, duckdb.TYPE_TINYINT: pgtype.Int2OID, @@ -83,7 +85,7 @@ var duckdbTypeToPostgresOID = map[duckdb.Type]uint32{ duckdb.TYPE_VARINT: pgtype.NumericOID, } -var pgTypeSizes = map[uint32]int32{ +var PostgresTypeSizes = map[uint32]int32{ pgtype.BoolOID: 1, // bool pgtype.ByteaOID: -1, // bytea pgtype.NameOID: -1, // name @@ -119,7 +121,50 @@ var pgTypeSizes = map[uint32]int32{ pgtype.UUIDOID: 16, // uuid } -func inferSchema(rows *stdsql.Rows) (sql.Schema, error) { +func PostgresTypeToArrowType(oid uint32) arrow.DataType { + switch oid { + case pgtype.BoolOID: + return arrow.FixedWidthTypes.Boolean + case pgtype.ByteaOID: + return arrow.BinaryTypes.Binary + case pgtype.NameOID, pgtype.TextOID, pgtype.VarcharOID, pgtype.BPCharOID, pgtype.JSONOID, pgtype.XMLOID: + return arrow.BinaryTypes.String + case pgtype.Int8OID: + return arrow.PrimitiveTypes.Int64 + case pgtype.Int2OID: + return arrow.PrimitiveTypes.Int16 + case pgtype.Int4OID: + return arrow.PrimitiveTypes.Int32 + case pgtype.OIDOID: + return arrow.PrimitiveTypes.Uint32 + case pgtype.TIDOID: + return &arrow.FixedSizeBinaryType{ByteWidth: 8} + case pgtype.Float4OID: + return arrow.PrimitiveTypes.Float32 + case pgtype.Float8OID: + return arrow.PrimitiveTypes.Float64 + case pgtype.PointOID: + return arrow.StructOf(arrow.Field{Name: "x", Type: arrow.PrimitiveTypes.Float64}, + arrow.Field{Name: "y", Type: arrow.PrimitiveTypes.Float64}) + case pgtype.DateOID: + return arrow.FixedWidthTypes.Date32 + case pgtype.TimeOID: + return arrow.FixedWidthTypes.Time64ns + case pgtype.TimestampOID, pgtype.TimestamptzOID: + return arrow.FixedWidthTypes.Timestamp_s + case pgtype.NumericOID: // TODO: Use Decimal128Type for precision <= 38 + return arrow.BinaryTypes.String + case pgtype.UUIDOID: + // TODO(fan): Currently, DuckDB does not support BLOB -> UUID conversion, + // so we use a string type for UUIDs. + // return &arrow.FixedSizeBinaryType{ByteWidth: 16} + return arrow.BinaryTypes.String + default: + return arrow.BinaryTypes.Binary // fall back for unknown types + } +} + +func InferSchema(rows *stdsql.Rows) (sql.Schema, error) { types, err := rows.ColumnTypes() if err != nil { return nil, err @@ -127,11 +172,11 @@ func inferSchema(rows *stdsql.Rows) (sql.Schema, error) { schema := make(sql.Schema, len(types)) for i, t := range types { - pgTypeName, ok := duckdbTypeStrToPostgresTypeStr[t.DatabaseTypeName()] + pgTypeName, ok := DuckdbTypeStrToPostgresTypeStr[t.DatabaseTypeName()] if !ok { return nil, fmt.Errorf("unsupported type %s", t.DatabaseTypeName()) } - pgType, ok := defaultTypeMap.TypeForName(pgTypeName) + pgType, ok := DefaultTypeMap.TypeForName(pgTypeName) if !ok { return nil, fmt.Errorf("unsupported type %s", pgTypeName) } @@ -141,7 +186,7 @@ func inferSchema(rows *stdsql.Rows) (sql.Schema, error) { Name: t.Name(), Type: PostgresType{ PG: pgType, - Size: pgTypeSizes[pgType.OID], + Size: PostgresTypeSizes[pgType.OID], }, Nullable: nullable, } @@ -150,18 +195,18 @@ func inferSchema(rows *stdsql.Rows) (sql.Schema, error) { return schema, nil } -func inferDriverSchema(rows driver.Rows) (sql.Schema, error) { +func InferDriverSchema(rows driver.Rows) (sql.Schema, error) { columns := rows.Columns() schema := make(sql.Schema, len(columns)) for i, colName := range columns { var pgTypeName string if colType, ok := rows.(driver.RowsColumnTypeDatabaseTypeName); ok { - pgTypeName = duckdbTypeStrToPostgresTypeStr[colType.ColumnTypeDatabaseTypeName(i)] + pgTypeName = DuckdbTypeStrToPostgresTypeStr[colType.ColumnTypeDatabaseTypeName(i)] } else { pgTypeName = "text" // Default to text if type name is not available } - pgType, ok := defaultTypeMap.TypeForName(pgTypeName) + pgType, ok := DefaultTypeMap.TypeForName(pgTypeName) if !ok { return nil, fmt.Errorf("unsupported type %s", pgTypeName) } @@ -175,7 +220,7 @@ func inferDriverSchema(rows driver.Rows) (sql.Schema, error) { Name: colName, Type: PostgresType{ PG: pgType, - Size: pgTypeSizes[pgType.OID], + Size: PostgresTypeSizes[pgType.OID], }, Nullable: nullable, } @@ -189,8 +234,19 @@ type PostgresType struct { Size int32 } +func NewPostgresType(oid uint32) (PostgresType, error) { + t, ok := DefaultTypeMap.TypeForOID(oid) + if !ok { + return PostgresType{}, fmt.Errorf("unsupported type OID %d", oid) + } + return PostgresType{ + PG: t, + Size: PostgresTypeSizes[oid], + }, nil +} + func (p PostgresType) Encode(v any, buf []byte) ([]byte, error) { - return defaultTypeMap.Encode(p.PG.OID, p.PG.Codec.PreferredFormat(), v, buf) + return DefaultTypeMap.Encode(p.PG.OID, p.PG.Codec.PreferredFormat(), v, buf) } var _ sql.Type = PostgresType{} diff --git a/replica/replication.go b/replica/replication.go index a543edaf..5b2f07fc 100644 --- a/replica/replication.go +++ b/replica/replication.go @@ -42,7 +42,7 @@ func RegisterReplicaController(provider *catalog.DatabaseProvider, engine *sqle. replica.SetExecutionContext(ctx) twp := &tableWriterProvider{pool: pool} - twp.controller = delta.NewController(pool) + twp.controller = delta.NewController() replica.SetTableWriterProvider(twp) builder.FlushDeltaBuffer = nil // TODO: implement this