Skip to content

Commit

Permalink
fix: incorrect reference to catalog (#332)
Browse files Browse the repository at this point in the history
  • Loading branch information
NoyException authored Dec 31, 2024
1 parent 1749dfe commit 2631948
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 38 deletions.
10 changes: 10 additions & 0 deletions adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ type ConnectionHolder interface {
GetCatalogConn(ctx context.Context) (*stdsql.Conn, error)
GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error)
TryGetTxn() *stdsql.Tx
GetCurrentCatalog() string
GetCurrentSchema() string
CloseTxn()
CloseConn()
}
Expand Down Expand Up @@ -42,6 +44,14 @@ func TryGetTxn(ctx *sql.Context) *stdsql.Tx {
return ctx.Session.(ConnectionHolder).TryGetTxn()
}

func GetCurrentCatalog(ctx *sql.Context) string {
return ctx.Session.(ConnectionHolder).GetCurrentCatalog()
}

func GetCurrentSchema(ctx *sql.Context) string {
return ctx.Session.(ConnectionHolder).GetCurrentSchema()
}

func CloseTxn(ctx *sql.Context) {
ctx.Session.(ConnectionHolder).CloseTxn()
}
Expand Down
3 changes: 2 additions & 1 deletion backend/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
stdsql "database/sql"
"fmt"

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/catalog"
"github.com/apecloud/myduckserver/transpiler"
"github.com/dolthub/go-mysql-server/sql"
Expand Down Expand Up @@ -124,7 +125,7 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row

