Skip to content

Commit 1133fc1

Browse files
committed
migrate: re-run migration tasks callbacks on error
This commit modifies the migration framework to re-attempt migration tasks if they error during a migration, on the next run. Previously, if a migration task failed, but their associated SQL migration succeeded, the database version would be set to a dirty state, and require manual intervention in order to reset the SQL migration and re-attempt it + the migration task. The new re-attempt mechanism is achieved by introducing the concept of a "migration task" version. Migration task versions are their corresponding SQL migration version offset by +1000000000. During the execution of a migration task, the migration task version will be persisted as the database's version. That way, if the migration task errors, the version for the database will be the migration task version on the next startup. The migration task will then be re-attempted before proceeding with the next SQL migration.
1 parent 35f3e3e commit 1133fc1

File tree

3 files changed

+455
-46
lines changed

3 files changed

+455
-46
lines changed

migrate.go

Lines changed: 193 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -232,15 +232,11 @@ func (m *Migrate) Migrate(version uint) error {
232232
return err
233233
}
234234

235-
curVersion, dirty, err := m.databaseDrv.Version()
235+
curVersion, err := m.ensureCleanCurrentSQLVersion()
236236
if err != nil {
237237
return m.unlockErr(err)
238238
}
239239

240-
if dirty {
241-
return m.unlockErr(ErrDirty{curVersion})
242-
}
243-
244240
ret := make(chan interface{}, m.PrefetchMigrations)
245241
go m.read(curVersion, int(version), ret)
246242

@@ -258,15 +254,11 @@ func (m *Migrate) Steps(n int) error {
258254
return err
259255
}
260256

261-
curVersion, dirty, err := m.databaseDrv.Version()
257+
curVersion, err := m.ensureCleanCurrentSQLVersion()
262258
if err != nil {
263259
return m.unlockErr(err)
264260
}
265261

266-
if dirty {
267-
return m.unlockErr(ErrDirty{curVersion})
268-
}
269-
270262
ret := make(chan interface{}, m.PrefetchMigrations)
271263

