Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for managing multiple databases #307

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:

- name: Test packages
run: |
go test -v -cover ./charset ./transpiler ./backend ./harness ./pgserver | tee packages.log
go test -v -cover ./charset ./transpiler ./backend ./harness ./pgserver ./catalog | tee packages.log
cat packages.log | grep -e "^--- " | sed 's/--- //g' | awk 'BEGIN {count=1} {printf "%d. %s\n", count++, $0}'
cat packages.log | grep -q "FAIL" && exit 1 || exit 0

Expand Down
8 changes: 3 additions & 5 deletions backend/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (

type DuckBuilder struct {
base sql.NodeExecBuilder
pool *ConnectionPool

provider *catalog.DatabaseProvider

Expand All @@ -40,10 +39,9 @@ type DuckBuilder struct {

var _ sql.NodeExecBuilder = (*DuckBuilder)(nil)

func NewDuckBuilder(base sql.NodeExecBuilder, pool *ConnectionPool, provider *catalog.DatabaseProvider) *DuckBuilder {
func NewDuckBuilder(base sql.NodeExecBuilder, provider *catalog.DatabaseProvider) *DuckBuilder {
return &DuckBuilder{
base: base,
pool: pool,
provider: provider,
}
}
Expand Down Expand Up @@ -106,14 +104,14 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
return b.base.Build(ctx, root, r)
}

conn, err := b.pool.GetConnForSchema(ctx, ctx.ID(), ctx.GetCurrentDatabase())
conn, err := b.provider.Pool().GetConnForSchema(ctx, ctx.ID(), ctx.GetCurrentDatabase())
if err != nil {
return nil, err
}

switch node := n.(type) {
case *plan.Use:
useStmt := "USE " + catalog.FullSchemaName(b.pool.catalog, node.Database().Name())
useStmt := "USE " + catalog.FullSchemaName(b.provider.CatalogName(), node.Database().Name())
if _, err := conn.ExecContext(ctx.Context, useStmt); err != nil {
if catalog.IsDuckDBSetSchemaNotFoundError(err) {
return nil, sql.ErrDatabaseNotFound.New(node.Database().Name())
Expand Down
13 changes: 7 additions & 6 deletions backend/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package backend
import (
"context"
"fmt"
"github.com/apecloud/myduckserver/catalog"

"github.com/dolthub/go-mysql-server/server"
"github.com/dolthub/vitess/go/mysql"
Expand All @@ -25,16 +26,16 @@ import (

type MyHandler struct {
*server.Handler
pool *ConnectionPool
provider *catalog.DatabaseProvider
}

func (h *MyHandler) ConnectionClosed(c *mysql.Conn) {
h.pool.CloseConn(c.ConnectionID)
h.provider.Pool().CloseConn(c.ConnectionID)
h.Handler.ConnectionClosed(c)
}

func (h *MyHandler) ComInitDB(c *mysql.Conn, schemaName string) error {
_, err := h.pool.GetConnForSchema(context.Background(), c.ConnectionID, schemaName)
_, err := h.provider.Pool().GetConnForSchema(context.Background(), c.ConnectionID, schemaName)
if err != nil {
return err
}
Expand Down Expand Up @@ -78,16 +79,16 @@ func (h *MyHandler) ComQuery(
return h.Handler.ComQuery(ctx, c, query, wrapResultCallback(callback, modifiers...))
}

func WrapHandler(pool *ConnectionPool) server.HandlerWrapper {
func WrapHandler(provider *catalog.DatabaseProvider) server.HandlerWrapper {
return func(h mysql.Handler) (mysql.Handler, error) {
handler, ok := h.(*server.Handler)
if !ok {
return nil, fmt.Errorf("expected *server.Handler, got %T", h)
}

return &MyHandler{
Handler: handler,
pool: pool,
Handler: handler,
provider: provider,
}, nil
}
}
29 changes: 14 additions & 15 deletions backend/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ import (

type Session struct {
*memory.Session
db *catalog.DatabaseProvider
pool *ConnectionPool
db *catalog.DatabaseProvider
}

func NewSession(base *memory.Session, provider *catalog.DatabaseProvider, pool *ConnectionPool) *Session {
return &Session{base, provider, pool}
func NewSession(base *memory.Session, provider *catalog.DatabaseProvider) *Session {
return &Session{base, provider}
}

// Provider returns the database provider for the session.
Expand All @@ -45,11 +44,11 @@ func (sess *Session) Provider() *catalog.DatabaseProvider {
}

func (sess *Session) CurrentSchemaOfUnderlyingConn() string {
return sess.pool.CurrentSchema(sess.ID())
return sess.db.Pool().CurrentSchema(sess.ID())
}

// NewSessionBuilder returns a session builder for the given database provider.
func NewSessionBuilder(provider *catalog.DatabaseProvider, pool *ConnectionPool) func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
func NewSessionBuilder(provider *catalog.DatabaseProvider) func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
return func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
host := ""
user := ""
Expand All @@ -63,13 +62,13 @@ func NewSessionBuilder(provider *catalog.DatabaseProvider, pool *ConnectionPool)
baseSession := sql.NewBaseSessionWithClientServer(addr, client, conn.ConnectionID)
memSession := memory.NewSession(baseSession, provider)

schema := pool.CurrentSchema(conn.ConnectionID)
schema := provider.Pool().CurrentSchema(conn.ConnectionID)
if schema != "" {
logrus.Traceln("SessionBuilder: new session: current schema:", schema)
memSession.SetCurrentDatabase(schema)
}

return &Session{memSession, provider, pool}, nil
return &Session{memSession, provider}, nil
}
}

Expand Down Expand Up @@ -203,37 +202,37 @@ func (sess *Session) GetPersistedValue(k string) (interface{}, error) {

// GetConn implements adapter.ConnectionHolder.
func (sess *Session) GetConn(ctx context.Context) (*stdsql.Conn, error) {
return sess.pool.GetConnForSchema(ctx, sess.ID(), sess.GetCurrentDatabase())
return sess.db.Pool().GetConnForSchema(ctx, sess.ID(), sess.GetCurrentDatabase())
}

// GetCatalogConn implements adapter.ConnectionHolder.
func (sess *Session) GetCatalogConn(ctx context.Context) (*stdsql.Conn, error) {
return sess.pool.GetConn(ctx, sess.ID())
return sess.db.Pool().GetConn(ctx, sess.ID())
}

// GetTxn implements adapter.ConnectionHolder.
func (sess *Session) GetTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
return sess.pool.GetTxn(ctx, sess.ID(), sess.GetCurrentDatabase(), options)
return sess.db.Pool().GetTxn(ctx, sess.ID(), sess.GetCurrentDatabase(), options)
}

// GetCatalogTxn implements adapter.ConnectionHolder.
func (sess *Session) GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
return sess.pool.GetTxn(ctx, sess.ID(), "", options)
return sess.db.Pool().GetTxn(ctx, sess.ID(), "", options)
}

// TryGetTxn implements adapter.ConnectionHolder.
func (sess *Session) TryGetTxn() *stdsql.Tx {
return sess.pool.TryGetTxn(sess.ID())
return sess.db.Pool().TryGetTxn(sess.ID())
}

// CloseTxn implements adapter.ConnectionHolder.
func (sess *Session) CloseTxn() {
sess.pool.CloseTxn(sess.ID())
sess.db.Pool().CloseTxn(sess.ID())
}

// CloseConn implements adapter.ConnectionHolder.
func (sess *Session) CloseConn() {
sess.pool.CloseConn(sess.ID())
sess.db.Pool().CloseConn(sess.ID())
}

func (sess *Session) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) {
Expand Down
7 changes: 3 additions & 4 deletions backend/connpool.go → catalog/connpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package backend
package catalog

import (
"context"
Expand All @@ -22,7 +22,6 @@ import (
"strings"
"sync"

"github.com/apecloud/myduckserver/catalog"
"github.com/dolthub/go-mysql-server/sql"
"github.com/marcboeker/go-duckdb"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -93,8 +92,8 @@ func (p *ConnectionPool) GetConnForSchema(ctx context.Context, id uint32, schema
logrus.WithError(err).Error("Failed to get current schema")
return nil, err
} else if currentSchema != schemaName {
if _, err := conn.ExecContext(context.Background(), "USE "+catalog.FullSchemaName(p.catalog, schemaName)); err != nil {
if catalog.IsDuckDBSetSchemaNotFoundError(err) {
if _, err := conn.ExecContext(context.Background(), "USE "+FullSchemaName(p.catalog, schemaName)); err != nil {
if IsDuckDBSetSchemaNotFoundError(err) {
return nil, sql.ErrDatabaseNotFound.New(schemaName)
}
logrus.WithField("schema", schemaName).WithError(err).Error("Failed to switch schema")
Expand Down
31 changes: 10 additions & 21 deletions catalog/internal_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ type InternalTable struct {
}

func (it *InternalTable) QualifiedName() string {
//if it.Schema == "__sys__" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unnecessary comment here

// return it.Name
//}
return it.Schema + "." + it.Name
}

Expand Down Expand Up @@ -58,9 +61,7 @@ func (it *InternalTable) UpsertStmt() string {
var b strings.Builder
b.Grow(128)
b.WriteString("INSERT OR REPLACE INTO ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
b.WriteString(" VALUES (?")
for range it.KeyColumns[1:] {
b.WriteString(", ?")
Expand All @@ -76,9 +77,7 @@ func (it *InternalTable) DeleteStmt() string {
var b strings.Builder
b.Grow(128)
b.WriteString("DELETE FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
b.WriteString(" WHERE ")
b.WriteString(it.KeyColumns[0])
b.WriteString(" = ?")
Expand All @@ -93,9 +92,7 @@ func (it *InternalTable) DeleteAllStmt() string {
var b strings.Builder
b.Grow(128)
b.WriteString("DELETE FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
return b.String()
}

Expand All @@ -109,9 +106,7 @@ func (it *InternalTable) SelectStmt() string {
b.WriteString(c)
}
b.WriteString(" FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
b.WriteString(" WHERE ")
b.WriteString(it.KeyColumns[0])
b.WriteString(" = ?")
Expand All @@ -133,9 +128,7 @@ func (it *InternalTable) SelectColumnsStmt(valueColumns []string) string {
b.WriteString(c)
}
b.WriteString(" FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
b.WriteString(" WHERE ")
b.WriteString(it.KeyColumns[0])
b.WriteString(" = ?")
Expand All @@ -151,9 +144,7 @@ func (it *InternalTable) SelectAllStmt() string {
var b strings.Builder
b.Grow(128)
b.WriteString("SELECT * FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
return b.String()
}

Expand All @@ -162,9 +153,7 @@ func (it *InternalTable) CountAllStmt() string {
b.Grow(128)
b.WriteString("SELECT COUNT(*)")
b.WriteString(" FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
return b.String()
}

Expand Down
Loading
Loading