Skip to content

Commit

Permalink
expose context for goose sql migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
diogogmt committed Aug 2, 2022
1 parent 0a72970 commit d5cfd14
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 49 deletions.
35 changes: 25 additions & 10 deletions down.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package goose

import (
"context"
"database/sql"
"fmt"
)

// Down rolls back a single migration from the current version.
func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
// DownCtx rolls back a single migration from the current version.
func DownCtx(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand All @@ -21,7 +22,7 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
}
currentVersion := migrations[len(migrations)-1].Version
// Migrate only the latest migration down.
return downToNoVersioning(db, migrations, currentVersion-1)
return downToNoVersioning(ctx, db, migrations, currentVersion-1)
}
currentVersion, err := GetDBVersion(db)
if err != nil {
Expand All @@ -31,11 +32,18 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
if err != nil {
return fmt.Errorf("no migration %v", currentVersion)
}
return current.Down(db)
return current.DownCtx(ctx, db)
}

// DownTo rolls back migrations to a specific version.
func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
// Down rolls back a single migration from the current version.
//
// Down uses context.Background internally; to specify the context, use DownCtx.
func Down(db *sql.DB, dir string, opts ...OptionsFunc) error {
return DownCtx(context.Background(), db, dir, opts...)
}

// DownToCtx rolls back migrations to a specific version.
func DownToCtx(ctx context.Context, db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand All @@ -45,7 +53,7 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
return err
}
if option.noVersioning {
return downToNoVersioning(db, migrations, version)
return downToNoVersioning(ctx, db, migrations, version)
}

for {
Expand All @@ -69,23 +77,30 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
return nil
}

if err = current.Down(db); err != nil {
if err = current.DownCtx(ctx, db); err != nil {
return err
}
}
}

// DownTo rolls back migrations to a specific version.
//
// DownTo uses context.Background internally; to specify the context, use DownToCtx.
func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error {
return DownToCtx(context.Background(), db, dir, version, opts...)
}

