diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 2a05ab4..952cff7 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -44,7 +44,7 @@ jobs: - name: Test packages run: | - go test -v -cover ./charset ./transpiler ./backend ./harness ./pgserver | tee packages.log + go test -v -cover ./charset ./transpiler ./backend ./harness ./pgserver ./catalog | tee packages.log cat packages.log | grep -e "^--- " | sed 's/--- //g' | awk 'BEGIN {count=1} {printf "%d. %s\n", count++, $0}' cat packages.log | grep -q "FAIL" && exit 1 || exit 0 diff --git a/catalog/provider.go b/catalog/provider.go index 45d1dc9..2b54faf 100644 --- a/catalog/provider.go +++ b/catalog/provider.go @@ -69,46 +69,6 @@ func NewDBProvider(defaultTimeZone, dataDir, defaultDB string) (*DatabaseProvide return prov, nil } -func (prov *DatabaseProvider) IsReady() bool { - return prov.ready -} - -func (prov *DatabaseProvider) DropCatalog(dbFile string) error { - dbFile = strings.TrimSpace(dbFile) - dsn := "" - if dbFile != "" { - dsn = filepath.Join(prov.dataDir, dbFile) - // if this is the current catalog, return error - if dsn == prov.dsn { - return fmt.Errorf("cannot drop the current catalog") - } - // if file does not exist, return error - _, err := os.Stat(dsn) - if os.IsNotExist(err) { - return fmt.Errorf("database file %s does not exist", dsn) - } - // delete the file - err = os.Remove(dsn) - if err != nil { - return fmt.Errorf("failed to delete database file %s: %w", dsn, err) - } - return nil - } else { - return fmt.Errorf("cannot drop the in-memory catalog") - } -} - -func (prov *DatabaseProvider) ExistCatalog(dbFile string) bool { - if dbFile == "" || dbFile == "memory.db" { - return true - } else { - dsn := filepath.Join(prov.dataDir, dbFile) - // if already exists, return error - _, err := os.Stat(dsn) - return os.IsExist(err) - } -} - func (prov *DatabaseProvider) initCatalog(connector *duckdb.Connector, storage *stdsql.DB) error { bootQueries := []string{ "INSTALL arrow", @@ -162,8 +122,49 @@ func (prov *DatabaseProvider) initCatalog(connector *duckdb.Connector, storage * return nil } -func (prov *DatabaseProvider) CreateCatalog(dbFile string) (bool, error) { - dbFile = strings.TrimSpace(dbFile) +func (prov *DatabaseProvider) IsReady() bool { + return prov.ready +} + +func (prov *DatabaseProvider) DropCatalog(dbName string) error { + dbFile := strings.TrimSpace(dbName) + ".db" + dsn := "" + if dbFile != "" { + dsn = filepath.Join(prov.dataDir, dbFile) + // if this is the current catalog, return error + if dsn == prov.dsn { + return fmt.Errorf("cannot drop the current catalog") + } + // if file does not exist, return error + _, err := os.Stat(dsn) + if os.IsNotExist(err) { + return fmt.Errorf("database file %s does not exist", dsn) + } + // delete the file + err = os.Remove(dsn) + if err != nil { + return fmt.Errorf("failed to delete database file %s: %w", dsn, err) + } + return nil + } else { + return fmt.Errorf("cannot drop the in-memory catalog") + } +} + +func (prov *DatabaseProvider) ExistCatalog(dbName string) bool { + dbFile := strings.TrimSpace(dbName) + ".db" + if dbFile == "" || dbFile == "memory.db" { + return true + } else { + dsn := filepath.Join(prov.dataDir, dbFile) + // if already exists, return error + _, err := os.Stat(dsn) + return os.IsExist(err) + } +} + +func (prov *DatabaseProvider) CreateCatalog(dbName string) (ready bool, err error) { + dbFile := strings.TrimSpace(dbName) + ".db" dsn := "" // in memory database does not need to be created if dbFile == "" || dbFile == "memory.db" { @@ -192,8 +193,8 @@ func (prov *DatabaseProvider) CreateCatalog(dbFile string) (bool, error) { return true, nil } -func (prov *DatabaseProvider) SwitchCatalog(dbFile string) error { - dbFile = strings.TrimSpace(dbFile) +func (prov *DatabaseProvider) SwitchCatalog(dbName string) error { + dbFile := strings.TrimSpace(dbName) + ".db" name := "" dsn := "" if dbFile == "" || dbFile == "memory.db" { diff --git a/catalog/provider_test.go b/catalog/provider_test.go new file mode 100644 index 0000000..ca43cc4 --- /dev/null +++ b/catalog/provider_test.go @@ -0,0 +1,142 @@ +package catalog + +import ( + "context" + "github.com/apecloud/myduckserver/testutil" + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" + "strconv" + "testing" +) + +type Execution struct { + SQL string + Expected string + WantErr bool +} + +func TestCreateCatalog(t *testing.T) { + tests := []struct { + name string + executions []Execution + }{ + { + name: "create database", + executions: []Execution{ + { + SQL: "CREATE DATABASE testdb1;", + Expected: "CREATE DATABASE", + }, + // Can not create the database with the same name + { + SQL: "CREATE DATABASE testdb1;", + WantErr: true, + }, + { + SQL: "CREATE DATABASE testdb2;", + Expected: "CREATE DATABASE", + }, + }, + }, + { + name: "switch database", + executions: []Execution{ + { + SQL: "USE testdb1;", + Expected: "SET", + }, + { + SQL: "CREATE SCHEMA testschema1;", + Expected: "CREATE SCHEMA", + }, + { + SQL: "USE testdb1.testschema1;", + Expected: "SET", + }, + { + SQL: "USE testdb2;", + Expected: "SET", + }, + // Can not drop the schema as it is not in the current database + { + SQL: "DROP SCHEMA testschema1;", + WantErr: true, + }, + { + SQL: "USE testdb1;", + Expected: "SET", + }, + { + SQL: "DROP SCHEMA testschema1;", + Expected: "DROP SCHEMA", + }, + // Can not switch to the schema that does not exist + { + SQL: "USE testdb1.testschema1;", + WantErr: true, + }, + }, + }, + { + name: "drop database", + executions: []Execution{ + { + SQL: "USE testdb1;", + Expected: "SET", + }, + // Can not drop the database when the current database is the one to be dropped + { + SQL: "DROP DATABASE testdb1;", + WantErr: true, + }, + { + SQL: "USE testdb2;", + Expected: "SET", + }, + { + SQL: "DROP DATABASE testdb1;", + Expected: "DROP DATABASE", + }, + }, + }, + } + testDir := testutil.CreateTestDir(t) + testEnv := testutil.NewTestEnv() + err := testutil.StartDuckSqlServer(t, testDir, nil, testEnv) + require.NoError(t, err) + defer testutil.StopDuckSqlServer(t, testEnv.DuckProcess) + dsn := "postgresql://postgres@localhost:" + strconv.Itoa(testEnv.DuckPgPort) + "/postgres" + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Connect to MyDuck Server + db, err := pgx.Connect(context.Background(), dsn) + if err != nil { + t.Errorf("Connect failed! dsn = %v, err: %v", dsn, err) + return + } + defer db.Close(context.Background()) + + for _, execution := range tt.executions { + func() { + tag, err := db.Exec(context.Background(), execution.SQL) + if execution.WantErr { + if err != nil { + return + } + t.Errorf("Test expectes error but got none! sql: %v", execution.SQL) + return + } + if err != nil { + t.Errorf("Query failed! sql: %v, err: %v", execution.SQL, err) + return + } + // check whether the result is as expected + if tag.String() != execution.Expected { + t.Errorf("sql: %v, got %v, want %v", execution.SQL, tag.String(), execution.Expected) + } + }() + } + }) + } +} diff --git a/main.go b/main.go index 5cfdc51..6b6fec3 100644 --- a/main.go +++ b/main.go @@ -52,7 +52,6 @@ var ( socket string defaultDb = "myduck" dataDirectory = "." - dbFileName string logLevel = int(logrus.InfoLevel) replicaOptions replica.ReplicaOptions @@ -112,7 +111,6 @@ func ensureSQLTranslate() { func main() { flag.Parse() // Parse all flags - dbFileName = defaultDb + ".db" if replicaOptions.ReportPort == 0 { replicaOptions.ReportPort = port @@ -130,7 +128,7 @@ func main() { return } - provider, err := catalog.NewDBProvider(defaultTimeZone, dataDirectory, dbFileName) + provider, err := catalog.NewDBProvider(defaultTimeZone, dataDirectory, defaultDb) if err != nil { logrus.Fatalln("Failed to open the database:", err) } @@ -243,7 +241,7 @@ func executeRestoreIfNeeded() { msg, err := pgserver.ExecuteRestore( defaultDb, dataDirectory, - dbFileName, + defaultDb+".db", restoreFile, restoreEndpoint, restoreAccessKeyId, diff --git a/pgserver/duck_handler.go b/pgserver/duck_handler.go index 125ed72..ccc5094 100644 --- a/pgserver/duck_handler.go +++ b/pgserver/duck_handler.go @@ -440,11 +440,17 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S break } parts := strings.Split(setVar.Values.String(), ".") - err = provider.SwitchCatalog(parts[0] + ".db") + err = provider.SwitchCatalog(parts[0]) if err != nil { break } - exec() + // If the query contains a schema name, we need to execute the query to set the schema as default. + if len(parts) > 1 { + exec() + } else { + schema = types.OkResultSchema + iter = sql.RowsToRowIter(sql.NewRow(types.OkResult{})) + } } else { exec() } @@ -459,7 +465,7 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S break } dbName := parsed.(*tree.CreateDatabase).Name.String() - _, err = provider.CreateCatalog(dbName + ".db") + _, err = provider.CreateCatalog(dbName) if err != nil { break } @@ -472,7 +478,7 @@ func (h *DuckHandler) executeQuery(ctx *sql.Context, query string, parsed tree.S break } dbName := parsed.(*tree.DropDatabase).Name.String() - err = provider.DropCatalog(dbName + ".db") + err = provider.DropCatalog(dbName) if err != nil { break }