Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow setting context #517

Merged
merged 3 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions down.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package goose

import (
"context"
"database/sql"
"fmt"
)

// Down rolls back a single migration from the current version.
func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return DownContext(ctx, db, dir, opts...)
}

// DownContext rolls back a single migration from the current version.
func DownContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand All @@ -21,21 +28,27 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
}
currentVersion := migrations[len(migrations)-1].Version
// Migrate only the latest migration down.
return downToNoVersioning(db, migrations, currentVersion-1)
return downToNoVersioning(ctx, db, migrations, currentVersion-1)
}
currentVersion, err := GetDBVersion(db)
currentVersion, err := GetDBVersionContext(ctx, db)
if err != nil {
return err
}
current, err := migrations.Current(currentVersion)
if err != nil {
return fmt.Errorf("no migration %v", currentVersion)
}
return current.Down(db)
return current.DownContext(ctx, db)
}

// DownTo rolls back migrations to a specific version.
func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
ctx := context.Background()
return DownToContext(ctx, db, dir, version, opts...)
}

// DownToContext rolls back migrations to a specific version.
func DownToContext(ctx context.Context, db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand All @@ -45,11 +58,11 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
return err
}
if option.noVersioning {
return downToNoVersioning(db, migrations, version)
return downToNoVersioning(ctx, db, migrations, version)
}

for {
currentVersion, err := GetDBVersion(db)
currentVersion, err := GetDBVersionContext(ctx, db)
if err != nil {
return err
}
Expand All @@ -69,23 +82,23 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
return nil
}

if err = current.Down(db); err != nil {
if err = current.DownContext(ctx, db); err != nil {
return err
}
}
}

