diff --git a/catalog/provider.go b/catalog/provider.go index c60e733..2caf567 100644 --- a/catalog/provider.go +++ b/catalog/provider.go @@ -74,16 +74,6 @@ func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabasePr prov.storage = stdsql.OpenDB(prov.connector) prov.pool = NewConnectionPool(prov.catalogName, prov.connector, prov.storage) - err = prov.initCatalog() - if err != nil { - return nil, err - } - - prov.ready = true - return prov, nil -} - -func (prov *DatabaseProvider) initCatalog() error { bootQueries := []string{ "INSTALL arrow", "LOAD arrow", @@ -97,10 +87,21 @@ func (prov *DatabaseProvider) initCatalog() error { if _, err := prov.storage.ExecContext(context.Background(), q); err != nil { prov.storage.Close() prov.connector.Close() - return fmt.Errorf("failed to execute boot query %q: %w", q, err) + return nil, fmt.Errorf("failed to execute boot query %q: %w", q, err) } } + err = prov.initCatalog() + if err != nil { + return nil, err + } + + prov.ready = true + return prov, nil +} + +func (prov *DatabaseProvider) initCatalog() error { + for _, t := range internalSchemas { if _, err := prov.storage.ExecContext( context.Background(), @@ -158,7 +159,7 @@ func (prov *DatabaseProvider) IsReady() bool { return prov.ready } -func (prov *DatabaseProvider) ExistCatalog(name string) bool { +func (prov *DatabaseProvider) ExistsCatalog(name string) bool { name = strings.TrimSpace(name) // in memory database does not need to be created if name == "" || name == "memory" { @@ -178,38 +179,46 @@ func (prov *DatabaseProvider) CreateCatalog(name string, ifNotExists bool) error return nil } dsn := filepath.Join(prov.dataDir, name+".db") + + _, err := os.Stat(dsn) + shouldInit := os.IsNotExist(err) + // attach attachSQL := "ATTACH" if ifNotExists { attachSQL += " IF NOT EXISTS" } attachSQL += " '" + dsn + "' AS " + name - res, err := prov.storage.ExecContext(context.Background(), attachSQL) - if err != nil { - return err - } - rows, err := res.RowsAffected() + _, err = prov.storage.ExecContext(context.Background(), attachSQL) if err != nil { - logrus.Errorf("Failed to get rows affected: %v", err) return err } - if rows <= 0 { - return nil - } - // if newly created, initialize the catalog - if _, err := prov.storage.ExecContext(context.Background(), "USE "+name); err != nil { - return err - } - defer func() { - if _, err := prov.storage.ExecContext(context.Background(), "USE "+prov.catalogName); err != nil { - logrus.WithError(err).Errorln("Failed to switch back to the old database") + if shouldInit { + res, err := prov.storage.QueryContext(context.Background(), "SELECT current_catalog") + if err != nil { + return fmt.Errorf("failed to init catalog: %w", err) + } + lastCatalog := "" + for res.Next() { + if err := res.Scan(&lastCatalog); err != nil { + return fmt.Errorf("failed to init catalog: %w", err) + } } - }() - err = prov.initCatalog() - if err != nil { - return err + if _, err := prov.storage.ExecContext(context.Background(), "USE "+name); err != nil { + return fmt.Errorf("failed to switch to the new catalog: %w", err) + } + + defer func() { + if _, err := prov.storage.ExecContext(context.Background(), "USE "+lastCatalog); err != nil { + logrus.WithError(err).Errorln("Failed to switch back to the old catalog") + } + }() + err = prov.initCatalog() + if err != nil { + return err + } } return nil }