Skip to content

Commit

Permalink
[REF] treat repeatable migrations as versioned migrations with name a…
Browse files Browse the repository at this point in the history
…s version
  • Loading branch information
JohnBra committed Jan 24, 2025
1 parent 34169f2 commit 322044b
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 85 deletions.
5 changes: 0 additions & 5 deletions internal/migration/apply/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@ func MigrateAndSeed(ctx context.Context, version string, conn *pgx.Conn, fsys af
if err != nil {
return err
}
repeatableMigrations, err := list.LoadRepeatableMigrations(fsys)
if err != nil {
return err
}
migrations = append(migrations, repeatableMigrations...)
if err := migration.ApplyMigrations(ctx, migrations, conn, afero.NewIOFS(fsys)); err != nil {
return err
}
Expand Down
41 changes: 36 additions & 5 deletions internal/migration/list/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"math"
"strconv"
"strings"

"github.com/charmbracelet/glamour"
"github.com/go-errors/errors"
Expand Down Expand Up @@ -68,6 +69,40 @@ func makeTable(remoteMigrations, localMigrations []string) string {
j++
}
}

for i, j := 0, 0; i < len(remoteMigrations) || j < len(localMigrations); {
if i < len(remoteMigrations) && !strings.HasPrefix(remoteMigrations[i], "r_") {
i++
continue
}

if j < len(localMigrations) && !strings.HasPrefix(localMigrations[j], "r_") {
j++
continue
}

// Append repeatable migrations to table
if i >= len(remoteMigrations) {
table += fmt.Sprintf("|`%s`|` `|` `|\n", localMigrations[j])
j++
} else if j >= len(localMigrations) {
table += fmt.Sprintf("|` `|`%s`|` `|\n", remoteMigrations[i])
i++
} else {
if localMigrations[j] < remoteMigrations[i] {
table += fmt.Sprintf("|`%s`|` `|` `|\n", localMigrations[j])
j++
} else if remoteMigrations[i] < localMigrations[j] {
table += fmt.Sprintf("|` `|`%s`|` `|\n", remoteMigrations[i])
i++
} else {
table += fmt.Sprintf("|`%s`|`%s`|` `|\n", localMigrations[j], remoteMigrations[i])
i++
j++
}
}
}

return table
}

Expand Down Expand Up @@ -99,11 +134,7 @@ func LoadLocalVersions(fsys afero.Fs) ([]string, error) {

func LoadPartialMigrations(version string, fsys afero.Fs) ([]string, error) {
filter := func(v string) bool {
return version == "" || v <= version
return version == "" || strings.HasPrefix(version, "r_") || v <= version
}
return migration.ListLocalMigrations(utils.MigrationsDir, afero.NewIOFS(fsys), filter)
}

func LoadRepeatableMigrations(fsys afero.Fs) ([]string, error) {
return migration.ListRepeatableMigrations(utils.MigrationsDir, afero.NewIOFS(fsys))
}
8 changes: 0 additions & 8 deletions internal/migration/up/up.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,6 @@ func GetPendingMigrations(ctx context.Context, includeAll bool, conn *pgx.Conn,
}
utils.CmdSuggestion = suggestIgnoreFlag(diff)
}
if err != nil {
return diff, err
}
repeatableMigrations, err := migration.ListRepeatableMigrations(utils.MigrationsDir, afero.NewIOFS(fsys))
if err != nil {
return nil, err
}
diff = append(diff, repeatableMigrations...)
return diff, err
}

Expand Down
6 changes: 5 additions & 1 deletion pkg/migration/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type MigrationFile struct {
Statements []string
}

var migrateFilePattern = regexp.MustCompile(`^([0-9]+)_(.*)\.sql$`)
var migrateFilePattern = regexp.MustCompile(`^([0-9]+|r)_(.*)\.sql$`)

func NewMigrationFromFile(path string, fsys fs.FS) (*MigrationFile, error) {
lines, err := parseFile(path, fsys)
Expand All @@ -38,6 +38,10 @@ func NewMigrationFromFile(path string, fsys fs.FS) (*MigrationFile, error) {
file.Version = matches[1]
file.Name = matches[2]
}
// Repeatable migration version => r_name
if file.Version == "r" {
file.Version += "_" + file.Name
}
return &file, nil
}

