diff --git a/adapter/adapter.go b/adapter/adapter.go index 1bbacd0..14bffd1 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -14,6 +14,8 @@ type ConnectionHolder interface { GetCatalogConn(ctx context.Context) (*stdsql.Conn, error) GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) TryGetTxn() *stdsql.Tx + GetCurrentCatalog() string + GetCurrentSchema() string CloseTxn() CloseConn() } @@ -42,6 +44,14 @@ func TryGetTxn(ctx *sql.Context) *stdsql.Tx { return ctx.Session.(ConnectionHolder).TryGetTxn() } +func GetCurrentCatalog(ctx *sql.Context) string { + return ctx.Session.(ConnectionHolder).GetCurrentCatalog() +} + +func GetCurrentSchema(ctx *sql.Context) string { + return ctx.Session.(ConnectionHolder).GetCurrentSchema() +} + func CloseTxn(ctx *sql.Context) { ctx.Session.(ConnectionHolder).CloseTxn() } diff --git a/backend/executor.go b/backend/executor.go index 29c22bd..51c8095 100644 --- a/backend/executor.go +++ b/backend/executor.go @@ -17,6 +17,7 @@ import ( stdsql "database/sql" "fmt" + "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/catalog" "github.com/apecloud/myduckserver/transpiler" "github.com/dolthub/go-mysql-server/sql" @@ -124,7 +125,7 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row switch node := n.(type) { case *plan.Use: - useStmt := "USE " + catalog.FullSchemaName(b.provider.CatalogName(), node.Database().Name()) + useStmt := "USE " + catalog.FullSchemaName(adapter.GetCurrentCatalog(ctx), node.Database().Name()) if _, err := conn.ExecContext(ctx.Context, useStmt); err != nil { if catalog.IsDuckDBSetSchemaNotFoundError(err) { return nil, sql.ErrDatabaseNotFound.New(node.Database().Name()) diff --git a/backend/session.go b/backend/session.go index 52dfc61..54e0059 100644 --- a/backend/session.go +++ b/backend/session.go @@ -225,6 +225,16 @@ func (sess *Session) TryGetTxn() *stdsql.Tx { return sess.db.Pool().TryGetTxn(sess.ID()) } +// GetCurrentCatalog implements adapter.ConnectionHolder. +func (sess *Session) GetCurrentCatalog() string { + return sess.db.Pool().CurrentCatalog(sess.ID()) +} + +// GetCurrentSchema implements adapter.ConnectionHolder. +func (sess *Session) GetCurrentSchema() string { + return sess.db.Pool().CurrentSchema(sess.ID()) +} + // CloseTxn implements adapter.ConnectionHolder. func (sess *Session) CloseTxn() { sess.db.Pool().CloseTxn(sess.ID()) diff --git a/catalog/connpool.go b/catalog/connpool.go index b89d974..177bab9 100644 --- a/catalog/connpool.go +++ b/catalog/connpool.go @@ -30,16 +30,14 @@ import ( type ConnectionPool struct { *stdsql.DB connector *duckdb.Connector - catalog string conns sync.Map // concurrent-safe map[uint32]*stdsql.Conn txns sync.Map // concurrent-safe map[uint32]*stdsql.Tx } -func NewConnectionPool(catalog string, connector *duckdb.Connector, db *stdsql.DB) *ConnectionPool { +func NewConnectionPool(connector *duckdb.Connector, db *stdsql.DB) *ConnectionPool { return &ConnectionPool{ DB: db, connector: connector, - catalog: catalog, } } @@ -57,13 +55,30 @@ func (p *ConnectionPool) CurrentSchema(id uint32) string { } conn := entry.(*stdsql.Conn) var schema string - if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA()").Scan(&schema); err != nil { + if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA").Scan(&schema); err != nil { logrus.WithError(err).Error("Failed to get current schema") return "" } return schema } +// CurrentCatalog retrieves the current catalog of the connection. +// Returns an empty string if the connection is not established +// or the catalog cannot be retrieved. +func (p *ConnectionPool) CurrentCatalog(id uint32) string { + entry, ok := p.conns.Load(id) + if !ok { + return "" + } + conn := entry.(*stdsql.Conn) + var catalog string + if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_CATALOG").Scan(&catalog); err != nil { + logrus.WithError(err).Error("Failed to get current catalog") + return "" + } + return catalog +} + func (p *ConnectionPool) GetConn(ctx context.Context, id uint32) (*stdsql.Conn, error) { var conn *stdsql.Conn entry, ok := p.conns.Load(id) @@ -88,11 +103,11 @@ func (p *ConnectionPool) GetConnForSchema(ctx context.Context, id uint32, schema if schemaName != "" { var currentSchema string - if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA()").Scan(¤tSchema); err != nil { + if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA").Scan(¤tSchema); err != nil { logrus.WithError(err).Error("Failed to get current schema") return nil, err } else if currentSchema != schemaName { - if _, err := conn.ExecContext(context.Background(), "USE "+FullSchemaName(p.catalog, schemaName)); err != nil { + if _, err := conn.ExecContext(context.Background(), "USE "+FullSchemaName(p.CurrentCatalog(id), schemaName)); err != nil { if IsDuckDBSetSchemaNotFoundError(err) { return nil, sql.ErrDatabaseNotFound.New(schemaName) } @@ -187,7 +202,7 @@ func (p *ConnectionPool) Close() error { return errors.Join(lastErr, p.DB.Close()) } -func (p *ConnectionPool) Reset(catalog string, connector *duckdb.Connector, db *stdsql.DB) error { +func (p *ConnectionPool) Reset(connector *duckdb.Connector, db *stdsql.DB) error { err := p.Close() if err != nil { return fmt.Errorf("failed to close connection pool: %w", err) @@ -195,7 +210,6 @@ func (p *ConnectionPool) Reset(catalog string, connector *duckdb.Connector, db * p.conns.Clear() p.txns.Clear() - p.catalog = catalog p.DB = db p.connector = connector diff --git a/catalog/provider.go b/catalog/provider.go index 2fe15db..b99ba15 100644 --- a/catalog/provider.go +++ b/catalog/provider.go @@ -14,7 +14,6 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/marcboeker/go-duckdb" - _ "github.com/marcboeker/go-duckdb" "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/configuration" @@ -27,7 +26,7 @@ type DatabaseProvider struct { connector *duckdb.Connector storage *stdsql.DB pool *ConnectionPool - catalogName string // database name in postgres + defaultCatalogName string // default database name in postgres dataDir string dbFile string dsn string @@ -60,11 +59,11 @@ func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabasePr shouldInit := true if defaultDB == "" || defaultDB == "memory" { - prov.catalogName = "memory" + prov.defaultCatalogName = "memory" prov.dbFile = "" prov.dsn = "" } else { - prov.catalogName = defaultDB + prov.defaultCatalogName = defaultDB prov.dbFile = defaultDB + ".db" prov.dsn = filepath.Join(prov.dataDir, prov.dbFile) _, err = os.Stat(prov.dsn) @@ -76,7 +75,7 @@ func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabasePr return nil, err } prov.storage = stdsql.OpenDB(prov.connector) - prov.pool = NewConnectionPool(prov.catalogName, prov.connector, prov.storage) + prov.pool = NewConnectionPool(prov.connector, prov.storage) bootQueries := []string{ "INSTALL arrow", @@ -353,8 +352,8 @@ func (prov *DatabaseProvider) Pool() *ConnectionPool { return prov.pool } -func (prov *DatabaseProvider) CatalogName() string { - return prov.catalogName +func (prov *DatabaseProvider) DefaultCatalogName() string { + return prov.defaultCatalogName } func (prov *DatabaseProvider) DataDir() string { @@ -380,7 +379,8 @@ func (prov *DatabaseProvider) AllDatabases(ctx *sql.Context) []sql.Database { prov.mu.RLock() defer prov.mu.RUnlock() - rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT schema_name FROM information_schema.schemata WHERE catalog_name = ?", prov.catalogName) + catalogName := adapter.GetCurrentCatalog(ctx) + rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT schema_name FROM information_schema.schemata WHERE catalog_name = ?", catalogName) if err != nil { panic(ErrDuckDB.New(err)) } @@ -398,7 +398,7 @@ func (prov *DatabaseProvider) AllDatabases(ctx *sql.Context) []sql.Database { continue } - all = append(all, NewDatabase(schemaName, prov.catalogName)) + all = append(all, NewDatabase(schemaName, catalogName)) } sort.Slice(all, func(i, j int) bool { @@ -413,13 +413,14 @@ func (prov *DatabaseProvider) Database(ctx *sql.Context, name string) (sql.Datab prov.mu.RLock() defer prov.mu.RUnlock() - ok, err := hasDatabase(ctx, prov.catalogName, name) + catalogName := adapter.GetCurrentCatalog(ctx) + ok, err := hasDatabase(ctx, catalogName, name) if err != nil { return nil, err } if ok { - return NewDatabase(name, prov.catalogName), nil + return NewDatabase(name, catalogName), nil } return nil, sql.ErrDatabaseNotFound.New(name) } @@ -429,7 +430,7 @@ func (prov *DatabaseProvider) HasDatabase(ctx *sql.Context, name string) bool { prov.mu.RLock() defer prov.mu.RUnlock() - ok, err := hasDatabase(ctx, prov.catalogName, name) + ok, err := hasDatabase(ctx, adapter.GetCurrentCatalog(ctx), name) if err != nil { panic(err) } @@ -451,7 +452,8 @@ func (prov *DatabaseProvider) CreateDatabase(ctx *sql.Context, name string) erro prov.mu.Lock() defer prov.mu.Unlock() - _, err := adapter.ExecCatalog(ctx, fmt.Sprintf(`CREATE SCHEMA %s`, FullSchemaName(prov.catalogName, name))) + _, err := adapter.ExecCatalog(ctx, fmt.Sprintf(`CREATE SCHEMA %s`, + FullSchemaName(adapter.GetCurrentCatalog(ctx), name))) if err != nil { return ErrDuckDB.New(err) } @@ -464,7 +466,8 @@ func (prov *DatabaseProvider) DropDatabase(ctx *sql.Context, name string) error prov.mu.Lock() defer prov.mu.Unlock() - _, err := adapter.Exec(ctx, fmt.Sprintf(`DROP SCHEMA %s CASCADE`, FullSchemaName(prov.catalogName, name))) + _, err := adapter.Exec(ctx, fmt.Sprintf(`DROP SCHEMA %s CASCADE`, + FullSchemaName(adapter.GetCurrentCatalog(ctx), name))) if err != nil { return ErrDuckDB.New(err) } @@ -494,5 +497,5 @@ func (prov *DatabaseProvider) Restart(readOnly bool) error { prov.connector = connector prov.storage = storage - return nil + return prov.pool.Reset(connector, storage) } diff --git a/pgserver/backup_handler.go b/pgserver/backup_handler.go index 65b0ae0..5bdb02b 100644 --- a/pgserver/backup_handler.go +++ b/pgserver/backup_handler.go @@ -89,12 +89,6 @@ func parseBackupSQL(sql string) (*BackupConfig, error) { } func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, error) { - // TODO(neo.zty): Add support for backing up multiple databases once MyDuck Server supports multi-database functionality. - if backupConfig.DbName != h.server.Provider.CatalogName() { - return "", fmt.Errorf("backup database name %s does not match server database name %s", - backupConfig.DbName, h.server.Provider.CatalogName()) - } - sqlCtx, err := h.duckHandler.sm.NewContextWithQuery(context.Background(), h.mysqlConn, "") if err != nil { return "", fmt.Errorf("failed to create context for query: %w", err) @@ -114,7 +108,7 @@ func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, e } msg, err := backupConfig.StorageConfig.UploadFile( - h.server.Provider.DataDir(), h.server.Provider.DbFile(), backupConfig.RemotePath) + h.server.Provider.DataDir(), backupConfig.DbName+".db", backupConfig.RemotePath) if err != nil { return "", err } @@ -133,12 +127,7 @@ func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, e func (h *ConnectionHandler) restartServer(readOnly bool) error { provider := h.server.Provider - err := provider.Restart(readOnly) - if err != nil { - return err - } - - return h.server.Provider.Pool().Reset(provider.CatalogName(), provider.Connector(), provider.Storage()) + return provider.Restart(readOnly) } func doCheckpoint(sqlCtx *sql.Context) error { diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index 16a1545..b4de311 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -313,7 +313,7 @@ func (h *ConnectionHandler) chooseInitialDatabase(startupMessage *pgproto3.Start } if db == "postgres" || db == "mysql" { if provider := h.duckHandler.GetCatalogProvider(); provider != nil { - db = provider.CatalogName() + db = provider.DefaultCatalogName() } }