Skip to content

Commit

Permalink
feat: support MySQL CTAS (#351)
Browse files Browse the repository at this point in the history
* Enable more CTAS tests
* add type mapping for hugeint and varint
* Add fallback for simple CTAS
  • Loading branch information
fanyang01 authored Jan 8, 2025
1 parent 052e742 commit 4145be1
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 77 deletions.
53 changes: 38 additions & 15 deletions backend/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,13 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
}

n := root
ctx.GetLogger().WithFields(logrus.Fields{
"Query": ctx.Query(),
"NodeType": fmt.Sprintf("%T", n),
}).Traceln("Building node:", n)

if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) {
log.WithFields(logrus.Fields{
"Query": ctx.Query(),
"NodeType": fmt.Sprintf("%T", n),
}).Traceln("Building node:", n)
}

// TODO; find a better way to fallback to the base builder
switch n.(type) {
Expand Down Expand Up @@ -114,7 +117,13 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
}

// Fallback to the base builder if the plan contains system/user variables or is not a pure data query.
if containsVariable(n) || !IsPureDataQuery(n) {
tree := n
switch n := n.(type) {
case *plan.TableCopier:
tree = n.Source
}
if containsVariable(tree) || !IsPureDataQuery(tree) {
ctx.GetLogger().Traceln("Falling back to the base builder")
return b.base.Build(ctx, root, r)
}

Expand All @@ -133,11 +142,21 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
return nil, err
}
return b.base.Build(ctx, root, r)
// SubqueryAlias is for select * from view
// ResolvedTable is for `SELECT * FROM table` and `TABLE table`
// SubqueryAlias is for `SELECT * FROM view`
case *plan.ResolvedTable, *plan.SubqueryAlias, *plan.TableAlias:
return b.executeQuery(ctx, node, conn)
case *plan.Distinct, *plan.OrderedDistinct:
return b.executeQuery(ctx, node, conn)
case *plan.TableCopier:
// We preserve the table schema in a best-effort manner.
// For simple `CREATE TABLE t AS SELECT * FROM t`,
// we fall back to the framework to create the table and copy the data.
// For more complex cases, we directly execute the CTAS statement in DuckDB.
if _, ok := node.Source.(*plan.ResolvedTable); ok {
return b.base.Build(ctx, root, r)
}
return b.executeDML(ctx, node, conn)
case sql.Expressioner:
return b.executeExpressioner(ctx, node, conn)
case *plan.DeleteFrom:
Expand Down Expand Up @@ -174,7 +193,7 @@ func (b *DuckBuilder) executeQuery(ctx *sql.Context, n sql.Node, conn *stdsql.Co
case *plan.ShowTables:
duckSQL = ctx.Query()
case *plan.ResolvedTable:
// SQLGlot cannot translate MySQL's `TABLE t` into DuckDB's `FROM t` - it produces `"table" AS t` instead.
// SQLGlot cannot translate MySQL's `TABLE t` into DuckDB's `FROM t` - it produces `"table" AS t` instead.
duckSQL = `FROM ` + catalog.ConnectIdentifiersANSI(n.Database().Name(), n.Name())
default:
duckSQL, err = transpiler.TranslateWithSQLGlot(ctx.Query())
Expand All @@ -183,10 +202,12 @@ func (b *DuckBuilder) executeQuery(ctx *sql.Context, n sql.Node, conn *stdsql.Co
return nil, catalog.ErrTranspiler.New(err)
}

ctx.GetLogger().WithFields(logrus.Fields{
"Query": ctx.Query(),
"DuckSQL": duckSQL,
}).Trace("Executing Query...")
if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) {
log.WithFields(logrus.Fields{
"Query": ctx.Query(),
"DuckSQL": duckSQL,
}).Trace("Executing Query...")
}

// Execute the DuckDB query
rows, err := conn.QueryContext(ctx.Context, duckSQL)
Expand All @@ -204,10 +225,12 @@ func (b *DuckBuilder) executeDML(ctx *sql.Context, n sql.Node, conn *stdsql.Conn
return nil, catalog.ErrTranspiler.New(err)
}

ctx.GetLogger().WithFields(logrus.Fields{
"Query": ctx.Query(),
"DuckSQL": duckSQL,
}).Trace("Executing DML...")
if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) {
log.WithFields(logrus.Fields{
"Query": ctx.Query(),
"DuckSQL": duckSQL,
}).Trace("Executing DML...")
}

