Skip to content

Commit

Permalink
Merge branch 'main' into 119-zero-etl-pg
Browse files Browse the repository at this point in the history
  • Loading branch information
TianyuZhang1214 authored Nov 25, 2024
2 parents f8be443 + 6806cbf commit 7761052
Show file tree
Hide file tree
Showing 34 changed files with 2,320 additions and 220 deletions.
66 changes: 66 additions & 0 deletions .github/workflows/postgres-compatibility.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
name: Compatibility Test for Postgres

on:
push:
branches:
- main
- compatibility
pull_request:
branches: [ "main" ]

jobs:

build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: '1.23'

- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '16'

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Install dependencies
run: |
go get .
pip3 install "sqlglot[rs]"
pip3 install psycopg2
curl -LJO https://github.com/duckdb/duckdb/releases/download/v1.1.3/duckdb_cli-linux-amd64.zip
unzip duckdb_cli-linux-amd64.zip
chmod +x duckdb
sudo mv duckdb /usr/local/bin
duckdb -c 'INSTALL json from core'
duckdb -c 'SELECT extension_name, loaded, install_path FROM duckdb_extensions() where installed'
sudo apt-get update
sudo apt-get install --yes --no-install-recommends postgresql-client bats cpanminus
cd compatibility/pg
curl -L -o ./java/postgresql-42.7.4.jar https://jdbc.postgresql.org/download/postgresql-42.7.3.jar
npm install pg
sudo cpanm DBD::Pg
sudo gem install pg
- name: Build
run: go build -v

- name: Start MyDuck Server
run: |
./myduckserver &
sleep 5
- name: Run the Compatibility Test
run: |
bats ./compatibility/pg/test.bats
2 changes: 1 addition & 1 deletion .github/workflows/psql.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
run: |
go get .
pip3 install "sqlglot[rs]"
pip3 install "sqlglot[rs]" pyarrow pandas
curl -LJO https://github.com/duckdb/duckdb/releases/download/v1.1.3/duckdb_cli-linux-amd64.zip
unzip duckdb_cli-linux-amd64.zip
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ __debug_*
.DS_Store
*.csv
*.parquet
*.arrow
5 changes: 5 additions & 0 deletions adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@ type ConnectionHolder interface {
GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error)
TryGetTxn() *stdsql.Tx
CloseTxn()
CloseBackendConn()
}

func GetConn(ctx *sql.Context) (*stdsql.Conn, error) {
return ctx.Session.(ConnectionHolder).GetConn(ctx)
}

func CloseBackendConn(ctx *sql.Context) {
ctx.Session.(ConnectionHolder).CloseBackendConn()
}