// downToNoVersioning applies down migrations down to, but not including, the
// target version.
func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error {
func downToNoVersioning(ctx context.Context, db *sql.DB, migrations Migrations, version int64) error {
var finalVersion int64
for i := len(migrations) - 1; i >= 0; i-- {
if version >= migrations[i].Version {
finalVersion = migrations[i].Version
break
}
migrations[i].noVersioning = true
if err := migrations[i].Down(db); err != nil {
if err := migrations[i].DownCtx(ctx, db); err != nil {
return err
}
}
Expand Down
35 changes: 25 additions & 10 deletions migration.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package goose

import (
"context"
"database/sql"
"errors"
"fmt"
Expand Down Expand Up @@ -33,23 +34,37 @@ func (m *Migration) String() string {
return fmt.Sprintf(m.Source)
}

// UpCtx runs an up migration.
func (m *Migration) UpCtx(ctx context.Context, db *sql.DB) error {
if err := m.run(ctx, db, true); err != nil {
return err
}
return nil
}

// Up runs an up migration.
//
// Up uses context.Background internally; to specify the context, use UpCtx.
func (m *Migration) Up(db *sql.DB) error {
if err := m.run(db, true); err != nil {
return m.UpCtx(context.Background(), db)
}

// DownCtx runs a down migration.
func (m *Migration) DownCtx(ctx context.Context, db *sql.DB) error {
if err := m.run(ctx, db, false); err != nil {
return err
}
return nil
}

// Down runs a down migration.
//
// Down uses context.Background internally; to specify the context, use DownCtx.
func (m *Migration) Down(db *sql.DB) error {
if err := m.run(db, false); err != nil {
return err
}
return nil
return m.DownCtx(context.Background(), db)
}

func (m *Migration) run(db *sql.DB, direction bool) error {
func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
switch filepath.Ext(m.Source) {
case ".sql":
f, err := baseFS.Open(m.Source)
Expand All @@ -63,7 +78,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error {
return fmt.Errorf("ERROR %v: failed to parse SQL migration file: %w", filepath.Base(m.Source), err)
}

if err := runSQLMigration(db, statements, useTx, m.Version, direction, m.noVersioning); err != nil {
if err := runSQLMigration(ctx, db, statements, useTx, m.Version, direction, m.noVersioning); err != nil {
return fmt.Errorf("ERROR %v: failed to run SQL migration: %w", filepath.Base(m.Source), err)
}

Expand All @@ -77,7 +92,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error {
if !m.Registered {
return fmt.Errorf("ERROR %v: failed to run Go migration: Go functions must be registered and built into a custom binary (see https://github.com/pressly/goose/tree/master/examples/go-migrations)", m.Source)
}
tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("ERROR failed to begin transaction: %w", err)
}
Expand All @@ -96,12 +111,12 @@ func (m *Migration) run(db *sql.DB, direction bool) error {
}
if !m.noVersioning {
if direction {
if _, err := tx.Exec(GetDialect().insertVersionSQL(), m.Version, direction); err != nil {
if _, err := tx.ExecContext(ctx, GetDialect().insertVersionSQL(), m.Version, direction); err != nil {
tx.Rollback()
return fmt.Errorf("ERROR failed to execute transaction: %w", err)
}
} else {
if _, err := tx.Exec(GetDialect().deleteVersionSQL(), m.Version); err != nil {
if _, err := tx.ExecContext(ctx, GetDialect().deleteVersionSQL(), m.Version); err != nil {
tx.Rollback()
return fmt.Errorf("ERROR failed to execute transaction: %w", err)
}
Expand Down
23 changes: 12 additions & 11 deletions migration_sql.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package goose

import (
"context"
"database/sql"
"fmt"
"regexp"
Expand All @@ -15,20 +16,20 @@ import (
//
// All statements following an Up or Down directive are grouped together
// until another direction directive is found.
func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direction bool, noVersioning bool) error {
func runSQLMigration(ctx context.Context, db *sql.DB, statements []string, useTx bool, v int64, direction bool, noVersioning bool) error {
if useTx {
// TRANSACTION.

verboseInfo("Begin transaction")

tx, err := db.Begin()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}

for _, query := range statements {
verboseInfo("Executing statement: %s\n", clearStatement(query))
if err = execQuery(tx.Exec, query); err != nil {
if err = execQuery(ctx, tx.ExecContext, query); err != nil {
verboseInfo("Rollback transaction")
tx.Rollback()
return fmt.Errorf("failed to execute SQL query %q: %w", clearStatement(query), err)
Expand All @@ -37,13 +38,13 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc

if !noVersioning {
if direction {
if err := execQuery(tx.Exec, GetDialect().insertVersionSQL(), v, direction); err != nil {
if err := execQuery(ctx, tx.ExecContext, GetDialect().insertVersionSQL(), v, direction); err != nil {
verboseInfo("Rollback transaction")
tx.Rollback()
return fmt.Errorf("failed to insert new goose version: %w", err)
}
} else {
if err := execQuery(tx.Exec, GetDialect().deleteVersionSQL(), v); err != nil {
if err := execQuery(ctx, tx.ExecContext, GetDialect().deleteVersionSQL(), v); err != nil {
verboseInfo("Rollback transaction")
tx.Rollback()
return fmt.Errorf("failed to delete goose version: %w", err)
Expand All @@ -62,17 +63,17 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc
// NO TRANSACTION.
for _, query := range statements {
verboseInfo("Executing statement: %s", clearStatement(query))
if err := execQuery(db.Exec, query); err != nil {
if err := execQuery(ctx, db.ExecContext, query); err != nil {
return fmt.Errorf("failed to execute SQL query %q: %w", clearStatement(query), err)
}
}
if !noVersioning {
if direction {
if err := execQuery(db.Exec, GetDialect().insertVersionSQL(), v, direction); err != nil {
if err := execQuery(ctx, db.ExecContext, GetDialect().insertVersionSQL(), v, direction); err != nil {
return fmt.Errorf("failed to insert new goose version: %w", err)
}
} else {
if err := execQuery(db.Exec, GetDialect().deleteVersionSQL(), v); err != nil {
if err := execQuery(ctx, db.ExecContext, GetDialect().deleteVersionSQL(), v); err != nil {
return fmt.Errorf("failed to delete goose version: %w", err)
}
}
Expand All @@ -81,16 +82,16 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc
return nil
}

func execQuery(fn func(string, ...interface{}) (sql.Result, error), query string, args ...interface{}) error {
func execQuery(ctx context.Context, fn func(context.Context, string, ...interface{}) (sql.Result, error), query string, args ...interface{}) error {
if !verbose {
_, err := fn(query, args...)
_, err := fn(ctx, query, args...)
return err
}

ch := make(chan error)

go func() {
_, err := fn(query, args...)
_, err := fn(ctx, query, args...)
ch <- err
}()

Expand Down
16 changes: 12 additions & 4 deletions redo.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package goose

import (
"context"
"database/sql"
)

// Redo rolls back the most recently applied migration, then runs it again.
func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
// RedoCtx rolls back the most recently applied migration, then runs it again.
func RedoCtx(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand Down Expand Up @@ -34,11 +35,18 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
}
current.noVersioning = option.noVersioning

if err := current.Down(db); err != nil {
if err := current.DownCtx(ctx, db); err != nil {
return err
}
if err := current.Up(db); err != nil {
if err := current.UpCtx(ctx, db); err != nil {
return err
}
return nil
}

// Redo rolls back the most recently applied migration, then runs it again.
//
// Redo uses context.Background internally; to specify the context, use RedoCtx.
func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error {
return RedoCtx(context.Background(), db, dir, opts...)
}
14 changes: 11 additions & 3 deletions reset.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package goose

import (
"context"
"database/sql"
"fmt"
"sort"
)

// Reset rolls back all migrations
func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
// ResetCtx rolls back all migrations
func ResetCtx(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error {
option := &options{}
for _, f := range opts {
f(option)
Expand All @@ -30,14 +31,21 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
if !statuses[migration.Version] {
continue
}
if err = migration.Down(db); err != nil {
if err = migration.DownCtx(ctx, db); err != nil {
return fmt.Errorf("failed to db-down: %w", err)
}
}

return nil
}

// Reset rolls back all migrations
//
// Reset uses context.Background internally; to specify the context, use ResetCtx.
func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error {
return ResetCtx(context.Background(), db, dir, opts...)
}

func dbMigrationsStatus(db *sql.DB) (map[int64]bool, error) {
rows, err := GetDialect().dbVersionQuery(db)
if err != nil {
Expand Down
Loading

0 comments on commit d5cfd14

Please sign in to comment.