diff --git a/catalog/internal_tables.go b/catalog/internal_tables.go index ff424ba..0f3d80d 100644 --- a/catalog/internal_tables.go +++ b/catalog/internal_tables.go @@ -12,6 +12,9 @@ type InternalTable struct { } func (it *InternalTable) QualifiedName() string { + //if it.Schema == "__sys__" { + // return it.Name + //} return it.Schema + "." + it.Name } @@ -58,9 +61,7 @@ func (it *InternalTable) UpsertStmt() string { var b strings.Builder b.Grow(128) b.WriteString("INSERT OR REPLACE INTO ") - b.WriteString(it.Schema) - b.WriteByte('.') - b.WriteString(it.Name) + b.WriteString(it.QualifiedName()) b.WriteString(" VALUES (?") for range it.KeyColumns[1:] { b.WriteString(", ?") @@ -76,9 +77,7 @@ func (it *InternalTable) DeleteStmt() string { var b strings.Builder b.Grow(128) b.WriteString("DELETE FROM ") - b.WriteString(it.Schema) - b.WriteByte('.') - b.WriteString(it.Name) + b.WriteString(it.QualifiedName()) b.WriteString(" WHERE ") b.WriteString(it.KeyColumns[0]) b.WriteString(" = ?") @@ -93,9 +92,7 @@ func (it *InternalTable) DeleteAllStmt() string { var b strings.Builder b.Grow(128) b.WriteString("DELETE FROM ") - b.WriteString(it.Schema) - b.WriteByte('.') - b.WriteString(it.Name) + b.WriteString(it.QualifiedName()) return b.String() } @@ -109,9 +106,7 @@ func (it *InternalTable) SelectStmt() string { b.WriteString(c) } b.WriteString(" FROM ") - b.WriteString(it.Schema) - b.WriteByte('.') - b.WriteString(it.Name) + b.WriteString(it.QualifiedName()) b.WriteString(" WHERE ") b.WriteString(it.KeyColumns[0]) b.WriteString(" = ?") @@ -133,9 +128,7 @@ func (it *InternalTable) SelectColumnsStmt(valueColumns []string) string { b.WriteString(c) } b.WriteString(" FROM ") - b.WriteString(it.Schema) - b.WriteByte('.') - b.WriteString(it.Name) + b.WriteString(it.QualifiedName()) b.WriteString(" WHERE ") b.WriteString(it.KeyColumns[0]) b.WriteString(" = ?") @@ -151,9 +144,7 @@ func (it *InternalTable) SelectAllStmt() string { var b strings.Builder b.Grow(128) b.WriteString("SELECT * FROM ") - b.WriteString(it.Schema) - b.WriteByte('.') - b.WriteString(it.Name) + b.WriteString(it.QualifiedName()) return b.String() } @@ -162,9 +153,7 @@ func (it *InternalTable) CountAllStmt() string { b.Grow(128) b.WriteString("SELECT COUNT(*)") b.WriteString(" FROM ") - b.WriteString(it.Schema) - b.WriteByte('.') - b.WriteString(it.Name) + b.WriteString(it.QualifiedName()) return b.String() } diff --git a/catalog/provider.go b/catalog/provider.go index f7dae92..c3ed9b3 100644 --- a/catalog/provider.go +++ b/catalog/provider.go @@ -94,7 +94,9 @@ func (prov *DatabaseProvider) DropCatalog(dataDir, dbFile string) error { func (prov *DatabaseProvider) CreateCatalog(dataDir, dbFile string) (bool, error) { dbFile = strings.TrimSpace(dbFile) dsn := "" - if dbFile != "" { + if dbFile == "" || dbFile == "memory.db" { + dsn = "memory" + } else { dsn = filepath.Join(dataDir, dbFile) // if already exists, return error _, err := os.Stat(dsn) @@ -168,6 +170,7 @@ func (prov *DatabaseProvider) SwitchCatalog(dataDir, dbFile string) error { if dbFile == "" || dbFile == "memory.db" { // in-memory mode, mainly for testing name = "memory" + dsn = "memory" } else { name = strings.Split(dbFile, ".")[0] dsn = filepath.Join(dataDir, dbFile) diff --git a/harness/duck_harness.go b/harness/duck_harness.go index 9d32443..4f9530d 100644 --- a/harness/duck_harness.go +++ b/harness/duck_harness.go @@ -297,13 +297,12 @@ func (e *DuckTestEngine) Close() error { func NewEngine(t *testing.T, harness enginetest.Harness, dbProvider sql.DatabaseProvider, setupData []setup.SetupScript, statsProvider sql.StatsProvider, server bool) (enginetest.QueryEngine, error) { // Create the connection pool first, as it is needed by `NewEngineWithProvider` provider := dbProvider.(*catalog.DatabaseProvider) - pool := catalog.NewConnectionPool(provider.CatalogName(), provider.Connector(), provider.Storage()) - harness.(*DuckHarness).pool = pool + harness.(*DuckHarness).pool = provider.Pool() e := enginetest.NewEngineWithProvider(t, harness, dbProvider) e.Analyzer.Catalog.StatsProvider = statsProvider - builder := backend.NewDuckBuilder(e.Analyzer.ExecBuilder, pool, provider) + builder := backend.NewDuckBuilder(e.Analyzer.ExecBuilder, provider) e.Analyzer.ExecBuilder = builder ctx := enginetest.NewContext(harness) @@ -329,7 +328,7 @@ func NewEngine(t *testing.T, harness enginetest.Harness, dbProvider sql.Database return nil, err } } - return &DuckTestEngine{qe, pool}, nil + return &DuckTestEngine{qe, provider.Pool()}, nil } func (m *DuckHarness) SupportsNativeIndexCreation() bool { diff --git a/pgserver/duck_handler.go b/pgserver/duck_handler.go index e7909d8..ebfed80 100644 --- a/pgserver/duck_handler.go +++ b/pgserver/duck_handler.go @@ -479,7 +479,7 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S schema = types.OkResultSchema iter = sql.RowsToRowIter(sql.NewRow(types.OkResult{})) default: - rows, err = adapter.QueryCatalog(ctx, ConvertToSys(query)) + rows, err = adapter.QueryCatalog(ctx, query) if err != nil { break } diff --git a/pgserver/pg_catalog_handler.go b/pgserver/pg_catalog_handler.go index 8dec651..84cb878 100644 --- a/pgserver/pg_catalog_handler.go +++ b/pgserver/pg_catalog_handler.go @@ -237,9 +237,9 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{ return h.handleCurrentSetting(query) } //if pgCatalogRegex.MatchString(sql) { - //if isSpecialPgCatalog(query) { - // return h.handlePgCatalog(query) - //} + if isSpecialPgCatalog(query) { + return h.handlePgCatalog(query) + } return false, nil }, }, diff --git a/pgserver/stmt.go b/pgserver/stmt.go index d419385..c88e44e 100644 --- a/pgserver/stmt.go +++ b/pgserver/stmt.go @@ -278,7 +278,7 @@ func getPgCatalogRegex() *regexp.Regexp { } tableNames = append(tableNames, table.Name) } - pgCatalogRegex = regexp.MustCompile(`(?i)\b(FROM|JOIN)\s+(?:pg_catalog\.)?(` + strings.Join(tableNames, "|") + `)`) + pgCatalogRegex = regexp.MustCompile(`(?i)\b(FROM|JOIN|INTO)\s+(?:pg_catalog\.)?(` + strings.Join(tableNames, "|") + `)`) }) return pgCatalogRegex }