diff --git a/catalog/provider.go b/catalog/provider.go index 2b54faf..c60e733 100644 --- a/catalog/provider.go +++ b/catalog/provider.go @@ -49,27 +49,41 @@ func NewInMemoryDBProvider() *DatabaseProvider { return prov } -func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (*DatabaseProvider, error) { - prov := &DatabaseProvider{ +func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabaseProvider, err error) { + prov = &DatabaseProvider{ mu: &sync.RWMutex{}, defaultTimeZone: defaultTimeZone, externalProcedureRegistry: sql.NewExternalStoredProcedureRegistry(), // This has no effect, just to satisfy the upper layer interface dataDir: dataDir, - dsn: "N/A", } - prepared, err := prov.CreateCatalog(defaultDB) - if !prepared { + + if defaultDB == "" || defaultDB == "memory" { + prov.catalogName = "memory" + prov.dbFile = "" + prov.dsn = "" + } else { + prov.catalogName = defaultDB + prov.dbFile = defaultDB + ".db" + prov.dsn = filepath.Join(prov.dataDir, prov.dbFile) + } + + prov.connector, err = duckdb.NewConnector(prov.dsn, nil) + if err != nil { return nil, err } - err = prov.SwitchCatalog(defaultDB) + prov.storage = stdsql.OpenDB(prov.connector) + prov.pool = NewConnectionPool(prov.catalogName, prov.connector, prov.storage) + + err = prov.initCatalog() if err != nil { return nil, err } + prov.ready = true return prov, nil } -func (prov *DatabaseProvider) initCatalog(connector *duckdb.Connector, storage *stdsql.DB) error { +func (prov *DatabaseProvider) initCatalog() error { bootQueries := []string{ "INSTALL arrow", "LOAD arrow", @@ -78,16 +92,17 @@ func (prov *DatabaseProvider) initCatalog(connector *duckdb.Connector, storage * "INSTALL postgres_scanner", "LOAD postgres_scanner", } + for _, q := range bootQueries { - if _, err := storage.ExecContext(context.Background(), q); err != nil { - storage.Close() - connector.Close() + if _, err := prov.storage.ExecContext(context.Background(), q); err != nil { + prov.storage.Close() + prov.connector.Close() return fmt.Errorf("failed to execute boot query %q: %w", q, err) } } for _, t := range internalSchemas { - if _, err := storage.ExecContext( + if _, err := prov.storage.ExecContext( context.Background(), "CREATE SCHEMA IF NOT EXISTS "+t.Schema, ); err != nil { @@ -96,20 +111,20 @@ func (prov *DatabaseProvider) initCatalog(connector *duckdb.Connector, storage * } for _, t := range internalTables { - if _, err := storage.ExecContext( + if _, err := prov.storage.ExecContext( context.Background(), "CREATE SCHEMA IF NOT EXISTS "+t.Schema, ); err != nil { return fmt.Errorf("failed to create internal schema %q: %w", t.Schema, err) } - if _, err := storage.ExecContext( + if _, err := prov.storage.ExecContext( context.Background(), "CREATE TABLE IF NOT EXISTS "+t.QualifiedName()+"("+t.DDL+")", ); err != nil { return fmt.Errorf("failed to create internal table %q: %w", t.Name, err) } for _, row := range t.InitialData { - if _, err := storage.ExecContext( + if _, err := prov.storage.ExecContext( context.Background(), t.UpsertStmt(), row..., @@ -119,6 +134,23 @@ func (prov *DatabaseProvider) initCatalog(connector *duckdb.Connector, storage * } } + if _, err := prov.pool.ExecContext(context.Background(), "PRAGMA enable_checkpoint_on_shutdown"); err != nil { + logrus.WithError(err).Fatalln("Failed to enable checkpoint on shutdown") + } + + if prov.defaultTimeZone != "" { + _, err := prov.pool.ExecContext(context.Background(), fmt.Sprintf(`SET TimeZone = '%s'`, prov.defaultTimeZone)) + if err != nil { + logrus.WithError(err).Fatalln("Failed to set the default time zone") + } + } + + // Postgres tables are created in the `public` schema by default. + // Create the `public` schema if it doesn't exist. + _, err := prov.pool.ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS public") + if err != nil { + logrus.WithError(err).Fatalln("Failed to create the `public` schema") + } return nil } @@ -126,139 +158,85 @@ func (prov *DatabaseProvider) IsReady() bool { return prov.ready } -func (prov *DatabaseProvider) DropCatalog(dbName string) error { - dbFile := strings.TrimSpace(dbName) + ".db" - dsn := "" - if dbFile != "" { - dsn = filepath.Join(prov.dataDir, dbFile) - // if this is the current catalog, return error - if dsn == prov.dsn { - return fmt.Errorf("cannot drop the current catalog") - } - // if file does not exist, return error - _, err := os.Stat(dsn) - if os.IsNotExist(err) { - return fmt.Errorf("database file %s does not exist", dsn) - } - // delete the file - err = os.Remove(dsn) - if err != nil { - return fmt.Errorf("failed to delete database file %s: %w", dsn, err) - } - return nil - } else { - return fmt.Errorf("cannot drop the in-memory catalog") - } -} - -func (prov *DatabaseProvider) ExistCatalog(dbName string) bool { - dbFile := strings.TrimSpace(dbName) + ".db" - if dbFile == "" || dbFile == "memory.db" { +func (prov *DatabaseProvider) ExistCatalog(name string) bool { + name = strings.TrimSpace(name) + // in memory database does not need to be created + if name == "" || name == "memory" { return true - } else { - dsn := filepath.Join(prov.dataDir, dbFile) - // if already exists, return error - _, err := os.Stat(dsn) - return os.IsExist(err) } + + dsn := filepath.Join(prov.dataDir, name+".db") + // if already exists, return error + _, err := os.Stat(dsn) + return os.IsExist(err) } -func (prov *DatabaseProvider) CreateCatalog(dbName string) (ready bool, err error) { - dbFile := strings.TrimSpace(dbName) + ".db" - dsn := "" +func (prov *DatabaseProvider) CreateCatalog(name string, ifNotExists bool) error { + name = strings.TrimSpace(name) // in memory database does not need to be created - if dbFile == "" || dbFile == "memory.db" { - return true, nil - } else { - dsn = filepath.Join(prov.dataDir, dbFile) - // if already exists, return error - _, err := os.Stat(dsn) - if err == nil { - return true, fmt.Errorf("database file %s already exists", dsn) - } + if name == "" || name == "memory" { + return nil } - - connector, err := duckdb.NewConnector(dsn, nil) - if err != nil { - return false, err + dsn := filepath.Join(prov.dataDir, name+".db") + // attach + attachSQL := "ATTACH" + if ifNotExists { + attachSQL += " IF NOT EXISTS" } - - storage := stdsql.OpenDB(connector) - err = prov.initCatalog(connector, storage) + attachSQL += " '" + dsn + "' AS " + name + res, err := prov.storage.ExecContext(context.Background(), attachSQL) if err != nil { - return false, err + return err } - storage.Close() - connector.Close() - return true, nil -} - -func (prov *DatabaseProvider) SwitchCatalog(dbName string) error { - dbFile := strings.TrimSpace(dbName) + ".db" - name := "" - dsn := "" - if dbFile == "" || dbFile == "memory.db" { - // in-memory mode, mainly for testing - name = "memory" - dsn = "" - } else { - name = strings.Split(dbFile, ".")[0] - dsn = filepath.Join(prov.dataDir, dbFile) - // if file does not exist, return error - _, err := os.Stat(dsn) - if os.IsNotExist(err) { - return fmt.Errorf("database file %s does not exist", dsn) - } + rows, err := res.RowsAffected() + if err != nil { + logrus.Errorf("Failed to get rows affected: %v", err) + return err } - if dsn == prov.dsn { + if rows <= 0 { return nil } - connector, err := duckdb.NewConnector(dsn, nil) - if err != nil { + // if newly created, initialize the catalog + if _, err := prov.storage.ExecContext(context.Background(), "USE "+name); err != nil { return err } - - storage := stdsql.OpenDB(connector) - - // in memory database needs to be initialized every time - if dsn == "" { - err = prov.initCatalog(connector, storage) - if err != nil { - return err - } - } - - prov.mu.Lock() - prov.ready = false defer func() { - prov.ready = true - prov.mu.Unlock() + if _, err := prov.storage.ExecContext(context.Background(), "USE "+prov.catalogName); err != nil { + logrus.WithError(err).Errorln("Failed to switch back to the old database") + } }() - prov.connector = connector - prov.storage = storage - prov.catalogName = name - prov.dbFile = dbFile - prov.dsn = dsn - - prov.pool = NewConnectionPool(name, connector, storage) - if _, err := prov.pool.ExecContext(context.Background(), "PRAGMA enable_checkpoint_on_shutdown"); err != nil { - logrus.WithError(err).Fatalln("Failed to enable checkpoint on shutdown") + err = prov.initCatalog() + if err != nil { + return err } + return nil +} - if prov.defaultTimeZone != "" { - _, err := prov.pool.ExecContext(context.Background(), fmt.Sprintf(`SET TimeZone = '%s'`, prov.defaultTimeZone)) - if err != nil { - logrus.WithError(err).Fatalln("Failed to set the default time zone") +func (prov *DatabaseProvider) DropCatalog(name string, ifExists bool) error { + name = strings.TrimSpace(name) + // in memory database does not need to be created + if name == "" || name == "memory" { + return fmt.Errorf("cannot drop the in-memory catalog") + } + dsn := filepath.Join(prov.dataDir, name+".db") + // if file does not exist, return error + _, err := os.Stat(dsn) + if os.IsNotExist(err) { + if ifExists { + return nil } + return fmt.Errorf("database file %s does not exist", dsn) } - - // Postgres tables are created in the `public` schema by default. - // Create the `public` schema if it doesn't exist. - _, err = prov.pool.ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS public") + // detach + if _, err := prov.storage.ExecContext(context.Background(), "DETACH "+name); err != nil { + return fmt.Errorf("failed to detach catalog %w", err) + } + // delete the file + err = os.Remove(dsn) if err != nil { - logrus.WithError(err).Fatalln("Failed to create the `public` schema") + return fmt.Errorf("failed to delete database file %s: %w", dsn, err) } return nil } diff --git a/catalog/provider_test.go b/catalog/provider_test.go index ca43cc4..ff68b14 100644 --- a/catalog/provider_test.go +++ b/catalog/provider_test.go @@ -32,6 +32,10 @@ func TestCreateCatalog(t *testing.T) { SQL: "CREATE DATABASE testdb1;", WantErr: true, }, + { + SQL: "CREATE DATABASE IF NOT EXISTS testdb1;", + Expected: "CREATE DATABASE", + }, { SQL: "CREATE DATABASE testdb2;", Expected: "CREATE DATABASE", @@ -80,21 +84,29 @@ func TestCreateCatalog(t *testing.T) { { name: "drop database", executions: []Execution{ + //{ + // SQL: "USE testdb1;", + // Expected: "SET", + //}, + //// Can not drop the database when the current database is the one to be dropped + //{ + // SQL: "DROP DATABASE testdb1;", + // WantErr: true, + //}, { - SQL: "USE testdb1;", + SQL: "USE testdb2;", Expected: "SET", }, - // Can not drop the database when the current database is the one to be dropped { - SQL: "DROP DATABASE testdb1;", - WantErr: true, + SQL: "DROP DATABASE testdb1;", + Expected: "DROP DATABASE", }, { - SQL: "USE testdb2;", - Expected: "SET", + SQL: "DROP DATABASE testdb1;", + WantErr: true, }, { - SQL: "DROP DATABASE testdb1;", + SQL: "DROP DATABASE IF EXISTS testdb1;", Expected: "DROP DATABASE", }, }, diff --git a/pgserver/duck_handler.go b/pgserver/duck_handler.go index 2e1a723..a2ddd69 100644 --- a/pgserver/duck_handler.go +++ b/pgserver/duck_handler.go @@ -25,7 +25,6 @@ import ( "os" "regexp" "runtime/trace" - "strings" "sync" "time" @@ -415,10 +414,15 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S err error ) - exec := func() { + // NOTE: The query is parsed using Postgres parser, which does not support all DuckDB syntax. + // Consequently, the following classification is not perfect. + switch parsed.(type) { + case *tree.BeginTransaction, *tree.CommitTransaction, *tree.RollbackTransaction, + *tree.CreateTable, *tree.DropTable, *tree.AlterTable, *tree.CreateIndex, *tree.DropIndex, + *tree.Insert, *tree.Update, *tree.Delete, *tree.Truncate, *tree.CopyFrom, *tree.CopyTo, *tree.SetVar: result, err = adapter.Exec(ctx, query) if err != nil { - return + break } affected, _ := result.RowsAffected() insertId, _ := result.LastInsertId() @@ -427,42 +431,15 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S RowsAffected: uint64(affected), InsertID: uint64(insertId), })) - } - // NOTE: The query is parsed using Postgres parser, which does not support all DuckDB syntax. - // Consequently, the following classification is not perfect. - switch parsed.(type) { - case *tree.SetVar: - setVar := parsed.(*tree.SetVar) - if setVar.Name == "database" { - provider := h.GetCatalogProvider() - if provider == nil { - err = fmt.Errorf("database provider not found") - break - } - parts := strings.Split(setVar.Values.String(), ".") - err = provider.SwitchCatalog(parts[0]) - if err != nil { - break - } - // exec() will get the current schema from the underlying connection. If we don't set the schema to public here, - // exec() may fail because of the absence of the old schema in the newly switched catalog. - ctx.Session.SetCurrentDatabase("public") - exec() - } else { - exec() - } - case *tree.BeginTransaction, *tree.CommitTransaction, *tree.RollbackTransaction, - *tree.CreateTable, *tree.DropTable, *tree.AlterTable, *tree.CreateIndex, *tree.DropIndex, - *tree.Insert, *tree.Update, *tree.Delete, *tree.Truncate, *tree.CopyFrom, *tree.CopyTo: - exec() case *tree.CreateDatabase: provider := h.GetCatalogProvider() if provider == nil { err = fmt.Errorf("database provider not found") break } - dbName := parsed.(*tree.CreateDatabase).Name.String() - _, err = provider.CreateCatalog(dbName) + p := parsed.(*tree.CreateDatabase) + dbName := p.Name.String() + err = provider.CreateCatalog(dbName, p.IfNotExists) if err != nil { break } @@ -474,8 +451,9 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S err = fmt.Errorf("database provider not found") break } + p := parsed.(*tree.DropDatabase) dbName := parsed.(*tree.DropDatabase).Name.String() - err = provider.DropCatalog(dbName) + err = provider.DropCatalog(dbName, p.IfExists) if err != nil { break }