Skip to content

Commit

Permalink
fix: adapt to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
NoyException committed Dec 23, 2024
1 parent 56d4a68 commit 041f248
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
5 changes: 3 additions & 2 deletions catalog/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ type DatabaseProvider struct {
defaultTimeZone string
connector *duckdb.Connector
storage *stdsql.DB
pool *ConnectionPool
catalogName string // database name in postgres
pool *ConnectionPool // TODO(Noy): Merge into the provider
catalogName string // database name in postgres
dataDir string
dbFile string
dsn string
Expand All @@ -53,6 +53,7 @@ func NewDBProvider(defaultTimeZone, dataDir, dbFile string) (*DatabaseProvider,
mu: &sync.RWMutex{},
defaultTimeZone: defaultTimeZone,
externalProcedureRegistry: sql.NewExternalStoredProcedureRegistry(), // This has no effect, just to satisfy the upper layer interface
dsn: "N/A",
}
err := prov.CreateCatalog(dataDir, dbFile)
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion pgserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ type Server struct {
NewInternalCtx func() *sql.Context
}

func NewServer(provider *catalog.DatabaseProvider, host string, port int, newCtx func() *sql.Context, options ...ListenerOpt) (*Server, error) {
func NewServer(provider *catalog.DatabaseProvider, host string, port int, password string, newCtx func() *sql.Context, options ...ListenerOpt) (*Server, error) {
InitSuperuser(password)
addr := fmt.Sprintf("%s:%d", host, port)
l, err := server.NewListener("tcp", addr, "")
if err != nil {
Expand Down
7 changes: 3 additions & 4 deletions pgtest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@ import (

func CreateTestServer(t *testing.T, port int) (ctx context.Context, pgServer *pgserver.Server, conn *pgx.Conn, close func() error, err error) {
provider := catalog.NewInMemoryDBProvider()
pool := catalog.NewConnectionPool(provider.CatalogName(), provider.Connector(), provider.Storage())

// Postgres tables are created in the `public` schema by default.
// Create the `public` schema if it doesn't exist.
_, err = pool.ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS public")
_, err = provider.Pool().ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS public")
if err != nil {
return nil, nil, nil, nil, err
}

engine := sqle.NewDefault(provider)

builder := backend.NewDuckBuilder(engine.Analyzer.ExecBuilder, pool, provider)
builder := backend.NewDuckBuilder(engine.Analyzer.ExecBuilder, provider)
engine.Analyzer.ExecBuilder = builder

config := server.Config{
Expand Down Expand Up @@ -74,7 +73,7 @@ func CreateTestServer(t *testing.T, port int) (ctx context.Context, pgServer *pg
close = func() error {
pgServer.Listener.Close()
return errors.Join(
pool.Close(),
provider.Pool().Close(),
provider.Close(),
)
}
Expand Down

0 comments on commit 041f248

Please sign in to comment.