272264
if n > 0 {
@@ -285,15 +277,11 @@ func (m *Migrate) Up() error {
285277
return err
286278
}
287279

288-
curVersion, dirty, err := m.databaseDrv.Version()
280+
curVersion, err := m.ensureCleanCurrentSQLVersion()
289281
if err != nil {
290282
return m.unlockErr(err)
291283
}
292284

293-
if dirty {
294-
return m.unlockErr(ErrDirty{curVersion})
295-
}
296-
297285
ret := make(chan interface{}, m.PrefetchMigrations)
298286

299287
go m.readUp(curVersion, -1, ret)
@@ -307,15 +295,11 @@ func (m *Migrate) Down() error {
307295
return err
308296
}
309297

310-
curVersion, dirty, err := m.databaseDrv.Version()
298+
curVersion, err := m.ensureCleanCurrentSQLVersion()
311299
if err != nil {
312300
return m.unlockErr(err)
313301
}
314302

315-
if dirty {
316-
return m.unlockErr(ErrDirty{curVersion})
317-
}
318-
319303
ret := make(chan interface{}, m.PrefetchMigrations)
320304
go m.readDown(curVersion, -1, ret)
321305
return m.unlockErr(m.runMigrations(ret))
@@ -345,15 +329,11 @@ func (m *Migrate) Run(migration ...*Migration) error {
345329
return err
346330
}
347331

348-
curVersion, dirty, err := m.databaseDrv.Version()
332+
_, err := m.ensureCleanCurrentSQLVersion()
349333
if err != nil {
350334
return m.unlockErr(err)
351335
}
352336

353-
if dirty {
354-
return m.unlockErr(ErrDirty{curVersion})
355-
}
356-
357337
ret := make(chan interface{}, m.PrefetchMigrations)
358338

359339
go func() {
@@ -542,6 +522,54 @@ func (m *Migrate) read(from int, to int, ret chan<- interface{}) {
542522
}
543523
}
544524

525+
// ensureCleanCurrentSQLVersion returns the database's current SQL migration
526+
// version in a clean (non-dirty) state. If the database is dirty, it returns
527+
// ErrDirty.
528+
//
529+
// If the current version when executing this function is a clean migrate task
530+
// version (meaning a migration task previously failed after the SQL migration
531+
// applied), this method re-executes the task for the associated SQL migration
532+
// version. If successful, the function normalizes the recorded version to the
533+
// SQL target version in a clean state so subsequent migrations can proceed.
534+
//
535+
// NOTE: The caller must hold the lock when calling this method.
536+
func (m *Migrate) ensureCleanCurrentSQLVersion() (int, error) {
537+
curVersion, dirty, err := m.databaseDrv.Version()
538+
if err != nil {
539+
return curVersion, err
540+
}
541+
542+
if dirty {
543+
return curVersion, ErrDirty{curVersion}
544+
}
545+
546+
// If the current version is a clean migration task version, then we
547+
// need to rerun the task for the previous version before we can
548+
// continue with any SQL migration(s). We can be certain here that the
549+
// task was attempted to be run before, but errored. This is since
550+
// the migration function only sets the version to a **clean** (i.e. not
551+
// dirty) **task** version if the task errored on the last attempt.
552+
if InTaskVersionRange(curVersion) {
553+
sqlMigVersion := SQLMigrationVersion(curVersion)
554+
555+
err = m.execTaskAtMigVersion(sqlMigVersion)
556+
if err != nil {
557+
return curVersion, err
558+
}
559+
560+
curVersion, dirty, err = m.databaseDrv.Version()
561+
if err != nil {
562+
return curVersion, err
563+
}
564+
565+
if dirty {
566+
return curVersion, ErrDirty{curVersion}
567+
}
568+
}
569+
570+
return curVersion, nil
571+
}
572+
545573
// readUp reads up migrations from `from` limited by `limit`.
546574
// limit can be -1, implying no limit and reading until there are no more migrations.
547575
// Each migration is then written to the ret channel.
@@ -732,6 +760,30 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) {
732760
}
733761
}
734762

763+
// readSingle reads a single migration for the given version, and sends it
764+
// over the passed channel.
765+
func (m *Migrate) readSingle(ver uint, ret chan<- interface{}) {
766+
defer close(ret)
767+
768+
if err := m.versionExists(ver); err != nil {
769+
ret <- err
770+
return
771+
}
772+
773+
migr, err := m.newMigration(ver, int(ver))
774+
if err != nil {
775+
ret <- err
776+
return
777+
}
778+
779+
ret <- migr
780+
go func() {
781+
if err := migr.Buffer(); err != nil {
782+
m.logErr(err)
783+
}
784+
}()
785+
}
786+
735787
// runMigrations reads *Migration and error from a channel. Any other type
736788
// sent on this channel will result in a panic. Each migration is then
737789
// proxied to the database driver and run against the database.
@@ -752,6 +804,12 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
752804
case *Migration:
753805
migr := r
754806

807+
if migr.Version >= TaskVersionOffset {
808+
return fmt.Errorf("migration version %v is "+
809+
"invalid, must be < %v", migr.Version,
810+
TaskVersionOffset)
811+
}
812+
755813
// set version with dirty state
756814
if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil {
757815
return err
@@ -763,23 +821,10 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
763821
return err
764822
}
765823

766-
// If there is a task function for this
767-
// migration, run it now.
768-
cb, ok := m.opts.tasks[migr.Version]
769-
if ok {
770-
m.logVerbosePrintf("Running migration "+
771-
"task for %v\n", migr.LogString())
772-
773-
err := cb(migr, m.databaseDrv)
774-
if err != nil {
775-
return fmt.Errorf("failed to "+
776-
"execute migration "+
777-
"task: %w",
778-
err)
779-
}
780-
781-
m.logVerbosePrintf("Migration task "+
782-
"finished for %v\n", migr.LogString())
824+
err := m.execTask(migr)
825+
if err != nil {
826+
return fmt.Errorf("migration task "+
827+
"error: %w", err)
783828
}
784829
}
785830

