Skip to content

Commit

Permalink
support temporay table
Browse files Browse the repository at this point in the history
  • Loading branch information
d h authored and d h committed Nov 28, 2024
1 parent 48be2ee commit 338c7ca
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 83 deletions.
171 changes: 95 additions & 76 deletions catalog/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (d *Database) tablesInsensitive(ctx *sql.Context, pattern string) ([]*Table
}

func (d *Database) findTables(ctx *sql.Context, pattern string) ([]*Table, error) {
rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT table_name, comment FROM duckdb_tables() where database_name = ? and schema_name = ? and table_name ILIKE ?", d.catalog, d.name, pattern)
rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT table_name, comment FROM duckdb_tables() where (database_name = ? and schema_name = ? and table_name ILIKE ?) or (database_name = 'temp' and schema_name = 'main' and table_name ILIKE ?)", d.catalog, d.name, pattern, pattern)
if err != nil {
return nil, ErrDuckDB.New(err)
}
Expand Down Expand Up @@ -106,18 +106,110 @@ func (d *Database) Name() string {
return d.name
}

func (d *Database) CreateAllTable(ctx *sql.Context, name string, schema sql.PrimaryKeySchema, collation sql.CollationID, comment string, is_temp bool) error {

var columns []string
var columnCommentSQLs []string
var fullTableName string

if is_temp {
fullTableName = FullTableName("temp", "main", name)
} else {
fullTableName = FullTableName(d.catalog, d.name, name)
}

for _, col := range schema.Schema {
typ, err := DuckdbDataType(col.Type)
if err != nil {
return err
}
colDef := fmt.Sprintf(`"%s" %s`, col.Name, typ.name)
if col.Nullable {
colDef += " NULL"
} else {
colDef += " NOT NULL"
}

if col.Default != nil {
columnDefault, err := typ.mysql.withDefault(col.Default.String())
if err != nil {
return err
}
colDef += " DEFAULT " + columnDefault
}

columns = append(columns, colDef)

var fullColumnName string

if is_temp {
fullColumnName = FullColumnName("temp", "main", name, col.Name)
} else {
fullColumnName = FullColumnName(d.catalog, d.name, name, col.Name)
}

if col.Comment != "" || typ.mysql.Name != "" || col.Default != nil {
columnCommentSQLs = append(columnCommentSQLs,
fmt.Sprintf(`COMMENT ON COLUMN %s IS '%s'`, fullColumnName,
NewCommentWithMeta[MySQLType](col.Comment, typ.mysql).Encode()))
}
}

var sqlsBuild strings.Builder

if is_temp {
sqlsBuild.WriteString(fmt.Sprintf(`CREATE TEMP TABLE %s (%s`, name, strings.Join(columns, ", ")))
} else {
sqlsBuild.WriteString(fmt.Sprintf(`CREATE TABLE %s (%s`, fullTableName, strings.Join(columns, ", ")))
}

var primaryKeys []string
for _, pkord := range schema.PkOrdinals {
primaryKeys = append(primaryKeys, schema.Schema[pkord].Name)
}

if len(primaryKeys) > 0 {
sqlsBuild.WriteString(fmt.Sprintf(", PRIMARY KEY (%s)", strings.Join(primaryKeys, ", ")))
}

sqlsBuild.WriteString(")")

// Add comment to the table
if comment != "" {
sqlsBuild.WriteString(fmt.Sprintf("; COMMENT ON TABLE %s IS '%s'", fullTableName, NewComment[any](comment).Encode()))
}

// Add column comments
for _, s := range columnCommentSQLs {
sqlsBuild.WriteString(";")
sqlsBuild.WriteString(s)
}

_, err := adapter.Exec(ctx, sqlsBuild.String())
if err != nil {
if IsDuckDBTableAlreadyExistsError(err) {
return sql.ErrTableAlreadyExists.New(name)
}
return ErrDuckDB.New(err)
}

// TODO: support collation

return nil
}

// CreateTable implements sql.TableCreator.
func (d *Database) CreateTable(ctx *sql.Context, name string, schema sql.PrimaryKeySchema, collation sql.CollationID, comment string) error {
d.mu.Lock()
defer d.mu.Unlock()
return d.CreateAllTable(ctx, name, schema, collation, comment)
return d.CreateAllTable(ctx, name, schema, collation, comment, false)
}

// CreateTemporaryTable implements sql.CreateTemporaryTable.
func (d *Database) CreateTemporaryTable(ctx *sql.Context, name string, schema sql.PrimaryKeySchema, collation sql.CollationID) error {
d.mu.Lock()
defer d.mu.Unlock()
return d.CreateAllTable(ctx, name, schema, collation, "")
return d.CreateAllTable(ctx, name, schema, collation, "", true)
}

// DropTable implements sql.TableDropper.
Expand Down Expand Up @@ -249,79 +341,6 @@ func (d *Database) DropView(ctx *sql.Context, name string) error {
return nil
}

// CreateTable implements sql.TableCreator.
func (d *Database) CreateAllTable(ctx *sql.Context, name string, schema sql.PrimaryKeySchema, collation sql.CollationID, comment string) error {

var columns []string
var columnCommentSQLs []string
for _, col := range schema.Schema {
typ, err := DuckdbDataType(col.Type)
if err != nil {
return err
}
colDef := fmt.Sprintf(`"%s" %s`, col.Name, typ.name)
if col.Nullable {
colDef += " NULL"
} else {
colDef += " NOT NULL"
}

if col.Default != nil {
columnDefault, err := typ.mysql.withDefault(col.Default.String())
if err != nil {
return err
}
colDef += " DEFAULT " + columnDefault
}

columns = append(columns, colDef)

if col.Comment != "" || typ.mysql.Name != "" || col.Default != nil {
columnCommentSQLs = append(columnCommentSQLs,
fmt.Sprintf(`COMMENT ON COLUMN %s IS '%s'`, FullColumnName(d.catalog, d.name, name, col.Name),
NewCommentWithMeta[MySQLType](col.Comment, typ.mysql).Encode()))
}
}

var sqlsBuild strings.Builder

sqlsBuild.WriteString(fmt.Sprintf(`CREATE TABLE %s (%s`, FullTableName(d.catalog, d.name, name), strings.Join(columns, ", ")))

var primaryKeys []string
for _, pkord := range schema.PkOrdinals {
primaryKeys = append(primaryKeys, schema.Schema[pkord].Name)
}

if len(primaryKeys) > 0 {
sqlsBuild.WriteString(fmt.Sprintf(", PRIMARY KEY (%s)", strings.Join(primaryKeys, ", ")))
}

sqlsBuild.WriteString(")")

// Add comment to the table
if comment != "" {
sqlsBuild.WriteString(fmt.Sprintf("; COMMENT ON TABLE %s IS '%s'", FullTableName(d.catalog, d.name, name), NewComment[any](comment).Encode()))
}

// Add column comments
for _, s := range columnCommentSQLs {
sqlsBuild.WriteString(";")
sqlsBuild.WriteString(s)
}

_, err := adapter.Exec(ctx, sqlsBuild.String())
if err != nil {
if IsDuckDBTableAlreadyExistsError(err) {
return sql.ErrTableAlreadyExists.New(name)
}
return ErrDuckDB.New(err)
}

// TODO: support collation

return nil
}

// CreateTrigger implements sql.TriggerDatabase.
func (d *Database) CreateTrigger(ctx *sql.Context, definition sql.TriggerDefinition) error {
return sql.ErrTriggersNotSupported.New(d.name)
Expand Down
14 changes: 7 additions & 7 deletions catalog/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ func (t *Table) PrimaryKeySchema() sql.PrimaryKeySchema {

func getPrimaryKeyOrdinals(ctx *sql.Context, catalogName, dbName, tableName string) []int {
rows, err := adapter.QueryCatalog(ctx, `
SELECT constraint_column_indexes FROM duckdb_constraints() WHERE database_name = ? AND schema_name = ? AND table_name = ? AND constraint_type = 'PRIMARY KEY' LIMIT 1
`, catalogName, dbName, tableName)
SELECT constraint_column_indexes FROM duckdb_constraints() WHERE ((database_name = ? AND schema_name = ? AND table_name = ?) OR (database_name = 'temp' AND schema_name = 'main' AND table_name = ?)) AND constraint_type = 'PRIMARY KEY' LIMIT 1
`, catalogName, dbName, tableName, tableName)
if err != nil {
panic(ErrDuckDB.New(err))
}
Expand Down Expand Up @@ -420,8 +420,8 @@ func (t *Table) GetIndexes(ctx *sql.Context) ([]sql.Index, error) {
defer t.mu.RUnlock()

// Query to get the indexes for the table
rows, err := adapter.QueryCatalog(ctx, `SELECT index_name, is_unique, comment, sql FROM duckdb_indexes() WHERE database_name = ? AND schema_name = ? AND table_name = ?`,
t.db.catalog, t.db.name, t.name)
rows, err := adapter.QueryCatalog(ctx, `SELECT index_name, is_unique, comment, sql FROM duckdb_indexes() WHERE (database_name = ? AND schema_name = ? AND table_name = ?) or (database_name = 'temp' AND schema_name = 'main' AND table_name = ?)`,
t.db.catalog, t.db.name, t.name, t.name)
if err != nil {
return nil, ErrDuckDB.New(err)
}
Expand Down Expand Up @@ -500,9 +500,9 @@ func (t *Table) Comment() string {
func queryColumns(ctx *sql.Context, catalogName, schemaName, tableName string) ([]*ColumnInfo, error) {
rows, err := adapter.QueryCatalog(ctx, `
SELECT column_name, column_index, data_type, is_nullable, column_default, comment, numeric_precision, numeric_scale
FROM duckdb_columns()
WHERE database_name = ? AND schema_name = ? AND table_name = ?
`, catalogName, schemaName, tableName)
FROM duckdb_columns()
WHERE (database_name = ? AND schema_name = ? AND table_name = ?) OR (database_name = 'temp' AND schema_name = 'main' AND table_name = ?)
`, catalogName, schemaName, tableName, tableName)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 338c7ca

Please sign in to comment.