Skip to content

Commit

Permalink
refactor: use attach
Browse files Browse the repository at this point in the history
  • Loading branch information
NoyException committed Dec 24, 2024
1 parent c930501 commit 32e8806
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 165 deletions.
226 changes: 102 additions & 124 deletions catalog/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,41 @@ func NewInMemoryDBProvider() *DatabaseProvider {
return prov
}

func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (*DatabaseProvider, error) {
prov := &DatabaseProvider{
func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabaseProvider, err error) {
prov = &DatabaseProvider{
mu: &sync.RWMutex{},
defaultTimeZone: defaultTimeZone,
externalProcedureRegistry: sql.NewExternalStoredProcedureRegistry(), // This has no effect, just to satisfy the upper layer interface
dataDir: dataDir,
dsn: "N/A",
}
prepared, err := prov.CreateCatalog(defaultDB)
if !prepared {

if defaultDB == "" || defaultDB == "memory" {
prov.catalogName = "memory"
prov.dbFile = ""
prov.dsn = ""
} else {
prov.catalogName = defaultDB
prov.dbFile = defaultDB + ".db"
prov.dsn = filepath.Join(prov.dataDir, prov.dbFile)
}

prov.connector, err = duckdb.NewConnector(prov.dsn, nil)
if err != nil {
return nil, err
}
err = prov.SwitchCatalog(defaultDB)
prov.storage = stdsql.OpenDB(prov.connector)
prov.pool = NewConnectionPool(prov.catalogName, prov.connector, prov.storage)

err = prov.initCatalog()
if err != nil {
return nil, err
}

prov.ready = true
return prov, nil
}

func (prov *DatabaseProvider) initCatalog(connector *duckdb.Connector, storage *stdsql.DB) error {
func (prov *DatabaseProvider) initCatalog() error {
bootQueries := []string{
"INSTALL arrow",
"LOAD arrow",
Expand All @@ -78,16 +92,17 @@ func (prov *DatabaseProvider) initCatalog(connector *duckdb.Connector, storage *
"INSTALL postgres_scanner",
"LOAD postgres_scanner",
}

for _, q := range bootQueries {
if _, err := storage.ExecContext(context.Background(), q); err != nil {
storage.Close()
connector.Close()
if _, err := prov.storage.ExecContext(context.Background(), q); err != nil {
prov.storage.Close()
prov.connector.Close()
return fmt.Errorf("failed to execute boot query %q: %w", q, err)
}
}

for _, t := range internalSchemas {
if _, err := storage.ExecContext(
if _, err := prov.storage.ExecContext(
context.Background(),
"CREATE SCHEMA IF NOT EXISTS "+t.Schema,
); err != nil {
Expand All @@ -96,20 +111,20 @@ func (prov *DatabaseProvider) initCatalog(connector *duckdb.Connector, storage *
}

for _, t := range internalTables {
if _, err := storage.ExecContext(
if _, err := prov.storage.ExecContext(
context.Background(),
"CREATE SCHEMA IF NOT EXISTS "+t.Schema,
); err != nil {
return fmt.Errorf("failed to create internal schema %q: %w", t.Schema, err)
}
if _, err := storage.ExecContext(
if _, err := prov.storage.ExecContext(
context.Background(),
"CREATE TABLE IF NOT EXISTS "+t.QualifiedName()+"("+t.DDL+")",
); err != nil {
return fmt.Errorf("failed to create internal table %q: %w", t.Name, err)
}
for _, row := range t.InitialData {
if _, err := storage.ExecContext(
if _, err := prov.storage.ExecContext(
context.Background(),
t.UpsertStmt(),
row...,
Expand All @@ -119,146 +134,109 @@ func (prov *DatabaseProvider) initCatalog(connector *duckdb.Connector, storage *
}
}

if _, err := prov.pool.ExecContext(context.Background(), "PRAGMA enable_checkpoint_on_shutdown"); err != nil {
logrus.WithError(err).Fatalln("Failed to enable checkpoint on shutdown")
}

if prov.defaultTimeZone != "" {
_, err := prov.pool.ExecContext(context.Background(), fmt.Sprintf(`SET TimeZone = '%s'`, prov.defaultTimeZone))
if err != nil {
logrus.WithError(err).Fatalln("Failed to set the default time zone")
}
}

// Postgres tables are created in the `public` schema by default.
// Create the `public` schema if it doesn't exist.
_, err := prov.pool.ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS public")
if err != nil {
logrus.WithError(err).Fatalln("Failed to create the `public` schema")
}
return nil
}

func (prov *DatabaseProvider) IsReady() bool {
return prov.ready
}

func (prov *DatabaseProvider) DropCatalog(dbName string) error {
dbFile := strings.TrimSpace(dbName) + ".db"
dsn := ""
if dbFile != "" {
dsn = filepath.Join(prov.dataDir, dbFile)
// if this is the current catalog, return error
if dsn == prov.dsn {
return fmt.Errorf("cannot drop the current catalog")
}
// if file does not exist, return error
_, err := os.Stat(dsn)
if os.IsNotExist(err) {
return fmt.Errorf("database file %s does not exist", dsn)
}
// delete the file
err = os.Remove(dsn)
if err != nil {
return fmt.Errorf("failed to delete database file %s: %w", dsn, err)
}
return nil
} else {
return fmt.Errorf("cannot drop the in-memory catalog")
}
}

func (prov *DatabaseProvider) ExistCatalog(dbName string) bool {
dbFile := strings.TrimSpace(dbName) + ".db"
if dbFile == "" || dbFile == "memory.db" {
func (prov *DatabaseProvider) ExistCatalog(name string) bool {
name = strings.TrimSpace(name)
// in memory database does not need to be created
if name == "" || name == "memory" {
return true
} else {
dsn := filepath.Join(prov.dataDir, dbFile)
// if already exists, return error
_, err := os.Stat(dsn)
return os.IsExist(err)
}

dsn := filepath.Join(prov.dataDir, name+".db")
// if already exists, return error
_, err := os.Stat(dsn)
return os.IsExist(err)
}

func (prov *DatabaseProvider) CreateCatalog(dbName string) (ready bool, err error) {
dbFile := strings.TrimSpace(dbName) + ".db"
dsn := ""
func (prov *DatabaseProvider) CreateCatalog(name string, ifNotExists bool) error {
name = strings.TrimSpace(name)
// in memory database does not need to be created
if dbFile == "" || dbFile == "memory.db" {
return true, nil
} else {
dsn = filepath.Join(prov.dataDir, dbFile)
// if already exists, return error
_, err := os.Stat(dsn)
if err == nil {
return true, fmt.Errorf("database file %s already exists", dsn)
}
if name == "" || name == "memory" {
return nil
}

connector, err := duckdb.NewConnector(dsn, nil)
if err != nil {
return false, err
dsn := filepath.Join(prov.dataDir, name+".db")
// attach
attachSQL := "ATTACH"
if ifNotExists {
attachSQL += " IF NOT EXISTS"
}

storage := stdsql.OpenDB(connector)
err = prov.initCatalog(connector, storage)
attachSQL += " '" + dsn + "' AS " + name
res, err := prov.storage.ExecContext(context.Background(), attachSQL)
if err != nil {
return false, err
return err
}
storage.Close()
connector.Close()
return true, nil
}

func (prov *DatabaseProvider) SwitchCatalog(dbName string) error {
dbFile := strings.TrimSpace(dbName) + ".db"
name := ""
dsn := ""
if dbFile == "" || dbFile == "memory.db" {
// in-memory mode, mainly for testing
name = "memory"
dsn = ""
} else {
name = strings.Split(dbFile, ".")[0]
dsn = filepath.Join(prov.dataDir, dbFile)
// if file does not exist, return error
_, err := os.Stat(dsn)
if os.IsNotExist(err) {
return fmt.Errorf("database file %s does not exist", dsn)
}
rows, err := res.RowsAffected()
if err != nil {
logrus.Errorf("Failed to get rows affected: %v", err)
return err
}
if dsn == prov.dsn {
if rows <= 0 {
return nil
}

connector, err := duckdb.NewConnector(dsn, nil)
if err != nil {
// if newly created, initialize the catalog
if _, err := prov.storage.ExecContext(context.Background(), "USE "+name); err != nil {
return err
}

storage := stdsql.OpenDB(connector)

// in memory database needs to be initialized every time
if dsn == "" {
err = prov.initCatalog(connector, storage)
if err != nil {
return err
}
}

prov.mu.Lock()
prov.ready = false
defer func() {
prov.ready = true
prov.mu.Unlock()
if _, err := prov.storage.ExecContext(context.Background(), "USE "+prov.catalogName); err != nil {
logrus.WithError(err).Errorln("Failed to switch back to the old database")
}
}()

prov.connector = connector
prov.storage = storage
prov.catalogName = name
prov.dbFile = dbFile
prov.dsn = dsn

prov.pool = NewConnectionPool(name, connector, storage)
if _, err := prov.pool.ExecContext(context.Background(), "PRAGMA enable_checkpoint_on_shutdown"); err != nil {
logrus.WithError(err).Fatalln("Failed to enable checkpoint on shutdown")
err = prov.initCatalog()
if err != nil {
return err
}
return nil
}

if prov.defaultTimeZone != "" {
_, err := prov.pool.ExecContext(context.Background(), fmt.Sprintf(`SET TimeZone = '%s'`, prov.defaultTimeZone))
if err != nil {
logrus.WithError(err).Fatalln("Failed to set the default time zone")
func (prov *DatabaseProvider) DropCatalog(name string, ifExists bool) error {
name = strings.TrimSpace(name)
// in memory database does not need to be created
if name == "" || name == "memory" {
return fmt.Errorf("cannot drop the in-memory catalog")
}
dsn := filepath.Join(prov.dataDir, name+".db")
// if file does not exist, return error
_, err := os.Stat(dsn)
if os.IsNotExist(err) {
if ifExists {
return nil
}
return fmt.Errorf("database file %s does not exist", dsn)
}

// Postgres tables are created in the `public` schema by default.
// Create the `public` schema if it doesn't exist.
_, err = prov.pool.ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS public")
// detach
if _, err := prov.storage.ExecContext(context.Background(), "DETACH "+name); err != nil {
return fmt.Errorf("failed to detach catalog %w", err)
}
// delete the file
err = os.Remove(dsn)
if err != nil {
logrus.WithError(err).Fatalln("Failed to create the `public` schema")
return fmt.Errorf("failed to delete database file %s: %w", dsn, err)
}
return nil
}
Expand Down
26 changes: 19 additions & 7 deletions catalog/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ func TestCreateCatalog(t *testing.T) {
SQL: "CREATE DATABASE testdb1;",
WantErr: true,
},
{
SQL: "CREATE DATABASE IF NOT EXISTS testdb1;",
Expected: "CREATE DATABASE",
},
{
SQL: "CREATE DATABASE testdb2;",
Expected: "CREATE DATABASE",
Expand Down Expand Up @@ -80,21 +84,29 @@ func TestCreateCatalog(t *testing.T) {
{
name: "drop database",
executions: []Execution{
//{
// SQL: "USE testdb1;",
// Expected: "SET",
//},
//// Can not drop the database when the current database is the one to be dropped
//{
// SQL: "DROP DATABASE testdb1;",
// WantErr: true,
//},
{
SQL: "USE testdb1;",
SQL: "USE testdb2;",
Expected: "SET",
},
// Can not drop the database when the current database is the one to be dropped
{
SQL: "DROP DATABASE testdb1;",
WantErr: true,
SQL: "DROP DATABASE testdb1;",
Expected: "DROP DATABASE",
},
{
SQL: "USE testdb2;",
Expected: "SET",
SQL: "DROP DATABASE testdb1;",
WantErr: true,
},
{
SQL: "DROP DATABASE testdb1;",
SQL: "DROP DATABASE IF EXISTS testdb1;",
Expected: "DROP DATABASE",
},
},
Expand Down
Loading

0 comments on commit 32e8806

Please sign in to comment.