Skip to content

Commit

Permalink
fix: adapt to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
NoyException committed Dec 23, 2024
1 parent a48ab95 commit df6a64d
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 31 deletions.
31 changes: 10 additions & 21 deletions catalog/internal_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ type InternalTable struct {
}

func (it *InternalTable) QualifiedName() string {
//if it.Schema == "__sys__" {
// return it.Name
//}
return it.Schema + "." + it.Name
}

Expand Down Expand Up @@ -58,9 +61,7 @@ func (it *InternalTable) UpsertStmt() string {
var b strings.Builder
b.Grow(128)
b.WriteString("INSERT OR REPLACE INTO ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
b.WriteString(" VALUES (?")
for range it.KeyColumns[1:] {
b.WriteString(", ?")
Expand All @@ -76,9 +77,7 @@ func (it *InternalTable) DeleteStmt() string {
var b strings.Builder
b.Grow(128)
b.WriteString("DELETE FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
b.WriteString(" WHERE ")
b.WriteString(it.KeyColumns[0])
b.WriteString(" = ?")
Expand All @@ -93,9 +92,7 @@ func (it *InternalTable) DeleteAllStmt() string {
var b strings.Builder
b.Grow(128)
b.WriteString("DELETE FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
return b.String()
}

Expand All @@ -109,9 +106,7 @@ func (it *InternalTable) SelectStmt() string {
b.WriteString(c)
}
b.WriteString(" FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
b.WriteString(" WHERE ")
b.WriteString(it.KeyColumns[0])
b.WriteString(" = ?")
Expand All @@ -133,9 +128,7 @@ func (it *InternalTable) SelectColumnsStmt(valueColumns []string) string {
b.WriteString(c)
}
b.WriteString(" FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
b.WriteString(" WHERE ")
b.WriteString(it.KeyColumns[0])
b.WriteString(" = ?")
Expand All @@ -151,9 +144,7 @@ func (it *InternalTable) SelectAllStmt() string {
var b strings.Builder
b.Grow(128)
b.WriteString("SELECT * FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
return b.String()
}

Expand All @@ -162,9 +153,7 @@ func (it *InternalTable) CountAllStmt() string {
b.Grow(128)
b.WriteString("SELECT COUNT(*)")
b.WriteString(" FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(it.QualifiedName())
return b.String()
}

Expand Down
5 changes: 4 additions & 1 deletion catalog/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ func (prov *DatabaseProvider) DropCatalog(dataDir, dbFile string) error {
func (prov *DatabaseProvider) CreateCatalog(dataDir, dbFile string) (bool, error) {
dbFile = strings.TrimSpace(dbFile)
dsn := ""
if dbFile != "" {
if dbFile == "" || dbFile == "memory.db" {
dsn = "memory"
} else {
dsn = filepath.Join(dataDir, dbFile)
// if already exists, return error
_, err := os.Stat(dsn)
Expand Down Expand Up @@ -168,6 +170,7 @@ func (prov *DatabaseProvider) SwitchCatalog(dataDir, dbFile string) error {
if dbFile == "" || dbFile == "memory.db" {
// in-memory mode, mainly for testing
name = "memory"
dsn = "memory"
} else {
name = strings.Split(dbFile, ".")[0]
dsn = filepath.Join(dataDir, dbFile)
Expand Down
7 changes: 3 additions & 4 deletions harness/duck_harness.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,12 @@ func (e *DuckTestEngine) Close() error {
func NewEngine(t *testing.T, harness enginetest.Harness, dbProvider sql.DatabaseProvider, setupData []setup.SetupScript, statsProvider sql.StatsProvider, server bool) (enginetest.QueryEngine, error) {
// Create the connection pool first, as it is needed by `NewEngineWithProvider`
provider := dbProvider.(*catalog.DatabaseProvider)
pool := catalog.NewConnectionPool(provider.CatalogName(), provider.Connector(), provider.Storage())
harness.(*DuckHarness).pool = pool
harness.(*DuckHarness).pool = provider.Pool()

e := enginetest.NewEngineWithProvider(t, harness, dbProvider)
e.Analyzer.Catalog.StatsProvider = statsProvider

builder := backend.NewDuckBuilder(e.Analyzer.ExecBuilder, pool, provider)
builder := backend.NewDuckBuilder(e.Analyzer.ExecBuilder, provider)
e.Analyzer.ExecBuilder = builder

ctx := enginetest.NewContext(harness)
Expand All @@ -329,7 +328,7 @@ func NewEngine(t *testing.T, harness enginetest.Harness, dbProvider sql.Database
return nil, err
}
}
return &DuckTestEngine{qe, pool}, nil
return &DuckTestEngine{qe, provider.Pool()}, nil
}

func (m *DuckHarness) SupportsNativeIndexCreation() bool {
Expand Down
2 changes: 1 addition & 1 deletion pgserver/duck_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S
schema = types.OkResultSchema
iter = sql.RowsToRowIter(sql.NewRow(types.OkResult{}))
default:
rows, err = adapter.QueryCatalog(ctx, ConvertToSys(query))
rows, err = adapter.QueryCatalog(ctx, query)
if err != nil {
break
}
Expand Down
6 changes: 3 additions & 3 deletions pgserver/pg_catalog_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@ var pgCatalogHandlers = map[string]PGCatalogHandler{
return h.handleCurrentSetting(query)
}
//if pgCatalogRegex.MatchString(sql) {
//if isSpecialPgCatalog(query) {
// return h.handlePgCatalog(query)
//}
if isSpecialPgCatalog(query) {
return h.handlePgCatalog(query)
}
return false, nil
},
},
Expand Down
2 changes: 1 addition & 1 deletion pgserver/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func getPgCatalogRegex() *regexp.Regexp {
}
tableNames = append(tableNames, table.Name)
}
pgCatalogRegex = regexp.MustCompile(`(?i)\b(FROM|JOIN)\s+(?:pg_catalog\.)?(` + strings.Join(tableNames, "|") + `)`)
pgCatalogRegex = regexp.MustCompile(`(?i)\b(FROM|JOIN|INTO)\s+(?:pg_catalog\.)?(` + strings.Join(tableNames, "|") + `)`)
})
return pgCatalogRegex
}
Expand Down

0 comments on commit df6a64d

Please sign in to comment.