From fbe30ff6a41f09b523e32d0a952d2e084275a74d Mon Sep 17 00:00:00 2001 From: Ori Shalom Date: Mon, 8 May 2023 04:06:26 +0300 Subject: [PATCH 1/3] add methods for working with context --- down.go | 29 +++++++++---- goose.go | 39 +++++++++++------ migrate.go | 109 +++++++++++++++++++++++++++++++++++++++-------- migration.go | 50 +++++++++++++++------- migration_sql.go | 2 +- redo.go | 13 ++++-- reset.go | 9 +++- status.go | 7 ++- up.go | 32 ++++++++++---- version.go | 9 +++- 10 files changed, 229 insertions(+), 70 deletions(-) diff --git a/down.go b/down.go index c58c2144c..bd63804a7 100644 --- a/down.go +++ b/down.go @@ -1,12 +1,19 @@ package goose import ( + "context" "database/sql" "fmt" ) // Down rolls back a single migration from the current version. func Down(db *sql.DB, dir string, opts ...OptionsFunc) error { + ctx := context.Background() + return DownContext(ctx, db, dir, opts...) +} + +// DownContext rolls back a single migration from the current version. +func DownContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -21,9 +28,9 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error { } currentVersion := migrations[len(migrations)-1].Version // Migrate only the latest migration down. - return downToNoVersioning(db, migrations, currentVersion-1) + return downToNoVersioning(ctx, db, migrations, currentVersion-1) } - currentVersion, err := GetDBVersion(db) + currentVersion, err := GetDBVersionContext(ctx, db) if err != nil { return err } @@ -31,11 +38,17 @@ func Down(db *sql.DB, dir string, opts ...OptionsFunc) error { if err != nil { return fmt.Errorf("no migration %v", currentVersion) } - return current.Down(db) + return current.DownContext(ctx, db) } // DownTo rolls back migrations to a specific version. func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { + ctx := context.Background() + return DownToContext(ctx, db, dir, version, opts...) +} + +// DownToContext rolls back migrations to a specific version. +func DownToContext(ctx context.Context, db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -45,11 +58,11 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { return err } if option.noVersioning { - return downToNoVersioning(db, migrations, version) + return downToNoVersioning(ctx, db, migrations, version) } for { - currentVersion, err := GetDBVersion(db) + currentVersion, err := GetDBVersionContext(ctx, db) if err != nil { return err } @@ -69,7 +82,7 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { return nil } - if err = current.Down(db); err != nil { + if err = current.DownContext(ctx, db); err != nil { return err } } @@ -77,7 +90,7 @@ func DownTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { // downToNoVersioning applies down migrations down to, but not including, the // target version. -func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error { +func downToNoVersioning(ctx context.Context, db *sql.DB, migrations Migrations, version int64) error { var finalVersion int64 for i := len(migrations) - 1; i >= 0; i-- { if version >= migrations[i].Version { @@ -85,7 +98,7 @@ func downToNoVersioning(db *sql.DB, migrations Migrations, version int64) error break } migrations[i].noVersioning = true - if err := migrations[i].Down(db); err != nil { + if err := migrations[i].DownContext(ctx, db); err != nil { return err } } diff --git a/goose.go b/goose.go index 0dbfd67da..e952041b0 100644 --- a/goose.go +++ b/goose.go @@ -1,6 +1,7 @@ package goose import ( + "context" "database/sql" "fmt" "io/fs" @@ -39,22 +40,34 @@ func SetBaseFS(fsys fs.FS) { // Run runs a goose command. func Run(command string, db *sql.DB, dir string, args ...string) error { - return run(command, db, dir, args) + ctx := context.Background() + return RunContext(ctx, command, db, dir, args...) } -// Run runs a goose command with options. +// RunContext runs a goose command. +func RunContext(ctx context.Context, command string, db *sql.DB, dir string, args ...string) error { + return run(ctx, command, db, dir, args) +} + +// RunWithOptions runs a goose command with options. func RunWithOptions(command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error { - return run(command, db, dir, args, options...) + ctx := context.Background() + return RunWithOptionsContext(ctx, command, db, dir, args, options...) +} + +// RunWithOptionsContext runs a goose command with options. +func RunWithOptionsContext(ctx context.Context, command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error { + return run(ctx, command, db, dir, args, options...) } -func run(command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error { +func run(ctx context.Context, command string, db *sql.DB, dir string, args []string, options ...OptionsFunc) error { switch command { case "up": - if err := Up(db, dir, options...); err != nil { + if err := UpContext(ctx, db, dir, options...); err != nil { return err } case "up-by-one": - if err := UpByOne(db, dir, options...); err != nil { + if err := UpByOneContext(ctx, db, dir, options...); err != nil { return err } case "up-to": @@ -66,7 +79,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio if err != nil { return fmt.Errorf("version must be a number (got '%s')", args[0]) } - if err := UpTo(db, dir, version, options...); err != nil { + if err := UpToContext(ctx, db, dir, version, options...); err != nil { return err } case "create": @@ -82,7 +95,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio return err } case "down": - if err := Down(db, dir, options...); err != nil { + if err := DownContext(ctx, db, dir, options...); err != nil { return err } case "down-to": @@ -94,7 +107,7 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio if err != nil { return fmt.Errorf("version must be a number (got '%s')", args[0]) } - if err := DownTo(db, dir, version, options...); err != nil { + if err := DownToContext(ctx, db, dir, version, options...); err != nil { return err } case "fix": @@ -102,19 +115,19 @@ func run(command string, db *sql.DB, dir string, args []string, options ...Optio return err } case "redo": - if err := Redo(db, dir, options...); err != nil { + if err := RedoContext(ctx, db, dir, options...); err != nil { return err } case "reset": - if err := Reset(db, dir, options...); err != nil { + if err := ResetContext(ctx, db, dir, options...); err != nil { return err } case "status": - if err := Status(db, dir, options...); err != nil { + if err := StatusContext(ctx, db, dir, options...); err != nil { return err } case "version": - if err := Version(db, dir, options...); err != nil { + if err := VersionContext(ctx, db, dir, options...); err != nil { return err } default: diff --git a/migrate.go b/migrate.go index 0cf96a860..eae14775b 100644 --- a/migrate.go +++ b/migrate.go @@ -128,17 +128,63 @@ func (ms Migrations) String() string { // GoMigration is a Go migration func that is run within a transaction. type GoMigration func(tx *sql.Tx) error +// GoMigrationContext is a Go migration func that is run within a transaction and receives a context. +type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error + // GoMigrationNoTx is a Go migration func that is run outside a transaction. type GoMigrationNoTx func(db *sql.DB) error +// GoMigrationNoTxContext is a Go migration func that is run outside a transaction and receives a context. +type GoMigrationNoTxContext func(ctx context.Context, db *sql.DB) error + +// withoutContext wraps a GoMigration to make it a GoMigrationContext. +func (gm GoMigrationContext) withoutContext() GoMigration { + return func(tx *sql.Tx) error { + return gm(context.Background(), tx) + } +} + +// withContext wraps a GoMigration to make it a GoMigrationContext. +func (gm GoMigrationNoTxContext) withoutContext() GoMigrationNoTx { + return func(db *sql.DB) error { + return gm(context.Background(), db) + } +} + +// withContext wraps a GoMigration to make it a GoMigrationContext. +func (gm GoMigration) withContext() GoMigrationContext { + return func(_ context.Context, tx *sql.Tx) error { + return gm(tx) + } +} + +// withContext wraps a GoMigrationNoTx to make it a GoMigrationNoTxContext. +func (gm GoMigrationNoTx) withContext() GoMigrationNoTxContext { + return func(_ context.Context, db *sql.DB) error { + return gm(db) + } +} + // AddMigration adds Go migrations. func AddMigration(up, down GoMigration) { _, filename, _, _ := runtime.Caller(1) - AddNamedMigration(filename, up, down) + // intentionally don't call to AddMigrationContext so each of these functions can calculate the filename correctly + AddNamedMigrationContext(filename, up.withContext(), down.withContext()) +} + +// AddMigration adds Go migrations. +func AddMigrationContext(up, down GoMigrationContext) { + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationContext(filename, up, down) } // AddNamedMigration adds named Go migrations. func AddNamedMigration(filename string, up, down GoMigration) { + AddNamedMigrationContext(filename, up.withContext(), down.withContext()) +} + +// AddNamedMigrationContext adds named Go migrations. +func AddNamedMigrationContext(filename string, up, down GoMigrationContext) { if err := register(filename, true, up, down, nil, nil); err != nil { panic(err) } @@ -146,12 +192,22 @@ func AddNamedMigration(filename string, up, down GoMigration) { // AddMigrationNoTx adds Go migrations that will be run outside transaction. func AddMigrationNoTx(up, down GoMigrationNoTx) { - _, filename, _, _ := runtime.Caller(1) - AddNamedMigrationNoTx(filename, up, down) + AddMigrationNoTxContext(up.withContext(), down.withContext()) +} + +// AddMigrationNoTxContext adds Go migrations that will be run outside transaction. +func AddMigrationNoTxContext(up, down GoMigrationNoTxContext) { + _, filename, _, _ := runtime.Caller(2) + AddNamedMigrationNoTxContext(filename, up, down) } // AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction. func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) { + AddNamedMigrationNoTxContext(filename, up.withContext(), down.withContext()) +} + +// AddNamedMigrationNoTxContext adds named Go migrations that will be run outside transaction. +func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxContext) { if err := register(filename, false, nil, nil, up, down); err != nil { panic(err) } @@ -160,8 +216,8 @@ func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) { func register( filename string, useTx bool, - up, down GoMigration, - upNoTx, downNoTx GoMigrationNoTx, + up, down GoMigrationContext, + upNoTx, downNoTx GoMigrationNoTxContext, ) error { // Sanity check caller did not mix tx and non-tx based functions. if (up != nil || down != nil) && (upNoTx != nil || downNoTx != nil) { @@ -177,16 +233,23 @@ func register( } // Add to global as a registered migration. registeredGoMigrations[v] = &Migration{ - Version: v, - Next: -1, - Previous: -1, - Registered: true, - Source: filename, - UseTx: useTx, - UpFn: up, - DownFn: down, - UpFnNoTx: upNoTx, - DownFnNoTx: downNoTx, + Version: v, + Next: -1, + Previous: -1, + Registered: true, + Source: filename, + UseTx: useTx, + UpFnContext: up, + DownFnContext: down, + UpFnNoTxContext: upNoTx, + DownFnNoTxContext: downNoTx, + // These are deprecated and will be removed in the future. + // For backwards compatibility we still save the non-context versions in the struct in case someone is using them. + // Goose does not use these internally anymore and instead uses the context versions. + UpFn: up.withoutContext(), + DownFn: down.withoutContext(), + UpFnNoTx: upNoTx.withoutContext(), + DownFnNoTx: downNoTx.withoutContext(), } return nil } @@ -296,6 +359,12 @@ func versionFilter(v, current, target int64) bool { // Create and initialize the DB version table if it doesn't exist. func EnsureDBVersion(db *sql.DB) (int64, error) { ctx := context.Background() + return EnsureDBVersionContext(ctx, db) +} + +// EnsureDBVersionContext retrieves the current version for this DB. +// Create and initialize the DB version table if it doesn't exist. +func EnsureDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) { dbMigrations, err := store.ListMigrations(ctx, db, TableName()) if err != nil { return 0, createVersionTable(ctx, db) @@ -332,7 +401,7 @@ func EnsureDBVersion(db *sql.DB) (int64, error) { // createVersionTable creates the db version table and inserts the // initial 0 value into it. func createVersionTable(ctx context.Context, db *sql.DB) error { - txn, err := db.Begin() + txn, err := db.BeginTx(ctx, nil) if err != nil { return err } @@ -349,7 +418,13 @@ func createVersionTable(ctx context.Context, db *sql.DB) error { // GetDBVersion is an alias for EnsureDBVersion, but returns -1 in error. func GetDBVersion(db *sql.DB) (int64, error) { - version, err := EnsureDBVersion(db) + ctx := context.Background() + return GetDBVersionContext(ctx, db) +} + +// GetDBVersionContext is an alias for EnsureDBVersion, but returns -1 in error. +func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) { + version, err := EnsureDBVersionContext(ctx, db) if err != nil { return -1, err } diff --git a/migration.go b/migration.go index 727ecc3d9..dcf0c6118 100644 --- a/migration.go +++ b/migration.go @@ -22,15 +22,23 @@ type MigrationRecord struct { // Migration struct. type Migration struct { - Version int64 - Next int64 // next version, or -1 if none - Previous int64 // previous version, -1 if none - Source string // path to .sql script or go file - Registered bool - UseTx bool + Version int64 + Next int64 // next version, or -1 if none + Previous int64 // previous version, -1 if none + Source string // path to .sql script or go file + Registered bool + UseTx bool + + // These are deprecated and will be removed in the future. + // For backwards compatibility we still save the non-context versions in the struct in case someone is using them. + // Goose does not use these internally anymore and instead uses the context versions. UpFn, DownFn GoMigration UpFnNoTx, DownFnNoTx GoMigrationNoTx - noVersioning bool + + // New functions with context + UpFnContext, DownFnContext GoMigrationContext + UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext + noVersioning bool } func (m *Migration) String() string { @@ -40,6 +48,11 @@ func (m *Migration) String() string { // Up runs an up migration. func (m *Migration) Up(db *sql.DB) error { ctx := context.Background() + return m.UpContext(ctx, db) +} + +// UpContext runs an up migration. +func (m *Migration) UpContext(ctx context.Context, db *sql.DB) error { if err := m.run(ctx, db, true); err != nil { return err } @@ -49,6 +62,11 @@ func (m *Migration) Up(db *sql.DB) error { // Down runs a down migration. func (m *Migration) Down(db *sql.DB) error { ctx := context.Background() + return m.DownContext(ctx, db) +} + +// DownContext runs a down migration. +func (m *Migration) DownContext(ctx context.Context, db *sql.DB) error { if err := m.run(ctx, db, false); err != nil { return err } @@ -89,9 +107,9 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error { var empty bool if m.UseTx { // Run go-based migration inside a tx. - fn := m.DownFn + fn := m.DownFnContext if direction { - fn = m.UpFn + fn = m.UpFnContext } empty = (fn == nil) if err := runGoMigration( @@ -106,9 +124,9 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error { } } else { // Run go-based migration outside a tx. - fn := m.DownFnNoTx + fn := m.DownFnNoTxContext if direction { - fn = m.UpFnNoTx + fn = m.UpFnNoTxContext } empty = (fn == nil) if err := runGoMigrationNoTx( @@ -135,14 +153,14 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error { func runGoMigrationNoTx( ctx context.Context, db *sql.DB, - fn GoMigrationNoTx, + fn GoMigrationNoTxContext, version int64, direction bool, recordVersion bool, ) error { if fn != nil { // Run go migration function. - if err := fn(db); err != nil { + if err := fn(ctx, db); err != nil { return fmt.Errorf("failed to run go migration: %w", err) } } @@ -155,7 +173,7 @@ func runGoMigrationNoTx( func runGoMigration( ctx context.Context, db *sql.DB, - fn GoMigration, + fn GoMigrationContext, version int64, direction bool, recordVersion bool, @@ -163,13 +181,13 @@ func runGoMigration( if fn == nil && !recordVersion { return nil } - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } if fn != nil { // Run go migration function. - if err := fn(tx); err != nil { + if err := fn(ctx, tx); err != nil { _ = tx.Rollback() return fmt.Errorf("failed to run go migration: %w", err) } diff --git a/migration_sql.go b/migration_sql.go index f74b70d75..1c6d4d0bf 100644 --- a/migration_sql.go +++ b/migration_sql.go @@ -29,7 +29,7 @@ func runSQLMigration( verboseInfo("Begin transaction") - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } diff --git a/redo.go b/redo.go index c485f9f67..ed3ff677f 100644 --- a/redo.go +++ b/redo.go @@ -1,11 +1,18 @@ package goose import ( + "context" "database/sql" ) // Redo rolls back the most recently applied migration, then runs it again. func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error { + ctx := context.Background() + return RedoContext(ctx, db, dir, opts...) +} + +// RedoContext rolls back the most recently applied migration, then runs it again. +func RedoContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -23,7 +30,7 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error { } currentVersion = migrations[len(migrations)-1].Version } else { - if currentVersion, err = GetDBVersion(db); err != nil { + if currentVersion, err = GetDBVersionContext(ctx, db); err != nil { return err } } @@ -34,10 +41,10 @@ func Redo(db *sql.DB, dir string, opts ...OptionsFunc) error { } current.noVersioning = option.noVersioning - if err := current.Down(db); err != nil { + if err := current.DownContext(ctx, db); err != nil { return err } - if err := current.Up(db); err != nil { + if err := current.UpContext(ctx, db); err != nil { return err } return nil diff --git a/reset.go b/reset.go index e14d36d22..274c539d4 100644 --- a/reset.go +++ b/reset.go @@ -10,6 +10,11 @@ import ( // Reset rolls back all migrations func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { ctx := context.Background() + return ResetContext(ctx, db, dir, opts...) +} + +// ResetContext rolls back all migrations +func ResetContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -19,7 +24,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { return fmt.Errorf("failed to collect migrations: %w", err) } if option.noVersioning { - return DownTo(db, dir, minVersion, opts...) + return DownToContext(ctx, db, dir, minVersion, opts...) } statuses, err := dbMigrationsStatus(ctx, db) @@ -32,7 +37,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { if !statuses[migration.Version] { continue } - if err = migration.Down(db); err != nil { + if err = migration.DownContext(ctx, db); err != nil { return fmt.Errorf("failed to db-down: %w", err) } } diff --git a/status.go b/status.go index dd1f16c4c..4f0943a3d 100644 --- a/status.go +++ b/status.go @@ -12,6 +12,11 @@ import ( // Status prints the status of all migrations. func Status(db *sql.DB, dir string, opts ...OptionsFunc) error { ctx := context.Background() + return StatusContext(ctx, db, dir, opts...) +} + +// StatusContext prints the status of all migrations. +func StatusContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -30,7 +35,7 @@ func Status(db *sql.DB, dir string, opts ...OptionsFunc) error { } // must ensure that the version table exists if we're running on a pristine DB - if _, err := EnsureDBVersion(db); err != nil { + if _, err := EnsureDBVersionContext(ctx, db); err != nil { return fmt.Errorf("failed to ensure DB version: %w", err) } diff --git a/up.go b/up.go index 3dc011d03..64b88dd4c 100644 --- a/up.go +++ b/up.go @@ -35,6 +35,10 @@ func withApplyUpByOne() OptionsFunc { // UpTo migrates up to a specific version. func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { ctx := context.Background() + return UpToContext(ctx, db, dir, version, opts...) +} + +func UpToContext(ctx context.Context, db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -53,10 +57,10 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { // migration over and over. version = foundMigrations[0].Version } - return upToNoVersioning(db, foundMigrations, version) + return upToNoVersioning(ctx, db, foundMigrations, version) } - if _, err := EnsureDBVersion(db); err != nil { + if _, err := EnsureDBVersionContext(ctx, db); err != nil { return err } dbMigrations, err := listAllDBVersions(ctx, db) @@ -103,7 +107,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { var current int64 for _, m := range migrationsToApply { - if err := m.Up(db); err != nil { + if err := m.UpContext(ctx, db); err != nil { return err } if option.applyUpByOne { @@ -112,7 +116,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { current = m.Version } if len(migrationsToApply) == 0 { - current, err = GetDBVersion(db) + current, err = GetDBVersionContext(ctx, db) if err != nil { return err } @@ -130,14 +134,14 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { // upToNoVersioning applies up migrations up to, and including, the // target version. -func upToNoVersioning(db *sql.DB, migrations Migrations, version int64) error { +func upToNoVersioning(ctx context.Context, db *sql.DB, migrations Migrations, version int64) error { var finalVersion int64 for _, current := range migrations { if current.Version > version { break } current.noVersioning = true - if err := current.Up(db); err != nil { + if err := current.UpContext(ctx, db); err != nil { return err } finalVersion = current.Version @@ -148,13 +152,25 @@ func upToNoVersioning(db *sql.DB, migrations Migrations, version int64) error { // Up applies all available migrations. func Up(db *sql.DB, dir string, opts ...OptionsFunc) error { - return UpTo(db, dir, maxVersion, opts...) + ctx := context.Background() + return UpContext(ctx, db, dir, opts...) +} + +// UpContext applies all available migrations. +func UpContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { + return UpToContext(ctx, db, dir, maxVersion, opts...) } // UpByOne migrates up by a single version. func UpByOne(db *sql.DB, dir string, opts ...OptionsFunc) error { + ctx := context.Background() + return UpByOneContext(ctx, db, dir, opts...) +} + +// UpByOneContext migrates up by a single version. +func UpByOneContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { opts = append(opts, withApplyUpByOne()) - return UpTo(db, dir, maxVersion, opts...) + return UpToContext(ctx, db, dir, maxVersion, opts...) } // listAllDBVersions returns a list of all migrations, ordered ascending. diff --git a/version.go b/version.go index 47765f728..89d0dccd1 100644 --- a/version.go +++ b/version.go @@ -1,12 +1,19 @@ package goose import ( + "context" "database/sql" "fmt" ) // Version prints the current version of the database. func Version(db *sql.DB, dir string, opts ...OptionsFunc) error { + ctx := context.Background() + return VersionContext(ctx, db, dir, opts...) +} + +// VersionContext prints the current version of the database. +func VersionContext(ctx context.Context, db *sql.DB, dir string, opts ...OptionsFunc) error { option := &options{} for _, f := range opts { f(option) @@ -24,7 +31,7 @@ func Version(db *sql.DB, dir string, opts ...OptionsFunc) error { return nil } - current, err := GetDBVersion(db) + current, err := GetDBVersionContext(ctx, db) if err != nil { return err } From a2c6375db9036bdb8cecfa161a33a397e45ad0c0 Mon Sep 17 00:00:00 2001 From: Ori Shalom Date: Mon, 8 May 2023 11:40:35 +0300 Subject: [PATCH 2/3] modify template for go migration --- create.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/create.go b/create.go index 1a8bb74c0..eeca796fb 100644 --- a/create.go +++ b/create.go @@ -99,20 +99,21 @@ SELECT 'down SQL query'; var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`package migrations import ( + "context" "database/sql" "github.com/pressly/goose/v3" ) func init() { - goose.AddMigration(up{{.CamelName}}, down{{.CamelName}}) + goose.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) } -func up{{.CamelName}}(tx *sql.Tx) error { +func up{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { // This code is executed when the migration is applied. return nil } -func down{{.CamelName}}(tx *sql.Tx) error { +func down{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { // This code is executed when the migration is rolled back. return nil } From ff8f84911f0fd6fe754ef11be5610d087fdcc136 Mon Sep 17 00:00:00 2001 From: Ori Shalom Date: Sat, 10 Jun 2023 23:28:31 +0300 Subject: [PATCH 3/3] revert go migration changes --- create.go | 7 ++-- migrate.go | 93 +++++++++------------------------------------------- migration.go | 38 +++++++++------------ 3 files changed, 33 insertions(+), 105 deletions(-) diff --git a/create.go b/create.go index eeca796fb..1a8bb74c0 100644 --- a/create.go +++ b/create.go @@ -99,21 +99,20 @@ SELECT 'down SQL query'; var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`package migrations import ( - "context" "database/sql" "github.com/pressly/goose/v3" ) func init() { - goose.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}}) + goose.AddMigration(up{{.CamelName}}, down{{.CamelName}}) } -func up{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { +func up{{.CamelName}}(tx *sql.Tx) error { // This code is executed when the migration is applied. return nil } -func down{{.CamelName}}(ctx context.Context, tx *sql.Tx) error { +func down{{.CamelName}}(tx *sql.Tx) error { // This code is executed when the migration is rolled back. return nil } diff --git a/migrate.go b/migrate.go index eae14775b..fb1500af1 100644 --- a/migrate.go +++ b/migrate.go @@ -128,63 +128,17 @@ func (ms Migrations) String() string { // GoMigration is a Go migration func that is run within a transaction. type GoMigration func(tx *sql.Tx) error -// GoMigrationContext is a Go migration func that is run within a transaction and receives a context. -type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error - // GoMigrationNoTx is a Go migration func that is run outside a transaction. type GoMigrationNoTx func(db *sql.DB) error -// GoMigrationNoTxContext is a Go migration func that is run outside a transaction and receives a context. -type GoMigrationNoTxContext func(ctx context.Context, db *sql.DB) error - -// withoutContext wraps a GoMigration to make it a GoMigrationContext. -func (gm GoMigrationContext) withoutContext() GoMigration { - return func(tx *sql.Tx) error { - return gm(context.Background(), tx) - } -} - -// withContext wraps a GoMigration to make it a GoMigrationContext. -func (gm GoMigrationNoTxContext) withoutContext() GoMigrationNoTx { - return func(db *sql.DB) error { - return gm(context.Background(), db) - } -} - -// withContext wraps a GoMigration to make it a GoMigrationContext. -func (gm GoMigration) withContext() GoMigrationContext { - return func(_ context.Context, tx *sql.Tx) error { - return gm(tx) - } -} - -// withContext wraps a GoMigrationNoTx to make it a GoMigrationNoTxContext. -func (gm GoMigrationNoTx) withContext() GoMigrationNoTxContext { - return func(_ context.Context, db *sql.DB) error { - return gm(db) - } -} - // AddMigration adds Go migrations. func AddMigration(up, down GoMigration) { _, filename, _, _ := runtime.Caller(1) - // intentionally don't call to AddMigrationContext so each of these functions can calculate the filename correctly - AddNamedMigrationContext(filename, up.withContext(), down.withContext()) -} - -// AddMigration adds Go migrations. -func AddMigrationContext(up, down GoMigrationContext) { - _, filename, _, _ := runtime.Caller(1) - AddNamedMigrationContext(filename, up, down) + AddNamedMigration(filename, up, down) } // AddNamedMigration adds named Go migrations. func AddNamedMigration(filename string, up, down GoMigration) { - AddNamedMigrationContext(filename, up.withContext(), down.withContext()) -} - -// AddNamedMigrationContext adds named Go migrations. -func AddNamedMigrationContext(filename string, up, down GoMigrationContext) { if err := register(filename, true, up, down, nil, nil); err != nil { panic(err) } @@ -192,22 +146,12 @@ func AddNamedMigrationContext(filename string, up, down GoMigrationContext) { // AddMigrationNoTx adds Go migrations that will be run outside transaction. func AddMigrationNoTx(up, down GoMigrationNoTx) { - AddMigrationNoTxContext(up.withContext(), down.withContext()) -} - -// AddMigrationNoTxContext adds Go migrations that will be run outside transaction. -func AddMigrationNoTxContext(up, down GoMigrationNoTxContext) { - _, filename, _, _ := runtime.Caller(2) - AddNamedMigrationNoTxContext(filename, up, down) + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationNoTx(filename, up, down) } // AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction. func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) { - AddNamedMigrationNoTxContext(filename, up.withContext(), down.withContext()) -} - -// AddNamedMigrationNoTxContext adds named Go migrations that will be run outside transaction. -func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxContext) { if err := register(filename, false, nil, nil, up, down); err != nil { panic(err) } @@ -216,8 +160,8 @@ func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxConte func register( filename string, useTx bool, - up, down GoMigrationContext, - upNoTx, downNoTx GoMigrationNoTxContext, + up, down GoMigration, + upNoTx, downNoTx GoMigrationNoTx, ) error { // Sanity check caller did not mix tx and non-tx based functions. if (up != nil || down != nil) && (upNoTx != nil || downNoTx != nil) { @@ -233,23 +177,16 @@ func register( } // Add to global as a registered migration. registeredGoMigrations[v] = &Migration{ - Version: v, - Next: -1, - Previous: -1, - Registered: true, - Source: filename, - UseTx: useTx, - UpFnContext: up, - DownFnContext: down, - UpFnNoTxContext: upNoTx, - DownFnNoTxContext: downNoTx, - // These are deprecated and will be removed in the future. - // For backwards compatibility we still save the non-context versions in the struct in case someone is using them. - // Goose does not use these internally anymore and instead uses the context versions. - UpFn: up.withoutContext(), - DownFn: down.withoutContext(), - UpFnNoTx: upNoTx.withoutContext(), - DownFnNoTx: downNoTx.withoutContext(), + Version: v, + Next: -1, + Previous: -1, + Registered: true, + Source: filename, + UseTx: useTx, + UpFn: up, + DownFn: down, + UpFnNoTx: upNoTx, + DownFnNoTx: downNoTx, } return nil } diff --git a/migration.go b/migration.go index dcf0c6118..8d4362c01 100644 --- a/migration.go +++ b/migration.go @@ -22,23 +22,15 @@ type MigrationRecord struct { // Migration struct. type Migration struct { - Version int64 - Next int64 // next version, or -1 if none - Previous int64 // previous version, -1 if none - Source string // path to .sql script or go file - Registered bool - UseTx bool - - // These are deprecated and will be removed in the future. - // For backwards compatibility we still save the non-context versions in the struct in case someone is using them. - // Goose does not use these internally anymore and instead uses the context versions. + Version int64 + Next int64 // next version, or -1 if none + Previous int64 // previous version, -1 if none + Source string // path to .sql script or go file + Registered bool + UseTx bool UpFn, DownFn GoMigration UpFnNoTx, DownFnNoTx GoMigrationNoTx - - // New functions with context - UpFnContext, DownFnContext GoMigrationContext - UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext - noVersioning bool + noVersioning bool } func (m *Migration) String() string { @@ -107,9 +99,9 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error { var empty bool if m.UseTx { // Run go-based migration inside a tx. - fn := m.DownFnContext + fn := m.DownFn if direction { - fn = m.UpFnContext + fn = m.UpFn } empty = (fn == nil) if err := runGoMigration( @@ -124,9 +116,9 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error { } } else { // Run go-based migration outside a tx. - fn := m.DownFnNoTxContext + fn := m.DownFnNoTx if direction { - fn = m.UpFnNoTxContext + fn = m.UpFnNoTx } empty = (fn == nil) if err := runGoMigrationNoTx( @@ -153,14 +145,14 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error { func runGoMigrationNoTx( ctx context.Context, db *sql.DB, - fn GoMigrationNoTxContext, + fn GoMigrationNoTx, version int64, direction bool, recordVersion bool, ) error { if fn != nil { // Run go migration function. - if err := fn(ctx, db); err != nil { + if err := fn(db); err != nil { return fmt.Errorf("failed to run go migration: %w", err) } } @@ -173,7 +165,7 @@ func runGoMigrationNoTx( func runGoMigration( ctx context.Context, db *sql.DB, - fn GoMigrationContext, + fn GoMigration, version int64, direction bool, recordVersion bool, @@ -187,7 +179,7 @@ func runGoMigration( } if fn != nil { // Run go migration function. - if err := fn(ctx, tx); err != nil { + if err := fn(tx); err != nil { _ = tx.Rollback() return fmt.Errorf("failed to run go migration: %w", err) }