func GetTxn(ctx *sql.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
return ctx.Session.(ConnectionHolder).GetTxn(ctx, options)
}
Expand Down
2 changes: 1 addition & 1 deletion backend/connpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (p *ConnectionPool) Close() error {
return true
})
for _, conn := range conns {
if err := conn.Close(); err != nil {
if err := conn.Close(); err != nil && !errors.Is(err, stdsql.ErrConnDone) {
logrus.WithError(err).Warn("Failed to close connection")
lastErr = err
}
Expand Down
68 changes: 58 additions & 10 deletions backend/iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package backend
import (
stdsql "database/sql"
"io"
"math/big"
"reflect"
"strings"

"github.com/apecloud/myduckserver/charset"
Expand All @@ -28,17 +30,23 @@ import (

var _ sql.RowIter = (*SQLRowIter)(nil)

type typeConversion struct {
idx int
kind reflect.Kind
}

// SQLRowIter wraps a standard sql.Rows as a RowIter.
type SQLRowIter struct {
rows *stdsql.Rows
columns []*stdsql.ColumnType
schema sql.Schema
buffer []any // pre-allocated buffer for scanning values
pointers []any // pointers to the buffer
decimals []int
intervals []int
nonUTF8 []int
charsets []sql.CharacterSetID
rows *stdsql.Rows
columns []*stdsql.ColumnType
schema sql.Schema
buffer []any // pre-allocated buffer for scanning values
pointers []any // pointers to the buffer
decimals []int
intervals []int
nonUTF8 []int
charsets []sql.CharacterSetID
conversions []typeConversion
}

func NewSQLRowIter(rows *stdsql.Rows, schema sql.Schema) (*SQLRowIter, error) {
Expand Down Expand Up @@ -72,14 +80,32 @@ func NewSQLRowIter(rows *stdsql.Rows, schema sql.Schema) (*SQLRowIter, error) {
}
}

var conversions []typeConversion
for i, c := range columns {
if c.DatabaseTypeName() == "HUGEINT" {
expectedType := schema[i].Type
if ok := types.IsFloat(expectedType); ok {
conversions = append(conversions, typeConversion{idx: i, kind: reflect.Float64})
} else {
conversions = append(conversions, typeConversion{idx: i, kind: reflect.Int64})
}
}
if c.DatabaseTypeName() == "DOUBLE" || c.DatabaseTypeName() == "FLOAT" {
expectedType := schema[i].Type
if ok := types.IsInteger(expectedType); ok {
conversions = append(conversions, typeConversion{idx: i, kind: reflect.Int64})
}
}
}

width := max(len(columns), len(schema))
buf := make([]any, width)
ptrs := make([]any, width)
for i := range buf {
ptrs[i] = &buf[i]
}

return &SQLRowIter{rows, columns, schema, buf, ptrs, decimals, intervals, nonUTF8, charsets}, nil
return &SQLRowIter{rows, columns, schema, buf, ptrs, decimals, intervals, nonUTF8, charsets, conversions}, nil
}

// Next retrieves the next row. It will return io.EOF if it's the last row.
Expand Down Expand Up @@ -115,6 +141,28 @@ func (iter *SQLRowIter) Next(ctx *sql.Context) (sql.Row, error) {
}
}

// Process type conversions
for _, targetType := range iter.conversions {
idx := targetType.idx
rawValue := iter.buffer[idx]
if targetType.kind == reflect.Float64 {
switch v := rawValue.(type) {
case *big.Int:
iter.buffer[idx], _ = v.Float64()
}
}
if targetType.kind == reflect.Int64 {
switch v := rawValue.(type) {
case float64:
iter.buffer[idx] = int64(v)
case float32:
iter.buffer[idx] = int64(v)
case *big.Int:
iter.buffer[idx] = v.Int64()
}
}
}

// Prune or fill the values to match the schema
width := len(iter.schema) // the desired width
if width == 0 {
Expand Down
5 changes: 5 additions & 0 deletions backend/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@ func (sess *Session) CloseTxn() {
sess.pool.CloseTxn(sess.ID())
}

// CloseBackendConn implements adapter.ConnectionHolder.
func (sess *Session) CloseBackendConn() {
sess.pool.CloseConn(sess.ID())
}

func (sess *Session) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) {
conn, err := sess.GetCatalogConn(ctx)
if err != nil {
Expand Down
42 changes: 30 additions & 12 deletions binlogreplication/binlog_replica_applier.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,18 @@ const (
// Match any strings starting with "ON" (case insensitive)
var gtidModeIsOnRegex = regexp.MustCompile(`(?i)^ON$`)

type tableIdentifier struct {
dbName, tableName string
}

// binlogReplicaApplier represents the process that applies updates from a binlog connection.
//
// This type is NOT used concurrently – there is currently only one single applier process running to process binlog
// events, so the state in this type is NOT protected with a mutex.
type binlogReplicaApplier struct {
format *mysql.BinlogFormat
tableMapsById map[uint64]*mysql.TableMap
tablesByName map[tableIdentifier]sql.Table
stopReplicationChan chan struct{}
currentGtid replication.GTID
replicationSourceUuid string
Expand All @@ -86,6 +91,7 @@ type binlogReplicaApplier struct {
func newBinlogReplicaApplier(filters *filterConfiguration) *binlogReplicaApplier {
return &binlogReplicaApplier{
tableMapsById: make(map[uint64]*mysql.TableMap),
tablesByName: make(map[tableIdentifier]sql.Table),
stopReplicationChan: make(chan struct{}),
filters: filters,
}
Expand Down Expand Up @@ -895,6 +901,7 @@ func (a *binlogReplicaApplier) executeQueryWithEngine(ctx *sql.Context, engine *
}
if mysqlutil.CauseSchemaChange(node) {
flushReason = delta.DDLStmtFlushReason
a.tablesByName = make(map[tableIdentifier]sql.Table)
}
}

Expand Down Expand Up @@ -1013,7 +1020,7 @@ func (a *binlogReplicaApplier) processRowEvent(ctx *sql.Context, event mysql.Bin
ctx.GetLogger().Errorf(msg)
MyBinlogReplicaController.setSqlError(sqlerror.ERUnknownError, msg)
}
pkSchema, tableName, err := getTableSchema(ctx, engine, tableMap.Name, tableMap.Database)
pkSchema, tableName, err := a.getTableSchema(ctx, engine, tableMap.Name, tableMap.Database)
if err != nil {
return err
}
Expand Down Expand Up @@ -1199,6 +1206,7 @@ func (a *binlogReplicaApplier) appendRowFormatChanges(
}
a.deltaBufSize.Add(uint64(pos))
}
appender.UpdateActionStats(binlog.DeleteRowEvent, len(rows.Rows))
}

// Insert the after image
Expand Down Expand Up @@ -1229,8 +1237,10 @@ func (a *binlogReplicaApplier) appendRowFormatChanges(
}
a.deltaBufSize.Add(uint64(pos))
}
appender.UpdateActionStats(binlog.InsertRowEvent, len(rows.Rows))
}

appender.ObserveEvents(event, len(rows.Rows))
return nil
}

Expand Down Expand Up @@ -1291,17 +1301,25 @@ varsLoop:

// getTableSchema returns a sql.Schema for the case-insensitive |tableName| in the database named
// |databaseName|, along with the exact, case-sensitive table name.
func getTableSchema(ctx *sql.Context, engine *gms.Engine, tableName, databaseName string) (sql.PrimaryKeySchema, string, error) {
database, err := engine.Analyzer.Catalog.Database(ctx, databaseName)
if err != nil {
return sql.PrimaryKeySchema{}, "", err
}
table, ok, err := database.GetTableInsensitive(ctx, tableName)
if err != nil {
return sql.PrimaryKeySchema{}, "", err
}
if !ok {
return sql.PrimaryKeySchema{}, "", fmt.Errorf("unable to find table %q", tableName)
func (a *binlogReplicaApplier) getTableSchema(ctx *sql.Context, engine *gms.Engine, tableName, databaseName string) (sql.PrimaryKeySchema, string, error) {
key := tableIdentifier{dbName: strings.ToLower(databaseName), tableName: strings.ToLower(tableName)}
table, found := a.tablesByName[key]

if !found {
database, err := engine.Analyzer.Catalog.Database(ctx, databaseName)
if err != nil {
return sql.PrimaryKeySchema{}, "", err
}
var ok bool
table, ok, err = database.GetTableInsensitive(ctx, tableName)
if err != nil {
return sql.PrimaryKeySchema{}, "", err
}
if !ok {
return sql.PrimaryKeySchema{}, "", fmt.Errorf("unable to find table %q", tableName)
}

a.tablesByName[key] = table
}

if pkTable, ok := table.(sql.PrimaryKeyTable); ok {
Expand Down
3 changes: 3 additions & 0 deletions binlogreplication/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ type DeltaAppender interface {
TxnGroup() *array.BinaryDictionaryBuilder
TxnSeqNumber() *array.Uint64Builder
TxnStmtOrdinal() *array.Uint64Builder

UpdateActionStats(action binlog.RowEventType, count int)
ObserveEvents(event binlog.RowEventType, count int)
}

type TableWriterProvider interface {
Expand Down
1 change: 1 addition & 0 deletions compatibility/pg/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.class
Loading

0 comments on commit 7761052

Please sign in to comment.