Skip to content

Commit

Permalink
add methods for working with context
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-shalom committed May 8, 2023
1 parent 3717a9e commit 273ded4
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 73 deletions.
29 changes: 21 additions & 8 deletions down.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package goose

import (
"context"
"database/sql"
"fmt"
)

// Down rolls back a single migration from the current version.
func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
ctx := context.Background()
return DownContext(ctx, db, dir, opts...)
}

// DownContext rolls back a single migration from the current version.
func DownContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand All @@ -21,21 +28,27 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
}
currentVersion := migrations[len(migrations)-1].Version
// Migrate only the latest migration down.
return downToNoVersioning(db, migrations, currentVersion-1)
return downToNoVersioning(ctx, db, migrations, currentVersion-1)
}
currentVersion, err := GetDBVersion(db)
currentVersion, err := GetDBVersionContext(ctx, db)
if err != nil {
return err
}
current, err := migrations.Current(currentVersion)
if err != nil {
return fmt.Errorf("no migration %v", currentVersion)
}
return current.Down(db)
return current.DownContext(ctx, db)
}

// DownTo rolls back migrations to a specific version.
func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
ctx := context.Background()
return DownToContext(ctx, db, dir, version, opts...)
}

// DownToContext rolls back migrations to a specific version.
func DownToContext(ctx context.Context, db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand All @@ -45,11 +58,11 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
return err
}
if option.noVersioning {
return downToNoVersioning(db, migrations, version)
return downToNoVersioning(ctx, db, migrations, version)
}

for {
currentVersion, err := GetDBVersion(db)
currentVersion, err := GetDBVersionContext(ctx, db)
if err != nil {
return err
}
Expand All @@ -69,23 +82,23 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
return nil
}

if err = current.Down(db); err != nil {
if err = current.DownContext(ctx, db); err != nil {
return err
}
}
}

