diff --git a/catalog/database.go b/catalog/database.go index 6839e260..8e985885 100644 --- a/catalog/database.go +++ b/catalog/database.go @@ -78,7 +78,7 @@ func (d *Database) tablesInsensitive(ctx *sql.Context, pattern string) ([]*Table } func (d *Database) findTables(ctx *sql.Context, pattern string) ([]*Table, error) { - rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT table_name, comment FROM duckdb_tables() where database_name = ? and schema_name = ? and table_name ILIKE ?", d.catalog, d.name, pattern) + rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT table_name, comment FROM duckdb_tables() where (database_name = ? and schema_name = ? and table_name ILIKE ?) or (database_name = 'temp' and schema_name = 'main' and table_name ILIKE ?)", d.catalog, d.name, pattern, pattern) if err != nil { return nil, ErrDuckDB.New(err) } @@ -106,18 +106,110 @@ func (d *Database) Name() string { return d.name } +func (d *Database) CreateAllTable(ctx *sql.Context, name string, schema sql.PrimaryKeySchema, collation sql.CollationID, comment string, is_temp bool) error { + + var columns []string + var columnCommentSQLs []string + var fullTableName string + + if is_temp { + fullTableName = FullTableName("temp", "main", name) + } else { + fullTableName = FullTableName(d.catalog, d.name, name) + } + + for _, col := range schema.Schema { + typ, err := DuckdbDataType(col.Type) + if err != nil { + return err + } + colDef := fmt.Sprintf(`"%s" %s`, col.Name, typ.name) + if col.Nullable { + colDef += " NULL" + } else { + colDef += " NOT NULL" + } + + if col.Default != nil { + columnDefault, err := typ.mysql.withDefault(col.Default.String()) + if err != nil { + return err + } + colDef += " DEFAULT " + columnDefault + } + + columns = append(columns, colDef) + + var fullColumnName string + + if is_temp { + fullColumnName = FullColumnName("temp", "main", name, col.Name) + } else { + fullColumnName = FullColumnName(d.catalog, d.name, name, col.Name) + } + + if col.Comment != "" || typ.mysql.Name != "" || col.Default != nil { + columnCommentSQLs = append(columnCommentSQLs, + fmt.Sprintf(`COMMENT ON COLUMN %s IS '%s'`, fullColumnName, + NewCommentWithMeta[MySQLType](col.Comment, typ.mysql).Encode())) + } + } + + var sqlsBuild strings.Builder + + if is_temp { + sqlsBuild.WriteString(fmt.Sprintf(`CREATE TEMP TABLE %s (%s`, name, strings.Join(columns, ", "))) + } else { + sqlsBuild.WriteString(fmt.Sprintf(`CREATE TABLE %s (%s`, fullTableName, strings.Join(columns, ", "))) + } + + var primaryKeys []string + for _, pkord := range schema.PkOrdinals { + primaryKeys = append(primaryKeys, schema.Schema[pkord].Name) + } + + if len(primaryKeys) > 0 { + sqlsBuild.WriteString(fmt.Sprintf(", PRIMARY KEY (%s)", strings.Join(primaryKeys, ", "))) + } + + sqlsBuild.WriteString(")") + + // Add comment to the table + if comment != "" { + sqlsBuild.WriteString(fmt.Sprintf("; COMMENT ON TABLE %s IS '%s'", fullTableName, NewComment[any](comment).Encode())) + } + + // Add column comments + for _, s := range columnCommentSQLs { + sqlsBuild.WriteString(";") + sqlsBuild.WriteString(s) + } + + _, err := adapter.Exec(ctx, sqlsBuild.String()) + if err != nil { + if IsDuckDBTableAlreadyExistsError(err) { + return sql.ErrTableAlreadyExists.New(name) + } + return ErrDuckDB.New(err) + } + + // TODO: support collation + + return nil +} + // CreateTable implements sql.TableCreator. func (d *Database) CreateTable(ctx *sql.Context, name string, schema sql.PrimaryKeySchema, collation sql.CollationID, comment string) error { d.mu.Lock() defer d.mu.Unlock() - return d.CreateAllTable(ctx, name, schema, collation, comment) + return d.CreateAllTable(ctx, name, schema, collation, comment, false) } // CreateTemporaryTable implements sql.CreateTemporaryTable. func (d *Database) CreateTemporaryTable(ctx *sql.Context, name string, schema sql.PrimaryKeySchema, collation sql.CollationID) error { d.mu.Lock() defer d.mu.Unlock() - return d.CreateAllTable(ctx, name, schema, collation, "") + return d.CreateAllTable(ctx, name, schema, collation, "", true) } // DropTable implements sql.TableDropper. @@ -249,79 +341,6 @@ func (d *Database) DropView(ctx *sql.Context, name string) error { return nil } -// CreateTable implements sql.TableCreator. -func (d *Database) CreateAllTable(ctx *sql.Context, name string, schema sql.PrimaryKeySchema, collation sql.CollationID, comment string) error { - - var columns []string - var columnCommentSQLs []string - for _, col := range schema.Schema { - typ, err := DuckdbDataType(col.Type) - if err != nil { - return err - } - colDef := fmt.Sprintf(`"%s" %s`, col.Name, typ.name) - if col.Nullable { - colDef += " NULL" - } else { - colDef += " NOT NULL" - } - - if col.Default != nil { - columnDefault, err := typ.mysql.withDefault(col.Default.String()) - if err != nil { - return err - } - colDef += " DEFAULT " + columnDefault - } - - columns = append(columns, colDef) - - if col.Comment != "" || typ.mysql.Name != "" || col.Default != nil { - columnCommentSQLs = append(columnCommentSQLs, - fmt.Sprintf(`COMMENT ON COLUMN %s IS '%s'`, FullColumnName(d.catalog, d.name, name, col.Name), - NewCommentWithMeta[MySQLType](col.Comment, typ.mysql).Encode())) - } - } - - var sqlsBuild strings.Builder - - sqlsBuild.WriteString(fmt.Sprintf(`CREATE TABLE %s (%s`, FullTableName(d.catalog, d.name, name), strings.Join(columns, ", "))) - - var primaryKeys []string - for _, pkord := range schema.PkOrdinals { - primaryKeys = append(primaryKeys, schema.Schema[pkord].Name) - } - - if len(primaryKeys) > 0 { - sqlsBuild.WriteString(fmt.Sprintf(", PRIMARY KEY (%s)", strings.Join(primaryKeys, ", "))) - } - - sqlsBuild.WriteString(")") - - // Add comment to the table - if comment != "" { - sqlsBuild.WriteString(fmt.Sprintf("; COMMENT ON TABLE %s IS '%s'", FullTableName(d.catalog, d.name, name), NewComment[any](comment).Encode())) - } - - // Add column comments - for _, s := range columnCommentSQLs { - sqlsBuild.WriteString(";") - sqlsBuild.WriteString(s) - } - - _, err := adapter.Exec(ctx, sqlsBuild.String()) - if err != nil { - if IsDuckDBTableAlreadyExistsError(err) { - return sql.ErrTableAlreadyExists.New(name) - } - return ErrDuckDB.New(err) - } - - // TODO: support collation - - return nil -} - // CreateTrigger implements sql.TriggerDatabase. func (d *Database) CreateTrigger(ctx *sql.Context, definition sql.TriggerDefinition) error { return sql.ErrTriggersNotSupported.New(d.name) diff --git a/catalog/table.go b/catalog/table.go index dc9cdc8b..aa4b0fad 100644 --- a/catalog/table.go +++ b/catalog/table.go @@ -145,8 +145,8 @@ func (t *Table) PrimaryKeySchema() sql.PrimaryKeySchema { func getPrimaryKeyOrdinals(ctx *sql.Context, catalogName, dbName, tableName string) []int { rows, err := adapter.QueryCatalog(ctx, ` - SELECT constraint_column_indexes FROM duckdb_constraints() WHERE database_name = ? AND schema_name = ? AND table_name = ? AND constraint_type = 'PRIMARY KEY' LIMIT 1 - `, catalogName, dbName, tableName) + SELECT constraint_column_indexes FROM duckdb_constraints() WHERE ((database_name = ? AND schema_name = ? AND table_name = ?) OR (database_name = 'temp' AND schema_name = 'main' AND table_name = ?)) AND constraint_type = 'PRIMARY KEY' LIMIT 1 + `, catalogName, dbName, tableName, tableName) if err != nil { panic(ErrDuckDB.New(err)) } @@ -420,8 +420,8 @@ func (t *Table) GetIndexes(ctx *sql.Context) ([]sql.Index, error) { defer t.mu.RUnlock() // Query to get the indexes for the table - rows, err := adapter.QueryCatalog(ctx, `SELECT index_name, is_unique, comment, sql FROM duckdb_indexes() WHERE database_name = ? AND schema_name = ? AND table_name = ?`, - t.db.catalog, t.db.name, t.name) + rows, err := adapter.QueryCatalog(ctx, `SELECT index_name, is_unique, comment, sql FROM duckdb_indexes() WHERE (database_name = ? AND schema_name = ? AND table_name = ?) or (database_name = 'temp' AND schema_name = 'main' AND table_name = ?)`, + t.db.catalog, t.db.name, t.name, t.name) if err != nil { return nil, ErrDuckDB.New(err) } @@ -500,9 +500,9 @@ func (t *Table) Comment() string { func queryColumns(ctx *sql.Context, catalogName, schemaName, tableName string) ([]*ColumnInfo, error) { rows, err := adapter.QueryCatalog(ctx, ` SELECT column_name, column_index, data_type, is_nullable, column_default, comment, numeric_precision, numeric_scale - FROM duckdb_columns() - WHERE database_name = ? AND schema_name = ? AND table_name = ? - `, catalogName, schemaName, tableName) + FROM duckdb_columns() + WHERE (database_name = ? AND schema_name = ? AND table_name = ?) OR (database_name = 'temp' AND schema_name = 'main' AND table_name = ?) + `, catalogName, schemaName, tableName, tableName) if err != nil { return nil, err }