@@ -808,6 +853,112 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
808853
return nil
809854
}
810855

856+
// execTask checks if a migration task exists for the passed migration and
857+
// proceeds to execute if one exists.
858+
func (m *Migrate) execTask(migr *Migration) error {
859+
task, ok := m.opts.tasks[migr.Version]
860+
if !ok {
861+
m.logVerbosePrintf("No migration task set for %v\n",
862+
migr.LogString())
863+
864+
return nil
865+
}
866+
867+
m.logVerbosePrintf("Running migration task for %v\n", migr.LogString())
868+
869+
taskVersion := int(migr.Version) + TaskVersionOffset
870+
871+
// Persist that we are in the migration task phase for this version.
872+
if err := m.databaseDrv.SetVersion(taskVersion, true); err != nil {
873+
return err
874+
}
875+
876+
err := task(migr, m.databaseDrv)
877+
if err != nil {
878+
// Mark the database version as the taskVersion but in a clean
879+
// state, to indicate that the migration task errored. We will
880+
// therefore re-run the task on the next migration run.
881+
setErr := m.databaseDrv.SetVersion(taskVersion, false)
882+
if setErr != nil {
883+
// Note that if we error here, the database version will
884+
// remain in a dirty state. As we cannot know if the
885+
// migration task was executed or not in that scenario,
886+
// manual intervention is required.
887+
return fmt.Errorf("WARNING, failed to set migration "+
888+
"version after migration task errored. Manual "+
889+
"intervention needed! Migration task error: "+
890+
"%w, version setting error : %w", err, setErr)
891+
}
892+
893+
return fmt.Errorf("failed to execute migration task: %w", err)
894+
}
895+
896+
m.logVerbosePrintf("Migration task finished for %v\n", migr.LogString())
897+
898+
return nil
899+
}
900+
901+
// execTaskAtMigVersion executes only the migration task for the passed SQL
902+
// migration version.
903+
// The function can be used to re-execute the task for a SQL migration version
904+
// where the SQL migration was successfully applied, but where the task failed.
905+
func (m *Migrate) execTaskAtMigVersion(sqlMigVersion int) error {
906+
var (
907+
r interface{}
908+
migRet = make(chan interface{}, m.PrefetchMigrations)
909+
err error
910+
)
911+
912+
// Fetch the migration for the specified SQL migration version.
913+
go m.readSingle(uint(sqlMigVersion), migRet)
914+
915+
select {
916+
case r = <-migRet:
917+
case <-time.After(DefaultSingleMigReadTimeout):
918+
return fmt.Errorf("timeout waiting for single migration "+
919+
"version %v", sqlMigVersion)
920+
}
921+
922+
if m.stop() {
923+
return nil
924+
}
925+
926+
switch r := r.(type) {
927+
case *Migration:
928+
// If the migration was found, execute the migration task.
929+
migr := r
930+
931+
err = m.execTask(migr)
932+
if err != nil {
933+
return fmt.Errorf("exection of migration task for SQL "+
934+
"migration version %d failed: %w",
935+
sqlMigVersion, err)
936+
}
937+
938+
m.logVerbosePrintf("successfully re-executed migration task "+
939+
"for SQL migration version: %v\n", sqlMigVersion)
940+
941+
// After the migration task has been executed successfully, we
942+
// set the db version to the SQL migration target version with a
943+
// clean state, as we can now proceed with the next migrations,
944+
// if any.
945+
err = m.databaseDrv.SetVersion(migr.TargetVersion, false)
946+
if err != nil {
947+
return err
948+
}
949+
950+
return nil
951+
952+
case error:
953+
return fmt.Errorf("reading SQL migration at version "+
954+
"%v failed: %w", sqlMigVersion, r)
955+
956+
default:
957+
return fmt.Errorf("unknown type: %T when reading "+
958+
"single migration", r)
959+
}
960+
}
961+
811962
// versionExists checks the source if either the up or down migration for
812963
// the specified migration version exists.
813964
func (m *Migrate) versionExists(version uint) (result error) {

0 commit comments

Comments
 (0)