Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setting context on migration transactions #195

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions down.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package goose

import (
"context"
"database/sql"
"fmt"
)

// Down rolls back a single migration from the current version.
func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
// DownCtx rolls back a single migration from the current version.
func DownCtx(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand All @@ -21,7 +22,7 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
}
currentVersion := migrations[len(migrations)-1].Version
// Migrate only the latest migration down.
return downToNoVersioning(db, migrations, currentVersion-1)
return downToNoVersioning(ctx, db, migrations, currentVersion-1)
}
currentVersion, err := GetDBVersion(db)
if err != nil {
Expand All @@ -31,11 +32,18 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
if err != nil {
return fmt.Errorf("no migration %v", currentVersion)
}
return current.Down(db)
return current.DownCtx(ctx, db)
}

// DownTo rolls back migrations to a specific version.
func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
// Down rolls back a single migration from the current version.
//
// Down uses context.Background internally; to specify the context, use DownCtx.
func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
return DownCtx(context.Background(), db, dir, opts...)
}

// DownToCtx rolls back migrations to a specific version.
func DownToCtx(ctx context.Context, db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand All @@ -45,7 +53,7 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
return err
}
if option.noVersioning {
return downToNoVersioning(db, migrations, version)
return downToNoVersioning(ctx, db, migrations, version)
}

for {
Expand All @@ -69,23 +77,30 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
return nil
}

if err = current.Down(db); err != nil {
if err = current.DownCtx(ctx, db); err != nil {
return err
}
}
}

// DownTo rolls back migrations to a specific version.
//
// DownTo uses context.Background internally; to specify the context, use DownToCtx.
func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
return DownToCtx(context.Background(), db, dir, version, opts...)
}

