From 041f24813c8713f7da8bf4645bcdba0e96236bb0 Mon Sep 17 00:00:00 2001 From: Noy Date: Mon, 23 Dec 2024 13:45:00 +0800 Subject: [PATCH] fix: adapt to tests --- catalog/provider.go | 5 +++-- pgserver/server.go | 3 ++- pgtest/server.go | 7 +++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/catalog/provider.go b/catalog/provider.go index 55087b1..80ef44b 100644 --- a/catalog/provider.go +++ b/catalog/provider.go @@ -25,8 +25,8 @@ type DatabaseProvider struct { defaultTimeZone string connector *duckdb.Connector storage *stdsql.DB - pool *ConnectionPool - catalogName string // database name in postgres + pool *ConnectionPool // TODO(Noy): Merge into the provider + catalogName string // database name in postgres dataDir string dbFile string dsn string @@ -53,6 +53,7 @@ func NewDBProvider(defaultTimeZone, dataDir, dbFile string) (*DatabaseProvider, mu: &sync.RWMutex{}, defaultTimeZone: defaultTimeZone, externalProcedureRegistry: sql.NewExternalStoredProcedureRegistry(), // This has no effect, just to satisfy the upper layer interface + dsn: "N/A", } err := prov.CreateCatalog(dataDir, dbFile) if err != nil { diff --git a/pgserver/server.go b/pgserver/server.go index ff06176..b1c5827 100644 --- a/pgserver/server.go +++ b/pgserver/server.go @@ -14,7 +14,8 @@ type Server struct { NewInternalCtx func() *sql.Context } -func NewServer(provider *catalog.DatabaseProvider, host string, port int, newCtx func() *sql.Context, options ...ListenerOpt) (*Server, error) { +func NewServer(provider *catalog.DatabaseProvider, host string, port int, password string, newCtx func() *sql.Context, options ...ListenerOpt) (*Server, error) { + InitSuperuser(password) addr := fmt.Sprintf("%s:%d", host, port) l, err := server.NewListener("tcp", addr, "") if err != nil { diff --git a/pgtest/server.go b/pgtest/server.go index a72c0c4..8b43253 100644 --- a/pgtest/server.go +++ b/pgtest/server.go @@ -20,18 +20,17 @@ import ( func CreateTestServer(t *testing.T, port int) (ctx context.Context, pgServer *pgserver.Server, conn *pgx.Conn, close func() error, err error) { provider := catalog.NewInMemoryDBProvider() - pool := catalog.NewConnectionPool(provider.CatalogName(), provider.Connector(), provider.Storage()) // Postgres tables are created in the `public` schema by default. // Create the `public` schema if it doesn't exist. - _, err = pool.ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS public") + _, err = provider.Pool().ExecContext(context.Background(), "CREATE SCHEMA IF NOT EXISTS public") if err != nil { return nil, nil, nil, nil, err } engine := sqle.NewDefault(provider) - builder := backend.NewDuckBuilder(engine.Analyzer.ExecBuilder, pool, provider) + builder := backend.NewDuckBuilder(engine.Analyzer.ExecBuilder, provider) engine.Analyzer.ExecBuilder = builder config := server.Config{ @@ -74,7 +73,7 @@ func CreateTestServer(t *testing.T, port int) (ctx context.Context, pgServer *pg close = func() error { pgServer.Listener.Close() return errors.Join( - pool.Close(), + provider.Pool().Close(), provider.Close(), ) }