// downToNoVersioning applies down migrations down to, but not including, the
// target version.
func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error {
func downToNoVersioning(ctx context.Context, db *sql.DB, migrations Migrations, version int64) error {
var finalVersion int64
for i := len(migrations) - 1; i >= 0; i-- {
if version >= migrations[i].Version {
finalVersion = migrations[i].Version
break
}
migrations[i].noVersioning = true
if err := migrations[i].Down(db); err != nil {
if err := migrations[i].DownContext(ctx, db); err != nil {
return err
}
}
Expand Down
39 changes: 26 additions & 13 deletions goose.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package goose

import (
"context"
"database/sql"
"fmt"
"io/fs"
Expand Down Expand Up @@ -39,22 +40,34 @@ func SetBaseFS(fsys fs.FS) {

// Run runs a goose command.
func Run(command string, db *sql.DB, dir string, args ...string) error {
return run(command, db, dir, args)
ctx := context.Background()
return RunContext(ctx, command, db, dir, args...)
}

// Run runs a goose command with options.
// RunContext runs a goose command.
func RunContext(ctx context.Context, command string, db *sql.DB, dir string, args ...string) error {
return run(ctx, command, db, dir, args)
}

// RunWithOptions runs a goose command with options.
func RunWithOptions(command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
return run(command, db, dir, args, options...)
ctx := context.Background()
return RunWithOptionsContext(ctx, command, db, dir, args, options...)
}

// RunWithOptionsContext runs a goose command with options.
func RunWithOptionsContext(ctx context.Context, command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
return run(ctx, command, db, dir, args, options...)
}

func run(command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
func run(ctx context.Context, command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
switch command {
case "up":
if err := Up(db, dir, options...); err != nil {
if err := UpContext(ctx, db, dir, options...); err != nil {
return err
}
case "up-by-one":
if err := UpByOne(db, dir, options...); err != nil {
if err := UpByOneContext(ctx, db, dir, options...); err != nil {
return err
}
case "up-to":
Expand All @@ -66,7 +79,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio
if err != nil {
return fmt.Errorf("version must be a number (got '%s')", args[0])
}
if err := UpTo(db, dir, version, options...); err != nil {
if err := UpToContext(ctx, db, dir, version, options...); err != nil {
return err
}
case "create":
Expand All @@ -82,7 +95,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio
return err
}
case "down":
if err := Down(db, dir, options...); err != nil {
if err := DownContext(ctx, db, dir, options...); err != nil {
return err
}
case "down-to":
Expand All @@ -94,27 +107,27 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio
if err != nil {
return fmt.Errorf("version must be a number (got '%s')", args[0])
}
if err := DownTo(db, dir, version, options...); err != nil {
if err := DownToContext(ctx, db, dir, version, options...); err != nil {
return err
}
case "fix":
if err := Fix(dir); err != nil {
return err
}
case "redo":
if err := Redo(db, dir, options...); err != nil {
if err := RedoContext(ctx, db, dir, options...); err != nil {
return err
}
case "reset":
if err := Reset(db, dir, options...); err != nil {
if err := ResetContext(ctx, db, dir, options...); err != nil {
return err
}
case "status":
if err := Status(db, dir, options...); err != nil {
if err := StatusContext(ctx, db, dir, options...); err != nil {
return err
}
case "version":
if err := Version(db, dir, options...); err != nil {
if err := VersionContext(ctx, db, dir, options...); err != nil {
return err
}
default:
Expand Down
16 changes: 14 additions & 2 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,12 @@ func versionFilter(v, current, target int64) bool {
// Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersion(db *sql.DB) (int64, error) {
ctx := context.Background()
return EnsureDBVersionContext(ctx, db)
}

// EnsureDBVersionContext retrieves the current version for this DB.
// Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
dbMigrations, err := store.ListMigrations(ctx, db, TableName())
if err != nil {
return 0, createVersionTable(ctx, db)
Expand Down Expand Up @@ -332,7 +338,7 @@ func EnsureDBVersion(db *sql.DB) (int64, error) {
// createVersionTable creates the db version table and inserts the
// initial 0 value into it.
func createVersionTable(ctx context.Context, db *sql.DB) error {
txn, err := db.Begin()
txn, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
Expand All @@ -349,7 +355,13 @@ func createVersionTable(ctx context.Context, db *sql.DB) error {

// GetDBVersion is an alias for EnsureDBVersion, but returns -1 in error.
func GetDBVersion(db *sql.DB) (int64, error) {
version, err := EnsureDBVersion(db)
ctx := context.Background()
return GetDBVersionContext(ctx, db)
}

// GetDBVersionContext is an alias for EnsureDBVersion, but returns -1 in error.
func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
version, err := EnsureDBVersionContext(ctx, db)
if err != nil {
return -1, err
}
Expand Down
12 changes: 11 additions & 1 deletion migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ func (m *Migration) String() string {
// Up runs an up migration.
func (m *Migration) Up(db *sql.DB) error {
ctx := context.Background()
return m.UpContext(ctx, db)
}

// UpContext runs an up migration.
func (m *Migration) UpContext(ctx context.Context, db *sql.DB) error {
if err := m.run(ctx, db, true); err != nil {
return err
}
Expand All @@ -49,6 +54,11 @@ func (m *Migration) Up(db *sql.DB) error {
// Down runs a down migration.
func (m *Migration) Down(db *sql.DB) error {
ctx := context.Background()
return m.DownContext(ctx, db)
}

// DownContext runs a down migration.
func (m *Migration) DownContext(ctx context.Context, db *sql.DB) error {
if err := m.run(ctx, db, false); err != nil {
return err
}
Expand Down Expand Up @@ -163,7 +173,7 @@ func runGoMigration(
if fn == nil && !recordVersion {
return nil
}
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion migration_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func runSQLMigration(

verboseInfo("Begin transaction")

tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
Expand Down
13 changes: 10 additions & 3 deletions redo.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
package goose

import (
"context"
"database/sql"
)

// Redo rolls back the most recently applied migration, then runs it again.
func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return RedoContext(ctx, db, dir, opts...)
}

// RedoContext rolls back the most recently applied migration, then runs it again.
func RedoContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand All @@ -23,7 +30,7 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
}
currentVersion = migrations[len(migrations)-1].Version
} else {
if currentVersion, err = GetDBVersion(db); err != nil {
if currentVersion, err = GetDBVersionContext(ctx, db); err != nil {
return err
}
}
Expand All @@ -34,10 +41,10 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
}
current.noVersioning = option.noVersioning

if err := current.Down(db); err != nil {
if err := current.DownContext(ctx, db); err != nil {
return err
}
if err := current.Up(db); err != nil {
if err := current.UpContext(ctx, db); err != nil {
return err
}
return nil
Expand Down
9 changes: 7 additions & 2 deletions reset.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ import (
// Reset rolls back all migrations
func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return ResetContext(ctx, db, dir, opts...)
}

// ResetContext rolls back all migrations
func ResetContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand All @@ -19,7 +24,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
return fmt.Errorf("failed to collect migrations: %w", err)
}
if option.noVersioning {
return DownTo(db, dir, minVersion, opts...)
return DownToContext(ctx, db, dir, minVersion, opts...)
}

statuses, err := dbMigrationsStatus(ctx, db)
Expand All @@ -32,7 +37,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
if !statuses[migration.Version] {
continue
}
if err = migration.Down(db); err != nil {
if err = migration.DownContext(ctx, db); err != nil {
return fmt.Errorf("failed to db-down: %w", err)
}
}
Expand Down
7 changes: 6 additions & 1 deletion status.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ import (
// Status prints the status of all migrations.
func Status(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return StatusContext(ctx, db, dir, opts...)
}

// StatusContext prints the status of all migrations.
func StatusContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand All @@ -30,7 +35,7 @@ func Status(db *sql.DB, dir string, opts ...OptionsFunc) error {
}

// must ensure that the version table exists if we're running on a pristine DB
if _, err := EnsureDBVersion(db); err != nil {
if _, err := EnsureDBVersionContext(ctx, db); err != nil {
return fmt.Errorf("failed to ensure DB version: %w", err)
}

Expand Down
Loading