// downToNoVersioning applies down migrations down to, but not including, the
// target version.
func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error {
func downToNoVersioning(ctx context.Context, db *sql.DB, migrations Migrations, version int64) error {
var finalVersion int64
for i := len(migrations) - 1; i >= 0; i-- {
if version >= migrations[i].Version {
finalVersion = migrations[i].Version
break
}
migrations[i].noVersioning = true
if err := migrations[i].Down(db); err != nil {
if err := migrations[i].DownCtx(ctx, db); err != nil {
return err
}
}
Expand Down
35 changes: 25 additions & 10 deletions migration.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package goose

import (
"context"
"database/sql"
"errors"
"fmt"
Expand Down Expand Up @@ -33,23 +34,37 @@ func (m *Migration) String() string {
return fmt.Sprintf(m.Source)
}

// UpCtx runs an up migration.
func (m *Migration) UpCtx(ctx context.Context, db *sql.DB) error {
if err := m.run(ctx, db, true); err != nil {
return err
}
return nil
}

// Up runs an up migration.
//
// Up uses context.Background internally; to specify the context, use UpCtx.
func (m *Migration) Up(db *sql.DB) error {
if err := m.run(db, true); err != nil {
return m.UpCtx(context.Background(), db)
}

// DownCtx runs a down migration.
func (m *Migration) DownCtx(ctx context.Context, db *sql.DB) error {
if err := m.run(ctx, db, false); err != nil {
return err
}
return nil
}

// Down runs a down migration.
//
// Down uses context.Background internally; to specify the context, use DownCtx.
func (m *Migration) Down(db *sql.DB) error {
if err := m.run(db, false); err != nil {
return err
}
return nil
return m.DownCtx(context.Background(), db)
}

func (m *Migration) run(db *sql.DB, direction bool) error {
func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
switch filepath.Ext(m.Source) {
case ".sql":
f, err := baseFS.Open(m.Source)
Expand All @@ -63,7 +78,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error {
return fmt.Errorf("ERROR %v: failed to parse SQL migration file: %w", filepath.Base(m.Source), err)
}

if err := runSQLMigration(db, statements, useTx, m.Version, direction, m.noVersioning); err != nil {
if err := runSQLMigration(ctx, db, statements, useTx, m.Version, direction, m.noVersioning); err != nil {
return fmt.Errorf("ERROR %v: failed to run SQL migration: %w", filepath.Base(m.Source), err)
}

Expand All @@ -77,7 +92,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error {
if !m.Registered {
return fmt.Errorf("ERROR %v: failed to run Go migration: Go functions must be registered and built into a custom binary (see https://github.com/pressly/goose/tree/master/examples/go-migrations)", m.Source)
}
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("ERROR failed to begin transaction: %w", err)
}
Expand All @@ -96,12 +111,12 @@ func (m *Migration) run(db *sql.DB, direction bool) error {
}
if !m.noVersioning {
if direction {
if _, err := tx.Exec(GetDialect().insertVersionSQL(), m.Version, direction); err != nil {
if _, err := tx.ExecContext(ctx, GetDialect().insertVersionSQL(), m.Version, direction); err != nil {
tx.Rollback()
return fmt.Errorf("ERROR failed to execute transaction: %w", err)
}
} else {
if _, err := tx.Exec(GetDialect().deleteVersionSQL(), m.Version); err != nil {
if _, err := tx.ExecContext(ctx, GetDialect().deleteVersionSQL(), m.Version); err != nil {
tx.Rollback()
return fmt.Errorf("ERROR failed to execute transaction: %w", err)
}
Expand Down
23 changes: 12 additions & 11 deletions migration_sql.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package goose

import (
"context"
"database/sql"
"fmt"
"regexp"
Expand All @@ -15,20 +16,20 @@ import (
//
// All statements following an Up or Down directive are grouped together
// until another direction directive is found.
func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direction bool, noVersioning bool) error {
func runSQLMigration(ctx context.Context, db *sql.DB, statements []string, useTx bool, v int64, direction bool, noVersioning bool) error {
if useTx {
// TRANSACTION.

verboseInfo("Begin transaction")

tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}

for _, query := range statements {
verboseInfo("Executing statement: %s\n", clearStatement(query))
if err = execQuery(tx.Exec, query); err != nil {
if err = execQuery(ctx, tx.ExecContext, query); err != nil {
verboseInfo("Rollback transaction")
tx.Rollback()
return fmt.Errorf("failed to execute SQL query %q: %w", clearStatement(query), err)
Expand All @@ -37,13 +38,13 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc

if !noVersioning {
if direction {
if err := execQuery(tx.Exec, GetDialect().insertVersionSQL(), v, direction); err != nil {
if err := execQuery(ctx, tx.ExecContext, GetDialect().insertVersionSQL(), v, direction); err != nil {
verboseInfo("Rollback transaction")
tx.Rollback()
return fmt.Errorf("failed to insert new goose version: %w", err)
}
} else {
if err := execQuery(tx.Exec, GetDialect().deleteVersionSQL(), v); err != nil {
if err := execQuery(ctx, tx.ExecContext, GetDialect().deleteVersionSQL(), v); err != nil {
verboseInfo("Rollback transaction")
tx.Rollback()
return fmt.Errorf("failed to delete goose version: %w", err)
Expand All @@ -62,17 +63,17 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc
// NO TRANSACTION.
for _, query := range statements {
verboseInfo("Executing statement: %s", clearStatement(query))
if err := execQuery(db.Exec, query); err != nil {
if err := execQuery(ctx, db.ExecContext, query); err != nil {
return fmt.Errorf("failed to execute SQL query %q: %w", clearStatement(query), err)
}
}
if !noVersioning {
if direction {
if err := execQuery(db.Exec, GetDialect().insertVersionSQL(), v, direction); err != nil {
if err := execQuery(ctx, db.ExecContext, GetDialect().insertVersionSQL(), v, direction); err != nil {
return fmt.Errorf("failed to insert new goose version: %w", err)
}
} else {
if err := execQuery(db.Exec, GetDialect().deleteVersionSQL(), v); err != nil {
if err := execQuery(ctx, db.ExecContext, GetDialect().deleteVersionSQL(), v); err != nil {
return fmt.Errorf("failed to delete goose version: %w", err)
}
}
Expand All @@ -81,16 +82,16 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc
return nil
}

func execQuery(fn func(string, ...interface{}) (sql.Result, error), query string, args ...interface{}) error {
func execQuery(ctx context.Context, fn func(context.Context, string, ...interface{}) (sql.Result, error), query string, args ...interface{}) error {
if !verbose {
_, err := fn(query, args...)
_, err := fn(ctx, query, args...)
return err
}

ch := make(chan error)

go func() {
_, err := fn(query, args...)
_, err := fn(ctx, query, args...)
ch <- err
}()

Expand Down
16 changes: 12 additions & 4 deletions redo.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package goose

import (
"context"
"database/sql"
)

// Redo rolls back the most recently applied migration, then runs it again.
func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
// RedoCtx rolls back the most recently applied migration, then runs it again.
func RedoCtx(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand Down Expand Up @@ -34,11 +35,18 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
}
current.noVersioning = option.noVersioning

if err := current.Down(db); err != nil {
if err := current.DownCtx(ctx, db); err != nil {
return err
}
if err := current.Up(db); err != nil {
if err := current.UpCtx(ctx, db); err != nil {
return err
}
return nil
}

// Redo rolls back the most recently applied migration, then runs it again.
//
// Redo uses context.Background internally; to specify the context, use RedoCtx.
func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
return RedoCtx(context.Background(), db, dir, opts...)
}
14 changes: 11 additions & 3 deletions reset.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package goose

import (
"context"
"database/sql"
"fmt"
"sort"
)

// Reset rolls back all migrations
func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
// ResetCtx rolls back all migrations
func ResetCtx(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand All @@ -30,14 +31,21 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
if !statuses[migration.Version] {
continue
}
if err = migration.Down(db); err != nil {
if err = migration.DownCtx(ctx, db); err != nil {
return fmt.Errorf("failed to db-down: %w", err)
}
}

return nil
}

// Reset rolls back all migrations
//
// Reset uses context.Background internally; to specify the context, use ResetCtx.
func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
return ResetCtx(context.Background(), db, dir, opts...)
}

func dbMigrationsStatus(db *sql.DB) (map[int64]bool, error) {
rows, err := GetDialect().dbVersionQuery(db)
if err != nil {
Expand Down
Loading