// Execute the DuckDB query
result, err := conn.ExecContext(ctx.Context, duckSQL)
Expand Down
22 changes: 22 additions & 0 deletions catalog/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,25 @@ func (d *Database) GetCollation(ctx *sql.Context) sql.CollationID {
func (d *Database) SetCollation(ctx *sql.Context, collation sql.CollationID) error {
return nil
}

// CopyTableData implements sql.TableCopierDatabase interface.
func (d *Database) CopyTableData(ctx *sql.Context, sourceTable string, destinationTable string) (uint64, error) {
d.mu.Lock()
defer d.mu.Unlock()

// Use INSERT INTO ... SELECT to copy data
sql := `INSERT INTO ` + FullTableName(d.catalog, d.name, destinationTable) + ` FROM ` + FullTableName(d.catalog, d.name, sourceTable)

res, err := adapter.Exec(ctx, sql)
if err != nil {
return 0, ErrDuckDB.New(err)
}

// Get count of affected rows
count, err := res.RowsAffected()
if err != nil {
return 0, ErrDuckDB.New(err)
}

return uint64(count), nil
}
5 changes: 4 additions & 1 deletion catalog/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,10 @@ func queryColumns(ctx *sql.Context, catalogName, schemaName, tableName string) (
}

decodedComment := DecodeComment[MySQLType](comment.String)
dataType := mysqlDataType(AnnotatedDuckType{dataTypes, decodedComment.Meta}, uint8(numericPrecision.Int32), uint8(numericScale.Int32))
dataType, err := mysqlDataType(AnnotatedDuckType{dataTypes, decodedComment.Meta}, uint8(numericPrecision.Int32), uint8(numericScale.Int32))
if err != nil {
return nil, err
}

columnInfo := &ColumnInfo{
ColumnName: columnName,
Expand Down
69 changes: 39 additions & 30 deletions catalog/type_mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func DuckdbDataType(mysqlType sql.Type) (AnnotatedDuckType, error) {
}
}

func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericScale uint8) sql.Type {
func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericScale uint8) (sql.Type, error) {
// TODO: The current type mappings are not lossless. We need to store the original type in the column comments.
duckName := strings.TrimSpace(strings.ToUpper(duckType.name))

Expand All @@ -219,7 +219,7 @@ func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericSc
intBaseType = sqltypes.Uint8
case "SMALLINT":
if mysqlName == "YEAR" {
return types.Year
return types.Year, nil
}
intBaseType = sqltypes.Int16
case "USMALLINT":
Expand All @@ -240,13 +240,13 @@ func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericSc
intBaseType = sqltypes.Int64
case "UBIGINT":
if mysqlName == "BIT" {
return types.MustCreateBitType(duckType.mysql.Precision)
return types.CreateBitType(duckType.mysql.Precision)
}
intBaseType = sqltypes.Uint64
}

if intBaseType != sqltypes.Null {
return types.MustCreateNumberTypeWithDisplayWidth(intBaseType, int(duckType.mysql.Display))
return types.CreateNumberTypeWithDisplayWidth(intBaseType, int(duckType.mysql.Display))
}

length := int64(duckType.mysql.Length)
Expand All @@ -255,70 +255,79 @@ func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericSc

switch duckName {
case "FLOAT":
return types.Float32
return types.Float32, nil
case "DOUBLE":
return types.Float64
return types.Float64, nil

case "TIMESTAMP", "TIMESTAMP_S", "TIMESTAMP_MS":
if mysqlName == "DATETIME" {
return types.MustCreateDatetimeType(sqltypes.Datetime, precision)
return types.CreateDatetimeType(sqltypes.Datetime, precision)
}
return types.MustCreateDatetimeType(sqltypes.Timestamp, precision)
return types.CreateDatetimeType(sqltypes.Timestamp, precision)

case "DATE":
return types.Date
return types.Date, nil
case "INTERVAL", "TIME":
return types.Time
return types.Time, nil

case "DECIMAL":
return types.MustCreateDecimalType(numericPrecision, numericScale)
return types.CreateDecimalType(numericPrecision, numericScale)

case "UHUGEINT", "HUGEINT":
// MySQL does not have these types. We store them as DECIMAL.
return types.CreateDecimalType(39, 0)

case "VARINT":
// MySQL does not have this type. We store it as DECIMAL.
// Here we use the maximum supported precision for DECIMAL in MySQL.
return types.CreateDecimalType(65, 0)

case "VARCHAR":
if mysqlName == "TEXT" {
if length <= types.TinyTextBlobMax {
return types.TinyText
return types.TinyText, nil
} else if length <= types.TextBlobMax {
return types.Text
return types.Text, nil
} else if length <= types.MediumTextBlobMax {
return types.MediumText
return types.MediumText, nil
} else {
return types.LongText
return types.LongText, nil
}
} else if mysqlName == "VARCHAR" {
return types.MustCreateString(sqltypes.VarChar, length, collation)
return types.CreateString(sqltypes.VarChar, length, collation)
} else if mysqlName == "CHAR" {
return types.MustCreateString(sqltypes.Char, length, collation)
return types.CreateString(sqltypes.Char, length, collation)
} else if mysqlName == "SET" {
return types.MustCreateSetType(duckType.mysql.Values, collation)
return types.CreateSetType(duckType.mysql.Values, collation)
}
return types.Text
return types.Text, nil

case "BLOB":
if mysqlName == "BLOB" {
if length <= types.TinyTextBlobMax {
return types.TinyBlob
return types.TinyBlob, nil
} else if length <= types.TextBlobMax {
return types.Blob
return types.Blob, nil
} else if length <= types.MediumTextBlobMax {
return types.MediumBlob
return types.MediumBlob, nil
} else {
return types.LongBlob
return types.LongBlob, nil
}
} else if mysqlName == "VARBINARY" {
return types.MustCreateBinary(sqltypes.VarBinary, length)
return types.CreateBinary(sqltypes.VarBinary, length)
} else if mysqlName == "BINARY" {
return types.MustCreateBinary(sqltypes.Binary, length)
return types.CreateBinary(sqltypes.Binary, length)
}
return types.Blob
return types.Blob, nil

case "JSON":
return types.JSON
return types.JSON, nil
case "ENUM":
return types.MustCreateEnumType(duckType.mysql.Values, collation)
return types.CreateEnumType(duckType.mysql.Values, collation)
case "SET":
return types.MustCreateSetType(duckType.mysql.Values, collation)
return types.CreateSetType(duckType.mysql.Values, collation)
default:
panic(fmt.Sprintf("encountered unknown DuckDB type(%v). This is likely a bug - please check the duckdbDataType function for missing type mappings", duckType))
return nil, fmt.Errorf("encountered unknown DuckDB type(%v)", duckType)
}
}

