Skip to content

Commit

Permalink
feat: add support for managing multiple databases (#307)
Browse files Browse the repository at this point in the history
  • Loading branch information
NoyException authored Dec 25, 2024
1 parent c23ff9a commit 85a560f
Show file tree
Hide file tree
Showing 18 changed files with 460 additions and 155 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:

- name: Test packages
run: |
go test -v -cover ./charset ./transpiler ./backend ./harness ./pgserver | tee packages.log
go test -v -cover ./charset ./transpiler ./backend ./harness ./pgserver ./catalog | tee packages.log
cat packages.log | grep -e "^--- " | sed 's/--- //g' | awk 'BEGIN {count=1} {printf "%d. %s\n", count++, $0}'
cat packages.log | grep -q "FAIL" && exit 1 || exit 0
Expand Down
8 changes: 3 additions & 5 deletions backend/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (

type DuckBuilder struct {
base sql.NodeExecBuilder
pool *ConnectionPool

provider *catalog.DatabaseProvider

Expand All @@ -40,10 +39,9 @@ type DuckBuilder struct {

var _ sql.NodeExecBuilder = (*DuckBuilder)(nil)

func NewDuckBuilder(base sql.NodeExecBuilder, pool *ConnectionPool, provider *catalog.DatabaseProvider) *DuckBuilder {
func NewDuckBuilder(base sql.NodeExecBuilder, provider *catalog.DatabaseProvider) *DuckBuilder {
return &DuckBuilder{
base: base,
pool: pool,
provider: provider,
}
}
Expand Down Expand Up @@ -106,14 +104,14 @@ func (b *DuckBuilder) Build(ctx *sql.Context, root sql.Node, r sql.Row) (sql.Row
return b.base.Build(ctx, root, r)
}

conn, err := b.pool.GetConnForSchema(ctx, ctx.ID(), ctx.GetCurrentDatabase())
conn, err := b.provider.Pool().GetConnForSchema(ctx, ctx.ID(), ctx.GetCurrentDatabase())
if err != nil {
return nil, err
}

switch node := n.(type) {
case *plan.Use:
useStmt := "USE " + catalog.FullSchemaName(b.pool.catalog, node.Database().Name())
useStmt := "USE " + catalog.FullSchemaName(b.provider.CatalogName(), node.Database().Name())
if _, err := conn.ExecContext(ctx.Context, useStmt); err != nil {
if catalog.IsDuckDBSetSchemaNotFoundError(err) {
return nil, sql.ErrDatabaseNotFound.New(node.Database().Name())
Expand Down
13 changes: 7 additions & 6 deletions backend/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package backend
import (
"context"
"fmt"
"github.com/apecloud/myduckserver/catalog"

"github.com/dolthub/go-mysql-server/server"
"github.com/dolthub/vitess/go/mysql"
Expand All @@ -25,16 +26,16 @@ import (

type MyHandler struct {
*server.Handler
pool *ConnectionPool
provider *catalog.DatabaseProvider
}

func (h *MyHandler) ConnectionClosed(c *mysql.Conn) {
h.pool.CloseConn(c.ConnectionID)
h.provider.Pool().CloseConn(c.ConnectionID)
h.Handler.ConnectionClosed(c)
}

func (h *MyHandler) ComInitDB(c *mysql.Conn, schemaName string) error {
_, err := h.pool.GetConnForSchema(context.Background(), c.ConnectionID, schemaName)
_, err := h.provider.Pool().GetConnForSchema(context.Background(), c.ConnectionID, schemaName)
if err != nil {
return err
}
Expand Down Expand Up @@ -78,16 +79,16 @@ func (h *MyHandler) ComQuery(
return h.Handler.ComQuery(ctx, c, query, wrapResultCallback(callback, modifiers...))
}

func WrapHandler(pool *ConnectionPool) server.HandlerWrapper {
func WrapHandler(provider *catalog.DatabaseProvider) server.HandlerWrapper {
return func(h mysql.Handler) (mysql.Handler, error) {
handler, ok := h.(*server.Handler)
if !ok {
return nil, fmt.Errorf("expected *server.Handler, got %T", h)
}

return &MyHandler{
Handler: handler,
pool: pool,
Handler: handler,
provider: provider,
}, nil
}
}
29 changes: 14 additions & 15 deletions backend/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ import (

type Session struct {
*memory.Session
db *catalog.DatabaseProvider
pool *ConnectionPool
db *catalog.DatabaseProvider
}

func NewSession(base *memory.Session, provider *catalog.DatabaseProvider, pool *ConnectionPool) *Session {
return &Session{base, provider, pool}
func NewSession(base *memory.Session, provider *catalog.DatabaseProvider) *Session {
return &Session{base, provider}
}

// Provider returns the database provider for the session.
Expand All @@ -45,11 +44,11 @@ func (sess *Session) Provider() *catalog.DatabaseProvider {
}

func (sess *Session) CurrentSchemaOfUnderlyingConn() string {
return sess.pool.CurrentSchema(sess.ID())
return sess.db.Pool().CurrentSchema(sess.ID())
}

// NewSessionBuilder returns a session builder for the given database provider.
func NewSessionBuilder(provider *catalog.DatabaseProvider, pool *ConnectionPool) func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
func NewSessionBuilder(provider *catalog.DatabaseProvider) func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
return func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
host := ""
user := ""
Expand All @@ -63,13 +62,13 @@ func NewSessionBuilder(provider *catalog.DatabaseProvider, pool *ConnectionPool)
baseSession := sql.NewBaseSessionWithClientServer(addr, client, conn.ConnectionID)
memSession := memory.NewSession(baseSession, provider)

schema := pool.CurrentSchema(conn.ConnectionID)
schema := provider.Pool().CurrentSchema(conn.ConnectionID)
if schema != "" {
logrus.Traceln("SessionBuilder: new session: current schema:", schema)
memSession.SetCurrentDatabase(schema)
}

return &Session{memSession, provider, pool}, nil
return &Session{memSession, provider}, nil
}
}

Expand Down Expand Up @@ -203,37 +202,37 @@ func (sess *Session) GetPersistedValue(k string) (interface{}, error) {

// GetConn implements adapter.ConnectionHolder.
func (sess *Session) GetConn(ctx context.Context) (*stdsql.Conn, error) {
return sess.pool.GetConnForSchema(ctx, sess.ID(), sess.GetCurrentDatabase())
return sess.db.Pool().GetConnForSchema(ctx, sess.ID(), sess.GetCurrentDatabase())
}

// GetCatalogConn implements adapter.ConnectionHolder.
func (sess *Session) GetCatalogConn(ctx context.Context) (*stdsql.Conn, error) {
return sess.pool.GetConn(ctx, sess.ID())
return sess.db.Pool().GetConn(ctx, sess.ID())
}

// GetTxn implements adapter.ConnectionHolder.
func (sess *Session) GetTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
return sess.pool.GetTxn(ctx, sess.ID(), sess.GetCurrentDatabase(), options)
return sess.db.Pool().GetTxn(ctx, sess.ID(), sess.GetCurrentDatabase(), options)
}

// GetCatalogTxn implements adapter.ConnectionHolder.
func (sess *Session) GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
return sess.pool.GetTxn(ctx, sess.ID(), "", options)
return sess.db.Pool().GetTxn(ctx, sess.ID(), "", options)
}

// TryGetTxn implements adapter.ConnectionHolder.
func (sess *Session) TryGetTxn() *stdsql.Tx {
return sess.pool.TryGetTxn(sess.ID())
return sess.db.Pool().TryGetTxn(sess.ID())
}

// CloseTxn implements adapter.ConnectionHolder.
func (sess *Session) CloseTxn() {
sess.pool.CloseTxn(sess.ID())
sess.db.Pool().CloseTxn(sess.ID())
}

// CloseConn implements adapter.ConnectionHolder.
func (sess *Session) CloseConn() {
sess.pool.CloseConn(sess.ID())
sess.db.Pool().CloseConn(sess.ID())
}

func (sess *Session) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) {
Expand Down
7 changes: 3 additions & 4 deletions backend/connpool.go → catalog/connpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package backend
package catalog

import (
"context"
Expand All @@ -22,7 +22,6 @@ import (
"strings"
"sync"

"github.com/apecloud/myduckserver/catalog"
"github.com/dolthub/go-mysql-server/sql"
"github.com/marcboeker/go-duckdb"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -93,8 +92,8 @@ func (p *ConnectionPool) GetConnForSchema(ctx context.Context, id uint32, schema
logrus.WithError(err).Error("Failed to get current schema")
return nil, err
} else if currentSchema != schemaName {
if _, err := conn.ExecContext(context.Background(), "USE "+catalog.FullSchemaName(p.catalog, schemaName)); err != nil {
if catalog.IsDuckDBSetSchemaNotFoundError(err) {
if _, err := conn.ExecContext(context.Background(), "USE "+FullSchemaName(p.catalog, schemaName)); err != nil {
if IsDuckDBSetSchemaNotFoundError(err) {
return nil, sql.ErrDatabaseNotFound.New(schemaName)
}
logrus.WithField("schema", schemaName).WithError(err).Error("Failed to switch schema")
Expand Down
28 changes: 7 additions & 21 deletions catalog/internal_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ func (it *InternalTable) UpsertStmt() string {
var b strings.Builder
b.Grow(128)
b.WriteString("INSERT OR REPLACE INTO ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
b.WriteString(" VALUES (?")
for range it.KeyColumns[1:] {
b.WriteString(", ?")
Expand All @@ -76,9 +74,7 @@ func (it *InternalTable) DeleteStmt() string {
var b strings.Builder
b.Grow(128)
b.WriteString("DELETE FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
b.WriteString(" WHERE ")
b.WriteString(it.KeyColumns[0])
b.WriteString(" = ?")
Expand All @@ -93,9 +89,7 @@ func (it *InternalTable) DeleteAllStmt() string {
var b strings.Builder
b.Grow(128)
b.WriteString("DELETE FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
return b.String()
}

Expand All @@ -109,9 +103,7 @@ func (it *InternalTable) SelectStmt() string {
b.WriteString(c)
}
b.WriteString(" FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
b.WriteString(" WHERE ")
b.WriteString(it.KeyColumns[0])
b.WriteString(" = ?")
Expand All @@ -133,9 +125,7 @@ func (it *InternalTable) SelectColumnsStmt(valueColumns []string) string {
b.WriteString(c)
}
b.WriteString(" FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
b.WriteString(" WHERE ")
b.WriteString(it.KeyColumns[0])
b.WriteString(" = ?")
Expand All @@ -151,9 +141,7 @@ func (it *InternalTable) SelectAllStmt() string {
var b strings.Builder
b.Grow(128)
b.WriteString("SELECT * FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
return b.String()
}

Expand All @@ -162,9 +150,7 @@ func (it *InternalTable) CountAllStmt() string {
b.Grow(128)
b.WriteString("SELECT COUNT(*)")
b.WriteString(" FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
return b.String()
}

Expand Down
Loading

0 comments on commit 85a560f

Please sign in to comment.