Expand Down
31 changes: 6 additions & 25 deletions pkg/migration/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"path/filepath"
"regexp"
"strconv"
"strings"

"github.com/go-errors/errors"
"github.com/jackc/pgconn"
Expand Down Expand Up @@ -46,43 +45,25 @@ func ListLocalMigrations(migrationsDir string, fsys fs.FS, filter ...func(string
fmt.Fprintf(os.Stderr, "Skipping migration %s... (replace \"init\" with a different file name to apply this migration)\n", filename)
continue
}
if strings.HasPrefix(filename, "r_") {
// silently skip repeatable migrations
continue
}
matches := migrateFilePattern.FindStringSubmatch(filename)
if len(matches) == 0 {
fmt.Fprintf(os.Stderr, "Skipping migration %s... (file name must match pattern \"<timestamp>_name.sql\")\n", filename)
fmt.Fprintf(os.Stderr, "Skipping migration %s... (file name must match pattern \"<timestamp>_name.sql\" or \"r_name.sql\")\n", filename)
continue
}
path := filepath.Join(migrationsDir, filename)
for _, keep := range filter {
if version := matches[1]; keep(version) {
version := matches[1]
if version == "r" && len(matches) > 2 {
version += "_" + matches[2]
}
if keep(version) {
clean = append(clean, path)
}
}
}
return clean, nil
}

func ListRepeatableMigrations(migrationsDir string, fsys fs.FS) ([]string, error) {
localMigrations, err := fs.ReadDir(fsys, migrationsDir)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, errors.Errorf("failed to read directory: %w", err)
}
var repeatable []string

for _, migration := range localMigrations {
filename := migration.Name()
if strings.HasPrefix(filename, "r_") && strings.HasSuffix(filename, ".sql") {
path := filepath.Join(migrationsDir, filename)
repeatable = append(repeatable, path)
}
}

return repeatable, nil
}

var initSchemaPattern = regexp.MustCompile(`([0-9]{14})_init\.sql`)

func shouldSkip(name string) bool {
Expand Down
42 changes: 1 addition & 41 deletions pkg/migration/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func TestLocalMigrations(t *testing.T) {
fsys := fs.MapFS{
"20211208000000_init.sql": &fs.MapFile{},
"20211208000001_invalid.ts": &fs.MapFile{},
"r_invalid.ts": &fs.MapFile{},
}
// Run test
versions, err := ListLocalMigrations(".", fsys)
Expand All @@ -90,44 +91,3 @@ func TestLocalMigrations(t *testing.T) {
assert.ErrorContains(t, err, "failed to read directory:")
})
}

func TestRepeatableMigrations(t *testing.T) {
t.Run("loads repeatable migrations", func(t *testing.T) {
// Setup in-memory fs
files := []string{
"r_test_view.sql",
"r_test_function.sql",
}
fsys := fs.MapFS{}
for _, name := range files {
fsys[name] = &fs.MapFile{}
}
// Run test
versions, err := ListRepeatableMigrations(".", fsys)
// Check error
assert.NoError(t, err)
assert.ElementsMatch(t, files, versions)
})

t.Run("ignores files without 'r_' prefix", func(t *testing.T) {
// Setup in-memory fs
fsys := fs.MapFS{
"20211208000000_init.sql": &fs.MapFile{},
"r_invalid.ts": &fs.MapFile{},
}
// Run test
versions, err := ListRepeatableMigrations(".", fsys)
// Check error
assert.NoError(t, err)
assert.Empty(t, versions)
})

t.Run("throws error on open failure", func(t *testing.T) {
// Setup in-memory fs
fsys := fs.MapFS{"migrations": &fs.MapFile{}}
// Run test
_, err := ListRepeatableMigrations("migrations", fsys)
// Check error
assert.ErrorContains(t, err, "failed to read directory:")
})
}

0 comments on commit 322044b

Please sign in to comment.