From 4145be11df601e0f4b1cb508b18f017627431c3b Mon Sep 17 00:00:00 2001 From: Fan Yang Date: Wed, 8 Jan 2025 18:42:01 +0800 Subject: [PATCH] feat: support MySQL CTAS (#351) * Enable more CTAS tests * add type mapping for hugeint and varint * Add fallback for simple CTAS --- backend/executor.go | 53 ++++++++++++++++++++++--------- catalog/database.go | 22 +++++++++++++ catalog/table.go | 5 ++- catalog/type_mapping.go | 69 +++++++++++++++++++++++------------------ main_test.go | 32 +------------------ 5 files changed, 104 insertions(+), 77 deletions(-) diff --git a/backend/executor.go b/backend/executor.go index 6f7ec30..f7d103b 100644 --- a/backend/executor.go +++ b/backend/executor.go @@ -61,10 +61,13 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row } n := root - ctx.GetLogger().WithFields(logrus.Fields{ - "Query": ctx.Query(), - "NodeType": fmt.Sprintf("%T", n), - }).Traceln("Building node:", n) + + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + log.WithFields(logrus.Fields{ + "Query": ctx.Query(), + "NodeType": fmt.Sprintf("%T", n), + }).Traceln("Building node:", n) + } // TODO; find a better way to fallback to the base builder switch n.(type) { @@ -114,7 +117,13 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row } // Fallback to the base builder if the plan contains system/user variables or is not a pure data query. - if containsVariable(n) || !IsPureDataQuery(n) { + tree := n + switch n := n.(type) { + case *plan.TableCopier: + tree = n.Source + } + if containsVariable(tree) || !IsPureDataQuery(tree) { + ctx.GetLogger().Traceln("Falling back to the base builder") return b.base.Build(ctx, root, r) } @@ -133,11 +142,21 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row return nil, err } return b.base.Build(ctx, root, r) - // SubqueryAlias is for select * from view + // ResolvedTable is for `SELECT * FROM table` and `TABLE table` + // SubqueryAlias is for `SELECT * FROM view` case *plan.ResolvedTable, *plan.SubqueryAlias, *plan.TableAlias: return b.executeQuery(ctx, node, conn) case *plan.Distinct, *plan.OrderedDistinct: return b.executeQuery(ctx, node, conn) + case *plan.TableCopier: + // We preserve the table schema in a best-effort manner. + // For simple `CREATE TABLE t AS SELECT * FROM t`, + // we fall back to the framework to create the table and copy the data. + // For more complex cases, we directly execute the CTAS statement in DuckDB. + if _, ok := node.Source.(*plan.ResolvedTable); ok { + return b.base.Build(ctx, root, r) + } + return b.executeDML(ctx, node, conn) case sql.Expressioner: return b.executeExpressioner(ctx, node, conn) case *plan.DeleteFrom: @@ -174,7 +193,7 @@ func (b *DuckBuilder) executeQuery(ctx *sql.Context, n sql.Node, conn *stdsql.Co case *plan.ShowTables: duckSQL = ctx.Query() case *plan.ResolvedTable: - // SQLGlot cannot translate MySQL's `TABLE t` into DuckDB's `FROM t` - it produces `"table" AS t` instead. + // SQLGlot cannot translate MySQL's `TABLE t` into DuckDB's `FROM t` - it produces `"table" AS t` instead. duckSQL = `FROM ` + catalog.ConnectIdentifiersANSI(n.Database().Name(), n.Name()) default: duckSQL, err = transpiler.TranslateWithSQLGlot(ctx.Query()) @@ -183,10 +202,12 @@ func (b *DuckBuilder) executeQuery(ctx *sql.Context, n sql.Node, conn *stdsql.Co return nil, catalog.ErrTranspiler.New(err) } - ctx.GetLogger().WithFields(logrus.Fields{ - "Query": ctx.Query(), - "DuckSQL": duckSQL, - }).Trace("Executing Query...") + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + log.WithFields(logrus.Fields{ + "Query": ctx.Query(), + "DuckSQL": duckSQL, + }).Trace("Executing Query...") + } // Execute the DuckDB query rows, err := conn.QueryContext(ctx.Context, duckSQL) @@ -204,10 +225,12 @@ func (b *DuckBuilder) executeDML(ctx *sql.Context, n sql.Node, conn *stdsql.Conn return nil, catalog.ErrTranspiler.New(err) } - ctx.GetLogger().WithFields(logrus.Fields{ - "Query": ctx.Query(), - "DuckSQL": duckSQL, - }).Trace("Executing DML...") + if log := ctx.GetLogger(); log.Logger.IsLevelEnabled(logrus.TraceLevel) { + log.WithFields(logrus.Fields{ + "Query": ctx.Query(), + "DuckSQL": duckSQL, + }).Trace("Executing DML...") + } // Execute the DuckDB query result, err := conn.ExecContext(ctx.Context, duckSQL) diff --git a/catalog/database.go b/catalog/database.go index 76d214f..63b2cd0 100644 --- a/catalog/database.go +++ b/catalog/database.go @@ -446,3 +446,25 @@ func (d *Database) GetCollation(ctx *sql.Context) sql.CollationID { func (d *Database) SetCollation(ctx *sql.Context, collation sql.CollationID) error { return nil } + +// CopyTableData implements sql.TableCopierDatabase interface. +func (d *Database) CopyTableData(ctx *sql.Context, sourceTable string, destinationTable string) (uint64, error) { + d.mu.Lock() + defer d.mu.Unlock() + + // Use INSERT INTO ... SELECT to copy data + sql := `INSERT INTO ` + FullTableName(d.catalog, d.name, destinationTable) + ` FROM ` + FullTableName(d.catalog, d.name, sourceTable) + + res, err := adapter.Exec(ctx, sql) + if err != nil { + return 0, ErrDuckDB.New(err) + } + + // Get count of affected rows + count, err := res.RowsAffected() + if err != nil { + return 0, ErrDuckDB.New(err) + } + + return uint64(count), nil +} diff --git a/catalog/table.go b/catalog/table.go index fa5d6e2..0e25829 100644 --- a/catalog/table.go +++ b/catalog/table.go @@ -754,7 +754,10 @@ func queryColumns(ctx *sql.Context, catalogName, schemaName, tableName string) ( } decodedComment := DecodeComment[MySQLType](comment.String) - dataType := mysqlDataType(AnnotatedDuckType{dataTypes, decodedComment.Meta}, uint8(numericPrecision.Int32), uint8(numericScale.Int32)) + dataType, err := mysqlDataType(AnnotatedDuckType{dataTypes, decodedComment.Meta}, uint8(numericPrecision.Int32), uint8(numericScale.Int32)) + if err != nil { + return nil, err + } columnInfo := &ColumnInfo{ ColumnName: columnName, diff --git a/catalog/type_mapping.go b/catalog/type_mapping.go index 3e199bb..c16681a 100644 --- a/catalog/type_mapping.go +++ b/catalog/type_mapping.go @@ -198,7 +198,7 @@ func DuckdbDataType(mysqlType sql.Type) (AnnotatedDuckType, error) { } } -func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericScale uint8) sql.Type { +func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericScale uint8) (sql.Type, error) { // TODO: The current type mappings are not lossless. We need to store the original type in the column comments. duckName := strings.TrimSpace(strings.ToUpper(duckType.name)) @@ -219,7 +219,7 @@ func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericSc intBaseType = sqltypes.Uint8 case "SMALLINT": if mysqlName == "YEAR" { - return types.Year + return types.Year, nil } intBaseType = sqltypes.Int16 case "USMALLINT": @@ -240,13 +240,13 @@ func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericSc intBaseType = sqltypes.Int64 case "UBIGINT": if mysqlName == "BIT" { - return types.MustCreateBitType(duckType.mysql.Precision) + return types.CreateBitType(duckType.mysql.Precision) } intBaseType = sqltypes.Uint64 } if intBaseType != sqltypes.Null { - return types.MustCreateNumberTypeWithDisplayWidth(intBaseType, int(duckType.mysql.Display)) + return types.CreateNumberTypeWithDisplayWidth(intBaseType, int(duckType.mysql.Display)) } length := int64(duckType.mysql.Length) @@ -255,70 +255,79 @@ func mysqlDataType(duckType AnnotatedDuckType, numericPrecision uint8, numericSc switch duckName { case "FLOAT": - return types.Float32 + return types.Float32, nil case "DOUBLE": - return types.Float64 + return types.Float64, nil case "TIMESTAMP", "TIMESTAMP_S", "TIMESTAMP_MS": if mysqlName == "DATETIME" { - return types.MustCreateDatetimeType(sqltypes.Datetime, precision) + return types.CreateDatetimeType(sqltypes.Datetime, precision) } - return types.MustCreateDatetimeType(sqltypes.Timestamp, precision) + return types.CreateDatetimeType(sqltypes.Timestamp, precision) case "DATE": - return types.Date + return types.Date, nil case "INTERVAL", "TIME": - return types.Time + return types.Time, nil case "DECIMAL": - return types.MustCreateDecimalType(numericPrecision, numericScale) + return types.CreateDecimalType(numericPrecision, numericScale) + + case "UHUGEINT", "HUGEINT": + // MySQL does not have these types. We store them as DECIMAL. + return types.CreateDecimalType(39, 0) + + case "VARINT": + // MySQL does not have this type. We store it as DECIMAL. + // Here we use the maximum supported precision for DECIMAL in MySQL. + return types.CreateDecimalType(65, 0) case "VARCHAR": if mysqlName == "TEXT" { if length <= types.TinyTextBlobMax { - return types.TinyText + return types.TinyText, nil } else if length <= types.TextBlobMax { - return types.Text + return types.Text, nil } else if length <= types.MediumTextBlobMax { - return types.MediumText + return types.MediumText, nil } else { - return types.LongText + return types.LongText, nil } } else if mysqlName == "VARCHAR" { - return types.MustCreateString(sqltypes.VarChar, length, collation) + return types.CreateString(sqltypes.VarChar, length, collation) } else if mysqlName == "CHAR" { - return types.MustCreateString(sqltypes.Char, length, collation) + return types.CreateString(sqltypes.Char, length, collation) } else if mysqlName == "SET" { - return types.MustCreateSetType(duckType.mysql.Values, collation) + return types.CreateSetType(duckType.mysql.Values, collation) } - return types.Text + return types.Text, nil case "BLOB": if mysqlName == "BLOB" { if length <= types.TinyTextBlobMax { - return types.TinyBlob + return types.TinyBlob, nil } else if length <= types.TextBlobMax { - return types.Blob + return types.Blob, nil } else if length <= types.MediumTextBlobMax { - return types.MediumBlob + return types.MediumBlob, nil } else { - return types.LongBlob + return types.LongBlob, nil } } else if mysqlName == "VARBINARY" { - return types.MustCreateBinary(sqltypes.VarBinary, length) + return types.CreateBinary(sqltypes.VarBinary, length) } else if mysqlName == "BINARY" { - return types.MustCreateBinary(sqltypes.Binary, length) + return types.CreateBinary(sqltypes.Binary, length) } - return types.Blob + return types.Blob, nil case "JSON": - return types.JSON + return types.JSON, nil case "ENUM": - return types.MustCreateEnumType(duckType.mysql.Values, collation) + return types.CreateEnumType(duckType.mysql.Values, collation) case "SET": - return types.MustCreateSetType(duckType.mysql.Values, collation) + return types.CreateSetType(duckType.mysql.Values, collation) default: - panic(fmt.Sprintf("encountered unknown DuckDB type(%v). This is likely a bug - please check the duckdbDataType function for missing type mappings", duckType)) + return nil, fmt.Errorf("encountered unknown DuckDB type(%v)", duckType) } } diff --git a/main_test.go b/main_test.go index 178ebe3..202f969 100644 --- a/main_test.go +++ b/main_test.go @@ -1110,45 +1110,15 @@ func TestCreateTable(t *testing.T) { "create_table_t1_(i_int_primary_key,_b1_blob,_b2_blob,_unique_index(b1(123),_b2(456)))", "create_table_t1_(i_int_primary_key,_b1_blob,_b2_blob,_index(b1(10)),_index(b2(20)),_index(b1(123),_b2(456)))", "create_table_t1_(i_int_primary_key,_b1_blob,_b2_blob,_index(b1(10)),_index(b2(20)),_index(b1(123),_b2(456)))", - "CREATE_TABLE_t1_as_select_*_from_mytable", - "CREATE_TABLE_t1_as_select_*_from_mytable", - "CREATE_TABLE_t1_as_select_*_from_mytable#01", - "CREATE_TABLE_t1_as_select_*_from_mytable", - "CREATE_TABLE_t1_as_select_s,_i_from_mytable", - "CREATE_TABLE_t1_as_select_s,_i_from_mytable", - "CREATE_TABLE_t1_as_select_distinct_s,_i_from_mytable", - "CREATE_TABLE_t1_as_select_distinct_s,_i_from_mytable", - "CREATE_TABLE_t1_as_select_s,_i_from_mytable_order_by_s", - "CREATE_TABLE_t1_as_select_s,_i_from_mytable_order_by_s", + // SUM(VARCHAR) is not supported by DuckDB "CREATE_TABLE_t1_as_select_s,_sum(i)_from_mytable_group_by_s", - "CREATE_TABLE_t1_as_select_s,_sum(i)_from_mytable_group_by_s", - "CREATE_TABLE_t1_as_select_s,_sum(i)_from_mytable_group_by_s_having_sum(i)_>_2", "CREATE_TABLE_t1_as_select_s,_sum(i)_from_mytable_group_by_s_having_sum(i)_>_2", - "CREATE_TABLE_t1_as_select_s,_i_from_mytable_order_by_s_limit_1", - "CREATE_TABLE_t1_as_select_s,_i_from_mytable_order_by_s_limit_1", - "CREATE_TABLE_t1_as_select_concat(\"new\",_s),_i_from_mytable", - "CREATE_TABLE_t1_as_select_concat(\"new\",_s),_i_from_mytable", "display_width_for_numeric_types", "SHOW_FULL_FIELDS_FROM_numericDisplayWidthTest;", "datetime_precision", "CREATE_TABLE_tt_(pk_int_primary_key,_d_datetime(6)_default_current_timestamp(6))", "Identifier_lengths", "table_charset_options", - "show_create_table_t3", - "show_create_table_t4", - "create_table_with_select_preserves_default", - "create_table_t1_select_*_from_a;", - "create_table_t2_select_j_from_a;", - "create_table_t3_select_j_as_i_from_a;", - "create_table_t4_select_j_+_1_from_a;", - "create_table_t5_select_a.j_from_a;", - "create_table_t6_select_sqa.j_from_(select_i,_j_from_a)_sqa;", - "show_create_table_t7;", - "create_table_t8_select_*_from_(select_*_from_a)_a_join_(select_*_from_b)_b;", - "show_create_table_t9;", - "create_table_t11_select_sum(j)_over()_as_jj_from_a;", - "create_table_t12_select_j_from_a_group_by_j;", - "create_table_t13_select_*_from_c;", "event_contains_CREATE_TABLE_AS", "CREATE_EVENT_foo_ON_SCHEDULE_EVERY_1_YEAR_DO_CREATE_TABLE_bar_AS_SELECT_1;", "trigger_contains_CREATE_TABLE_AS",