// downToNoVersioning applies down migrations down to, but not including, the
// target version.
func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error {
func downToNoVersioning(ctx context.Context, db *sql.DB, migrations Migrations, version int64) error {
var finalVersion int64
for i := len(migrations) - 1; i >= 0; i-- {
if version >= migrations[i].Version {
finalVersion = migrations[i].Version
break
}
migrations[i].noVersioning = true
if err := migrations[i].Down(db); err != nil {
if err := migrations[i].DownContext(ctx, db); err != nil {
return err
}
}
Expand Down
39 changes: 26 additions & 13 deletions goose.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package goose

import (
"context"
"database/sql"
"fmt"
"io/fs"
Expand Down Expand Up @@ -39,22 +40,34 @@ func SetBaseFS(fsys fs.FS) {

// Run runs a goose command.
func Run(command string, db *sql.DB, dir string, args ...string) error {
return run(command, db, dir, args)
ctx := context.Background()
return RunContext(ctx, command, db, dir, args...)
}

// Run runs a goose command with options.
// RunContext runs a goose command.
func RunContext(ctx context.Context, command string, db *sql.DB, dir string, args ...string) error {
return run(ctx, command, db, dir, args)
}

// RunWithOptions runs a goose command with options.
func RunWithOptions(command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
return run(command, db, dir, args, options...)
ctx := context.Background()
return RunWithOptionsContext(ctx, command, db, dir, args, options...)
}

// RunWithOptionsContext runs a goose command with options.
func RunWithOptionsContext(ctx context.Context, command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
return run(ctx, command, db, dir, args, options...)
}

func run(command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
func run(ctx context.Context, command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error {
switch command {
case "up":
if err := Up(db, dir, options...); err != nil {
if err := UpContext(ctx, db, dir, options...); err != nil {
return err
}
case "up-by-one":
if err := UpByOne(db, dir, options...); err != nil {
if err := UpByOneContext(ctx, db, dir, options...); err != nil {
return err
}
case "up-to":
Expand All @@ -66,7 +79,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio
if err != nil {
return fmt.Errorf("version must be a number (got '%s')", args[0])
}
if err := UpTo(db, dir, version, options...); err != nil {
if err := UpToContext(ctx, db, dir, version, options...); err != nil {
return err
}
case "create":
Expand All @@ -82,7 +95,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio
return err
}
case "down":
if err := Down(db, dir, options...); err != nil {
if err := DownContext(ctx, db, dir, options...); err != nil {
return err
}
case "down-to":
Expand All @@ -94,27 +107,27 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio
if err != nil {
return fmt.Errorf("version must be a number (got '%s')", args[0])
}
if err := DownTo(db, dir, version, options...); err != nil {
if err := DownToContext(ctx, db, dir, version, options...); err != nil {
return err
}
case "fix":
if err := Fix(dir); err != nil {
return err
}
case "redo":
if err := Redo(db, dir, options...); err != nil {
if err := RedoContext(ctx, db, dir, options...); err != nil {
return err
}
case "reset":
if err := Reset(db, dir, options...); err != nil {
if err := ResetContext(ctx, db, dir, options...); err != nil {
return err
}
case "status":
if err := Status(db, dir, options...); err != nil {
if err := StatusContext(ctx, db, dir, options...); err != nil {
return err
}
case "version":
if err := Version(db, dir, options...); err != nil {
if err := VersionContext(ctx, db, dir, options...); err != nil {
return err
}
default:
Expand Down
109 changes: 92 additions & 17 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,30 +128,86 @@ func (ms Migrations) String() string {
// GoMigration is a Go migration func that is run within a transaction.
type GoMigration func(tx *sql.Tx) error

// GoMigrationContext is a Go migration func that is run within a transaction and receives a context.
type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error

// GoMigrationNoTx is a Go migration func that is run outside a transaction.
type GoMigrationNoTx func(db *sql.DB) error

// GoMigrationNoTxContext is a Go migration func that is run outside a transaction and receives a context.
type GoMigrationNoTxContext func(ctx context.Context, db *sql.DB) error

// withoutContext wraps a GoMigration to make it a GoMigrationContext.
func (gm GoMigrationContext) withoutContext() GoMigration {
return func(tx *sql.Tx) error {
return gm(context.Background(), tx)
}
}

// withContext wraps a GoMigration to make it a GoMigrationContext.
func (gm GoMigrationNoTxContext) withoutContext() GoMigrationNoTx {
return func(db *sql.DB) error {
return gm(context.Background(), db)
}
}

// withContext wraps a GoMigration to make it a GoMigrationContext.
func (gm GoMigration) withContext() GoMigrationContext {
return func(_ context.Context, tx *sql.Tx) error {
return gm(tx)
}
}

// withContext wraps a GoMigrationNoTx to make it a GoMigrationNoTxContext.
func (gm GoMigrationNoTx) withContext() GoMigrationNoTxContext {
return func(_ context.Context, db *sql.DB) error {
return gm(db)
}
}

// AddMigration adds Go migrations.
func AddMigration(up, down GoMigration) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigration(filename, up, down)
// intentionally don't call to AddMigrationContext so each of these functions can calculate the filename correctly
AddNamedMigrationContext(filename, up.withContext(), down.withContext())
}

// AddMigration adds Go migrations.
func AddMigrationContext(up, down GoMigrationContext) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigrationContext(filename, up, down)
}

// AddNamedMigration adds named Go migrations.
func AddNamedMigration(filename string, up, down GoMigration) {
AddNamedMigrationContext(filename, up.withContext(), down.withContext())
}

// AddNamedMigrationContext adds named Go migrations.
func AddNamedMigrationContext(filename string, up, down GoMigrationContext) {
if err := register(filename, true, up, down, nil, nil); err != nil {
panic(err)
}
}

// AddMigrationNoTx adds Go migrations that will be run outside transaction.
func AddMigrationNoTx(up, down GoMigrationNoTx) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigrationNoTx(filename, up, down)
AddMigrationNoTxContext(up.withContext(), down.withContext())
}

// AddMigrationNoTxContext adds Go migrations that will be run outside transaction.
func AddMigrationNoTxContext(up, down GoMigrationNoTxContext) {
_, filename, _, _ := runtime.Caller(2)
AddNamedMigrationNoTxContext(filename, up, down)
}

// AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction.
func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) {
AddNamedMigrationNoTxContext(filename, up.withContext(), down.withContext())
}

// AddNamedMigrationNoTxContext adds named Go migrations that will be run outside transaction.
func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxContext) {
if err := register(filename, false, nil, nil, up, down); err != nil {
panic(err)
}
Expand All @@ -160,8 +216,8 @@ func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) {
func register(
filename string,
useTx bool,
up, down GoMigration,
upNoTx, downNoTx GoMigrationNoTx,
up, down GoMigrationContext,
upNoTx, downNoTx GoMigrationNoTxContext,
) error {
// Sanity check caller did not mix tx and non-tx based functions.
if (up != nil || down != nil) && (upNoTx != nil || downNoTx != nil) {
Expand All @@ -177,16 +233,23 @@ func register(
}
// Add to global as a registered migration.
registeredGoMigrations[v] = &Migration{
Version: v,
Next: -1,
Previous: -1,
Registered: true,
Source: filename,
UseTx: useTx,
UpFn: up,
DownFn: down,
UpFnNoTx: upNoTx,
DownFnNoTx: downNoTx,
Version: v,
Next: -1,
Previous: -1,
Registered: true,
Source: filename,
UseTx: useTx,
UpFnContext: up,
DownFnContext: down,
UpFnNoTxContext: upNoTx,
DownFnNoTxContext: downNoTx,
// These are deprecated and will be removed in the future.
// For backwards compatibility we still save the non-context versions in the struct in case someone is using them.
// Goose does not use these internally anymore and instead uses the context versions.
UpFn: up.withoutContext(),
DownFn: down.withoutContext(),
UpFnNoTx: upNoTx.withoutContext(),
DownFnNoTx: downNoTx.withoutContext(),
}
return nil
}
Expand Down Expand Up @@ -296,6 +359,12 @@ func versionFilter(v, current, target int64) bool {
// Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersion(db *sql.DB) (int64, error) {
ctx := context.Background()
return EnsureDBVersionContext(ctx, db)
}

// EnsureDBVersionContext retrieves the current version for this DB.
// Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
dbMigrations, err := store.ListMigrations(ctx, db, TableName())
if err != nil {
return 0, createVersionTable(ctx, db)
Expand Down Expand Up @@ -332,7 +401,7 @@ func EnsureDBVersion(db *sql.DB) (int64, error) {
// createVersionTable creates the db version table and inserts the
// initial 0 value into it.
func createVersionTable(ctx context.Context, db *sql.DB) error {
txn, err := db.Begin()
txn, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
Expand All @@ -349,7 +418,13 @@ func createVersionTable(ctx context.Context, db *sql.DB) error {

// GetDBVersion is an alias for EnsureDBVersion, but returns -1 in error.
func GetDBVersion(db *sql.DB) (int64, error) {
version, err := EnsureDBVersion(db)
ctx := context.Background()
return GetDBVersionContext(ctx, db)
}

// GetDBVersionContext is an alias for EnsureDBVersion, but returns -1 in error.
func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {
version, err := EnsureDBVersionContext(ctx, db)
if err != nil {
return -1, err
}
Expand Down
Loading

0 comments on commit 273ded4

Please sign in to comment.