From d5cfd14e408db1ffd824caa99935b2b51f16f37d Mon Sep 17 00:00:00 2001 From: Diogo Monteiro Date: Tue, 2 Aug 2022 14:05:26 -0400 Subject: [PATCH] expose context for goose sql migrations --- down.go | 35 +++++++++++++++++++++++++---------- migration.go | 35 +++++++++++++++++++++++++---------- migration_sql.go | 23 ++++++++++++----------- redo.go | 16 ++++++++++++---- reset.go | 14 +++++++++++--- up.go | 46 +++++++++++++++++++++++++++++++++++----------- 6 files changed, 120 insertions(+), 49 deletions(-) diff --git a/down.go b/down.go index c58c2144c..c8d7d3286 100644 --- a/down.go +++ b/down.go @@ -1,12 +1,13 @@ package goose import ( + "context" "database/sql" "fmt" ) -// Down rolls back a single migration from the current version. -func Down(db *sql.DB, dir string, opts ...OptionsFunc) error { +// DownCtx rolls back a single migration from the current version. +func DownCtx(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -21,7 +22,7 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error { } currentVersion := migrations[len(migrations)-1].Version // Migrate only the latest migration down. - return downToNoVersioning(db, migrations, currentVersion-1) + return downToNoVersioning(ctx, db, migrations, currentVersion-1) } currentVersion, err := GetDBVersion(db) if err != nil { @@ -31,11 +32,18 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error { if err != nil { return fmt.Errorf("no migration %v", currentVersion) } - return current.Down(db) + return current.DownCtx(ctx, db) } -// DownTo rolls back migrations to a specific version. -func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { +// Down rolls back a single migration from the current version. +// +// Down uses context.Background internally; to specify the context, use DownCtx. +func Down(db *sql.DB, dir string, opts ...OptionsFunc) error { + return DownCtx(context.Background(), db, dir, opts...) +} + +// DownToCtx rolls back migrations to a specific version. +func DownToCtx(ctx context.Context, db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -45,7 +53,7 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { return err } if option.noVersioning { - return downToNoVersioning(db, migrations, version) + return downToNoVersioning(ctx, db, migrations, version) } for { @@ -69,15 +77,22 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { return nil } - if err = current.Down(db); err != nil { + if err = current.DownCtx(ctx, db); err != nil { return err } } } +// DownTo rolls back migrations to a specific version. +// +// DownTo uses context.Background internally; to specify the context, use DownToCtx. +func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { + return DownToCtx(context.Background(), db, dir, version, opts...) +} + // downToNoVersioning applies down migrations down to, but not including, the // target version. -func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error { +func downToNoVersioning(ctx context.Context, db *sql.DB, migrations Migrations, version int64) error { var finalVersion int64 for i := len(migrations) - 1; i >= 0; i-- { if version >= migrations[i].Version { @@ -85,7 +100,7 @@ func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error break } migrations[i].noVersioning = true - if err := migrations[i].Down(db); err != nil { + if err := migrations[i].DownCtx(ctx, db); err != nil { return err } } diff --git a/migration.go b/migration.go index 775dc779c..455f5b7ec 100644 --- a/migration.go +++ b/migration.go @@ -1,6 +1,7 @@ package goose import ( + "context" "database/sql" "errors" "fmt" @@ -33,23 +34,37 @@ func (m *Migration) String() string { return fmt.Sprintf(m.Source) } +// UpCtx runs an up migration. +func (m *Migration) UpCtx(ctx context.Context, db *sql.DB) error { + if err := m.run(ctx, db, true); err != nil { + return err + } + return nil +} + // Up runs an up migration. +// +// Up uses context.Background internally; to specify the context, use UpCtx. func (m *Migration) Up(db *sql.DB) error { - if err := m.run(db, true); err != nil { + return m.UpCtx(context.Background(), db) +} + +// DownCtx runs a down migration. +func (m *Migration) DownCtx(ctx context.Context, db *sql.DB) error { + if err := m.run(ctx, db, false); err != nil { return err } return nil } // Down runs a down migration. +// +// Down uses context.Background internally; to specify the context, use DownCtx. func (m *Migration) Down(db *sql.DB) error { - if err := m.run(db, false); err != nil { - return err - } - return nil + return m.DownCtx(context.Background(), db) } -func (m *Migration) run(db *sql.DB, direction bool) error { +func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error { switch filepath.Ext(m.Source) { case ".sql": f, err := baseFS.Open(m.Source) @@ -63,7 +78,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error { return fmt.Errorf("ERROR %v: failed to parse SQL migration file: %w", filepath.Base(m.Source), err) } - if err := runSQLMigration(db, statements, useTx, m.Version, direction, m.noVersioning); err != nil { + if err := runSQLMigration(ctx, db, statements, useTx, m.Version, direction, m.noVersioning); err != nil { return fmt.Errorf("ERROR %v: failed to run SQL migration: %w", filepath.Base(m.Source), err) } @@ -77,7 +92,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error { if !m.Registered { return fmt.Errorf("ERROR %v: failed to run Go migration: Go functions must be registered and built into a custom binary (see https://github.com/pressly/goose/tree/master/examples/go-migrations)", m.Source) } - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("ERROR failed to begin transaction: %w", err) } @@ -96,12 +111,12 @@ func (m *Migration) run(db *sql.DB, direction bool) error { } if !m.noVersioning { if direction { - if _, err := tx.Exec(GetDialect().insertVersionSQL(), m.Version, direction); err != nil { + if _, err := tx.ExecContext(ctx, GetDialect().insertVersionSQL(), m.Version, direction); err != nil { tx.Rollback() return fmt.Errorf("ERROR failed to execute transaction: %w", err) } } else { - if _, err := tx.Exec(GetDialect().deleteVersionSQL(), m.Version); err != nil { + if _, err := tx.ExecContext(ctx, GetDialect().deleteVersionSQL(), m.Version); err != nil { tx.Rollback() return fmt.Errorf("ERROR failed to execute transaction: %w", err) } diff --git a/migration_sql.go b/migration_sql.go index 359ebf6be..d9a84672a 100644 --- a/migration_sql.go +++ b/migration_sql.go @@ -1,6 +1,7 @@ package goose import ( + "context" "database/sql" "fmt" "regexp" @@ -15,20 +16,20 @@ import ( // // All statements following an Up or Down directive are grouped together // until another direction directive is found. -func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direction bool, noVersioning bool) error { +func runSQLMigration(ctx context.Context, db *sql.DB, statements []string, useTx bool, v int64, direction bool, noVersioning bool) error { if useTx { // TRANSACTION. verboseInfo("Begin transaction") - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } for _, query := range statements { verboseInfo("Executing statement: %s\n", clearStatement(query)) - if err = execQuery(tx.Exec, query); err != nil { + if err = execQuery(ctx, tx.ExecContext, query); err != nil { verboseInfo("Rollback transaction") tx.Rollback() return fmt.Errorf("failed to execute SQL query %q: %w", clearStatement(query), err) @@ -37,13 +38,13 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc if !noVersioning { if direction { - if err := execQuery(tx.Exec, GetDialect().insertVersionSQL(), v, direction); err != nil { + if err := execQuery(ctx, tx.ExecContext, GetDialect().insertVersionSQL(), v, direction); err != nil { verboseInfo("Rollback transaction") tx.Rollback() return fmt.Errorf("failed to insert new goose version: %w", err) } } else { - if err := execQuery(tx.Exec, GetDialect().deleteVersionSQL(), v); err != nil { + if err := execQuery(ctx, tx.ExecContext, GetDialect().deleteVersionSQL(), v); err != nil { verboseInfo("Rollback transaction") tx.Rollback() return fmt.Errorf("failed to delete goose version: %w", err) @@ -62,17 +63,17 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc // NO TRANSACTION. for _, query := range statements { verboseInfo("Executing statement: %s", clearStatement(query)) - if err := execQuery(db.Exec, query); err != nil { + if err := execQuery(ctx, db.ExecContext, query); err != nil { return fmt.Errorf("failed to execute SQL query %q: %w", clearStatement(query), err) } } if !noVersioning { if direction { - if err := execQuery(db.Exec, GetDialect().insertVersionSQL(), v, direction); err != nil { + if err := execQuery(ctx, db.ExecContext, GetDialect().insertVersionSQL(), v, direction); err != nil { return fmt.Errorf("failed to insert new goose version: %w", err) } } else { - if err := execQuery(db.Exec, GetDialect().deleteVersionSQL(), v); err != nil { + if err := execQuery(ctx, db.ExecContext, GetDialect().deleteVersionSQL(), v); err != nil { return fmt.Errorf("failed to delete goose version: %w", err) } } @@ -81,16 +82,16 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc return nil } -func execQuery(fn func(string, ...interface{}) (sql.Result, error), query string, args ...interface{}) error { +func execQuery(ctx context.Context, fn func(context.Context, string, ...interface{}) (sql.Result, error), query string, args ...interface{}) error { if !verbose { - _, err := fn(query, args...) + _, err := fn(ctx, query, args...) return err } ch := make(chan error) go func() { - _, err := fn(query, args...) + _, err := fn(ctx, query, args...) ch <- err }() diff --git a/redo.go b/redo.go index c485f9f67..c584bdbea 100644 --- a/redo.go +++ b/redo.go @@ -1,11 +1,12 @@ package goose import ( + "context" "database/sql" ) -// Redo rolls back the most recently applied migration, then runs it again. -func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error { +// RedoCtx rolls back the most recently applied migration, then runs it again. +func RedoCtx(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -34,11 +35,18 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error { } current.noVersioning = option.noVersioning - if err := current.Down(db); err != nil { + if err := current.DownCtx(ctx, db); err != nil { return err } - if err := current.Up(db); err != nil { + if err := current.UpCtx(ctx, db); err != nil { return err } return nil } + +// Redo rolls back the most recently applied migration, then runs it again. +// +// Redo uses context.Background internally; to specify the context, use RedoCtx. +func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error { + return RedoCtx(context.Background(), db, dir, opts...) +} diff --git a/reset.go b/reset.go index 258841fad..20d01bd09 100644 --- a/reset.go +++ b/reset.go @@ -1,13 +1,14 @@ package goose import ( + "context" "database/sql" "fmt" "sort" ) -// Reset rolls back all migrations -func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { +// ResetCtx rolls back all migrations +func ResetCtx(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -30,7 +31,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { if !statuses[migration.Version] { continue } - if err = migration.Down(db); err != nil { + if err = migration.DownCtx(ctx, db); err != nil { return fmt.Errorf("failed to db-down: %w", err) } } @@ -38,6 +39,13 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { return nil } +// Reset rolls back all migrations +// +// Reset uses context.Background internally; to specify the context, use ResetCtx. +func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { + return ResetCtx(context.Background(), db, dir, opts...) +} + func dbMigrationsStatus(db *sql.DB) (map[int64]bool, error) { rows, err := GetDialect().dbVersionQuery(db) if err != nil { diff --git a/up.go b/up.go index d8d19cfe7..6d506ea59 100644 --- a/up.go +++ b/up.go @@ -1,6 +1,7 @@ package goose import ( + "context" "database/sql" "errors" "fmt" @@ -28,8 +29,8 @@ func withApplyUpByOne() OptionsFunc { return func(o *options) { o.applyUpByOne = true } } -// UpTo migrates up to a specific version. -func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { +// UpToCtx migrates up to a specific version. +func UpToCtx(ctx context.Context, db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -48,7 +49,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { // migration over and over. version = foundMigrations[0].Version } - return upToNoVersioning(db, foundMigrations, version) + return upToNoVersioning(ctx, db, foundMigrations, version) } if _, err := EnsureDBVersion(db); err != nil { @@ -76,6 +77,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { if option.allowMissing { return upWithMissing( + ctx, db, missingMigrations, foundMigrations, @@ -98,7 +100,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { } return fmt.Errorf("failed to find next migration: %v", err) } - if err := next.Up(db); err != nil { + if err := next.UpCtx(ctx, db); err != nil { return err } if option.applyUpByOne { @@ -116,16 +118,23 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { return nil } +// UpTo migrates up to a specific version. +// +// UpTo uses context.Background internally; to specify the context, use UpToCtx. +func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { + return UpToCtx(context.Background(), db, dir, version, opts...) +} + // upToNoVersioning applies up migrations up to, and including, the // target version. -func upToNoVersioning(db *sql.DB, migrations Migrations, version int64) error { +func upToNoVersioning(ctx context.Context, db *sql.DB, migrations Migrations, version int64) error { var finalVersion int64 for _, current := range migrations { if current.Version > version { break } current.noVersioning = true - if err := current.Up(db); err != nil { + if err := current.UpCtx(ctx, db); err != nil { return err } finalVersion = current.Version @@ -135,6 +144,7 @@ func upToNoVersioning(db *sql.DB, migrations Migrations, version int64) error { } func upWithMissing( + ctx context.Context, db *sql.DB, missingMigrations Migrations, foundMigrations Migrations, @@ -148,7 +158,7 @@ func upWithMissing( // Apply all missing migrations first. for _, missing := range missingMigrations { - if err := missing.Up(db); err != nil { + if err := missing.UpCtx(ctx, db); err != nil { return err } // Apply one migration and return early. @@ -183,7 +193,7 @@ func upWithMissing( if lookupApplied[found.Version] { continue } - if err := found.Up(db); err != nil { + if err := found.UpCtx(ctx, db); err != nil { return err } if option.applyUpByOne { @@ -205,15 +215,29 @@ func upWithMissing( return nil } +// UpCtx applies all available migrations. +func UpCtx(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { + return UpToCtx(ctx, db, dir, maxVersion, opts...) +} + // Up applies all available migrations. +// +// Up uses context.Background internally; to specify the context, use UpCtx. func Up(db *sql.DB, dir string, opts ...OptionsFunc) error { - return UpTo(db, dir, maxVersion, opts...) + return UpCtx(context.Background(), db, dir, opts...) +} + +// UpByOneCtx migrates up by a single version. +func UpByOneCtx(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { + opts = append(opts, withApplyUpByOne()) + return UpToCtx(ctx, db, dir, maxVersion, opts...) } // UpByOne migrates up by a single version. +// +// UpByOne uses context.Background internally; to specify the context, use UpByOneCtx. func UpByOne(db *sql.DB, dir string, opts ...OptionsFunc) error { - opts = append(opts, withApplyUpByOne()) - return UpTo(db, dir, maxVersion, opts...) + return UpByOneCtx(context.Background(), db, dir, opts...) } // listAllDBVersions returns a list of all migrations, ordered ascending.