diff --git a/internal/database/dbutil/database.go b/internal/database/dbutil/database.go index 0d1f61259..723036a1a 100644 --- a/internal/database/dbutil/database.go +++ b/internal/database/dbutil/database.go @@ -2,17 +2,19 @@ package dbutil import ( "fmt" + "strings" "github.com/go-sql-driver/mysql" "github.com/jackc/pgerrcode" "github.com/jmoiron/sqlx" "github.com/lib/pq" - "github.com/oom-ai/oomstore/pkg/errdefs" - "github.com/oom-ai/oomstore/pkg/oomstore/types" "github.com/snowflakedb/gosnowflake" "google.golang.org/api/googleapi" "modernc.org/sqlite" sqlite3 "modernc.org/sqlite/lib" + + "github.com/oom-ai/oomstore/pkg/errdefs" + "github.com/oom-ai/oomstore/pkg/oomstore/types" ) const ( @@ -72,7 +74,7 @@ func IsTableNotFoundError(err error, backend types.BackendType) (bool, error) { } case types.BackendSnowflake: if e2, ok := err.(*gosnowflake.SnowflakeError); ok { - return e2.Number == gosnowflake.ErrObjectNotExistOrAuthorized, nil + return strings.Contains(e2.Error(), "does not exist or not authorized"), nil } // https://cloud.google.com/bigquery/docs/error-messages case types.BackendBigQuery: diff --git a/internal/database/dbutil/sql.go b/internal/database/dbutil/sql.go index ceffae9a8..0d8e4a541 100644 --- a/internal/database/dbutil/sql.go +++ b/internal/database/dbutil/sql.go @@ -90,6 +90,31 @@ func QuoteFn(backendType types.BackendType) func(...string) string { } } +func UnQuoteFn(backendType types.BackendType) func(string) string { + var q byte + switch backendType { + case types.BackendPostgres, types.BackendSnowflake, types.BackendRedshift, types.BackendCassandra, types.BackendSQLite: + q = '"' + case types.BackendMySQL, types.BackendBigQuery: + q = '`' + default: + panic(fmt.Sprintf("unsupported backend type %s", backendType)) + } + + return func(s string) string { + if s == "" { + return s + } + if s[0] == q && s[len(s)-1] == q { + if s == string(q) { + return s + } + return s[1 : len(s)-1] + } + return s + } +} + func DropTable(ctx context.Context, dbOpt DBOpt, tableName string) error { query := fmt.Sprintf(`DROP TABLE IF EXISTS %s;`, tableName) err := dbOpt.ExecContext(ctx, query) diff --git a/internal/database/offline/bigquery/join.go b/internal/database/offline/bigquery/join.go index e490af227..a746d5f8c 100644 --- a/internal/database/offline/bigquery/join.go +++ b/internal/database/offline/bigquery/join.go @@ -2,7 +2,6 @@ package bigquery import ( "context" - "fmt" "strings" "cloud.google.com/go/bigquery" @@ -43,13 +42,9 @@ func bigqueryQueryResults(ctx context.Context, dbOpt dbutil.DBOpt, query string, data := make(chan types.JoinRecord) go func() { defer func() { - if err = dropTemporaryTables(ctx, dbOpt.BigQueryDB, dropTableNames); err != nil { - select { - case data <- types.JoinRecord{Error: err}: - // nothing to do - default: - } - } + // The logic of the temporary table should not affect the main process, so nil is returned here. + // TODO: Print log in the cloud service version of oomstore + _ = sqlutil.DropTemporaryTables(ctx, dbOpt, dropTableNames) close(data) }() @@ -99,24 +94,6 @@ func bigqueryQueryResults(ctx context.Context, dbOpt dbutil.DBOpt, query string, }, nil } -func dropTemporaryTables(ctx context.Context, db *bigquery.Client, tableNames []string) error { - var err error - for _, tableName := range tableNames { - if tmpErr := dropTable(ctx, db, tableName); tmpErr != nil { - err = tmpErr - } - } - return err -} - -func dropTable(ctx context.Context, db *bigquery.Client, tableName string) error { - query := fmt.Sprintf(`DROP TABLE IF EXISTS %s;`, tableName) - if _, err := db.Query(query).Read(ctx); err != nil { - return errdefs.WithStack(err) - } - return nil -} - const READ_JOIN_RESULT_QUERY = ` SELECT {{ qt .EntityRowsTableName }}.{{ qt .EntityKey }}, diff --git a/internal/database/offline/bigquery/store.go b/internal/database/offline/bigquery/store.go index 64c6febae..433be1048 100644 --- a/internal/database/offline/bigquery/store.go +++ b/internal/database/offline/bigquery/store.go @@ -2,14 +2,17 @@ package bigquery import ( "context" + "fmt" "cloud.google.com/go/bigquery" - "github.com/oom-ai/oomstore/pkg/errdefs" + "github.com/spf13/cast" + "google.golang.org/api/iterator" "google.golang.org/api/option" "github.com/oom-ai/oomstore/internal/database/dbutil" "github.com/oom-ai/oomstore/internal/database/offline" "github.com/oom-ai/oomstore/internal/database/offline/sqlutil" + "github.com/oom-ai/oomstore/pkg/errdefs" "github.com/oom-ai/oomstore/pkg/oomstore/types" ) @@ -62,3 +65,37 @@ func (db *DB) Push(ctx context.Context, opt offline.PushOpt) error { } return nil } + +func (db *DB) DropTemporaryTable(ctx context.Context, tableNames []string) error { + dbOpt := dbutil.DBOpt{Backend: Backend, BigQueryDB: db.Client, DatasetID: &db.datasetID} + return sqlutil.DropTemporaryTables(ctx, dbOpt, tableNames) +} + +func (db *DB) GetTemporaryTables(ctx context.Context, unixMilli int64) ([]string, error) { + qt := dbutil.QuoteFn(Backend) + query := fmt.Sprintf("SELECT table_name FROM %s.%s WHERE create_time < %s", + db.datasetID, qt(offline.TemporaryTableRecordTable), cast.ToString(unixMilli)) + + rows, err := db.Query(query).Read(ctx) + if err != nil { + tableNotFound, notFoundErr := dbutil.IsTableNotFoundError(err, Backend) + if notFoundErr != nil { + return nil, notFoundErr + } + if tableNotFound { + return nil, nil + } + return nil, err + } + + var tableNames []string + for { + recordMap := make(map[string]bigquery.Value) + err = rows.Next(&recordMap) + if err == iterator.Done { + break + } + tableNames = append(tableNames, recordMap["table_name"].(string)) + } + return tableNames, nil +} diff --git a/internal/database/offline/mock_offline/store.go b/internal/database/offline/mock_offline/store.go index 2ecbc4a23..7977e7d71 100644 --- a/internal/database/offline/mock_offline/store.go +++ b/internal/database/offline/mock_offline/store.go @@ -64,6 +64,20 @@ func (mr *MockStoreMockRecorder) CreateTable(ctx, opt interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTable", reflect.TypeOf((*MockStore)(nil).CreateTable), ctx, opt) } +// DropTemporaryTable mocks base method. +func (m *MockStore) DropTemporaryTable(ctx context.Context, tableNames []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DropTemporaryTable", ctx, tableNames) + ret0, _ := ret[0].(error) + return ret0 +} + +// DropTemporaryTable indicates an expected call of DropTemporaryTable. +func (mr *MockStoreMockRecorder) DropTemporaryTable(ctx, tableNames interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropTemporaryTable", reflect.TypeOf((*MockStore)(nil).DropTemporaryTable), ctx, tableNames) +} + // Export mocks base method. func (m *MockStore) Export(ctx context.Context, opt offline.ExportOpt) (*types.ExportResult, error) { m.ctrl.T.Helper() @@ -79,6 +93,21 @@ func (mr *MockStoreMockRecorder) Export(ctx, opt interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Export", reflect.TypeOf((*MockStore)(nil).Export), ctx, opt) } +// GetTemporaryTables mocks base method. +func (m *MockStore) GetTemporaryTables(ctx context.Context, unixMilli int64) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTemporaryTables", ctx, unixMilli) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTemporaryTables indicates an expected call of GetTemporaryTables. +func (mr *MockStoreMockRecorder) GetTemporaryTables(ctx, unixMilli interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTemporaryTables", reflect.TypeOf((*MockStore)(nil).GetTemporaryTables), ctx, unixMilli) +} + // Import mocks base method. func (m *MockStore) Import(ctx context.Context, opt offline.ImportOpt) (int64, error) { m.ctrl.T.Helper() diff --git a/internal/database/offline/mysql/store.go b/internal/database/offline/mysql/store.go index 9e1599f19..ca9a58e51 100644 --- a/internal/database/offline/mysql/store.go +++ b/internal/database/offline/mysql/store.go @@ -70,3 +70,12 @@ func (db *DB) Push(ctx context.Context, opt offline.PushOpt) error { } return nil } + +func (db *DB) DropTemporaryTable(ctx context.Context, tableNames []string) error { + dbOpt := dbutil.DBOpt{Backend: Backend, SqlxDB: db.DB} + return sqlutil.DropTemporaryTables(ctx, dbOpt, tableNames) +} + +func (db *DB) GetTemporaryTables(ctx context.Context, unixMilli int64) ([]string, error) { + return sqlutil.GetTemporaryTables(ctx, db.DB, Backend, unixMilli) +} diff --git a/internal/database/offline/postgres/store.go b/internal/database/offline/postgres/store.go index 0aa502ff9..abf0cee36 100644 --- a/internal/database/offline/postgres/store.go +++ b/internal/database/offline/postgres/store.go @@ -68,3 +68,12 @@ func (db *DB) Push(ctx context.Context, opt offline.PushOpt) error { } return nil } + +func (db *DB) DropTemporaryTable(ctx context.Context, tableNames []string) error { + dbOpt := dbutil.DBOpt{Backend: Backend, SqlxDB: db.DB} + return sqlutil.DropTemporaryTables(ctx, dbOpt, tableNames) +} + +func (db *DB) GetTemporaryTables(ctx context.Context, unixMilli int64) ([]string, error) { + return sqlutil.GetTemporaryTables(ctx, db.DB, Backend, unixMilli) +} diff --git a/internal/database/offline/redshift/store.go b/internal/database/offline/redshift/store.go index 0c787e3a8..12255812e 100644 --- a/internal/database/offline/redshift/store.go +++ b/internal/database/offline/redshift/store.go @@ -70,3 +70,12 @@ func (db *DB) Push(ctx context.Context, opt offline.PushOpt) error { } return nil } + +func (db *DB) DropTemporaryTable(ctx context.Context, tableNames []string) error { + dbOpt := dbutil.DBOpt{Backend: Backend, SqlxDB: db.DB} + return sqlutil.DropTemporaryTables(ctx, dbOpt, tableNames) +} + +func (db *DB) GetTemporaryTables(ctx context.Context, unixMilli int64) ([]string, error) { + return sqlutil.GetTemporaryTables(ctx, db.DB, Backend, unixMilli) +} diff --git a/internal/database/offline/snowflake/store.go b/internal/database/offline/snowflake/store.go index 49591e315..580489d8e 100644 --- a/internal/database/offline/snowflake/store.go +++ b/internal/database/offline/snowflake/store.go @@ -84,3 +84,12 @@ func (db *DB) Push(ctx context.Context, opt offline.PushOpt) error { } return nil } + +func (db *DB) DropTemporaryTable(ctx context.Context, tableNames []string) error { + dbOpt := dbutil.DBOpt{Backend: Backend, SqlxDB: db.DB} + return sqlutil.DropTemporaryTables(ctx, dbOpt, tableNames) +} + +func (db *DB) GetTemporaryTables(ctx context.Context, unixMilli int64) ([]string, error) { + return sqlutil.GetTemporaryTables(ctx, db.DB, Backend, unixMilli) +} diff --git a/internal/database/offline/sqlite/store.go b/internal/database/offline/sqlite/store.go index 0af6cff58..9d38a28e2 100644 --- a/internal/database/offline/sqlite/store.go +++ b/internal/database/offline/sqlite/store.go @@ -78,3 +78,12 @@ func (db *DB) Push(ctx context.Context, opt offline.PushOpt) error { } return nil } + +func (db *DB) DropTemporaryTable(ctx context.Context, tableNames []string) error { + dbOpt := dbutil.DBOpt{Backend: Backend, SqlxDB: db.DB} + return sqlutil.DropTemporaryTables(ctx, dbOpt, tableNames) +} + +func (db *DB) GetTemporaryTables(ctx context.Context, unixMilli int64) ([]string, error) { + return sqlutil.GetTemporaryTables(ctx, db.DB, Backend, unixMilli) +} diff --git a/internal/database/offline/sqlutil/export_helper.go b/internal/database/offline/sqlutil/export_helper.go index bc2a4838f..543637020 100644 --- a/internal/database/offline/sqlutil/export_helper.go +++ b/internal/database/offline/sqlutil/export_helper.go @@ -143,6 +143,11 @@ func buildExportQuery(params exportQueryParams) (string, error) { func prepareEntityTable(ctx context.Context, dbOpt dbutil.DBOpt, opt offline.ExportOpt, snapshotTables, cdcTables []string) (string, error) { // Step 1: create table export_entity tableName := dbutil.TempTable("export_entity") + + // The logic of the temporary table should not affect the main process, so nil is returned here. + // TODO: Print log in the cloud service version of oomstore + _ = AddTemporaryTableRecord(ctx, dbOpt, buildTableName(dbOpt, tableName)) + qtTableName, columnDefs, err := prepareTableSchema(dbOpt, prepareTableSchemaParams{ tableName: tableName, entityName: opt.EntityName, diff --git a/internal/database/offline/sqlutil/join.go b/internal/database/offline/sqlutil/join.go index 80acc0dca..b24b8643a 100644 --- a/internal/database/offline/sqlutil/join.go +++ b/internal/database/offline/sqlutil/join.go @@ -287,7 +287,12 @@ func readJoinedTable(ctx context.Context, dbOpt dbutil.DBOpt, opt readJoinedTabl dropTableNames := []string{buildTableName(dbOpt, opt.EntityRowsTableName)} for _, tableName := range opt.AllTableNames { - dropTableNames = append(dropTableNames, buildTableName(dbOpt, tableName)) + dropTableName := buildTableName(dbOpt, tableName) + dropTableNames = append(dropTableNames, dropTableName) + + // The logic of the temporary table should not affect the main process, so nil is returned here. + // TODO: Print log in the cloud service version of oomstore + _ = AddTemporaryTableRecord(ctx, dbOpt, dropTableName) } // Step 2: read joined results @@ -308,14 +313,9 @@ func sqlxQueryResults(ctx context.Context, dbOpt dbutil.DBOpt, query string, hea data := make(chan types.JoinRecord) go func() { defer func() { - if err := dropTemporaryTables(ctx, dbOpt.SqlxDB, dropTableNames); err != nil { - select { - case data <- types.JoinRecord{Error: err}: - // nothing to do - default: - } - } - + // The logic of the temporary table should not affect the main process, so nil is returned here. + // TODO: Print log in the cloud service version of oomstore + _ = DropTemporaryTables(ctx, dbOpt, dropTableNames) rows.Close() close(data) }() diff --git a/internal/database/offline/sqlutil/join_helper.go b/internal/database/offline/sqlutil/join_helper.go index ee97e28cf..0e98cb7c5 100644 --- a/internal/database/offline/sqlutil/join_helper.go +++ b/internal/database/offline/sqlutil/join_helper.go @@ -7,7 +7,6 @@ import ( "strings" "text/template" - "github.com/jmoiron/sqlx" "github.com/oom-ai/oomstore/pkg/errdefs" "github.com/oom-ai/oomstore/internal/database/dbutil" @@ -127,6 +126,11 @@ func prepareEntityRowsTable(ctx context.Context, ) (string, error) { // Step 1: create table entity_rows tableName := dbutil.TempTable("entity_rows") + + // The logic of the temporary table should not affect the main process, so nil is returned here. + // TODO: Print log in the cloud service version of oomstore + _ = AddTemporaryTableRecord(ctx, dbOpt, buildTableName(dbOpt, tableName)) + qtTableName, columnDefs, err := prepareTableSchema(dbOpt, prepareTableSchemaParams{ tableName: tableName, entityName: "entity_key", @@ -248,22 +252,6 @@ func insertEntityRows(ctx context.Context, return dbutil.InsertRecordsToTable(ctx, dbOpt, tableName, records, columns) } -func dropTemporaryTables(ctx context.Context, db *sqlx.DB, tableNames []string) error { - var err error - for _, tableName := range tableNames { - if tmpErr := dropTable(ctx, db, tableName); tmpErr != nil { - err = tmpErr - } - } - return err -} - -func dropTable(ctx context.Context, db *sqlx.DB, tableName string) error { - query := fmt.Sprintf(`DROP TABLE IF EXISTS %s;`, tableName) - _, err := db.ExecContext(ctx, query) - return errdefs.WithStack(err) -} - func supportIndex(backendType types.BackendType) bool { for _, b := range []types.BackendType{types.BackendSnowflake, types.BackendRedshift, types.BackendBigQuery} { if b == backendType { diff --git a/internal/database/offline/sqlutil/temporary_table.go b/internal/database/offline/sqlutil/temporary_table.go new file mode 100644 index 000000000..e6a15e7b9 --- /dev/null +++ b/internal/database/offline/sqlutil/temporary_table.go @@ -0,0 +1,111 @@ +package sqlutil + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/jmoiron/sqlx" + + "github.com/oom-ai/oomstore/internal/database/dbutil" + "github.com/oom-ai/oomstore/internal/database/offline" + "github.com/oom-ai/oomstore/pkg/oomstore/types" +) + +func AddTemporaryTableRecord(ctx context.Context, dbOpt dbutil.DBOpt, tableName string) error { + if err := createTemporaryTableRecordTable(ctx, dbOpt); err != nil { + return err + } + unQt := dbutil.UnQuoteFn(dbOpt.Backend) + tableName = unQt(tableName) + if dbOpt.Backend == types.BackendBigQuery { + tableName = fmt.Sprintf(`"%s"`, tableName) + } + query := fmt.Sprintf(`INSERT INTO %s (table_name, create_time) VALUES(?,?)`, buildTableName(dbOpt, offline.TemporaryTableRecordTable)) + return dbOpt.ExecContext(ctx, query, tableName, time.Now().UnixMilli()) +} + +func createTemporaryTableRecordTable(ctx context.Context, dbOpt dbutil.DBOpt) error { + tableNameDBType, err := dbutil.DBValueType(dbOpt.Backend, types.String) + if err != nil { + return err + } + + createTimeDBType, err := dbutil.DBValueType(dbOpt.Backend, types.Int64) + if err != nil { + return err + } + + query := fmt.Sprintf(` +CREATE TABLE IF NOT EXISTS %s ( + table_name %s, + create_time %s +) +`, buildTableName(dbOpt, offline.TemporaryTableRecordTable), tableNameDBType, createTimeDBType) + return dbOpt.ExecContext(ctx, query) +} + +func GetTemporaryTables(ctx context.Context, db *sqlx.DB, backend types.BackendType, unixMill int64) ([]string, error) { + var tableName string + if backend == types.BackendSnowflake { + tableName = fmt.Sprintf(`PUBLIC."%s"`, offline.TemporaryTableRecordTable) + } else { + tableName = offline.TemporaryTableRecordTable + } + query := fmt.Sprintf("SELECT table_name FROM %s WHERE create_time < ?", tableName) + + rows, err := db.QueryContext(ctx, db.Rebind(query), unixMill) + if err != nil { + tableNotFound, notFoundErr := dbutil.IsTableNotFoundError(err, backend) + if notFoundErr != nil { + return nil, notFoundErr + } + if tableNotFound { + return nil, nil + } + return nil, err + } + defer rows.Close() + + var tableNames []string + for rows.Next() { + var tableName string + if err := rows.Scan(&tableName); err != nil { + return nil, err + } + tableNames = append(tableNames, tableName) + } + return tableNames, nil +} + +func DropTemporaryTables(ctx context.Context, db dbutil.DBOpt, tableNames []string) error { + for _, tableName := range tableNames { + query := fmt.Sprintf(`DROP TABLE IF EXISTS %s`, tableName) + if err := db.ExecContext(ctx, query); err != nil { + return err + } + } + + unQt := dbutil.UnQuoteFn(db.Backend) + for i := 0; i < len(tableNames); i++ { + tableNames[i] = unQt(tableNames[i]) + if db.Backend == types.BackendBigQuery { + tableNames[i] = fmt.Sprintf(`"%s"`, tableNames[i]) + } + } + + cond, args, err := dbutil.BuildConditions(nil, map[string]interface{}{ + "table_name": tableNames, + }) + if err != nil { + return nil + } + if len(cond) > 0 { + query := fmt.Sprintf("DELETE FROM %s WHERE %s", + buildTableName(db, offline.TemporaryTableRecordTable), + strings.Join(cond, " AND ")) + return db.ExecContext(ctx, query, args...) + } + return nil +} diff --git a/internal/database/offline/store.go b/internal/database/offline/store.go index d675cd904..e33db4c52 100644 --- a/internal/database/offline/store.go +++ b/internal/database/offline/store.go @@ -17,6 +17,9 @@ type Store interface { TableSchema(ctx context.Context, opt TableSchemaOpt) (*types.DataTableSchema, error) Snapshot(ctx context.Context, opt SnapshotOpt) error + GetTemporaryTables(ctx context.Context, unixMilli int64) ([]string, error) + DropTemporaryTable(ctx context.Context, tableNames []string) error + Ping(ctx context.Context) error io.Closer } diff --git a/internal/database/offline/test_impl/export.go b/internal/database/offline/test_impl/export.go index 1ad95865c..04a6dae9b 100644 --- a/internal/database/offline/test_impl/export.go +++ b/internal/database/offline/test_impl/export.go @@ -4,10 +4,12 @@ import ( "bufio" "strings" "testing" + "time" + + "github.com/stretchr/testify/assert" "github.com/oom-ai/oomstore/internal/database/offline" "github.com/oom-ai/oomstore/pkg/oomstore/types" - "github.com/stretchr/testify/assert" ) func TestExport(t *testing.T, prepareStore PrepareStoreFn, destroyStore DestroyStoreFn) { @@ -88,9 +90,12 @@ func TestExport(t *testing.T, prepareStore PrepareStoreFn, destroyStore DestroyS }, } - for _, tc := range testCases { + t0 := time.Now().UnixMilli() + for i, tc := range testCases { t.Run(tc.description, func(t *testing.T) { result, err := store.Export(ctx, tc.opt) + assert.NoError(t, err) + values := make([][]interface{}, 0) for row := range result.Data { assert.NoError(t, row.Error) @@ -98,8 +103,28 @@ func TestExport(t *testing.T, prepareStore PrepareStoreFn, destroyStore DestroyS } assert.ElementsMatch(t, tc.expected, values) assert.NoError(t, err) + + tempTables, err := store.GetTemporaryTables(ctx, t0) + assert.NoError(t, err) + assert.Equal(t, 0, len(tempTables)) + + tempTables, err = store.GetTemporaryTables(ctx, time.Now().UnixMilli()) + assert.NoError(t, err) + assert.Equal(t, i+1, len(tempTables)) }) } + + t.Run("drop temporary table", func(t *testing.T) { + tempTables, err := store.GetTemporaryTables(ctx, time.Now().UnixMilli()) + assert.NoError(t, err) + assert.Equal(t, len(tempTables), len(tempTables)) + + assert.NoError(t, store.DropTemporaryTable(ctx, tempTables)) + + tempTables, err = store.GetTemporaryTables(ctx, time.Now().UnixMilli()) + assert.NoError(t, err) + assert.Equal(t, 0, len(tempTables)) + }) } func prepareFeaturesForExport() (batchFeatures types.FeatureList, streamFeatures types.FeatureList) { diff --git a/internal/database/offline/test_impl/import.go b/internal/database/offline/test_impl/import.go index a0f1dfd4f..7b3f6ee74 100644 --- a/internal/database/offline/test_impl/import.go +++ b/internal/database/offline/test_impl/import.go @@ -7,10 +7,10 @@ import ( "testing" "github.com/spf13/cast" + "github.com/stretchr/testify/assert" "github.com/oom-ai/oomstore/internal/database/offline" "github.com/oom-ai/oomstore/pkg/oomstore/types" - "github.com/stretchr/testify/assert" ) func TestImport(t *testing.T, prepareStore PrepareStoreFn, destroyStore DestroyStoreFn) { @@ -67,6 +67,8 @@ func TestImport(t *testing.T, prepareStore PrepareStoreFn, destroyStore DestroyS {Name: "price", ValueType: types.Int64, Group: group}, }}, }) + assert.NoError(t, err) + records := make([][]interface{}, 0) for row := range result.Data { assert.NoError(t, row.Error) diff --git a/internal/database/offline/test_impl/join.go b/internal/database/offline/test_impl/join.go index 4cc4fe9f3..1f7e1b7e9 100644 --- a/internal/database/offline/test_impl/join.go +++ b/internal/database/offline/test_impl/join.go @@ -7,6 +7,7 @@ import ( "sort" "strings" "testing" + "time" "github.com/spf13/cast" @@ -144,6 +145,10 @@ func TestJoin(t *testing.T, prepareStore PrepareStoreFn, destroyStore DestroySto for i := range expectedValues { assert.ElementsMatch(t, expectedValues[i], actualValues[i]) } + + tempTables, err := store.GetTemporaryTables(ctx, time.Now().UnixMilli()) + assert.NoError(t, err) + assert.Equal(t, 0, len(tempTables)) }) } } diff --git a/internal/database/offline/types.go b/internal/database/offline/types.go index ad5a97820..b467a49fc 100644 --- a/internal/database/offline/types.go +++ b/internal/database/offline/types.go @@ -6,6 +6,8 @@ import ( "github.com/oom-ai/oomstore/pkg/oomstore/types" ) +const TemporaryTableRecordTable = "temporary_table_records_table" + type ExportOpt struct { SnapshotTables map[int]string CdcTables map[int]string diff --git a/oomcli/cmd/gc.go b/oomcli/cmd/gc.go new file mode 100644 index 000000000..f96169189 --- /dev/null +++ b/oomcli/cmd/gc.go @@ -0,0 +1,56 @@ +package cmd + +import ( + "context" + "fmt" + "os" + + "github.com/spf13/cobra" +) + +type GcOption struct { + Force bool + UnixMilli int64 +} + +var gcOpt GcOption + +var gcCmd = &cobra.Command{ + Use: "gc", + Short: "gc temporary table", + Run: func(cmd *cobra.Command, args []string) { + ctx := context.Background() + oomStore := mustOpenOomStore(ctx, oomStoreCfg) + defer oomStore.Close() + + tableNames, err := oomStore.GetTemporaryTables(ctx, gcOpt.UnixMilli) + if err != nil { + exitf("gc failed: %+v", err) + } + if len(tableNames) == 0 { + return + } + + if !gcOpt.Force { + fmt.Fprintln(os.Stderr, "The following tables will be deleted:") + for _, name := range tableNames { + fmt.Fprintln(os.Stderr, name) + } + } else { + if err := oomStore.DropTemporaryTables(ctx, tableNames); err != nil { + exitf("gc failed: %+v", err) + } + } + }, +} + +func init() { + rootCmd.AddCommand(gcCmd) + + flags := gcCmd.Flags() + + flags.Int64VarP(&gcOpt.UnixMilli, "unix-milli", "u", 0, "any temporary tables before this time will be deleted") + _ = gcCmd.MarkFlagRequired("unix-milli") + + flags.BoolVar(&gcOpt.Force, "force", false, "force run gc") +} diff --git a/oomcli/test/test_temporary_table.sh b/oomcli/test/test_temporary_table.sh new file mode 100755 index 000000000..e4a9c33e8 --- /dev/null +++ b/oomcli/test/test_temporary_table.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env bash + +SDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) && cd "$SDIR" || exit 1 +source ./util.sh + +init_store +register_features + +import_student_sample +oomcli push --entity-key 1 --group user-click --feature last_5_click_posts=1,2,3,4,5 --feature number_of_user_starred_posts=10 + +oomcli_export_temporary_table() { + case="temporary table in export" + + t0=${1:-$(perl -MTime::HiRes=time -E 'say int(time * 1000)')} + oomcli export --feature student.name,student.gender,student.age,user-click.last_5_click_posts,user-click.number_of_user_starred_posts --unix-milli $t0 -o csv 2>&1 >> /dev/null + + t1=${1:-$(perl -MTime::HiRes=time -E 'say int(time * 1000)')} + oomcli export --feature student.name,student.gender,student.age,user-click.last_5_click_posts,user-click.number_of_user_starred_posts --unix-milli $t0 -o csv 2>&1 >> /dev/null + + t2=${1:-$(perl -MTime::HiRes=time -E 'say int(time * 1000)')} + + actual=$(oomcli gc --unix-milli $t0 2>&1 |wc -l) + assert_eq "$case" 0 "$actual" + + actual=$(oomcli gc --unix-milli $t1 2>&1 |wc -l) + assert_eq "$case" 2 "$actual" + + actual=$(oomcli gc --unix-milli $t2 2>&1 |wc -l) + assert_eq "$case" 3 "$actual" + + oomcli gc --unix-milli $t0 --force + actual=$(oomcli gc --unix-milli $t2 2>&1 |wc -l) + assert_eq "$case" 3 "$actual" + + oomcli gc --unix-milli $t1 --force + actual=$(oomcli gc --unix-milli $t2 2>&1 |wc -l) + assert_eq "$case" 2 "$actual" + + oomcli gc --unix-milli $t2 --force + actual=$(oomcli gc --unix-milli $t2 2>&1 |wc -l) + assert_eq "$case" 0 "$actual" +} + +oomcli_join_temporary_table() { + t0=${1:-$(perl -MTime::HiRes=time -E 'say int(time * 1000)')} + + # clean up the tmp file + trap 'command rm -rf entity_rows.csv entity_rows.csv' EXIT INT TERM HUP + cat <<-EOF > entity_rows.csv +entity_key,unix_milli +1,$t0 +2,$t0 +EOF + + case="temporary table in join" + oomcli join \ + --feature student.name,student.gender,student.age \ + --input-file entity_rows.csv \ + --output csv > /dev/null + + oomcli join \ + --feature student.name,student.gender,student.age \ + --input-file entity_rows.csv \ + --output csv > /dev/null + + oomcli join \ + --feature student.name,student.gender,student.age \ + --input-file entity_rows.csv \ + --output csv > /dev/null + + + t3=${1:-$(perl -MTime::HiRes=time -E 'say int(time * 1000)')} + actual=$(oomcli gc --unix-milli $t3 2>&1 |wc -l) + assert_eq "$case" 0 "$actual" +} + +main() { + oomcli_export_temporary_table + oomcli_join_temporary_table +} + +main diff --git a/pkg/oomstore/gc.go b/pkg/oomstore/gc.go new file mode 100644 index 000000000..7e0dfa0ca --- /dev/null +++ b/pkg/oomstore/gc.go @@ -0,0 +1,13 @@ +package oomstore + +import ( + "context" +) + +func (o *OomStore) DropTemporaryTables(ctx context.Context, tableNames []string) error { + return o.offline.DropTemporaryTable(ctx, tableNames) +} + +func (o *OomStore) GetTemporaryTables(ctx context.Context, unixMilli int64) ([]string, error) { + return o.offline.GetTemporaryTables(ctx, unixMilli) +}