diff --git a/down.go b/down.go index c58c2144c..bd63804a7 100644 --- a/down.go +++ b/down.go @@ -1,12 +1,19 @@ package goose import ( + "context" "database/sql" "fmt" ) // Down rolls back a single migration from the current version. func Down(db *sql.DB, dir string, opts ...OptionsFunc) error { + ctx := context.Background() + return DownContext(ctx, db, dir, opts...) +} + +// DownContext rolls back a single migration from the current version. +func DownContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -21,9 +28,9 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error { } currentVersion := migrations[len(migrations)-1].Version // Migrate only the latest migration down. - return downToNoVersioning(db, migrations, currentVersion-1) + return downToNoVersioning(ctx, db, migrations, currentVersion-1) } - currentVersion, err := GetDBVersion(db) + currentVersion, err := GetDBVersionContext(ctx, db) if err != nil { return err } @@ -31,11 +38,17 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error { if err != nil { return fmt.Errorf("no migration %v", currentVersion) } - return current.Down(db) + return current.DownContext(ctx, db) } // DownTo rolls back migrations to a specific version. func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { + ctx := context.Background() + return DownToContext(ctx, db, dir, version, opts...) +} + +// DownToContext rolls back migrations to a specific version. +func DownToContext(ctx context.Context, db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -45,11 +58,11 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { return err } if option.noVersioning { - return downToNoVersioning(db, migrations, version) + return downToNoVersioning(ctx, db, migrations, version) } for { - currentVersion, err := GetDBVersion(db) + currentVersion, err := GetDBVersionContext(ctx, db) if err != nil { return err } @@ -69,7 +82,7 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { return nil } - if err = current.Down(db); err != nil { + if err = current.DownContext(ctx, db); err != nil { return err } } @@ -77,7 +90,7 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { // downToNoVersioning applies down migrations down to, but not including, the // target version. -func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error { +func downToNoVersioning(ctx context.Context, db *sql.DB, migrations Migrations, version int64) error { var finalVersion int64 for i := len(migrations) - 1; i >= 0; i-- { if version >= migrations[i].Version { @@ -85,7 +98,7 @@ func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error break } migrations[i].noVersioning = true - if err := migrations[i].Down(db); err != nil { + if err := migrations[i].DownContext(ctx, db); err != nil { return err } } diff --git a/goose.go b/goose.go index 0dbfd67da..e952041b0 100644 --- a/goose.go +++ b/goose.go @@ -1,6 +1,7 @@ package goose import ( + "context" "database/sql" "fmt" "io/fs" @@ -39,22 +40,34 @@ func SetBaseFS(fsys fs.FS) { // Run runs a goose command. func Run(command string, db *sql.DB, dir string, args ...string) error { - return run(command, db, dir, args) + ctx := context.Background() + return RunContext(ctx, command, db, dir, args...) } -// Run runs a goose command with options. +// RunContext runs a goose command. +func RunContext(ctx context.Context, command string, db *sql.DB, dir string, args ...string) error { + return run(ctx, command, db, dir, args) +} + +// RunWithOptions runs a goose command with options. func RunWithOptions(command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error { - return run(command, db, dir, args, options...) + ctx := context.Background() + return RunWithOptionsContext(ctx, command, db, dir, args, options...) +} + +// RunWithOptionsContext runs a goose command with options. +func RunWithOptionsContext(ctx context.Context, command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error { + return run(ctx, command, db, dir, args, options...) } -func run(command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error { +func run(ctx context.Context, command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error { switch command { case "up": - if err := Up(db, dir, options...); err != nil { + if err := UpContext(ctx, db, dir, options...); err != nil { return err } case "up-by-one": - if err := UpByOne(db, dir, options...); err != nil { + if err := UpByOneContext(ctx, db, dir, options...); err != nil { return err } case "up-to": @@ -66,7 +79,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio if err != nil { return fmt.Errorf("version must be a number (got '%s')", args[0]) } - if err := UpTo(db, dir, version, options...); err != nil { + if err := UpToContext(ctx, db, dir, version, options...); err != nil { return err } case "create": @@ -82,7 +95,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio return err } case "down": - if err := Down(db, dir, options...); err != nil { + if err := DownContext(ctx, db, dir, options...); err != nil { return err } case "down-to": @@ -94,7 +107,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio if err != nil { return fmt.Errorf("version must be a number (got '%s')", args[0]) } - if err := DownTo(db, dir, version, options...); err != nil { + if err := DownToContext(ctx, db, dir, version, options...); err != nil { return err } case "fix": @@ -102,19 +115,19 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio return err } case "redo": - if err := Redo(db, dir, options...); err != nil { + if err := RedoContext(ctx, db, dir, options...); err != nil { return err } case "reset": - if err := Reset(db, dir, options...); err != nil { + if err := ResetContext(ctx, db, dir, options...); err != nil { return err } case "status": - if err := Status(db, dir, options...); err != nil { + if err := StatusContext(ctx, db, dir, options...); err != nil { return err } case "version": - if err := Version(db, dir, options...); err != nil { + if err := VersionContext(ctx, db, dir, options...); err != nil { return err } default: diff --git a/migrate.go b/migrate.go index 0cf96a860..fb1500af1 100644 --- a/migrate.go +++ b/migrate.go @@ -296,6 +296,12 @@ func versionFilter(v, current, target int64) bool { // Create and initialize the DB version table if it doesn't exist. func EnsureDBVersion(db *sql.DB) (int64, error) { ctx := context.Background() + return EnsureDBVersionContext(ctx, db) +} + +// EnsureDBVersionContext retrieves the current version for this DB. +// Create and initialize the DB version table if it doesn't exist. +func EnsureDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) { dbMigrations, err := store.ListMigrations(ctx, db, TableName()) if err != nil { return 0, createVersionTable(ctx, db) @@ -332,7 +338,7 @@ func EnsureDBVersion(db *sql.DB) (int64, error) { // createVersionTable creates the db version table and inserts the // initial 0 value into it. func createVersionTable(ctx context.Context, db *sql.DB) error { - txn, err := db.Begin() + txn, err := db.BeginTx(ctx, nil) if err != nil { return err } @@ -349,7 +355,13 @@ func createVersionTable(ctx context.Context, db *sql.DB) error { // GetDBVersion is an alias for EnsureDBVersion, but returns -1 in error. func GetDBVersion(db *sql.DB) (int64, error) { - version, err := EnsureDBVersion(db) + ctx := context.Background() + return GetDBVersionContext(ctx, db) +} + +// GetDBVersionContext is an alias for EnsureDBVersion, but returns -1 in error. +func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) { + version, err := EnsureDBVersionContext(ctx, db) if err != nil { return -1, err } diff --git a/migration.go b/migration.go index 727ecc3d9..8d4362c01 100644 --- a/migration.go +++ b/migration.go @@ -40,6 +40,11 @@ func (m *Migration) String() string { // Up runs an up migration. func (m *Migration) Up(db *sql.DB) error { ctx := context.Background() + return m.UpContext(ctx, db) +} + +// UpContext runs an up migration. +func (m *Migration) UpContext(ctx context.Context, db *sql.DB) error { if err := m.run(ctx, db, true); err != nil { return err } @@ -49,6 +54,11 @@ func (m *Migration) Up(db *sql.DB) error { // Down runs a down migration. func (m *Migration) Down(db *sql.DB) error { ctx := context.Background() + return m.DownContext(ctx, db) +} + +// DownContext runs a down migration. +func (m *Migration) DownContext(ctx context.Context, db *sql.DB) error { if err := m.run(ctx, db, false); err != nil { return err } @@ -163,7 +173,7 @@ func runGoMigration( if fn == nil && !recordVersion { return nil } - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } diff --git a/migration_sql.go b/migration_sql.go index f74b70d75..1c6d4d0bf 100644 --- a/migration_sql.go +++ b/migration_sql.go @@ -29,7 +29,7 @@ func runSQLMigration( verboseInfo("Begin transaction") - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } diff --git a/redo.go b/redo.go index c485f9f67..ed3ff677f 100644 --- a/redo.go +++ b/redo.go @@ -1,11 +1,18 @@ package goose import ( + "context" "database/sql" ) // Redo rolls back the most recently applied migration, then runs it again. func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error { + ctx := context.Background() + return RedoContext(ctx, db, dir, opts...) +} + +// RedoContext rolls back the most recently applied migration, then runs it again. +func RedoContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -23,7 +30,7 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error { } currentVersion = migrations[len(migrations)-1].Version } else { - if currentVersion, err = GetDBVersion(db); err != nil { + if currentVersion, err = GetDBVersionContext(ctx, db); err != nil { return err } } @@ -34,10 +41,10 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error { } current.noVersioning = option.noVersioning - if err := current.Down(db); err != nil { + if err := current.DownContext(ctx, db); err != nil { return err } - if err := current.Up(db); err != nil { + if err := current.UpContext(ctx, db); err != nil { return err } return nil diff --git a/reset.go b/reset.go index e14d36d22..274c539d4 100644 --- a/reset.go +++ b/reset.go @@ -10,6 +10,11 @@ import ( // Reset rolls back all migrations func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { ctx := context.Background() + return ResetContext(ctx, db, dir, opts...) +} + +// ResetContext rolls back all migrations +func ResetContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -19,7 +24,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { return fmt.Errorf("failed to collect migrations: %w", err) } if option.noVersioning { - return DownTo(db, dir, minVersion, opts...) + return DownToContext(ctx, db, dir, minVersion, opts...) } statuses, err := dbMigrationsStatus(ctx, db) @@ -32,7 +37,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { if !statuses[migration.Version] { continue } - if err = migration.Down(db); err != nil { + if err = migration.DownContext(ctx, db); err != nil { return fmt.Errorf("failed to db-down: %w", err) } } diff --git a/status.go b/status.go index dd1f16c4c..4f0943a3d 100644 --- a/status.go +++ b/status.go @@ -12,6 +12,11 @@ import ( // Status prints the status of all migrations. func Status(db *sql.DB, dir string, opts ...OptionsFunc) error { ctx := context.Background() + return StatusContext(ctx, db, dir, opts...) +} + +// StatusContext prints the status of all migrations. +func StatusContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -30,7 +35,7 @@ func Status(db *sql.DB, dir string, opts ...OptionsFunc) error { } // must ensure that the version table exists if we're running on a pristine DB - if _, err := EnsureDBVersion(db); err != nil { + if _, err := EnsureDBVersionContext(ctx, db); err != nil { return fmt.Errorf("failed to ensure DB version: %w", err) } diff --git a/up.go b/up.go index 3dc011d03..64b88dd4c 100644 --- a/up.go +++ b/up.go @@ -35,6 +35,10 @@ func withApplyUpByOne() OptionsFunc { // UpTo migrates up to a specific version. func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { ctx := context.Background() + return UpToContext(ctx, db, dir, version, opts...) +} + +func UpToContext(ctx context.Context, db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -53,10 +57,10 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { // migration over and over. version = foundMigrations[0].Version } - return upToNoVersioning(db, foundMigrations, version) + return upToNoVersioning(ctx, db, foundMigrations, version) } - if _, err := EnsureDBVersion(db); err != nil { + if _, err := EnsureDBVersionContext(ctx, db); err != nil { return err } dbMigrations, err := listAllDBVersions(ctx, db) @@ -103,7 +107,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { var current int64 for _, m := range migrationsToApply { - if err := m.Up(db); err != nil { + if err := m.UpContext(ctx, db); err != nil { return err } if option.applyUpByOne { @@ -112,7 +116,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { current = m.Version } if len(migrationsToApply) == 0 { - current, err = GetDBVersion(db) + current, err = GetDBVersionContext(ctx, db) if err != nil { return err } @@ -130,14 +134,14 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { // upToNoVersioning applies up migrations up to, and including, the // target version. -func upToNoVersioning(db *sql.DB, migrations Migrations, version int64) error { +func upToNoVersioning(ctx context.Context, db *sql.DB, migrations Migrations, version int64) error { var finalVersion int64 for _, current := range migrations { if current.Version > version { break } current.noVersioning = true - if err := current.Up(db); err != nil { + if err := current.UpContext(ctx, db); err != nil { return err } finalVersion = current.Version @@ -148,13 +152,25 @@ func upToNoVersioning(db *sql.DB, migrations Migrations, version int64) error { // Up applies all available migrations. func Up(db *sql.DB, dir string, opts ...OptionsFunc) error { - return UpTo(db, dir, maxVersion, opts...) + ctx := context.Background() + return UpContext(ctx, db, dir, opts...) +} + +// UpContext applies all available migrations. +func UpContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { + return UpToContext(ctx, db, dir, maxVersion, opts...) } // UpByOne migrates up by a single version. func UpByOne(db *sql.DB, dir string, opts ...OptionsFunc) error { + ctx := context.Background() + return UpByOneContext(ctx, db, dir, opts...) +} + +// UpByOneContext migrates up by a single version. +func UpByOneContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { opts = append(opts, withApplyUpByOne()) - return UpTo(db, dir, maxVersion, opts...) + return UpToContext(ctx, db, dir, maxVersion, opts...) } // listAllDBVersions returns a list of all migrations, ordered ascending. diff --git a/version.go b/version.go index 47765f728..89d0dccd1 100644 --- a/version.go +++ b/version.go @@ -1,12 +1,19 @@ package goose import ( + "context" "database/sql" "fmt" ) // Version prints the current version of the database. func Version(db *sql.DB, dir string, opts ...OptionsFunc) error { + ctx := context.Background() + return VersionContext(ctx, db, dir, opts...) +} + +// VersionContext prints the current version of the database. +func VersionContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -24,7 +31,7 @@ func Version(db *sql.DB, dir string, opts ...OptionsFunc) error { return nil } - current, err := GetDBVersion(db) + current, err := GetDBVersionContext(ctx, db) if err != nil { return err }