Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pg): transaction batching and columnar buffer for replication #160

Merged
merged 8 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions delta/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"unsafe"

"github.com/apache/arrow-go/v18/arrow/ipc"
"github.com/apecloud/myduckserver/backend"
"github.com/apecloud/myduckserver/binlog"
"github.com/apecloud/myduckserver/catalog"
"github.com/dolthub/go-mysql-server/sql"
Expand All @@ -28,12 +27,10 @@ type FlushStats struct {
type DeltaController struct {
mutex sync.Mutex
tables map[tableIdentifier]*DeltaAppender
pool *backend.ConnectionPool
}

func NewController(pool *backend.ConnectionPool) *DeltaController {
func NewController() *DeltaController {
return &DeltaController{
pool: pool,
tables: make(map[tableIdentifier]*DeltaAppender),
}
}
Expand Down Expand Up @@ -142,6 +139,10 @@ func (c *DeltaController) updateTable(
buf *bytes.Buffer,
stats *FlushStats,
) error {
if tx == nil {
return fmt.Errorf("no active transaction")
}

buf.Reset()

schema := appender.BaseSchema() // schema of the base table
Expand Down Expand Up @@ -284,6 +285,25 @@ func (c *DeltaController) updateTable(
}
stats.Deletions += affected

// For debugging:
//
// rows, err := tx.QueryContext(ctx, "SELECT * FROM "+qualifiedTableName)
// if err != nil {
// return err
// }
// defer rows.Close()
// row := make([]any, len(schema))
// pointers := make([]any, len(row))
// for i := range row {
// pointers[i] = &row[i]
// }
// for rows.Next() {
// if err := rows.Scan(pointers...); err != nil {
// return err
// }
// fmt.Printf("row:%+v\n", row)
// }

if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) {
log.WithFields(logrus.Fields{
"table": qualifiedTableName,
Expand Down
4 changes: 2 additions & 2 deletions delta/delta.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type DeltaAppender struct {
// https://mariadb.com/kb/en/gtid/
// https://dev.mysql.com/doc/refman/9.0/en/replication-gtids-concepts.html
func newDeltaAppender(schema sql.Schema) (*DeltaAppender, error) {
augmented := make(sql.Schema, 0, len(schema)+5)
augmented := make(sql.Schema, 0, len(schema)+6)
augmented = append(augmented, &sql.Column{
Name: "action", // delete = 0, update = 1, insert = 2
Type: types.Int8,
Expand Down Expand Up @@ -73,7 +73,7 @@ func (a *DeltaAppender) Schema() sql.Schema {
}

func (a *DeltaAppender) BaseSchema() sql.Schema {
return a.schema[5:]
return a.schema[6:]
}

func (a *DeltaAppender) Action() *array.Int8Builder {
Expand Down
4 changes: 4 additions & 0 deletions myarrow/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package myarrow

import (
"github.com/apache/arrow-go/v18/arrow"
"github.com/apecloud/myduckserver/pgtypes"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/vt/proto/query"
)
Expand Down Expand Up @@ -38,6 +39,9 @@ func ToArrowType(t sql.Type) (arrow.DataType, error) {
}

func toArrowType(t sql.Type) arrow.DataType {
if pgType, ok := t.(pgtypes.PostgresType); ok {
return pgtypes.PostgresTypeToArrowType(pgType.PG.OID)
}
switch t.Type() {
case query.Type_UINT8:
return arrow.PrimitiveTypes.Uint8
Expand Down
1 change: 1 addition & 0 deletions pgserver/connection_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ type PortalData struct {
IsEmptyQuery bool
Fields []pgproto3.FieldDescription
Stmt *duckdb.Stmt
Vars []any
}

type PreparedStatementData struct {
Expand Down
12 changes: 9 additions & 3 deletions pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ func (h *ConnectionHandler) handleBind(message *pgproto3.Bind) error {
Query: preparedData.Query,
Fields: fields,
Stmt: preparedData.Stmt,
Vars: bindVars,
}
return h.send(&pgproto3.BindComplete{})
}
Expand Down Expand Up @@ -775,19 +776,24 @@ func (h *ConnectionHandler) deletePortal(name string) {
}

// convertBindParameters handles the conversion from bind parameters to variable values.
func (h *ConnectionHandler) convertBindParameters(types []uint32, formatCodes []int16, values [][]byte) ([]string, error) {
func (h *ConnectionHandler) convertBindParameters(types []uint32, formatCodes []int16, values [][]byte) ([]any, error) {
if len(types) != len(values) {
return nil, fmt.Errorf("number of values does not match number of parameters")
}
bindings := make([]string, len(values))
bindings := make([]pgtype.Text, len(values))
for i := range values {
typ := types[i]
// We'll rely on a library to decode each format, which will deal with text and binary representations for us
if err := h.pgTypeMap.Scan(typ, formatCodes[i], values[i], &bindings[i]); err != nil {
return nil, err
}
}
return bindings, nil

vars := make([]any, len(bindings))
for i, b := range bindings {
vars[i] = b.String
}
return vars, nil
}

// query runs the given query and sends a CommandComplete message to the client
Expand Down
103 changes: 80 additions & 23 deletions pgserver/duck_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/backend"
"github.com/apecloud/myduckserver/pgtypes"
"github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree"
sqle "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/server"
Expand Down Expand Up @@ -80,7 +81,7 @@ type DuckHandler struct {
var _ Handler = &DuckHandler{}

// ComBind implements the Handler interface.
func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, prepared PreparedStatementData, bindVars []string) ([]pgproto3.FieldDescription, error) {
func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, prepared PreparedStatementData, bindVars []any) ([]pgproto3.FieldDescription, error) {
vars := make([]driver.NamedValue, len(bindVars))
for i, v := range bindVars {
vars[i] = driver.NamedValue{
Expand All @@ -100,7 +101,7 @@ func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, prepared Prepa

// ComExecuteBound implements the Handler interface.
func (h *DuckHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, portal PortalData, callback func(*Result) error) error {
err := h.doQuery(ctx, conn, portal.Query.String, portal.Query.AST, portal.Stmt, h.executeBoundPlan, callback)
err := h.doQuery(ctx, conn, portal.Query.String, portal.Query.AST, portal.Stmt, portal.Vars, h.executeBoundPlan, callback)
if err != nil {
err = sql.CastSQLError(err)
}
Expand Down Expand Up @@ -161,7 +162,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query

paramOIDs := make([]uint32, len(paramTypes))
for i, t := range paramTypes {
paramOIDs[i] = duckdbTypeToPostgresOID[t]
paramOIDs[i] = pgtypes.DuckdbTypeToPostgresOID[t]
}

var (
Expand All @@ -188,7 +189,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query
break
}
defer rows.Close()
schema, err := inferSchema(rows)
schema, err := pgtypes.InferSchema(rows)
if err != nil {
break
}
Expand All @@ -213,7 +214,7 @@ func (h *DuckHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query

// ComQuery implements the Handler interface.
func (h *DuckHandler) ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, callback func(*Result) error) error {
err := h.doQuery(ctx, c, query, parsed, nil, h.executeQuery, callback)
err := h.doQuery(ctx, c, query, parsed, nil, nil, h.executeQuery, callback)
if err != nil {
err = sql.CastSQLError(err)
}
Expand Down Expand Up @@ -290,7 +291,7 @@ func (h *DuckHandler) getStatementTag(mysqlConn *mysql.Conn, query string) (stri

var queryLoggingRegex = regexp.MustCompile(`[\r\n\t ]+`)

func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, stmt *duckdb.Stmt, queryExec QueryExecutor, callback func(*Result) error) error {
func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed tree.Statement, stmt *duckdb.Stmt, vars []any, queryExec QueryExecutor, callback func(*Result) error) error {
sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query)
if err != nil {
return err
Expand Down Expand Up @@ -326,7 +327,7 @@ func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string,
}
}()

schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, stmt)
schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, stmt, vars)
if err != nil {
if printErrorStackTraces {
fmt.Printf("error running query: %+v\n", err)
Expand Down Expand Up @@ -378,11 +379,11 @@ func (h *DuckHandler) doQuery(ctx context.Context, c *mysql.Conn, query string,

// QueryExecutor is a function that executes a query and returns the result as a schema and iterator. Either of
// |parsed| or |analyzed| can be nil depending on the use case
type QueryExecutor func(ctx *sql.Context, query string, parsed tree.Statement, stmt *duckdb.Stmt) (sql.Schema, sql.RowIter, *sql.QueryFlags, error)
type QueryExecutor func(ctx *sql.Context, query string, parsed tree.Statement, stmt *duckdb.Stmt, vars []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error)

// executeQuery is a QueryExecutor that calls QueryWithBindings on the given engine using the given query and parsed
// statement, which may be nil.
func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.Statement, stmt *duckdb.Stmt) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.Statement, _ *duckdb.Stmt, _ []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
// return h.e.QueryWithBindings(ctx, query, parsed, nil, nil)

sql.IncrementStatusVariable(ctx, "Questions", 1)
Expand Down Expand Up @@ -421,7 +422,7 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S
if err != nil {
break
}
schema, err = inferSchema(rows)
schema, err = pgtypes.InferSchema(rows)
if err != nil {
rows.Close()
break
Expand All @@ -438,38 +439,91 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S

// executeBoundPlan is a QueryExecutor that calls QueryWithBindings on the given engine using the given query and parsed
// statement, which may be nil.
func (h *DuckHandler) executeBoundPlan(ctx *sql.Context, query string, _ tree.Statement, stmt *duckdb.Stmt) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
func (h *DuckHandler) executeBoundPlan(ctx *sql.Context, query string, _ tree.Statement, stmt *duckdb.Stmt, vars []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
// return h.e.PrepQueryPlanForExecution(ctx, query, plan, nil)

// TODO(fan): Currently, the result of executing the bound query is occasionally incorrect.
// For example, for the "concurrent writes" test in the "TestReplication" test case,
// this approach returns [[2 x] [4 i]] instead of [[2 three] [4 five]].
// However, `x` and `i` never appear in the data.
// The reason is not clear and needs further investigation.
// Therefore, we fall back to the unbound query execution for now.
//
// var (
// schema sql.Schema
// iter sql.RowIter
// rows driver.Rows
// result driver.Result
// err error
// )
// switch stmt.StatementType() {
// case duckdb.DUCKDB_STATEMENT_TYPE_SELECT,
// duckdb.DUCKDB_STATEMENT_TYPE_RELATION,
// duckdb.DUCKDB_STATEMENT_TYPE_CALL,
// duckdb.DUCKDB_STATEMENT_TYPE_PRAGMA,
// duckdb.DUCKDB_STATEMENT_TYPE_EXPLAIN:
// rows, err = stmt.QueryBound(ctx)
// if err != nil {
// break
// }
// schema, err = pgtypes.InferDriverSchema(rows)
// if err != nil {
// rows.Close()
// break
// }
// iter, err = NewDriverRowIter(rows, schema)
// if err != nil {
// rows.Close()
// break
// }
// default:
// result, err = stmt.ExecBound(ctx)
// if err != nil {
// break
// }
// affected, _ := result.RowsAffected()
// insertId, _ := result.LastInsertId()
// schema = types.OkResultSchema
// iter = sql.RowsToRowIter(sql.NewRow(types.OkResult{
// RowsAffected: uint64(affected),
// InsertID: uint64(insertId),
// }))
// }
// if err != nil {
// return nil, nil, nil, err
// }

var (
schema sql.Schema
iter sql.RowIter
rows driver.Rows
result driver.Result
err error
stmtType = stmt.StatementType()
schema sql.Schema
iter sql.RowIter
rows *stdsql.Rows
result stdsql.Result
err error
)
switch stmt.StatementType() {

switch stmtType {
case duckdb.DUCKDB_STATEMENT_TYPE_SELECT,
duckdb.DUCKDB_STATEMENT_TYPE_RELATION,
duckdb.DUCKDB_STATEMENT_TYPE_CALL,
duckdb.DUCKDB_STATEMENT_TYPE_PRAGMA,
duckdb.DUCKDB_STATEMENT_TYPE_EXPLAIN:
rows, err = stmt.QueryBound(ctx)
rows, err = adapter.QueryCatalog(ctx, query, vars...)
if err != nil {
break
}
schema, err = inferDriverSchema(rows)
schema, err = pgtypes.InferSchema(rows)
if err != nil {
rows.Close()
break
}
iter, err = NewDriverRowIter(rows, schema)
iter, err = backend.NewSQLRowIter(rows, schema)
if err != nil {
rows.Close()
break
}
default:
result, err = stmt.ExecBound(ctx)
result, err = adapter.ExecCatalog(ctx, query, vars...)
if err != nil {
break
}
Expand All @@ -481,6 +535,9 @@ func (h *DuckHandler) executeBoundPlan(ctx *sql.Context, query string, _ tree.St
InsertID: uint64(insertId),
}))
}
if err != nil {
return nil, nil, nil, err
}

return schema, iter, nil, nil
}
Expand Down Expand Up @@ -509,7 +566,7 @@ func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema) []pgproto3.FieldD
var size int16
var format int16
var err error
if pgType, ok := c.Type.(PostgresType); ok {
if pgType, ok := c.Type.(pgtypes.PostgresType); ok {
oid = pgType.PG.OID
format = pgType.PG.Codec.PreferredFormat()
size = int16(pgType.Size)
Expand Down Expand Up @@ -747,7 +804,7 @@ func rowToBytes(ctx *sql.Context, s sql.Schema, row sql.Row) ([][]byte, error) {
}

// TODO(fan): Preallocate the buffer
if pgType, ok := s[i].Type.(PostgresType); ok {
if pgType, ok := s[i].Type.(pgtypes.PostgresType); ok {
bytes, err := pgType.Encode(v, []byte{})
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion pgserver/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (

type Handler interface {
// ComBind is called when a connection receives a request to bind a prepared statement to a set of values.
ComBind(ctx context.Context, c *mysql.Conn, prepared PreparedStatementData, bindVars []string) ([]pgproto3.FieldDescription, error)
ComBind(ctx context.Context, c *mysql.Conn, prepared PreparedStatementData, bindVars []any) ([]pgproto3.FieldDescription, error)
// ComExecuteBound is called when a connection receives a request to execute a prepared statement that has already bound to a set of values.
ComExecuteBound(ctx context.Context, conn *mysql.Conn, portal PortalData, callback func(*Result) error) error
// ComPrepareParsed is called when a connection receives a prepared statement query that has already been parsed.
Expand Down
5 changes: 4 additions & 1 deletion pgserver/iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgserver

import (
"database/sql/driver"
"fmt"
"io"
"strings"

Expand Down Expand Up @@ -31,7 +32,7 @@ func NewDriverRowIter(rows driver.Rows, schema sql.Schema) (*DriverRowIter, erro
}
sb.WriteString(col.Type.String())
}
logrus.Debugf("New DriverRowIter: columns=%v, schema=[%s]", columns, sb.String())
logrus.Debugf("New DriverRowIter: columns=%v, schema=[%s]\n", columns, sb.String())

return &DriverRowIter{rows, columns, schema, buf, row}, nil
}
Expand All @@ -45,6 +46,8 @@ func (iter *DriverRowIter) Next(ctx *sql.Context) (sql.Row, error) {
return nil, err
}

fmt.Printf("DriverRowIter.Next: buffer=%+v\n", iter.buffer)
fanyang01 marked this conversation as resolved.
Show resolved Hide resolved

// Prune or fill the values to match the schema
width := len(iter.schema) // the desired width
if width == 0 {
Expand Down
Loading
Loading