Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 129 additions & 75 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
package migrate

import (
"bytes"
"errors"
"fmt"
"io"
"os"
"regexp"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -55,61 +59,6 @@
return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version)
}

// PostStepCallback is a callback function type that can be used to execute a
// Golang based migration step after a SQL based migration step has been
// executed. The callback function receives the migration and the database
// driver as arguments.
type PostStepCallback func(migr *Migration, driver database.Driver) error

// options is a set of optional options that can be set when a Migrate instance
// is created.
type options struct {
// postStepCallbacks is a map of PostStepCallback functions that can be
// used to execute a Golang based migration step after a SQL based
// migration step has been executed. The key is the migration version
// and the value is the callback function that should be run _after_ the
// step was executed (but within the same database transaction).
postStepCallbacks map[uint]PostStepCallback
}

// defaultOptions returns a new options struct with default values.
func defaultOptions() options {
return options{
postStepCallbacks: make(map[uint]PostStepCallback),
}
}

// Option is a function that can be used to set options on a Migrate instance.
type Option func(*options)

// WithPostStepCallbacks is an option that can be used to set a map of
// PostStepCallback functions that can be used to execute a Golang based
// migration step after a SQL based migration step has been executed. The key is
// the migration version and the value is the callback function that should be
// run _after_ the step was executed (but before the version is marked as
// cleanly executed). An error returned from the callback will cause the
// migration to fail and the step to be marked as dirty.
func WithPostStepCallbacks(
postStepCallbacks map[uint]PostStepCallback) Option {

return func(o *options) {
o.postStepCallbacks = postStepCallbacks
}
}

// WithPostStepCallback is an option that can be used to set a PostStepCallback
// function that can be used to execute a Golang based migration step after the
// SQL based migration step with the given version number has been executed. The
// callback is the function that should be run _after_ the step was executed
// (but before the version is marked as cleanly executed). An error returned
// from the callback will cause the migration to fail and the step to be marked
// as dirty.
func WithPostStepCallback(version uint, callback PostStepCallback) Option {
return func(o *options) {
o.postStepCallbacks[version] = callback
}
}

type Migrate struct {
sourceName string
sourceDrv source.Driver
Expand Down Expand Up @@ -787,6 +736,34 @@
}
}

// hasSQLMigration checks if the passed data contains executable statements,
// meaning that the data doesn't only contain comments/whitespace or semicolons.
func (m *Migrate) hasSQLMigration(data []byte) (bool, error) {

Check failure on line 741 in migrate.go

View workflow job for this annotation

GitHub Actions / lint

(*Migrate).hasSQLMigration - result 1 (error) is always nil (unparam)
s := string(data)

// Remove Byte Order Mark (BOM) if present in the migration file.
s = strings.TrimPrefix(s, "\uFEFF")

// Strip block comments /* ... */ (non-greedy, across lines).
reBlock := regexp.MustCompile(`(?s)/\*.*?\*/`)
s = reBlock.ReplaceAllString(s, "")

// Strip line comments -- ... (to end of line).
reLine := regexp.MustCompile(`(?m)--[^\n\r]*`)
s = reLine.ReplaceAllString(s, "")

// Trim whitespaces.
s = strings.TrimSpace(s)

// Remove any semicolons, newlines, tabs, or spaces from the beginning
// and end of the string.
s = strings.Trim(s, ";\r\n\t ")

// If the string still contains any characters, the data likely
// contains executable statements.
return len(s) > 0, nil
}

// runMigrations reads *Migration and error from a channel. Any other type
// sent on this channel will result in a panic. Each migration is then
// proxied to the database driver and run against the database.
Expand All @@ -807,34 +784,58 @@
case *Migration:
migr := r

// set version with dirty state
if err := m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil {
return err
}

if migr.Body != nil {
m.logVerbosePrintf("Read and execute %v\n", migr.LogString())
if err := m.databaseDrv.Run(migr.BufferedBody); err != nil {
// Read the body so we can inspect and (re)use it.
data, err := io.ReadAll(migr.BufferedBody)
if err != nil {
return fmt.Errorf("read migration body: %w", err)
}

// Reset the reader so the driver can read it
migr.BufferedBody = bytes.NewReader(data)

// Check if the migration contains an SQL
// migration.
hasSqlMig, err := m.hasSQLMigration(data)
if err != nil {
return err
}

// If there is a post execution function for
// this migration, run it now.
cb, ok := m.opts.postStepCallbacks[migr.Version]
if ok {
m.logVerbosePrintf("Running post step "+
"callback for %v\n", migr.LogString())
// Check if the migration contains a migration
// task.
_, hasMigTask := m.opts.tasks[migr.Version]

// Execute the SQL migration or the migration
// task.
switch {
case hasSqlMig && hasMigTask:
return fmt.Errorf("migration has both " +
"a SQL migration and a " +
"migration task set")

case hasSqlMig:
if err = m.databaseDrv.SetVersion(migr.TargetVersion, true); err != nil {
return err
}

err := cb(migr, m.databaseDrv)
m.logVerbosePrintf("Read and execute %v\n", migr.LogString())
if err = m.databaseDrv.Run(migr.BufferedBody); err != nil {
return err
}

case hasMigTask:
err = m.execTask(migr)
if err != nil {
return fmt.Errorf("failed to "+
"execute post "+
"step callback: %w",
err)
return fmt.Errorf("migration "+
"task execution "+
"failed: %w", err)
}

m.logVerbosePrintf("Post step callback "+
"finished for %v\n", migr.LogString())
default:
// When the migration contains no SQL
// migration or migration task, we
// continue and set the version to the
// migr.TargetVersion.
}
}

Expand Down Expand Up @@ -863,6 +864,59 @@
return nil
}

// execTask checks if a migration task exists for the passed migration and
// proceeds to execute if one exists. If the migration task fails, the function
// will reset the database version to the version it was set to before
// attempting to execute the migration task.
func (m *Migrate) execTask(migr *Migration) error {
m.logVerbosePrintf("Running migration task for %v\n", migr.LogString())

task, ok := m.opts.tasks[migr.Version]
if !ok {
return fmt.Errorf("no migration task set for %v",
migr.LogString())
}

// Get the current database version before executing the migration task.
curVersion, dirty, err := m.databaseDrv.Version()
if err != nil {
return fmt.Errorf("unable to get current version: %w", err)
}

if dirty {
return ErrDirty{curVersion}
}

// Persist that we are at the migration version of the migration task.
if err = m.databaseDrv.SetVersion(int(migr.Version), true); err != nil {
return err
}

err = task(migr, m.databaseDrv)
if err != nil {
// Reset the version to the version set before executing the
// migration task. Therefore, the migration task will be
// re-executed on nnext startup until it succeeds.
setErr := m.databaseDrv.SetVersion(curVersion, false)
if setErr != nil {
// Note that if we error here, the database version will
// remain in a dirty state. As we cannot know if the
// migration task was executed or not in that scenario,
// manual intervention is required.
return fmt.Errorf("WARNING, failed to set migration "+
"version after migration task errored. Manual "+
"intervention needed! Migration task error: "+
"%w, version setting error : %w", err, setErr)
}

return fmt.Errorf("failed to execute migration task: %w", err)
}

m.logVerbosePrintf("Migration task finished for %v\n", migr.LogString())

return nil
}

// versionExists checks the source if either the up or down migration for
// the specified migration version exists.
func (m *Migrate) versionExists(version uint) (result error) {
Expand Down
Loading
Loading