Expand Down
32 changes: 1 addition & 31 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1110,45 +1110,15 @@ func TestCreateTable(t *testing.T) {
"create_table_t1_(i_int_primary_key,_b1_blob,_b2_blob,_unique_index(b1(123),_b2(456)))",
"create_table_t1_(i_int_primary_key,_b1_blob,_b2_blob,_index(b1(10)),_index(b2(20)),_index(b1(123),_b2(456)))",
"create_table_t1_(i_int_primary_key,_b1_blob,_b2_blob,_index(b1(10)),_index(b2(20)),_index(b1(123),_b2(456)))",
"CREATE_TABLE_t1_as_select_*_from_mytable",
"CREATE_TABLE_t1_as_select_*_from_mytable",
"CREATE_TABLE_t1_as_select_*_from_mytable#01",
"CREATE_TABLE_t1_as_select_*_from_mytable",
"CREATE_TABLE_t1_as_select_s,_i_from_mytable",
"CREATE_TABLE_t1_as_select_s,_i_from_mytable",
"CREATE_TABLE_t1_as_select_distinct_s,_i_from_mytable",
"CREATE_TABLE_t1_as_select_distinct_s,_i_from_mytable",
"CREATE_TABLE_t1_as_select_s,_i_from_mytable_order_by_s",
"CREATE_TABLE_t1_as_select_s,_i_from_mytable_order_by_s",
// SUM(VARCHAR) is not supported by DuckDB
"CREATE_TABLE_t1_as_select_s,_sum(i)_from_mytable_group_by_s",
"CREATE_TABLE_t1_as_select_s,_sum(i)_from_mytable_group_by_s",
"CREATE_TABLE_t1_as_select_s,_sum(i)_from_mytable_group_by_s_having_sum(i)_>_2",
"CREATE_TABLE_t1_as_select_s,_sum(i)_from_mytable_group_by_s_having_sum(i)_>_2",
"CREATE_TABLE_t1_as_select_s,_i_from_mytable_order_by_s_limit_1",
"CREATE_TABLE_t1_as_select_s,_i_from_mytable_order_by_s_limit_1",
"CREATE_TABLE_t1_as_select_concat(\"new\",_s),_i_from_mytable",
"CREATE_TABLE_t1_as_select_concat(\"new\",_s),_i_from_mytable",
"display_width_for_numeric_types",
"SHOW_FULL_FIELDS_FROM_numericDisplayWidthTest;",
"datetime_precision",
"CREATE_TABLE_tt_(pk_int_primary_key,_d_datetime(6)_default_current_timestamp(6))",
"Identifier_lengths",
"table_charset_options",
"show_create_table_t3",
"show_create_table_t4",
"create_table_with_select_preserves_default",
"create_table_t1_select_*_from_a;",
"create_table_t2_select_j_from_a;",
"create_table_t3_select_j_as_i_from_a;",
"create_table_t4_select_j_+_1_from_a;",
"create_table_t5_select_a.j_from_a;",
"create_table_t6_select_sqa.j_from_(select_i,_j_from_a)_sqa;",
"show_create_table_t7;",
"create_table_t8_select_*_from_(select_*_from_a)_a_join_(select_*_from_b)_b;",
"show_create_table_t9;",
"create_table_t11_select_sum(j)_over()_as_jj_from_a;",
"create_table_t12_select_j_from_a_group_by_j;",
"create_table_t13_select_*_from_c;",
"event_contains_CREATE_TABLE_AS",
"CREATE_EVENT_foo_ON_SCHEDULE_EVERY_1_YEAR_DO_CREATE_TABLE_bar_AS_SELECT_1;",
"trigger_contains_CREATE_TABLE_AS",
Expand Down

0 comments on commit 4145be1

Please sign in to comment.