switch node := n.(type) {
case *plan.Use:
useStmt := "USE " + catalog.FullSchemaName(b.provider.CatalogName(), node.Database().Name())
useStmt := "USE " + catalog.FullSchemaName(adapter.GetCurrentCatalog(ctx), node.Database().Name())
if _, err := conn.ExecContext(ctx.Context, useStmt); err != nil {
if catalog.IsDuckDBSetSchemaNotFoundError(err) {
return nil, sql.ErrDatabaseNotFound.New(node.Database().Name())
Expand Down
10 changes: 10 additions & 0 deletions backend/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,16 @@ func (sess *Session) TryGetTxn() *stdsql.Tx {
return sess.db.Pool().TryGetTxn(sess.ID())
}

// GetCurrentCatalog implements adapter.ConnectionHolder.
func (sess *Session) GetCurrentCatalog() string {
return sess.db.Pool().CurrentCatalog(sess.ID())
}

// GetCurrentSchema implements adapter.ConnectionHolder.
func (sess *Session) GetCurrentSchema() string {
return sess.db.Pool().CurrentSchema(sess.ID())
}

// CloseTxn implements adapter.ConnectionHolder.
func (sess *Session) CloseTxn() {
sess.db.Pool().CloseTxn(sess.ID())
Expand Down
30 changes: 22 additions & 8 deletions catalog/connpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,14 @@ import (
type ConnectionPool struct {
*stdsql.DB
connector *duckdb.Connector
catalog string
conns sync.Map // concurrent-safe map[uint32]*stdsql.Conn
txns sync.Map // concurrent-safe map[uint32]*stdsql.Tx
}

func NewConnectionPool(catalog string, connector *duckdb.Connector, db *stdsql.DB) *ConnectionPool {
func NewConnectionPool(connector *duckdb.Connector, db *stdsql.DB) *ConnectionPool {
return &ConnectionPool{
DB: db,
connector: connector,
catalog: catalog,
}
}

Expand All @@ -57,13 +55,30 @@ func (p *ConnectionPool) CurrentSchema(id uint32) string {
}
conn := entry.(*stdsql.Conn)
var schema string
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA()").Scan(&schema); err != nil {
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA").Scan(&schema); err != nil {
logrus.WithError(err).Error("Failed to get current schema")
return ""
}
return schema
}

// CurrentCatalog retrieves the current catalog of the connection.
// Returns an empty string if the connection is not established
// or the catalog cannot be retrieved.
func (p *ConnectionPool) CurrentCatalog(id uint32) string {
entry, ok := p.conns.Load(id)
if !ok {
return ""
}
conn := entry.(*stdsql.Conn)
var catalog string
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_CATALOG").Scan(&catalog); err != nil {
logrus.WithError(err).Error("Failed to get current catalog")
return ""
}
return catalog
}

func (p *ConnectionPool) GetConn(ctx context.Context, id uint32) (*stdsql.Conn, error) {
var conn *stdsql.Conn
entry, ok := p.conns.Load(id)
Expand All @@ -88,11 +103,11 @@ func (p *ConnectionPool) GetConnForSchema(ctx context.Context, id uint32, schema

if schemaName != "" {
var currentSchema string
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA()").Scan(&currentSchema); err != nil {
if err := conn.QueryRowContext(context.Background(), "SELECT CURRENT_SCHEMA").Scan(&currentSchema); err != nil {
logrus.WithError(err).Error("Failed to get current schema")
return nil, err
} else if currentSchema != schemaName {
if _, err := conn.ExecContext(context.Background(), "USE "+FullSchemaName(p.catalog, schemaName)); err != nil {
if _, err := conn.ExecContext(context.Background(), "USE "+FullSchemaName(p.CurrentCatalog(id), schemaName)); err != nil {
if IsDuckDBSetSchemaNotFoundError(err) {
return nil, sql.ErrDatabaseNotFound.New(schemaName)
}
Expand Down Expand Up @@ -187,15 +202,14 @@ func (p *ConnectionPool) Close() error {
return errors.Join(lastErr, p.DB.Close())
}

func (p *ConnectionPool) Reset(catalog string, connector *duckdb.Connector, db *stdsql.DB) error {
func (p *ConnectionPool) Reset(connector *duckdb.Connector, db *stdsql.DB) error {
err := p.Close()
if err != nil {
return fmt.Errorf("failed to close connection pool: %w", err)
}

p.conns.Clear()
p.txns.Clear()
p.catalog = catalog
p.DB = db
p.connector = connector

Expand Down
33 changes: 18 additions & 15 deletions catalog/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (

"github.com/dolthub/go-mysql-server/sql"
"github.com/marcboeker/go-duckdb"
_ "github.com/marcboeker/go-duckdb"

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/configuration"
Expand All @@ -27,7 +26,7 @@ type DatabaseProvider struct {
connector *duckdb.Connector
storage *stdsql.DB
pool *ConnectionPool
catalogName string // database name in postgres
defaultCatalogName string // default database name in postgres
dataDir string
dbFile string
dsn string
Expand Down Expand Up @@ -60,11 +59,11 @@ func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabasePr

shouldInit := true
if defaultDB == "" || defaultDB == "memory" {
prov.catalogName = "memory"
prov.defaultCatalogName = "memory"
prov.dbFile = ""
prov.dsn = ""
} else {
prov.catalogName = defaultDB
prov.defaultCatalogName = defaultDB
prov.dbFile = defaultDB + ".db"
prov.dsn = filepath.Join(prov.dataDir, prov.dbFile)
_, err = os.Stat(prov.dsn)
Expand All @@ -76,7 +75,7 @@ func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabasePr
return nil, err
}
prov.storage = stdsql.OpenDB(prov.connector)
prov.pool = NewConnectionPool(prov.catalogName, prov.connector, prov.storage)
prov.pool = NewConnectionPool(prov.connector, prov.storage)

bootQueries := []string{
"INSTALL arrow",
Expand Down Expand Up @@ -353,8 +352,8 @@ func (prov *DatabaseProvider) Pool() *ConnectionPool {
return prov.pool
}

func (prov *DatabaseProvider) CatalogName() string {
return prov.catalogName
func (prov *DatabaseProvider) DefaultCatalogName() string {
return prov.defaultCatalogName
}

func (prov *DatabaseProvider) DataDir() string {
Expand All @@ -380,7 +379,8 @@ func (prov *DatabaseProvider) AllDatabases(ctx *sql.Context) []sql.Database {
prov.mu.RLock()
defer prov.mu.RUnlock()

rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT schema_name FROM information_schema.schemata WHERE catalog_name = ?", prov.catalogName)
catalogName := adapter.GetCurrentCatalog(ctx)
rows, err := adapter.QueryCatalog(ctx, "SELECT DISTINCT schema_name FROM information_schema.schemata WHERE catalog_name = ?", catalogName)
if err != nil {
panic(ErrDuckDB.New(err))
}
Expand All @@ -398,7 +398,7 @@ func (prov *DatabaseProvider) AllDatabases(ctx *sql.Context) []sql.Database {
continue
}

all = append(all, NewDatabase(schemaName, prov.catalogName))
all = append(all, NewDatabase(schemaName, catalogName))
}

sort.Slice(all, func(i, j int) bool {
Expand All @@ -413,13 +413,14 @@ func (prov *DatabaseProvider) Database(ctx *sql.Context, name string) (sql.Datab
prov.mu.RLock()
defer prov.mu.RUnlock()

ok, err := hasDatabase(ctx, prov.catalogName, name)
catalogName := adapter.GetCurrentCatalog(ctx)
ok, err := hasDatabase(ctx, catalogName, name)
if err != nil {
return nil, err
}

if ok {
return NewDatabase(name, prov.catalogName), nil
return NewDatabase(name, catalogName), nil
}
return nil, sql.ErrDatabaseNotFound.New(name)
}
Expand All @@ -429,7 +430,7 @@ func (prov *DatabaseProvider) HasDatabase(ctx *sql.Context, name string) bool {
prov.mu.RLock()
defer prov.mu.RUnlock()

ok, err := hasDatabase(ctx, prov.catalogName, name)
ok, err := hasDatabase(ctx, adapter.GetCurrentCatalog(ctx), name)
if err != nil {
panic(err)
}
Expand All @@ -451,7 +452,8 @@ func (prov *DatabaseProvider) CreateDatabase(ctx *sql.Context, name string) erro
prov.mu.Lock()
defer prov.mu.Unlock()

_, err := adapter.ExecCatalog(ctx, fmt.Sprintf(`CREATE SCHEMA %s`, FullSchemaName(prov.catalogName, name)))
_, err := adapter.ExecCatalog(ctx, fmt.Sprintf(`CREATE SCHEMA %s`,
FullSchemaName(adapter.GetCurrentCatalog(ctx), name)))
if err != nil {
return ErrDuckDB.New(err)
}
Expand All @@ -464,7 +466,8 @@ func (prov *DatabaseProvider) DropDatabase(ctx *sql.Context, name string) error
prov.mu.Lock()
defer prov.mu.Unlock()

_, err := adapter.Exec(ctx, fmt.Sprintf(`DROP SCHEMA %s CASCADE`, FullSchemaName(prov.catalogName, name)))
_, err := adapter.Exec(ctx, fmt.Sprintf(`DROP SCHEMA %s CASCADE`,
FullSchemaName(adapter.GetCurrentCatalog(ctx), name)))
if err != nil {
return ErrDuckDB.New(err)
}
Expand Down Expand Up @@ -494,5 +497,5 @@ func (prov *DatabaseProvider) Restart(readOnly bool) error {
prov.connector = connector
prov.storage = storage

return nil
return prov.pool.Reset(connector, storage)
}
15 changes: 2 additions & 13 deletions pgserver/backup_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,6 @@ func parseBackupSQL(sql string) (*BackupConfig, error) {
}

func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, error) {
// TODO(neo.zty): Add support for backing up multiple databases once MyDuck Server supports multi-database functionality.
if backupConfig.DbName != h.server.Provider.CatalogName() {
return "", fmt.Errorf("backup database name %s does not match server database name %s",
backupConfig.DbName, h.server.Provider.CatalogName())
}

sqlCtx, err := h.duckHandler.sm.NewContextWithQuery(context.Background(), h.mysqlConn, "")
if err != nil {
return "", fmt.Errorf("failed to create context for query: %w", err)
Expand All @@ -114,7 +108,7 @@ func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, e
}

msg, err := backupConfig.StorageConfig.UploadFile(
h.server.Provider.DataDir(), h.server.Provider.DbFile(), backupConfig.RemotePath)
h.server.Provider.DataDir(), backupConfig.DbName+".db", backupConfig.RemotePath)
if err != nil {
return "", err
}
Expand All @@ -133,12 +127,7 @@ func (h *ConnectionHandler) executeBackup(backupConfig *BackupConfig) (string, e

func (h *ConnectionHandler) restartServer(readOnly bool) error {
provider := h.server.Provider
err := provider.Restart(readOnly)
if err != nil {
return err
}

return h.server.Provider.Pool().Reset(provider.CatalogName(), provider.Connector(), provider.Storage())
return provider.Restart(readOnly)
}

func doCheckpoint(sqlCtx *sql.Context) error {
Expand Down
2 changes: 1 addition & 1 deletion pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func (h *ConnectionHandler) chooseInitialDatabase(startupMessage *pgproto3.Start
}
if db == "postgres" || db == "mysql" {
if provider := h.duckHandler.GetCatalogProvider(); provider != nil {
db = provider.CatalogName()
db = provider.DefaultCatalogName()
}
}

Expand Down

0 comments on commit 2631948

Please sign in to comment.