diff --git a/.github/workflows/mysql-copy-tests.yml b/.github/workflows/mysql-copy-tests.yml index 8699fee..233292e 100644 --- a/.github/workflows/mysql-copy-tests.yml +++ b/.github/workflows/mysql-copy-tests.yml @@ -35,12 +35,6 @@ jobs: with: python-version: '3.13' - - name: Install system packages - uses: awalsh128/cache-apt-pkgs-action@latest - with: - packages: libnsl2 # required by MySQL Shell - version: 1.1 - - name: Install dependencies run: | go get . @@ -48,7 +42,7 @@ jobs: pip3 install "sqlglot[rs]" curl -LJO https://dev.mysql.com/get/Downloads/MySQL-Shell/mysql-shell_9.1.0-1debian12_amd64.deb - sudo dpkg -i ./mysql-shell_9.1.0-1debian12_amd64.deb + sudo apt-get install -y ./mysql-shell_9.1.0-1debian12_amd64.deb - name: Setup test data in source MySQL run: | @@ -67,10 +61,11 @@ jobs: -- A table with non-default starting auto_increment value CREATE TABLE items ( id INT AUTO_INCREMENT PRIMARY KEY, + v BIGINT check (v > 0), name VARCHAR(100) ) AUTO_INCREMENT=1000; - INSERT INTO items (name) VALUES ('item1'), ('item2'), ('item3'); + INSERT INTO items (v, name) VALUES (1, 'item1'), (2, 'item2'), (3, 'item3'); " - name: Build and start MyDuck Server diff --git a/backend/executor.go b/backend/executor.go index 44cd74b..29c22bd 100644 --- a/backend/executor.go +++ b/backend/executor.go @@ -80,13 +80,12 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row case *plan.InsertInto: insert := n.(*plan.InsertInto) - // For AUTO_INCREMENT column, we fallback to the framework if the column is specified. - if dst, err := plan.GetInsertable(insert.Destination); err == nil && dst.Schema().HasAutoIncrement() { - if len(insert.ColumnNames) == 0 || len(insert.ColumnNames) == len(dst.Schema()) { - return b.base.Build(ctx, root, r) - } - } - + // The handling of auto_increment reset and check constraints is not supported by DuckDB. + // We need to fallback to the framework for these cases. + // But we want to rewrite LOAD DATA to be handled by DuckDB, + // as it is a common way to import data into the database. + // Therefore, we ignoring auto_increment and check constraints for LOAD DATA. + // So rewriting LOAD DATA is done eagerly here. src := insert.Source if proj, ok := src.(*plan.Project); ok { src = proj.Child @@ -97,6 +96,20 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row } return b.base.Build(ctx, root, r) } + + if dst, err := plan.GetInsertable(insert.Destination); err == nil { + // For AUTO_INCREMENT column, we fallback to the framework if the column is specified. + // if dst.Schema().HasAutoIncrement() && (0 == len(insert.ColumnNames) || len(insert.ColumnNames) == len(dst.Schema())) { + if dst.Schema().HasAutoIncrement() { + return b.base.Build(ctx, root, r) + } + // For table with check constraints, we fallback to the framework. + if ct, ok := dst.(sql.CheckTable); ok { + if checks, err := ct.GetChecks(ctx); err == nil && len(checks) > 0 { + return b.base.Build(ctx, root, r) + } + } + } } // Fallback to the base builder if the plan contains system/user variables or is not a pure data query. diff --git a/catalog/database.go b/catalog/database.go index bf95a24..13ee654 100644 --- a/catalog/database.go +++ b/catalog/database.go @@ -229,7 +229,7 @@ func (d *Database) createAllTable(ctx *sql.Context, name string, schema sql.Prim b.WriteString(")") // Add comment to the table - info := ExtraTableInfo{schema.PkOrdinals, withoutIndex, fullSequenceName} + info := ExtraTableInfo{schema.PkOrdinals, withoutIndex, fullSequenceName, nil} b.WriteString(fmt.Sprintf( "; COMMENT ON TABLE %s IS '%s'", fullTableName, diff --git a/catalog/table.go b/catalog/table.go index fd46e57..f92c8ca 100644 --- a/catalog/table.go +++ b/catalog/table.go @@ -27,6 +27,7 @@ type ExtraTableInfo struct { PkOrdinals []int Replicated bool Sequence string + Checks []sql.CheckDefinition } type ColumnInfo struct { @@ -37,6 +38,7 @@ type ColumnInfo struct { ColumnDefault stdsql.NullString Comment stdsql.NullString } + type IndexedTable struct { *Table Lookup sql.IndexLookup @@ -54,6 +56,8 @@ var _ sql.TruncateableTable = (*Table)(nil) var _ sql.ReplaceableTable = (*Table)(nil) var _ sql.CommentedTable = (*Table)(nil) var _ sql.AutoIncrementTable = (*Table)(nil) +var _ sql.CheckTable = (*Table)(nil) +var _ sql.CheckAlterableTable = (*Table)(nil) func NewTable(name string, db *Database) *Table { return &Table{ @@ -707,6 +711,9 @@ func (t *Table) PreciseMatch() bool { // Comment implements sql.CommentedTable. func (t *Table) Comment() string { + t.mu.RLock() + defer t.mu.RUnlock() + return t.comment.Text } @@ -761,10 +768,16 @@ func (t *IndexedTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup // PeekNextAutoIncrementValue implements sql.AutoIncrementTable. func (t *Table) PeekNextAutoIncrementValue(ctx *sql.Context) (uint64, error) { + t.mu.RLock() + defer t.mu.RUnlock() + if t.comment.Meta.Sequence == "" { return 0, sql.ErrNoAutoIncrementCol } + return t.getNextAutoIncrementValue(ctx) +} +func (t *Table) getNextAutoIncrementValue(ctx *sql.Context) (uint64, error) { // For PeekNextAutoIncrementValue, we want to see what the next value would be // without actually incrementing. We can do this by getting currval + 1. var val uint64 @@ -788,12 +801,20 @@ func (t *Table) PeekNextAutoIncrementValue(ctx *sql.Context) (uint64, error) { } // GetNextAutoIncrementValue implements sql.AutoIncrementTable. -func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) { +func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal any) (uint64, error) { + t.mu.Lock() + defer t.mu.Unlock() + if t.comment.Meta.Sequence == "" { return 0, sql.ErrNoAutoIncrementCol } - // If insertVal is provided and greater than current sequence value, update sequence + nextVal, err := t.getNextAutoIncrementValue(ctx) + if err != nil { + return 0, err + } + + // If insertVal is provided and greater than the next sequence value, update sequence if insertVal != nil { var start uint64 switch v := insertVal.(type) { @@ -804,7 +825,7 @@ func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{ start = uint64(v) } } - if start > 0 { + if start > 0 && start > nextVal { err := t.setAutoIncrementValue(ctx, start) if err != nil { return 0, err @@ -815,7 +836,7 @@ func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{ // Get next value from sequence var val uint64 - err := adapter.QueryRowCatalog(ctx, `SELECT nextval('`+t.comment.Meta.Sequence+`')`).Scan(&val) + err = adapter.QueryRowCatalog(ctx, `SELECT nextval('`+t.comment.Meta.Sequence+`')`).Scan(&val) if err != nil { return 0, ErrDuckDB.New(err) } @@ -885,14 +906,12 @@ func (t *Table) setAutoIncrementValue(ctx *sql.Context, value uint64) error { // } // Update the table comment with the new sequence name - tableInfo := t.comment.Meta - tableInfo.Sequence = fullSequenceName - comment := NewCommentWithMeta(t.comment.Text, tableInfo) - if _, err = adapter.Exec(ctx, `COMMENT ON TABLE `+FullTableName(t.db.catalog, t.db.name, t.name)+` IS '`+comment.Encode()+`'`); err != nil { - return ErrDuckDB.New(err) + if err = t.updateExtraTableInfo(ctx, func(info *ExtraTableInfo) { + info.Sequence = fullSequenceName + }); err != nil { + return err } - t.comment.Meta.Sequence = fullSequenceName return t.withSchema(ctx) } @@ -910,6 +929,62 @@ func (s *autoIncrementSetter) Close(ctx *sql.Context) error { } func (s *autoIncrementSetter) AcquireAutoIncrementLock(ctx *sql.Context) (func(), error) { - // DuckDB handles sequence synchronization internally - return func() {}, nil + s.t.mu.Lock() + return s.t.mu.Unlock, nil +} + +func (t *Table) updateExtraTableInfo(ctx *sql.Context, updater func(*ExtraTableInfo)) error { + tableInfo := t.comment.Meta + updater(&tableInfo) + comment := NewCommentWithMeta(t.comment.Text, tableInfo) + _, err := adapter.Exec(ctx, `COMMENT ON TABLE `+FullTableName(t.db.catalog, t.db.name, t.name)+` IS '`+comment.Encode()+`'`) + if err != nil { + return ErrDuckDB.New(err) + } + t.comment.Meta = tableInfo // Update the in-memory metadata + return nil +} + +// CheckConstraints implements sql.CheckTable. +func (t *Table) GetChecks(ctx *sql.Context) ([]sql.CheckDefinition, error) { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.comment.Meta.Checks, nil +} + +// AddCheck implements sql.CheckAlterableTable. +func (t *Table) CreateCheck(ctx *sql.Context, check *sql.CheckDefinition) error { + t.mu.Lock() + defer t.mu.Unlock() + + // TODO(fan): Implement this once DuckDB supports modifying check constraints. + // https://duckdb.org/docs/sql/statements/alter_table.html#add--drop-constraint + // https://github.com/duckdb/duckdb/issues/57 + // Just record the check constraint for now. + return t.updateExtraTableInfo(ctx, func(info *ExtraTableInfo) { + info.Checks = append(info.Checks, *check) + }) +} + +// DropCheck implements sql.CheckAlterableTable. +func (t *Table) DropCheck(ctx *sql.Context, checkName string) error { + t.mu.Lock() + defer t.mu.Unlock() + + checks := make([]sql.CheckDefinition, 0, max(len(t.comment.Meta.Checks)-1, 0)) + found := false + for i, check := range t.comment.Meta.Checks { + if check.Name == checkName { + found = true + continue + } + checks = append(checks, t.comment.Meta.Checks[i]) + } + if !found { + return sql.ErrUnknownConstraint.New(checkName) + } + return t.updateExtraTableInfo(ctx, func(info *ExtraTableInfo) { + info.Checks = checks + }) } diff --git a/main_test.go b/main_test.go index b183f06..178ebe3 100644 --- a/main_test.go +++ b/main_test.go @@ -1130,15 +1130,10 @@ func TestCreateTable(t *testing.T) { "CREATE_TABLE_t1_as_select_concat(\"new\",_s),_i_from_mytable", "display_width_for_numeric_types", "SHOW_FULL_FIELDS_FROM_numericDisplayWidthTest;", - "Validate_that_CREATE_LIKE_preserves_checks", "datetime_precision", "CREATE_TABLE_tt_(pk_int_primary_key,_d_datetime(6)_default_current_timestamp(6))", "Identifier_lengths", - "create_table_b_(a_int_primary_key,_constraint_abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijkl_check_(a_>_0))", - "create_table_d_(a_int_primary_key,_constraint_abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijkl_foreign_key_(a)_references_parent(a))", "table_charset_options", - "show_create_table_t1", - "show_create_table_t2", "show_create_table_t3", "show_create_table_t4", "create_table_with_select_preserves_default", @@ -1158,17 +1153,10 @@ func TestCreateTable(t *testing.T) { "CREATE_EVENT_foo_ON_SCHEDULE_EVERY_1_YEAR_DO_CREATE_TABLE_bar_AS_SELECT_1;", "trigger_contains_CREATE_TABLE_AS", "CREATE_TRIGGER_foo_AFTER_UPDATE_ON_t_FOR_EACH_ROW_BEGIN_CREATE_TABLE_bar_AS_SELECT_1;_END;", - "insert_into_t1_(b)_values_(1),_(2)", - "show_create_table_t1", - "select_*_from_t1_order_by_b", - "insert_into_t1_(b)_values_(1),_(2)", - "show_create_table_t1", - "select_*_from_t1_order_by_b", } // Patch auto-generated queries that are known to fail waitForFixQueries = append(waitForFixQueries, []string{ - "CREATE TABLE t1 (pk int primary key, test_score int, height int CHECK (height < 10) , CONSTRAINT mycheck CHECK (test_score >= 50))", "create table a (i int primary key, j int default 100);", // skip the case "create table with select preserves default" since there is no support for CREATE TABLE SELECT }...)