From 85a560fb8ff5004b8828d0afbf052933a79adbaf Mon Sep 17 00:00:00 2001 From: Noy <130386570+NoyException@users.noreply.github.com> Date: Wed, 25 Dec 2024 14:41:53 +0800 Subject: [PATCH] feat: add support for managing multiple databases (#307) --- .github/workflows/go.yml | 2 +- backend/executor.go | 8 +- backend/handler.go | 13 +- backend/session.go | 29 ++-- {backend => catalog}/connpool.go | 7 +- catalog/internal_tables.go | 28 +--- catalog/provider.go | 231 ++++++++++++++++++++++++++----- catalog/provider_test.go | 154 +++++++++++++++++++++ harness/duck_harness.go | 15 +- main.go | 36 +---- pgserver/backup_handler.go | 2 +- pgserver/connection_handler.go | 8 +- pgserver/duck_handler.go | 44 +++++- pgserver/server.go | 7 +- pgserver/stmt.go | 4 +- pgtest/server.go | 13 +- replica/replication.go | 8 +- replica/updater.go | 6 +- 18 files changed, 460 insertions(+), 155 deletions(-) rename {backend => catalog}/connpool.go (95%) create mode 100644 catalog/provider_test.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 2a05ab4c..952cff71 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -44,7 +44,7 @@ jobs: - name: Test packages run: | - go test -v -cover ./charset ./transpiler ./backend ./harness ./pgserver | tee packages.log + go test -v -cover ./charset ./transpiler ./backend ./harness ./pgserver ./catalog | tee packages.log cat packages.log | grep -e "^--- " | sed 's/--- //g' | awk 'BEGIN {count=1} {printf "%d. %s\n", count++, $0}' cat packages.log | grep -q "FAIL" && exit 1 || exit 0 diff --git a/backend/executor.go b/backend/executor.go index 5b7fe7d3..44cd74b8 100644 --- a/backend/executor.go +++ b/backend/executor.go @@ -31,7 +31,6 @@ import ( type DuckBuilder struct { base sql.NodeExecBuilder - pool *ConnectionPool provider *catalog.DatabaseProvider @@ -40,10 +39,9 @@ type DuckBuilder struct { var _ sql.NodeExecBuilder = (*DuckBuilder)(nil) -func NewDuckBuilder(base sql.NodeExecBuilder, pool *ConnectionPool, provider *catalog.DatabaseProvider) *DuckBuilder { +func NewDuckBuilder(base sql.NodeExecBuilder, provider *catalog.DatabaseProvider) *DuckBuilder { return &DuckBuilder{ base: base, - pool: pool, provider: provider, } } @@ -106,14 +104,14 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row return b.base.Build(ctx, root, r) } - conn, err := b.pool.GetConnForSchema(ctx, ctx.ID(), ctx.GetCurrentDatabase()) + conn, err := b.provider.Pool().GetConnForSchema(ctx, ctx.ID(), ctx.GetCurrentDatabase()) if err != nil { return nil, err } switch node := n.(type) { case *plan.Use: - useStmt := "USE " + catalog.FullSchemaName(b.pool.catalog, node.Database().Name()) + useStmt := "USE " + catalog.FullSchemaName(b.provider.CatalogName(), node.Database().Name()) if _, err := conn.ExecContext(ctx.Context, useStmt); err != nil { if catalog.IsDuckDBSetSchemaNotFoundError(err) { return nil, sql.ErrDatabaseNotFound.New(node.Database().Name()) diff --git a/backend/handler.go b/backend/handler.go index 18416849..0cfd0927 100644 --- a/backend/handler.go +++ b/backend/handler.go @@ -17,6 +17,7 @@ package backend import ( "context" "fmt" + "github.com/apecloud/myduckserver/catalog" "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/vitess/go/mysql" @@ -25,16 +26,16 @@ import ( type MyHandler struct { *server.Handler - pool *ConnectionPool + provider *catalog.DatabaseProvider } func (h *MyHandler) ConnectionClosed(c *mysql.Conn) { - h.pool.CloseConn(c.ConnectionID) + h.provider.Pool().CloseConn(c.ConnectionID) h.Handler.ConnectionClosed(c) } func (h *MyHandler) ComInitDB(c *mysql.Conn, schemaName string) error { - _, err := h.pool.GetConnForSchema(context.Background(), c.ConnectionID, schemaName) + _, err := h.provider.Pool().GetConnForSchema(context.Background(), c.ConnectionID, schemaName) if err != nil { return err } @@ -78,7 +79,7 @@ func (h *MyHandler) ComQuery( return h.Handler.ComQuery(ctx, c, query, wrapResultCallback(callback, modifiers...)) } -func WrapHandler(pool *ConnectionPool) server.HandlerWrapper { +func WrapHandler(provider *catalog.DatabaseProvider) server.HandlerWrapper { return func(h mysql.Handler) (mysql.Handler, error) { handler, ok := h.(*server.Handler) if !ok { @@ -86,8 +87,8 @@ func WrapHandler(pool *ConnectionPool) server.HandlerWrapper { } return &MyHandler{ - Handler: handler, - pool: pool, + Handler: handler, + provider: provider, }, nil } } diff --git a/backend/session.go b/backend/session.go index afb9d940..52dfc612 100644 --- a/backend/session.go +++ b/backend/session.go @@ -31,12 +31,11 @@ import ( type Session struct { *memory.Session - db *catalog.DatabaseProvider - pool *ConnectionPool + db *catalog.DatabaseProvider } -func NewSession(base *memory.Session, provider *catalog.DatabaseProvider, pool *ConnectionPool) *Session { - return &Session{base, provider, pool} +func NewSession(base *memory.Session, provider *catalog.DatabaseProvider) *Session { + return &Session{base, provider} } // Provider returns the database provider for the session. @@ -45,11 +44,11 @@ func (sess *Session) Provider() *catalog.DatabaseProvider { } func (sess *Session) CurrentSchemaOfUnderlyingConn() string { - return sess.pool.CurrentSchema(sess.ID()) + return sess.db.Pool().CurrentSchema(sess.ID()) } // NewSessionBuilder returns a session builder for the given database provider. -func NewSessionBuilder(provider *catalog.DatabaseProvider, pool *ConnectionPool) func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) { +func NewSessionBuilder(provider *catalog.DatabaseProvider) func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) { return func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) { host := "" user := "" @@ -63,13 +62,13 @@ func NewSessionBuilder(provider *catalog.DatabaseProvider, pool *ConnectionPool) baseSession := sql.NewBaseSessionWithClientServer(addr, client, conn.ConnectionID) memSession := memory.NewSession(baseSession, provider) - schema := pool.CurrentSchema(conn.ConnectionID) + schema := provider.Pool().CurrentSchema(conn.ConnectionID) if schema != "" { logrus.Traceln("SessionBuilder: new session: current schema:", schema) memSession.SetCurrentDatabase(schema) } - return &Session{memSession, provider, pool}, nil + return &Session{memSession, provider}, nil } } @@ -203,37 +202,37 @@ func (sess *Session) GetPersistedValue(k string) (interface{}, error) { // GetConn implements adapter.ConnectionHolder. func (sess *Session) GetConn(ctx context.Context) (*stdsql.Conn, error) { - return sess.pool.GetConnForSchema(ctx, sess.ID(), sess.GetCurrentDatabase()) + return sess.db.Pool().GetConnForSchema(ctx, sess.ID(), sess.GetCurrentDatabase()) } // GetCatalogConn implements adapter.ConnectionHolder. func (sess *Session) GetCatalogConn(ctx context.Context) (*stdsql.Conn, error) { - return sess.pool.GetConn(ctx, sess.ID()) + return sess.db.Pool().GetConn(ctx, sess.ID()) } // GetTxn implements adapter.ConnectionHolder. func (sess *Session) GetTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) { - return sess.pool.GetTxn(ctx, sess.ID(), sess.GetCurrentDatabase(), options) + return sess.db.Pool().GetTxn(ctx, sess.ID(), sess.GetCurrentDatabase(), options) } // GetCatalogTxn implements adapter.ConnectionHolder. func (sess *Session) GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) { - return sess.pool.GetTxn(ctx, sess.ID(), "", options) + return sess.db.Pool().GetTxn(ctx, sess.ID(), "", options) } // TryGetTxn implements adapter.ConnectionHolder. func (sess *Session) TryGetTxn() *stdsql.Tx { - return sess.pool.TryGetTxn(sess.ID()) + return sess.db.Pool().TryGetTxn(sess.ID()) } // CloseTxn implements adapter.ConnectionHolder. func (sess *Session) CloseTxn() { - sess.pool.CloseTxn(sess.ID()) + sess.db.Pool().CloseTxn(sess.ID()) } // CloseConn implements adapter.ConnectionHolder. func (sess *Session) CloseConn() { - sess.pool.CloseConn(sess.ID()) + sess.db.Pool().CloseConn(sess.ID()) } func (sess *Session) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) { diff --git a/backend/connpool.go b/catalog/connpool.go similarity index 95% rename from backend/connpool.go rename to catalog/connpool.go index 45232e92..b89d9748 100644 --- a/backend/connpool.go +++ b/catalog/connpool.go @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -package backend +package catalog import ( "context" @@ -22,7 +22,6 @@ import ( "strings" "sync" - "github.com/apecloud/myduckserver/catalog" "github.com/dolthub/go-mysql-server/sql" "github.com/marcboeker/go-duckdb" "github.com/sirupsen/logrus" @@ -93,8 +92,8 @@ func (p *ConnectionPool) GetConnForSchema(ctx context.Context, id uint32, schema logrus.WithError(err).Error("Failed to get current schema") return nil, err } else if currentSchema != schemaName { - if _, err := conn.ExecContext(context.Background(), "USE "+catalog.FullSchemaName(p.catalog, schemaName)); err != nil { - if catalog.IsDuckDBSetSchemaNotFoundError(err) { + if _, err := conn.ExecContext(context.Background(), "USE "+FullSchemaName(p.catalog, schemaName)); err != nil { + if IsDuckDBSetSchemaNotFoundError(err) { return nil, sql.ErrDatabaseNotFound.New(schemaName) } logrus.WithField("schema", schemaName).WithError(err).Error("Failed to switch schema") diff --git a/catalog/internal_tables.go b/catalog/internal_tables.go index ff424ba4..14943186 100644 --- a/catalog/internal_tables.go +++ b/catalog/internal_tables.go @@ -58,9 +58,7 @@ func (it *InternalTable) UpsertStmt() string { var b strings.Builder b.Grow(128) b.WriteString("INSERT OR REPLACE INTO ") - b.WriteString(it.Schema) - b.WriteByte('.') - b.WriteString(it.Name) + b.WriteString(it.QualifiedName()) b.WriteString(" VALUES (?") for range it.KeyColumns[1:] { b.WriteString(", ?") @@ -76,9 +74,7 @@ func (it *InternalTable) DeleteStmt() string { var b strings.Builder b.Grow(128) b.WriteString("DELETE FROM ") - b.WriteString(it.Schema) - b.WriteByte('.') - b.WriteString(it.Name) + b.WriteString(it.QualifiedName()) b.WriteString(" WHERE ") b.WriteString(it.KeyColumns[0]) b.WriteString(" = ?") @@ -93,9 +89,7 @@ func (it *InternalTable) DeleteAllStmt() string { var b strings.Builder b.Grow(128) b.WriteString("DELETE FROM ") - b.WriteString(it.Schema) - b.WriteByte('.') - b.WriteString(it.Name) + b.WriteString(it.QualifiedName()) return b.String() } @@ -109,9 +103,7 @@ func (it *InternalTable) SelectStmt() string { b.WriteString(c) } b.WriteString(" FROM ") - b.WriteString(it.Schema) - b.WriteByte('.') - b.WriteString(it.Name) + b.WriteString(it.QualifiedName()) b.WriteString(" WHERE ") b.WriteString(it.KeyColumns[0]) b.WriteString(" = ?") @@ -133,9 +125,7 @@ func (it *InternalTable) SelectColumnsStmt(valueColumns []string) string { b.WriteString(c) } b.WriteString(" FROM ") - b.WriteString(it.Schema) - b.WriteByte('.') - b.WriteString(it.Name) + b.WriteString(it.QualifiedName()) b.WriteString(" WHERE ") b.WriteString(it.KeyColumns[0]) b.WriteString(" = ?") @@ -151,9 +141,7 @@ func (it *InternalTable) SelectAllStmt() string { var b strings.Builder b.Grow(128) b.WriteString("SELECT * FROM ") - b.WriteString(it.Schema) - b.WriteByte('.') - b.WriteString(it.Name) + b.WriteString(it.QualifiedName()) return b.String() } @@ -162,9 +150,7 @@ func (it *InternalTable) CountAllStmt() string { b.Grow(128) b.WriteString("SELECT COUNT(*)") b.WriteString(" FROM ") - b.WriteString(it.Schema) - b.WriteByte('.') - b.WriteString(it.Name) + b.WriteString(it.QualifiedName()) return b.String() } diff --git a/catalog/provider.go b/catalog/provider.go index b9367341..50ca5b0d 100644 --- a/catalog/provider.go +++ b/catalog/provider.go @@ -3,6 +3,8 @@ package catalog import ( "context" "fmt" + "github.com/sirupsen/logrus" + "os" "path/filepath" "sort" "strings" @@ -20,13 +22,16 @@ import ( type DatabaseProvider struct { mu *sync.RWMutex + defaultTimeZone string connector *duckdb.Connector storage *stdsql.DB - catalogName string + pool *ConnectionPool + catalogName string // database name in postgres dataDir string dbFile string dsn string externalProcedureRegistry sql.ExternalStoredProcedureRegistry + ready bool } var _ sql.DatabaseProvider = (*DatabaseProvider)(nil) @@ -37,31 +42,40 @@ var _ configuration.DataDirProvider = (*DatabaseProvider)(nil) const readOnlySuffix = "?access_mode=read_only" func NewInMemoryDBProvider() *DatabaseProvider { - prov, err := NewDBProvider(".", "") + prov, err := NewDBProvider("", ".", "") if err != nil { panic(err) } return prov } -func NewDBProvider(dataDir, dbFile string) (*DatabaseProvider, error) { - dbFile = strings.TrimSpace(dbFile) - name := "" - dsn := "" - if dbFile == "" { - // in-memory mode, mainly for testing - name = "memory" +func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (prov *DatabaseProvider, err error) { + prov = &DatabaseProvider{ + mu: &sync.RWMutex{}, + defaultTimeZone: defaultTimeZone, + externalProcedureRegistry: sql.NewExternalStoredProcedureRegistry(), // This has no effect, just to satisfy the upper layer interface + dataDir: dataDir, + } + + shouldInit := true + if defaultDB == "" || defaultDB == "memory" { + prov.catalogName = "memory" + prov.dbFile = "" + prov.dsn = "" } else { - name = strings.Split(dbFile, ".")[0] - dsn = filepath.Join(dataDir, dbFile) + prov.catalogName = defaultDB + prov.dbFile = defaultDB + ".db" + prov.dsn = filepath.Join(prov.dataDir, prov.dbFile) + _, err = os.Stat(prov.dsn) + shouldInit = os.IsNotExist(err) } - connector, err := duckdb.NewConnector(dsn, nil) + prov.connector, err = duckdb.NewConnector(prov.dsn, nil) if err != nil { return nil, err } - - storage := stdsql.OpenDB(connector) + prov.storage = stdsql.OpenDB(prov.connector) + prov.pool = NewConnectionPool(prov.catalogName, prov.connector, prov.storage) bootQueries := []string{ "INSTALL arrow", @@ -71,57 +85,200 @@ func NewDBProvider(dataDir, dbFile string) (*DatabaseProvider, error) { "INSTALL postgres_scanner", "LOAD postgres_scanner", } + for _, q := range bootQueries { - if _, err := storage.ExecContext(context.Background(), q); err != nil { - storage.Close() - connector.Close() + if _, err := prov.storage.ExecContext(context.Background(), q); err != nil { + prov.storage.Close() + prov.connector.Close() return nil, fmt.Errorf("failed to execute boot query %q: %w", q, err) } } + if shouldInit { + err = prov.initCatalog() + if err != nil { + return nil, err + } + } + + err = prov.attachCatalogs() + if err != nil { + return nil, err + } + + prov.ready = true + return prov, nil +} + +func (prov *DatabaseProvider) initCatalog() error { + for _, t := range internalSchemas { - if _, err := storage.ExecContext( + if _, err := prov.storage.ExecContext( context.Background(), "CREATE SCHEMA IF NOT EXISTS "+t.Schema, ); err != nil { - return nil, fmt.Errorf("failed to create internal schema %q: %w", t.Schema, err) + return fmt.Errorf("failed to create internal schema %q: %w", t.Schema, err) } } for _, t := range internalTables { - if _, err := storage.ExecContext( + if _, err := prov.storage.ExecContext( context.Background(), "CREATE SCHEMA IF NOT EXISTS "+t.Schema, ); err != nil { - return nil, fmt.Errorf("failed to create internal schema %q: %w", t.Schema, err) + return fmt.Errorf("failed to create internal schema %q: %w", t.Schema, err) } - if _, err := storage.ExecContext( + if _, err := prov.storage.ExecContext( context.Background(), "CREATE TABLE IF NOT EXISTS "+t.QualifiedName()+"("+t.DDL+")", ); err != nil { - return nil, fmt.Errorf("failed to create internal table %q: %w", t.Name, err) + return fmt.Errorf("failed to create internal table %q: %w", t.Name, err) } for _, row := range t.InitialData { - if _, err := storage.ExecContext( + if _, err := prov.storage.ExecContext( context.Background(), t.UpsertStmt(), row..., ); err != nil { - return nil, fmt.Errorf("failed to insert initial data into internal table %q: %w", t.Name, err) + return fmt.Errorf("failed to insert initial data into internal table %q: %w", t.Name, err) } } } - return &DatabaseProvider{ - mu: &sync.RWMutex{}, - connector: connector, - storage: storage, - catalogName: name, - dataDir: dataDir, - dbFile: dbFile, - dsn: dsn, - externalProcedureRegistry: sql.NewExternalStoredProcedureRegistry(), // This has no effect, just to satisfy the upper layer interface - }, nil + if _, err := prov.pool.ExecContext(context.Background(), "PRAGMA enable_checkpoint_on_shutdown"); err != nil { + logrus.WithError(err).Fatalln("Failed to enable checkpoint on shutdown") + } + + if prov.defaultTimeZone != "" { + _, err := prov.pool.ExecContext(context.Background(), fmt.Sprintf(`SET TimeZone = '%s'`, prov.defaultTimeZone)) + if err != nil { + logrus.WithError(err).Fatalln("Failed to set the default time zone") + } + } + + // Postgres tables are created in the `public` schema by default. + // Create the `public` schema if it doesn't exist. + _, err := prov.pool.ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS public") + if err != nil { + logrus.WithError(err).Fatalln("Failed to create the `public` schema") + } + return nil +} + +func (prov *DatabaseProvider) IsReady() bool { + return prov.ready +} + +func (prov *DatabaseProvider) HasCatalog(name string) bool { + name = strings.TrimSpace(name) + // in memory database does not need to be created + if name == "" || name == "memory" { + return true + } + + dsn := filepath.Join(prov.dataDir, name+".db") + // if already exists, return error + _, err := os.Stat(dsn) + return os.IsExist(err) +} + +// attachCatalogs attaches all the databases in the data directory +func (prov *DatabaseProvider) attachCatalogs() error { + files, err := os.ReadDir(prov.dataDir) + if err != nil { + return fmt.Errorf("failed to read data directory: %w", err) + } + for _, file := range files { + if file.IsDir() { + continue + } + if !strings.HasSuffix(file.Name(), ".db") { + continue + } + name := strings.TrimSuffix(file.Name(), ".db") + if _, err := prov.storage.ExecContext(context.Background(), "ATTACH IF NOT EXISTS '"+filepath.Join(prov.dataDir, file.Name())+"' AS "+name); err != nil { + logrus.WithError(err).Errorf("Failed to attach database %s", name) + } + } + return nil +} + +func (prov *DatabaseProvider) CreateCatalog(name string, ifNotExists bool) error { + name = strings.TrimSpace(name) + // in memory database does not need to be created + if name == "" || name == "memory" { + return nil + } + dsn := filepath.Join(prov.dataDir, name+".db") + + _, err := os.Stat(dsn) + shouldInit := os.IsNotExist(err) + + // attach + attachSQL := "ATTACH" + if ifNotExists { + attachSQL += " IF NOT EXISTS" + } + attachSQL += " '" + dsn + "' AS " + name + _, err = prov.storage.ExecContext(context.Background(), attachSQL) + if err != nil { + return err + } + + if shouldInit { + res, err := prov.storage.QueryContext(context.Background(), "SELECT current_catalog") + if err != nil { + return fmt.Errorf("failed to init catalog: %w", err) + } + lastCatalog := "" + for res.Next() { + if err := res.Scan(&lastCatalog); err != nil { + return fmt.Errorf("failed to init catalog: %w", err) + } + } + + if _, err := prov.storage.ExecContext(context.Background(), "USE "+name); err != nil { + return fmt.Errorf("failed to switch to the new catalog: %w", err) + } + + defer func() { + if _, err := prov.storage.ExecContext(context.Background(), "USE "+lastCatalog); err != nil { + logrus.WithError(err).Errorln("Failed to switch back to the old catalog") + } + }() + err = prov.initCatalog() + if err != nil { + return err + } + } + return nil +} + +func (prov *DatabaseProvider) DropCatalog(name string, ifExists bool) error { + name = strings.TrimSpace(name) + // in memory database does not need to be created + if name == "" || name == "memory" { + return fmt.Errorf("cannot drop the in-memory catalog") + } + dsn := filepath.Join(prov.dataDir, name+".db") + // if file does not exist, return error + _, err := os.Stat(dsn) + if os.IsNotExist(err) { + if ifExists { + return nil + } + return fmt.Errorf("database file %s does not exist", dsn) + } + // detach + if _, err := prov.storage.ExecContext(context.Background(), "DETACH "+name); err != nil { + return fmt.Errorf("failed to detach catalog %w", err) + } + // delete the file + err = os.Remove(dsn) + if err != nil { + return fmt.Errorf("failed to delete database file %s: %w", dsn, err) + } + return nil } func (prov *DatabaseProvider) Close() error { @@ -137,6 +294,10 @@ func (prov *DatabaseProvider) Storage() *stdsql.DB { return prov.storage } +func (prov *DatabaseProvider) Pool() *ConnectionPool { + return prov.pool +} + func (prov *DatabaseProvider) CatalogName() string { return prov.catalogName } diff --git a/catalog/provider_test.go b/catalog/provider_test.go new file mode 100644 index 00000000..ff68b144 --- /dev/null +++ b/catalog/provider_test.go @@ -0,0 +1,154 @@ +package catalog + +import ( + "context" + "github.com/apecloud/myduckserver/testutil" + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" + "strconv" + "testing" +) + +type Execution struct { + SQL string + Expected string + WantErr bool +} + +func TestCreateCatalog(t *testing.T) { + tests := []struct { + name string + executions []Execution + }{ + { + name: "create database", + executions: []Execution{ + { + SQL: "CREATE DATABASE testdb1;", + Expected: "CREATE DATABASE", + }, + // Can not create the database with the same name + { + SQL: "CREATE DATABASE testdb1;", + WantErr: true, + }, + { + SQL: "CREATE DATABASE IF NOT EXISTS testdb1;", + Expected: "CREATE DATABASE", + }, + { + SQL: "CREATE DATABASE testdb2;", + Expected: "CREATE DATABASE", + }, + }, + }, + { + name: "switch database", + executions: []Execution{ + { + SQL: "USE testdb1;", + Expected: "SET", + }, + { + SQL: "CREATE SCHEMA testschema1;", + Expected: "CREATE SCHEMA", + }, + { + SQL: "USE testdb1.testschema1;", + Expected: "SET", + }, + { + SQL: "USE testdb2;", + Expected: "SET", + }, + // Can not drop the schema as it is not in the current database + { + SQL: "DROP SCHEMA testschema1;", + WantErr: true, + }, + { + SQL: "USE testdb1;", + Expected: "SET", + }, + { + SQL: "DROP SCHEMA testschema1;", + Expected: "DROP SCHEMA", + }, + // Can not switch to the schema that does not exist + { + SQL: "USE testdb1.testschema1;", + WantErr: true, + }, + }, + }, + { + name: "drop database", + executions: []Execution{ + //{ + // SQL: "USE testdb1;", + // Expected: "SET", + //}, + //// Can not drop the database when the current database is the one to be dropped + //{ + // SQL: "DROP DATABASE testdb1;", + // WantErr: true, + //}, + { + SQL: "USE testdb2;", + Expected: "SET", + }, + { + SQL: "DROP DATABASE testdb1;", + Expected: "DROP DATABASE", + }, + { + SQL: "DROP DATABASE testdb1;", + WantErr: true, + }, + { + SQL: "DROP DATABASE IF EXISTS testdb1;", + Expected: "DROP DATABASE", + }, + }, + }, + } + testDir := testutil.CreateTestDir(t) + testEnv := testutil.NewTestEnv() + err := testutil.StartDuckSqlServer(t, testDir, nil, testEnv) + require.NoError(t, err) + defer testutil.StopDuckSqlServer(t, testEnv.DuckProcess) + dsn := "postgresql://postgres@localhost:" + strconv.Itoa(testEnv.DuckPgPort) + "/postgres" + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Connect to MyDuck Server + db, err := pgx.Connect(context.Background(), dsn) + if err != nil { + t.Errorf("Connect failed! dsn = %v, err: %v", dsn, err) + return + } + defer db.Close(context.Background()) + + for _, execution := range tt.executions { + func() { + tag, err := db.Exec(context.Background(), execution.SQL) + if execution.WantErr { + if err != nil { + return + } + t.Errorf("Test expectes error but got none! sql: %v", execution.SQL) + return + } + if err != nil { + t.Errorf("Query failed! sql: %v, err: %v", execution.SQL, err) + return + } + // check whether the result is as expected + if tag.String() != execution.Expected { + t.Errorf("sql: %v, got %v, want %v", execution.SQL, tag.String(), execution.Expected) + } + }() + } + }) + } +} diff --git a/harness/duck_harness.go b/harness/duck_harness.go index a2de8053..4f9530d4 100644 --- a/harness/duck_harness.go +++ b/harness/duck_harness.go @@ -49,7 +49,7 @@ type DuckHarness struct { numTablePartitions int // readonly bool provider sql.DatabaseProvider - pool *backend.ConnectionPool + pool *catalog.ConnectionPool indexDriverInitializer IndexDriverInitializer driver sql.IndexDriver nativeIndexSupport bool @@ -122,7 +122,7 @@ func (m *DuckHarness) SessionBuilder() server.SessionBuilder { client := sql.Client{Address: host, User: user, Capabilities: c.Capabilities} baseSession := sql.NewBaseSessionWithClientServer(addr, client, c.ConnectionID) memSession := memory.NewSession(baseSession, m.getProvider()) - return backend.NewSession(memSession, m.getProvider().(*catalog.DatabaseProvider), m.pool), nil + return backend.NewSession(memSession, m.getProvider().(*catalog.DatabaseProvider)), nil } } @@ -285,7 +285,7 @@ func (m *DuckHarness) NewEngine(t *testing.T) (enginetest.QueryEngine, error) { type DuckTestEngine struct { enginetest.QueryEngine - pool *backend.ConnectionPool + pool *catalog.ConnectionPool } func (e *DuckTestEngine) Close() error { @@ -297,13 +297,12 @@ func (e *DuckTestEngine) Close() error { func NewEngine(t *testing.T, harness enginetest.Harness, dbProvider sql.DatabaseProvider, setupData []setup.SetupScript, statsProvider sql.StatsProvider, server bool) (enginetest.QueryEngine, error) { // Create the connection pool first, as it is needed by `NewEngineWithProvider` provider := dbProvider.(*catalog.DatabaseProvider) - pool := backend.NewConnectionPool(provider.CatalogName(), provider.Connector(), provider.Storage()) - harness.(*DuckHarness).pool = pool + harness.(*DuckHarness).pool = provider.Pool() e := enginetest.NewEngineWithProvider(t, harness, dbProvider) e.Analyzer.Catalog.StatsProvider = statsProvider - builder := backend.NewDuckBuilder(e.Analyzer.ExecBuilder, pool, provider) + builder := backend.NewDuckBuilder(e.Analyzer.ExecBuilder, provider) e.Analyzer.ExecBuilder = builder ctx := enginetest.NewContext(harness) @@ -329,7 +328,7 @@ func NewEngine(t *testing.T, harness enginetest.Harness, dbProvider sql.Database return nil, err } } - return &DuckTestEngine{qe, pool}, nil + return &DuckTestEngine{qe, provider.Pool()}, nil } func (m *DuckHarness) SupportsNativeIndexCreation() bool { @@ -365,7 +364,7 @@ func (m *DuckHarness) newSession() sql.Session { if m.driver != nil { session.GetIndexRegistry().RegisterIndexDriver(m.driver) } - return backend.NewSession(session, m.getProvider().(*catalog.DatabaseProvider), m.pool) + return backend.NewSession(session, m.getProvider().(*catalog.DatabaseProvider)) } func (m *DuckHarness) NewContextWithClient(client sql.Client) *sql.Context { diff --git a/main.go b/main.go index 17ef9bae..6b6fec3e 100644 --- a/main.go +++ b/main.go @@ -52,7 +52,6 @@ var ( socket string defaultDb = "myduck" dataDirectory = "." - dbFileName string logLevel = int(logrus.InfoLevel) replicaOptions replica.ReplicaOptions @@ -112,7 +111,6 @@ func ensureSQLTranslate() { func main() { flag.Parse() // Parse all flags - dbFileName = defaultDb + ".db" if replicaOptions.ReportPort == 0 { replicaOptions.ReportPort = port @@ -130,31 +128,18 @@ func main() { return } - provider, err := catalog.NewDBProvider(dataDirectory, dbFileName) + provider, err := catalog.NewDBProvider(defaultTimeZone, dataDirectory, defaultDb) if err != nil { logrus.Fatalln("Failed to open the database:", err) } defer provider.Close() - pool := backend.NewConnectionPool(provider.CatalogName(), provider.Connector(), provider.Storage()) - - if _, err := pool.ExecContext(context.Background(), "PRAGMA enable_checkpoint_on_shutdown"); err != nil { - logrus.WithError(err).Fatalln("Failed to enable checkpoint on shutdown") - } - - if defaultTimeZone != "" { - _, err := pool.ExecContext(context.Background(), fmt.Sprintf(`SET TimeZone = '%s'`, defaultTimeZone)) - if err != nil { - logrus.WithError(err).Fatalln("Failed to set the default time zone") - } - } - // Clear the pipes directory on startup. backend.RemoveAllPipes(dataDirectory) engine := sqle.NewDefault(provider) - builder := backend.NewDuckBuilder(engine.Analyzer.ExecBuilder, pool, provider) + builder := backend.NewDuckBuilder(engine.Analyzer.ExecBuilder, provider) engine.Analyzer.ExecBuilder = builder engine.Analyzer.Catalog.RegisterFunction(sql.NewContext(context.Background()), myfunc.ExtraBuiltIns...) engine.Analyzer.Catalog.MySQLDb.SetPlugins(plugin.AuthPlugins) @@ -164,32 +149,25 @@ func main() { } replica.RegisterReplicaOptions(&replicaOptions) - replica.RegisterReplicaController(provider, engine, pool, builder) + replica.RegisterReplicaController(provider, engine, builder) serverConfig := server.Config{ Protocol: "tcp", Address: fmt.Sprintf("%s:%d", address, port), Socket: socket, } - myServer, err := server.NewServerWithHandler(serverConfig, engine, backend.NewSessionBuilder(provider, pool), nil, backend.WrapHandler(pool)) + myServer, err := server.NewServerWithHandler(serverConfig, engine, backend.NewSessionBuilder(provider), nil, backend.WrapHandler(provider)) if err != nil { logrus.WithError(err).Fatalln("Failed to create MySQL-protocol server") } if postgresPort > 0 { - // Postgres tables are created in the `public` schema by default. - // Create the `public` schema if it doesn't exist. - _, err := pool.ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS public") - if err != nil { - logrus.WithError(err).Fatalln("Failed to create the `public` schema") - } - pgServer, err := pgserver.NewServer( - provider, pool, + provider, address, postgresPort, superuserPassword, func() *sql.Context { - session := backend.NewSession(memory.NewSession(sql.NewBaseSession(), provider), provider, pool) + session := backend.NewSession(memory.NewSession(sql.NewBaseSession(), provider), provider) return sql.NewContext(context.Background(), sql.WithSession(session)) }, pgserver.WithEngine(myServer.Engine), @@ -263,7 +241,7 @@ func executeRestoreIfNeeded() { msg, err := pgserver.ExecuteRestore( defaultDb, dataDirectory, - dbFileName, + defaultDb+".db", restoreFile, restoreEndpoint, restoreAccessKeyId, diff --git a/pgserver/backup_handler.go b/pgserver/backup_handler.go index 5fec6770..65b0ae04 100644 --- a/pgserver/backup_handler.go +++ b/pgserver/backup_handler.go @@ -138,7 +138,7 @@ func (h *ConnectionHandler) restartServer(readOnly bool) error { return err } - return h.server.ConnPool.Reset(provider.CatalogName(), provider.Connector(), provider.Storage()) + return h.server.Provider.Pool().Reset(provider.CatalogName(), provider.Connector(), provider.Storage()) } func doCheckpoint(sqlCtx *sql.Context) error { diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index ec558bc3..8ecf5d3d 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -31,8 +31,6 @@ import ( "sync/atomic" "github.com/apecloud/myduckserver/adapter" - "github.com/apecloud/myduckserver/catalog" - "github.com/cockroachdb/cockroachdb-parser/pkg/sql/parser" "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" gms "github.com/dolthub/go-mysql-server" @@ -313,8 +311,8 @@ func (h *ConnectionHandler) chooseInitialDatabase(startupMessage *pgproto3.Start if !dbSpecified { db = h.mysqlConn.User } - if db == "postgres" { - if provider, ok := h.duckHandler.e.Analyzer.Catalog.DbProvider.(*catalog.DatabaseProvider); ok { + if db == "postgres" || db == "mysql" { + if provider := h.duckHandler.GetCatalogProvider(); provider != nil { db = provider.CatalogName() } } @@ -334,7 +332,7 @@ func (h *ConnectionHandler) chooseInitialDatabase(startupMessage *pgproto3.Start _ = h.send(&pgproto3.ErrorResponse{ Severity: string(ErrorResponseSeverity_Fatal), Code: "3D000", - Message: fmt.Sprintf(`"database "%s" does not exist"`, db), + Message: fmt.Sprintf(`"database "%s" does not exist, err: %v"`, db, err), Routine: "InitPostgres", }) return err diff --git a/pgserver/duck_handler.go b/pgserver/duck_handler.go index 2b3c7375..a2ddd69b 100644 --- a/pgserver/duck_handler.go +++ b/pgserver/duck_handler.go @@ -20,6 +20,7 @@ import ( "database/sql/driver" "encoding/base64" "fmt" + "github.com/apecloud/myduckserver/catalog" "io" "os" "regexp" @@ -87,6 +88,14 @@ type DuckHandler struct { var _ Handler = &DuckHandler{} +func (h *DuckHandler) GetCatalogProvider() *catalog.DatabaseProvider { + provider, ok := h.e.Analyzer.Catalog.DbProvider.(*catalog.DatabaseProvider) + if !ok { + return nil + } + return provider +} + // ComBind implements the Handler interface. func (h *DuckHandler) ComBind(ctx context.Context, c *mysql.Conn, prepared PreparedStatementData, bindVars []any) ([]pgproto3.FieldDescription, error) { vars := make([]driver.NamedValue, len(bindVars)) @@ -409,8 +418,8 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S // Consequently, the following classification is not perfect. switch parsed.(type) { case *tree.BeginTransaction, *tree.CommitTransaction, *tree.RollbackTransaction, - *tree.SetVar, *tree.CreateTable, *tree.DropTable, *tree.AlterTable, *tree.CreateIndex, *tree.DropIndex, - *tree.Insert, *tree.Update, *tree.Delete, *tree.Truncate, *tree.CopyFrom, *tree.CopyTo: + *tree.CreateTable, *tree.DropTable, *tree.AlterTable, *tree.CreateIndex, *tree.DropIndex, + *tree.Insert, *tree.Update, *tree.Delete, *tree.Truncate, *tree.CopyFrom, *tree.CopyTo, *tree.SetVar: result, err = adapter.Exec(ctx, query) if err != nil { break @@ -422,9 +431,36 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S RowsAffected: uint64(affected), InsertID: uint64(insertId), })) - + case *tree.CreateDatabase: + provider := h.GetCatalogProvider() + if provider == nil { + err = fmt.Errorf("database provider not found") + break + } + p := parsed.(*tree.CreateDatabase) + dbName := p.Name.String() + err = provider.CreateCatalog(dbName, p.IfNotExists) + if err != nil { + break + } + schema = types.OkResultSchema + iter = sql.RowsToRowIter(sql.NewRow(types.OkResult{})) + case *tree.DropDatabase: + provider := h.GetCatalogProvider() + if provider == nil { + err = fmt.Errorf("database provider not found") + break + } + p := parsed.(*tree.DropDatabase) + dbName := parsed.(*tree.DropDatabase).Name.String() + err = provider.DropCatalog(dbName, p.IfExists) + if err != nil { + break + } + schema = types.OkResultSchema + iter = sql.RowsToRowIter(sql.NewRow(types.OkResult{})) default: - rows, err = adapter.QueryCatalog(ctx, ConvertToSys(query)) + rows, err = adapter.QueryCatalog(ctx, query) if err != nil { break } diff --git a/pgserver/server.go b/pgserver/server.go index 1a9b5771..b1c5827d 100644 --- a/pgserver/server.go +++ b/pgserver/server.go @@ -2,8 +2,6 @@ package pgserver import ( "fmt" - - "github.com/apecloud/myduckserver/backend" "github.com/apecloud/myduckserver/catalog" "github.com/dolthub/go-mysql-server/server" "github.com/dolthub/go-mysql-server/sql" @@ -13,11 +11,10 @@ import ( type Server struct { Listener *Listener Provider *catalog.DatabaseProvider - ConnPool *backend.ConnectionPool NewInternalCtx func() *sql.Context } -func NewServer(provider *catalog.DatabaseProvider, connPool *backend.ConnectionPool, host string, port int, password string, newCtx func() *sql.Context, options ...ListenerOpt) (*Server, error) { +func NewServer(provider *catalog.DatabaseProvider, host string, port int, password string, newCtx func() *sql.Context, options ...ListenerOpt) (*Server, error) { InitSuperuser(password) addr := fmt.Sprintf("%s:%d", host, port) l, err := server.NewListener("tcp", addr, "") @@ -35,7 +32,7 @@ func NewServer(provider *catalog.DatabaseProvider, connPool *backend.ConnectionP if err != nil { return nil, err } - return &Server{Listener: listener, Provider: provider, ConnPool: connPool, NewInternalCtx: newCtx}, nil + return &Server{Listener: listener, Provider: provider, NewInternalCtx: newCtx}, nil } func (s *Server) Start() { diff --git a/pgserver/stmt.go b/pgserver/stmt.go index 27a4c8d1..c88e44e5 100644 --- a/pgserver/stmt.go +++ b/pgserver/stmt.go @@ -278,11 +278,11 @@ func getPgCatalogRegex() *regexp.Regexp { } tableNames = append(tableNames, table.Name) } - pgCatalogRegex = regexp.MustCompile(`(?i)\b(?:FROM|JOIN)\s+(?:pg_catalog\.)?(` + strings.Join(tableNames, "|") + `)`) + pgCatalogRegex = regexp.MustCompile(`(?i)\b(FROM|JOIN|INTO)\s+(?:pg_catalog\.)?(` + strings.Join(tableNames, "|") + `)`) }) return pgCatalogRegex } func ConvertToSys(sql string) string { - return getPgCatalogRegex().ReplaceAllString(RemoveComments(sql), " __sys__.$1") + return getPgCatalogRegex().ReplaceAllString(RemoveComments(sql), "$1 __sys__.$2") } diff --git a/pgtest/server.go b/pgtest/server.go index 92204bf7..8b432535 100644 --- a/pgtest/server.go +++ b/pgtest/server.go @@ -20,25 +20,24 @@ import ( func CreateTestServer(t *testing.T, port int) (ctx context.Context, pgServer *pgserver.Server, conn *pgx.Conn, close func() error, err error) { provider := catalog.NewInMemoryDBProvider() - pool := backend.NewConnectionPool(provider.CatalogName(), provider.Connector(), provider.Storage()) // Postgres tables are created in the `public` schema by default. // Create the `public` schema if it doesn't exist. - _, err = pool.ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS public") + _, err = provider.Pool().ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS public") if err != nil { return nil, nil, nil, nil, err } engine := sqle.NewDefault(provider) - builder := backend.NewDuckBuilder(engine.Analyzer.ExecBuilder, pool, provider) + builder := backend.NewDuckBuilder(engine.Analyzer.ExecBuilder, provider) engine.Analyzer.ExecBuilder = builder config := server.Config{ Address: fmt.Sprintf("127.0.0.1:%d", port-1), // Unused } - sb := backend.NewSessionBuilder(provider, pool) + sb := backend.NewSessionBuilder(provider) tracer := sql.NoopTracer sm := server.NewSessionManager( @@ -52,11 +51,11 @@ func CreateTestServer(t *testing.T, port int) (ctx context.Context, pgServer *pg var connID atomic.Uint32 pgServer, err = pgserver.NewServer( - provider, pool, + provider, "127.0.0.1", port, "", func() *sql.Context { - session := backend.NewSession(memory.NewSession(sql.NewBaseSession(), provider), provider, pool) + session := backend.NewSession(memory.NewSession(sql.NewBaseSession(), provider), provider) return sql.NewContext(context.Background(), sql.WithSession(session)) }, pgserver.WithEngine(engine), @@ -74,7 +73,7 @@ func CreateTestServer(t *testing.T, port int) (ctx context.Context, pgServer *pg close = func() error { pgServer.Listener.Close() return errors.Join( - pool.Close(), + provider.Pool().Close(), provider.Close(), ) } diff --git a/replica/replication.go b/replica/replication.go index 5a039b84..8b6671e6 100644 --- a/replica/replication.go +++ b/replica/replication.go @@ -33,19 +33,19 @@ import ( // registerReplicaController registers the replica controller into the engine // to handle the replication commands, such as START REPLICA, STOP REPLICA, etc. -func RegisterReplicaController(provider *catalog.DatabaseProvider, engine *sqle.Engine, pool *backend.ConnectionPool, builder *backend.DuckBuilder) { +func RegisterReplicaController(provider *catalog.DatabaseProvider, engine *sqle.Engine, builder *backend.DuckBuilder) { replica := binlogreplication.MyBinlogReplicaController replica.SetEngine(engine) stdctx := context.Background() stdctx = mycontext.WithQueryOrigin(stdctx, mycontext.MySQLReplicationQueryOrigin) - session := backend.NewSession(memory.NewSession(sql.NewBaseSession(), provider), provider, pool) + session := backend.NewSession(memory.NewSession(sql.NewBaseSession(), provider), provider) ctx := sql.NewContext(stdctx, sql.WithSession(session)) ctx.SetCurrentDatabase("mysql") replica.SetExecutionContext(ctx) - twp := &tableWriterProvider{pool: pool} + twp := &tableWriterProvider{provider: provider} twp.controller = delta.NewController() replica.SetTableWriterProvider(twp) @@ -60,7 +60,7 @@ func RegisterReplicaController(provider *catalog.DatabaseProvider, engine *sqle. } type tableWriterProvider struct { - pool *backend.ConnectionPool + provider *catalog.DatabaseProvider controller *delta.DeltaController } diff --git a/replica/updater.go b/replica/updater.go index bf982f0b..eef5dea2 100644 --- a/replica/updater.go +++ b/replica/updater.go @@ -3,9 +3,9 @@ package replica import ( stdsql "database/sql" "errors" + "github.com/apecloud/myduckserver/catalog" "strings" - "github.com/apecloud/myduckserver/backend" "github.com/apecloud/myduckserver/binlog" "github.com/apecloud/myduckserver/binlogreplication" "github.com/dolthub/go-mysql-server/sql" @@ -121,7 +121,7 @@ func (twp *tableWriterProvider) newTableUpdater( } return &tableUpdater{ - pool: twp.pool, + provider: twp.provider, stmt: stmt, replace: replace, cleanup: cleanup, @@ -230,7 +230,7 @@ func buildUpdateTemplate(tableName string, columnCount int, schema sql.Schema, p } type tableUpdater struct { - pool *backend.ConnectionPool + provider *catalog.DatabaseProvider tx *stdsql.Tx stmt *stdsql.Stmt replace bool