@@ -232,15 +232,11 @@ func (m *Migrate) Migrate(version uint) error {
232232 return err
233233 }
234234
235- curVersion , dirty , err := m .databaseDrv . Version ()
235+ curVersion , err := m .ensureCleanCurrentSQLVersion ()
236236 if err != nil {
237237 return m .unlockErr (err )
238238 }
239239
240- if dirty {
241- return m .unlockErr (ErrDirty {curVersion })
242- }
243-
244240 ret := make (chan interface {}, m .PrefetchMigrations )
245241 go m .read (curVersion , int (version ), ret )
246242
@@ -258,15 +254,11 @@ func (m *Migrate) Steps(n int) error {
258254 return err
259255 }
260256
261- curVersion , dirty , err := m .databaseDrv . Version ()
257+ curVersion , err := m .ensureCleanCurrentSQLVersion ()
262258 if err != nil {
263259 return m .unlockErr (err )
264260 }
265261
266- if dirty {
267- return m .unlockErr (ErrDirty {curVersion })
268- }
269-
270262 ret := make (chan interface {}, m .PrefetchMigrations )
271263
272264 if n > 0 {
@@ -285,15 +277,11 @@ func (m *Migrate) Up() error {
285277 return err
286278 }
287279
288- curVersion , dirty , err := m .databaseDrv . Version ()
280+ curVersion , err := m .ensureCleanCurrentSQLVersion ()
289281 if err != nil {
290282 return m .unlockErr (err )
291283 }
292284
293- if dirty {
294- return m .unlockErr (ErrDirty {curVersion })
295- }
296-
297285 ret := make (chan interface {}, m .PrefetchMigrations )
298286
299287 go m .readUp (curVersion , - 1 , ret )
@@ -307,15 +295,11 @@ func (m *Migrate) Down() error {
307295 return err
308296 }
309297
310- curVersion , dirty , err := m .databaseDrv . Version ()
298+ curVersion , err := m .ensureCleanCurrentSQLVersion ()
311299 if err != nil {
312300 return m .unlockErr (err )
313301 }
314302
315- if dirty {
316- return m .unlockErr (ErrDirty {curVersion })
317- }
318-
319303 ret := make (chan interface {}, m .PrefetchMigrations )
320304 go m .readDown (curVersion , - 1 , ret )
321305 return m .unlockErr (m .runMigrations (ret ))
@@ -345,15 +329,11 @@ func (m *Migrate) Run(migration ...*Migration) error {
345329 return err
346330 }
347331
348- curVersion , dirty , err := m .databaseDrv . Version ()
332+ _ , err := m .ensureCleanCurrentSQLVersion ()
349333 if err != nil {
350334 return m .unlockErr (err )
351335 }
352336
353- if dirty {
354- return m .unlockErr (ErrDirty {curVersion })
355- }
356-
357337 ret := make (chan interface {}, m .PrefetchMigrations )
358338
359339 go func () {
@@ -542,6 +522,54 @@ func (m *Migrate) read(from int, to int, ret chan<- interface{}) {
542522 }
543523}
544524
525+ // ensureCleanCurrentSQLVersion returns the database's current SQL migration
526+ // version in a clean (non-dirty) state. If the database is dirty, it returns
527+ // ErrDirty.
528+ //
529+ // If the current version when executing this function is a clean migrate task
530+ // version (meaning a migration task previously failed after the SQL migration
531+ // applied), this method re-executes the task for the associated SQL migration
532+ // version. If successful, the function normalizes the recorded version to the
533+ // SQL target version in a clean state so subsequent migrations can proceed.
534+ //
535+ // NOTE: The caller must hold the lock when calling this method.
536+ func (m * Migrate ) ensureCleanCurrentSQLVersion () (int , error ) {
537+ curVersion , dirty , err := m .databaseDrv .Version ()
538+ if err != nil {
539+ return curVersion , err
540+ }
541+
542+ if dirty {
543+ return curVersion , ErrDirty {curVersion }
544+ }
545+
546+ // If the current version is a clean migration task version, then we
547+ // need to rerun the task for the previous version before we can
548+ // continue with any SQL migration(s). We can be certain here that the
549+ // task was attempted to be run before, but errored. This is since
550+ // the migration function only sets the version to a **clean** (i.e. not
551+ // dirty) **task** version if the task errored on the last attempt.
552+ if InTaskVersionRange (curVersion ) {
553+ sqlMigVersion := SQLMigrationVersion (curVersion )
554+
555+ err = m .execTaskAtMigVersion (sqlMigVersion )
556+ if err != nil {
557+ return curVersion , err
558+ }
559+
560+ curVersion , dirty , err = m .databaseDrv .Version ()
561+ if err != nil {
562+ return curVersion , err
563+ }
564+
565+ if dirty {
566+ return curVersion , ErrDirty {curVersion }
567+ }
568+ }
569+
570+ return curVersion , nil
571+ }
572+
545573// readUp reads up migrations from `from` limited by `limit`.
546574// limit can be -1, implying no limit and reading until there are no more migrations.
547575// Each migration is then written to the ret channel.
@@ -732,6 +760,30 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) {
732760 }
733761}
734762
763+ // readSingle reads a single migration for the given version, and sends it
764+ // over the passed channel.
765+ func (m * Migrate ) readSingle (ver uint , ret chan <- interface {}) {
766+ defer close (ret )
767+
768+ if err := m .versionExists (ver ); err != nil {
769+ ret <- err
770+ return
771+ }
772+
773+ migr , err := m .newMigration (ver , int (ver ))
774+ if err != nil {
775+ ret <- err
776+ return
777+ }
778+
779+ ret <- migr
780+ go func () {
781+ if err := migr .Buffer (); err != nil {
782+ m .logErr (err )
783+ }
784+ }()
785+ }
786+
735787// runMigrations reads *Migration and error from a channel. Any other type
736788// sent on this channel will result in a panic. Each migration is then
737789// proxied to the database driver and run against the database.
@@ -752,6 +804,12 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
752804 case * Migration :
753805 migr := r
754806
807+ if migr .Version >= TaskVersionOffset {
808+ return fmt .Errorf ("migration version %v is " +
809+ "invalid, must be < %v" , migr .Version ,
810+ TaskVersionOffset )
811+ }
812+
755813 // set version with dirty state
756814 if err := m .databaseDrv .SetVersion (migr .TargetVersion , true ); err != nil {
757815 return err
@@ -763,23 +821,10 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
763821 return err
764822 }
765823
766- // If there is a task function for this
767- // migration, run it now.
768- cb , ok := m .opts .tasks [migr .Version ]
769- if ok {
770- m .logVerbosePrintf ("Running migration " +
771- "task for %v\n " , migr .LogString ())
772-
773- err := cb (migr , m .databaseDrv )
774- if err != nil {
775- return fmt .Errorf ("failed to " +
776- "execute migration " +
777- "task: %w" ,
778- err )
779- }
780-
781- m .logVerbosePrintf ("Migration task " +
782- "finished for %v\n " , migr .LogString ())
824+ err := m .execTask (migr )
825+ if err != nil {
826+ return fmt .Errorf ("migration task " +
827+ "error: %w" , err )
783828 }
784829 }
785830
@@ -808,6 +853,112 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
808853 return nil
809854}
810855
856+ // execTask checks if a migration task exists for the passed migration and
857+ // proceeds to execute if one exists.
858+ func (m * Migrate ) execTask (migr * Migration ) error {
859+ task , ok := m .opts .tasks [migr .Version ]
860+ if ! ok {
861+ m .logVerbosePrintf ("No migration task set for %v\n " ,
862+ migr .LogString ())
863+
864+ return nil
865+ }
866+
867+ m .logVerbosePrintf ("Running migration task for %v\n " , migr .LogString ())
868+
869+ taskVersion := int (migr .Version ) + TaskVersionOffset
870+
871+ // Persist that we are in the migration task phase for this version.
872+ if err := m .databaseDrv .SetVersion (taskVersion , true ); err != nil {
873+ return err
874+ }
875+
876+ err := task (migr , m .databaseDrv )
877+ if err != nil {
878+ // Mark the database version as the taskVersion but in a clean
879+ // state, to indicate that the migration task errored. We will
880+ // therefore re-run the task on the next migration run.
881+ setErr := m .databaseDrv .SetVersion (taskVersion , false )
882+ if setErr != nil {
883+ // Note that if we error here, the database version will
884+ // remain in a dirty state. As we cannot know if the
885+ // migration task was executed or not in that scenario,
886+ // manual intervention is required.
887+ return fmt .Errorf ("WARNING, failed to set migration " +
888+ "version after migration task errored. Manual " +
889+ "intervention needed! Migration task error: " +
890+ "%w, version setting error : %w" , err , setErr )
891+ }
892+
893+ return fmt .Errorf ("failed to execute migration task: %w" , err )
894+ }
895+
896+ m .logVerbosePrintf ("Migration task finished for %v\n " , migr .LogString ())
897+
898+ return nil
899+ }
900+
901+ // execTaskAtMigVersion executes only the migration task for the passed SQL
902+ // migration version.
903+ // The function can be used to re-execute the task for a SQL migration version
904+ // where the SQL migration was successfully applied, but where the task failed.
905+ func (m * Migrate ) execTaskAtMigVersion (sqlMigVersion int ) error {
906+ var (
907+ r interface {}
908+ migRet = make (chan interface {}, m .PrefetchMigrations )
909+ err error
910+ )
911+
912+ // Fetch the migration for the specified SQL migration version.
913+ go m .readSingle (uint (sqlMigVersion ), migRet )
914+
915+ select {
916+ case r = <- migRet :
917+ case <- time .After (DefaultSingleMigReadTimeout ):
918+ return fmt .Errorf ("timeout waiting for single migration " +
919+ "version %v" , sqlMigVersion )
920+ }
921+
922+ if m .stop () {
923+ return nil
924+ }
925+
926+ switch r := r .(type ) {
927+ case * Migration :
928+ // If the migration was found, execute the migration task.
929+ migr := r
930+
931+ err = m .execTask (migr )
932+ if err != nil {
933+ return fmt .Errorf ("exection of migration task for SQL " +
934+ "migration version %d failed: %w" ,
935+ sqlMigVersion , err )
936+ }
937+
938+ m .logVerbosePrintf ("successfully re-executed migration task " +
939+ "for SQL migration version: %v\n " , sqlMigVersion )
940+
941+ // After the migration task has been executed successfully, we
942+ // set the db version to the SQL migration target version with a
943+ // clean state, as we can now proceed with the next migrations,
944+ // if any.
945+ err = m .databaseDrv .SetVersion (migr .TargetVersion , false )
946+ if err != nil {
947+ return err
948+ }
949+
950+ return nil
951+
952+ case error :
953+ return fmt .Errorf ("reading SQL migration at version " +
954+ "%v failed: %w" , sqlMigVersion , r )
955+
956+ default :
957+ return fmt .Errorf ("unknown type: %T when reading " +
958+ "single migration" , r )
959+ }
960+ }
961+
811962// versionExists checks the source if either the up or down migration for
812963// the specified migration version exists.
813964func (m * Migrate ) versionExists (version uint ) (result error ) {
0 commit comments