Skip to content

Commit

Permalink
feat: support check constraints (#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
fanyang01 authored Dec 27, 2024
1 parent 0283fe8 commit 062442f
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 40 deletions.
11 changes: 3 additions & 8 deletions .github/workflows/mysql-copy-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,14 @@ jobs:
with:
python-version: '3.13'

- name: Install system packages
uses: awalsh128/cache-apt-pkgs-action@latest
with:
packages: libnsl2 # required by MySQL Shell
version: 1.1

- name: Install dependencies
run: |
go get .
pip3 install "sqlglot[rs]"
curl -LJO https://dev.mysql.com/get/Downloads/MySQL-Shell/mysql-shell_9.1.0-1debian12_amd64.deb
sudo dpkg -i ./mysql-shell_9.1.0-1debian12_amd64.deb
sudo apt-get install -y ./mysql-shell_9.1.0-1debian12_amd64.deb
- name: Setup test data in source MySQL
run: |
Expand All @@ -67,10 +61,11 @@ jobs:
-- A table with non-default starting auto_increment value
CREATE TABLE items (
id INT AUTO_INCREMENT PRIMARY KEY,
v BIGINT check (v > 0),
name VARCHAR(100)
) AUTO_INCREMENT=1000;
INSERT INTO items (name) VALUES ('item1'), ('item2'), ('item3');
INSERT INTO items (v, name) VALUES (1, 'item1'), (2, 'item2'), (3, 'item3');
"
- name: Build and start MyDuck Server
Expand Down
27 changes: 20 additions & 7 deletions backend/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,12 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
case *plan.InsertInto:
insert := n.(*plan.InsertInto)

// For AUTO_INCREMENT column, we fallback to the framework if the column is specified.
if dst, err := plan.GetInsertable(insert.Destination); err == nil && dst.Schema().HasAutoIncrement() {
if len(insert.ColumnNames) == 0 || len(insert.ColumnNames) == len(dst.Schema()) {
return b.base.Build(ctx, root, r)
}
}

// The handling of auto_increment reset and check constraints is not supported by DuckDB.
// We need to fallback to the framework for these cases.
// But we want to rewrite LOAD DATA to be handled by DuckDB,
// as it is a common way to import data into the database.
// Therefore, we ignoring auto_increment and check constraints for LOAD DATA.
// So rewriting LOAD DATA is done eagerly here.
src := insert.Source
if proj, ok := src.(*plan.Project); ok {
src = proj.Child
Expand All @@ -97,6 +96,20 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
}
return b.base.Build(ctx, root, r)
}

if dst, err := plan.GetInsertable(insert.Destination); err == nil {
// For AUTO_INCREMENT column, we fallback to the framework if the column is specified.
// if dst.Schema().HasAutoIncrement() && (0 == len(insert.ColumnNames) || len(insert.ColumnNames) == len(dst.Schema())) {
if dst.Schema().HasAutoIncrement() {
return b.base.Build(ctx, root, r)
}
// For table with check constraints, we fallback to the framework.
if ct, ok := dst.(sql.CheckTable); ok {
if checks, err := ct.GetChecks(ctx); err == nil && len(checks) > 0 {
return b.base.Build(ctx, root, r)
}
}
}
}

