diff --git a/internal/database/dbutil/database.go b/internal/database/dbutil/database.go index 0d1f61259..a25abd886 100644 --- a/internal/database/dbutil/database.go +++ b/internal/database/dbutil/database.go @@ -2,6 +2,7 @@ package dbutil import ( "fmt" + "strings" "github.com/go-sql-driver/mysql" "github.com/jackc/pgerrcode" @@ -72,8 +73,9 @@ func IsTableNotFoundError(err error, backend types.BackendType) (bool, error) { } case types.BackendSnowflake: if e2, ok := err.(*gosnowflake.SnowflakeError); ok { - return e2.Number == gosnowflake.ErrObjectNotExistOrAuthorized, nil + return strings.Contains(e2.Error(), "does not exist or not authorized"), nil } + // https://cloud.google.com/bigquery/docs/error-messages case types.BackendBigQuery: if e2, ok := err.(*googleapi.Error); ok { diff --git a/internal/database/dbutil/sql.go b/internal/database/dbutil/sql.go index 5a8cf2a03..0d8e4a541 100644 --- a/internal/database/dbutil/sql.go +++ b/internal/database/dbutil/sql.go @@ -91,19 +91,25 @@ func QuoteFn(backendType types.BackendType) func(...string) string { } func UnQuoteFn(backendType types.BackendType) func(string) string { - var q string + var q byte switch backendType { case types.BackendPostgres, types.BackendSnowflake, types.BackendRedshift, types.BackendCassandra, types.BackendSQLite: - q = `"` + q = '"' case types.BackendMySQL, types.BackendBigQuery: - q = "`" + q = '`' default: panic(fmt.Sprintf("unsupported backend type %s", backendType)) } return func(s string) string { - if strings.HasPrefix(s, q) && strings.HasSuffix(s, q) { - return strings.Trim(s, q) + if s == "" { + return s + } + if s[0] == q && s[len(s)-1] == q { + if s == string(q) { + return s + } + return s[1 : len(s)-1] } return s } diff --git a/internal/database/offline/snowflake/store.go b/internal/database/offline/snowflake/store.go index 580489d8e..dec528f43 100644 --- a/internal/database/offline/snowflake/store.go +++ b/internal/database/offline/snowflake/store.go @@ -4,12 +4,13 @@ import ( "context" "github.com/jmoiron/sqlx" + "github.com/snowflakedb/gosnowflake" + "github.com/oom-ai/oomstore/internal/database/dbutil" "github.com/oom-ai/oomstore/internal/database/offline" "github.com/oom-ai/oomstore/internal/database/offline/sqlutil" "github.com/oom-ai/oomstore/pkg/errdefs" "github.com/oom-ai/oomstore/pkg/oomstore/types" - "github.com/snowflakedb/gosnowflake" ) const ( diff --git a/internal/database/offline/sqlutil/temporary_table.go b/internal/database/offline/sqlutil/temporary_table.go index eeec84968..9fdca52da 100644 --- a/internal/database/offline/sqlutil/temporary_table.go +++ b/internal/database/offline/sqlutil/temporary_table.go @@ -23,10 +23,7 @@ func AddTemporaryTableRecord(ctx context.Context, dbOpt dbutil.DBOpt, tableName tableName = fmt.Sprintf(`"%s"`, tableName) } query := fmt.Sprintf(`INSERT INTO %s (table_name, create_time) VALUES(?,?)`, buildTableName(dbOpt, offline.TemporaryTableRecordTable)) - if err := dbOpt.ExecContext(ctx, query, tableName, time.Now().UnixMilli()); err != nil { - return err - } - return nil + return dbOpt.ExecContext(ctx, query, tableName, time.Now().UnixMilli()) } func createTemporaryTableRecordTable(ctx context.Context, dbOpt dbutil.DBOpt) error { @@ -50,8 +47,13 @@ CREATE TABLE IF NOT EXISTS %s ( } func GetTemporaryTables(ctx context.Context, db *sqlx.DB, backend types.BackendType, unixMill int64) ([]string, error) { - query := fmt.Sprintf("SELECT table_name FROM %s WHERE create_time < ?", - offline.TemporaryTableRecordTable) + var tableName string + if backend == types.BackendSnowflake { + tableName = fmt.Sprintf(`PUBLIC."%s"`, offline.TemporaryTableRecordTable) + } else { + tableName = offline.TemporaryTableRecordTable + } + query := fmt.Sprintf("SELECT table_name FROM %s WHERE create_time < ?", tableName) rows, err := db.QueryContext(ctx, db.Rebind(query), unixMill) if err != nil { @@ -88,6 +90,10 @@ func DropTemporaryTables(ctx context.Context, db dbutil.DBOpt, tableNames []stri unQt := dbutil.UnQuoteFn(db.Backend) for i := 0; i < len(tableNames); i++ { tableNames[i] = unQt(tableNames[i]) + + if db.Backend == types.BackendBigQuery { + tableNames[i] = fmt.Sprintf(`"%s"`, tableNames[i]) + } } cond, args, err := dbutil.BuildConditions(nil, map[string]interface{}{ "table_name": tableNames, @@ -99,9 +105,7 @@ func DropTemporaryTables(ctx context.Context, db dbutil.DBOpt, tableNames []stri query := fmt.Sprintf("DELETE FROM %s WHERE %s", buildTableName(db, offline.TemporaryTableRecordTable), strings.Join(cond, " AND ")) - if err := db.ExecContext(ctx, query, args...); err != nil { - return err - } + return db.ExecContext(ctx, query, args...) } return nil }