// Fallback to the base builder if the plan contains system/user variables or is not a pure data query.
Expand Down
2 changes: 1 addition & 1 deletion catalog/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func (d *Database) createAllTable(ctx *sql.Context, name string, schema sql.Prim
b.WriteString(")")

// Add comment to the table
info := ExtraTableInfo{schema.PkOrdinals, withoutIndex, fullSequenceName}
info := ExtraTableInfo{schema.PkOrdinals, withoutIndex, fullSequenceName, nil}
b.WriteString(fmt.Sprintf(
"; COMMENT ON TABLE %s IS '%s'",
fullTableName,
Expand Down
99 changes: 87 additions & 12 deletions catalog/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type ExtraTableInfo struct {
PkOrdinals []int
Replicated bool
Sequence string
Checks []sql.CheckDefinition
}

type ColumnInfo struct {
Expand All @@ -37,6 +38,7 @@ type ColumnInfo struct {
ColumnDefault stdsql.NullString
Comment stdsql.NullString
}

type IndexedTable struct {
*Table
Lookup sql.IndexLookup
Expand All @@ -54,6 +56,8 @@ var _ sql.TruncateableTable = (*Table)(nil)
var _ sql.ReplaceableTable = (*Table)(nil)
var _ sql.CommentedTable = (*Table)(nil)
var _ sql.AutoIncrementTable = (*Table)(nil)
var _ sql.CheckTable = (*Table)(nil)
var _ sql.CheckAlterableTable = (*Table)(nil)

func NewTable(name string, db *Database) *Table {
return &Table{
Expand Down Expand Up @@ -707,6 +711,9 @@ func (t *Table) PreciseMatch() bool {

// Comment implements sql.CommentedTable.
func (t *Table) Comment() string {
t.mu.RLock()
defer t.mu.RUnlock()

return t.comment.Text
}

Expand Down Expand Up @@ -761,10 +768,16 @@ func (t *IndexedTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup

// PeekNextAutoIncrementValue implements sql.AutoIncrementTable.
func (t *Table) PeekNextAutoIncrementValue(ctx *sql.Context) (uint64, error) {
t.mu.RLock()
defer t.mu.RUnlock()

if t.comment.Meta.Sequence == "" {
return 0, sql.ErrNoAutoIncrementCol
}
return t.getNextAutoIncrementValue(ctx)
}

func (t *Table) getNextAutoIncrementValue(ctx *sql.Context) (uint64, error) {
// For PeekNextAutoIncrementValue, we want to see what the next value would be
// without actually incrementing. We can do this by getting currval + 1.
var val uint64
Expand All @@ -788,12 +801,20 @@ func (t *Table) PeekNextAutoIncrementValue(ctx *sql.Context) (uint64, error) {
}

// GetNextAutoIncrementValue implements sql.AutoIncrementTable.
func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) {
func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal any) (uint64, error) {
t.mu.Lock()
defer t.mu.Unlock()

if t.comment.Meta.Sequence == "" {
return 0, sql.ErrNoAutoIncrementCol
}

// If insertVal is provided and greater than current sequence value, update sequence
nextVal, err := t.getNextAutoIncrementValue(ctx)
if err != nil {
return 0, err
}

// If insertVal is provided and greater than the next sequence value, update sequence
if insertVal != nil {
var start uint64
switch v := insertVal.(type) {
Expand All @@ -804,7 +825,7 @@ func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{
start = uint64(v)
}
}
if start > 0 {
if start > 0 && start > nextVal {
err := t.setAutoIncrementValue(ctx, start)
if err != nil {
return 0, err
Expand All @@ -815,7 +836,7 @@ func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{

// Get next value from sequence
var val uint64
err := adapter.QueryRowCatalog(ctx, `SELECT nextval('`+t.comment.Meta.Sequence+`')`).Scan(&val)
err = adapter.QueryRowCatalog(ctx, `SELECT nextval('`+t.comment.Meta.Sequence+`')`).Scan(&val)
if err != nil {
return 0, ErrDuckDB.New(err)
}
Expand Down Expand Up @@ -885,14 +906,12 @@ func (t *Table) setAutoIncrementValue(ctx *sql.Context, value uint64) error {
// }

// Update the table comment with the new sequence name
tableInfo := t.comment.Meta
tableInfo.Sequence = fullSequenceName
comment := NewCommentWithMeta(t.comment.Text, tableInfo)
if _, err = adapter.Exec(ctx, `COMMENT ON TABLE `+FullTableName(t.db.catalog, t.db.name, t.name)+` IS '`+comment.Encode()+`'`); err != nil {
return ErrDuckDB.New(err)
if err = t.updateExtraTableInfo(ctx, func(info *ExtraTableInfo) {
info.Sequence = fullSequenceName
}); err != nil {
return err
}

t.comment.Meta.Sequence = fullSequenceName
return t.withSchema(ctx)
}

Expand All @@ -910,6 +929,62 @@ func (s *autoIncrementSetter) Close(ctx *sql.Context) error {
}

func (s *autoIncrementSetter) AcquireAutoIncrementLock(ctx *sql.Context) (func(), error) {
// DuckDB handles sequence synchronization internally
return func() {}, nil
s.t.mu.Lock()
return s.t.mu.Unlock, nil
}

func (t *Table) updateExtraTableInfo(ctx *sql.Context, updater func(*ExtraTableInfo)) error {
tableInfo := t.comment.Meta
updater(&tableInfo)
comment := NewCommentWithMeta(t.comment.Text, tableInfo)
_, err := adapter.Exec(ctx, `COMMENT ON TABLE `+FullTableName(t.db.catalog, t.db.name, t.name)+` IS '`+comment.Encode()+`'`)
if err != nil {
return ErrDuckDB.New(err)
}
t.comment.Meta = tableInfo // Update the in-memory metadata
return nil
}

// CheckConstraints implements sql.CheckTable.
func (t *Table) GetChecks(ctx *sql.Context) ([]sql.CheckDefinition, error) {
t.mu.RLock()
defer t.mu.RUnlock()

return t.comment.Meta.Checks, nil
}

// AddCheck implements sql.CheckAlterableTable.
func (t *Table) CreateCheck(ctx *sql.Context, check *sql.CheckDefinition) error {
t.mu.Lock()
defer t.mu.Unlock()

// TODO(fan): Implement this once DuckDB supports modifying check constraints.
// https://duckdb.org/docs/sql/statements/alter_table.html#add--drop-constraint
// https://github.com/duckdb/duckdb/issues/57
// Just record the check constraint for now.
return t.updateExtraTableInfo(ctx, func(info *ExtraTableInfo) {
info.Checks = append(info.Checks, *check)
})
}

// DropCheck implements sql.CheckAlterableTable.
func (t *Table) DropCheck(ctx *sql.Context, checkName string) error {
t.mu.Lock()
defer t.mu.Unlock()

checks := make([]sql.CheckDefinition, 0, max(len(t.comment.Meta.Checks)-1, 0))
found := false
for i, check := range t.comment.Meta.Checks {
if check.Name == checkName {
found = true
continue
}
checks = append(checks, t.comment.Meta.Checks[i])
}
if !found {
return sql.ErrUnknownConstraint.New(checkName)
}
return t.updateExtraTableInfo(ctx, func(info *ExtraTableInfo) {
info.Checks = checks
})
}
12 changes: 0 additions & 12 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1130,15 +1130,10 @@ func TestCreateTable(t *testing.T) {
"CREATE_TABLE_t1_as_select_concat(\"new\",_s),_i_from_mytable",
"display_width_for_numeric_types",
"SHOW_FULL_FIELDS_FROM_numericDisplayWidthTest;",
"Validate_that_CREATE_LIKE_preserves_checks",
"datetime_precision",
"CREATE_TABLE_tt_(pk_int_primary_key,_d_datetime(6)_default_current_timestamp(6))",
"Identifier_lengths",
"create_table_b_(a_int_primary_key,_constraint_abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijkl_check_(a_>_0))",
"create_table_d_(a_int_primary_key,_constraint_abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijkl_foreign_key_(a)_references_parent(a))",
"table_charset_options",
"show_create_table_t1",
"show_create_table_t2",
"show_create_table_t3",
"show_create_table_t4",
"create_table_with_select_preserves_default",
Expand All @@ -1158,17 +1153,10 @@ func TestCreateTable(t *testing.T) {
"CREATE_EVENT_foo_ON_SCHEDULE_EVERY_1_YEAR_DO_CREATE_TABLE_bar_AS_SELECT_1;",
"trigger_contains_CREATE_TABLE_AS",
"CREATE_TRIGGER_foo_AFTER_UPDATE_ON_t_FOR_EACH_ROW_BEGIN_CREATE_TABLE_bar_AS_SELECT_1;_END;",
"insert_into_t1_(b)_values_(1),_(2)",
"show_create_table_t1",
"select_*_from_t1_order_by_b",
"insert_into_t1_(b)_values_(1),_(2)",
"show_create_table_t1",
"select_*_from_t1_order_by_b",
}

// Patch auto-generated queries that are known to fail
waitForFixQueries = append(waitForFixQueries, []string{
"CREATE TABLE t1 (pk int primary key, test_score int, height int CHECK (height < 10) , CONSTRAINT mycheck CHECK (test_score >= 50))",
"create table a (i int primary key, j int default 100);", // skip the case "create table with select preserves default" since there is no support for CREATE TABLE SELECT
}...)

Expand Down

0 comments on commit 062442f

Please sign in to comment.