diff --git a/provider.go b/provider.go index 24a9eb5a7..64dd6acf0 100644 --- a/provider.go +++ b/provider.go @@ -107,7 +107,7 @@ func newProvider( // feat(mf): we could add a flag to parse SQL migrations eagerly. This would allow us to return // an error if there are any SQL parsing errors. This adds a bit overhead to startup though, so // we should make it optional. - filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions) + filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions, cfg.recursive) if err != nil { return nil, err } diff --git a/provider_collect.go b/provider_collect.go index 6e230926c..86f3cd338 100644 --- a/provider_collect.go +++ b/provider_collect.go @@ -15,6 +15,54 @@ type fileSources struct { goSources []Source } +func checkFile(fullpath string, strict bool, excludePaths map[string]bool, excludeVersions map[int64]bool, versionToBaseLookup map[int64]string) (Source, bool, error) { + base := filepath.Base(fullpath) + if strings.HasSuffix(base, "_test.go") { + return Source{}, false, nil + } + if excludePaths[base] { + // TODO(mf): log this? + return Source{}, false, nil + } + // If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use + // that as the version. Otherwise, ignore it. This allows users to have arbitrary + // filenames, but still have versioned migrations within the same directory. For + // example, a user could have a helpers.go file which contains unexported helper + // functions for migrations. + version, err := NumericComponent(base) + if err != nil { + if strict { + return Source{}, false, fmt.Errorf("failed to parse numeric component from %q: %w", base, err) + } + return Source{}, false, nil + } + if excludeVersions[version] { + // TODO: log this? + return Source{}, false, nil + } + // Ensure there are no duplicate versions. + if existing, ok := versionToBaseLookup[version]; ok { + return Source{}, false, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v", + version, + existing, + base, + ) + } + source := Source{Path: fullpath, Version: version} + switch filepath.Ext(base) { + case ".sql": + source.Type = TypeSQL + case ".go": + source.Type = TypeGo + default: + // Should never happen since we already filtered out all other file types. + return Source{}, false, fmt.Errorf("invalid file extension: %q", base) + } + // Add the version to the lookup map. + versionToBaseLookup[version] = base + return source, true, nil +} + // collectFilesystemSources scans the file system for migration files that have a numeric prefix // (greater than one) followed by an underscore and a file extension of either .go or .sql. fsys may // be nil, in which case an empty fileSources is returned. @@ -29,6 +77,7 @@ func collectFilesystemSources( strict bool, excludePaths map[string]bool, excludeVersions map[int64]bool, + recursive bool, ) (*fileSources, error) { if fsys == nil { return new(fileSources), nil @@ -39,65 +88,69 @@ func collectFilesystemSources( "*.sql", "*.go", } { - files, err := fs.Glob(fsys, pattern) + files, err := func() ([]string, error) { + if recursive { + var files []string + err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + subFs, err := fs.Sub(fsys, path) + if err != nil { + return err + } + dirFiles, err := fs.Glob(subFs, pattern) + if err != nil { + return err + } + for _, file := range dirFiles { + files = append(files, filepath.Join(path, file)) + } + } + return nil + }) + if err != nil { + return nil, err + } + return files, nil + } else { + files, err := fs.Glob(fsys, pattern) + if err != nil { + return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err) + } + return files, nil + } + }() if err != nil { - return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err) + return nil, err } for _, fullpath := range files { - base := filepath.Base(fullpath) - if strings.HasSuffix(base, "_test.go") { - continue - } - if excludePaths[base] { - // TODO(mf): log this? - continue - } - // If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use - // that as the version. Otherwise, ignore it. This allows users to have arbitrary - // filenames, but still have versioned migrations within the same directory. For - // example, a user could have a helpers.go file which contains unexported helper - // functions for migrations. - version, err := NumericComponent(base) + source, isValid, err := checkFile( + fullpath, + strict, + excludePaths, + excludeVersions, + versionToBaseLookup, + ) if err != nil { - if strict { - return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err) - } - continue + return nil, err } - if excludeVersions[version] { - // TODO: log this? + if !isValid { continue } - // Ensure there are no duplicate versions. - if existing, ok := versionToBaseLookup[version]; ok { - return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v", - version, - existing, - base, - ) - } - switch filepath.Ext(base) { - case ".sql": - sources.sqlSources = append(sources.sqlSources, Source{ - Type: TypeSQL, - Path: fullpath, - Version: version, - }) - case ".go": - sources.goSources = append(sources.goSources, Source{ - Type: TypeGo, - Path: fullpath, - Version: version, - }) + switch source.Type { + case TypeSQL: + sources.sqlSources = append(sources.sqlSources, source) + case TypeGo: + sources.goSources = append(sources.goSources, source) default: - // Should never happen since we already filtered out all other file types. - return nil, fmt.Errorf("invalid file extension: %q", base) + return nil, errors.New("unreachable") } - // Add the version to the lookup map. - versionToBaseLookup[version] = base } } return sources, nil + } func newSQLMigration(source Source) *Migration { diff --git a/provider_collect_test.go b/provider_collect_test.go index f53722721..7c77aa043 100644 --- a/provider_collect_test.go +++ b/provider_collect_test.go @@ -12,21 +12,21 @@ import ( func TestCollectFileSources(t *testing.T) { t.Parallel() t.Run("nil_fsys", func(t *testing.T) { - sources, err := collectFilesystemSources(nil, false, nil, nil) + sources, err := collectFilesystemSources(nil, false, nil, nil, false) check.NoError(t, err) check.Bool(t, sources != nil, true) check.Number(t, len(sources.goSources), 0) check.Number(t, len(sources.sqlSources), 0) }) t.Run("noop_fsys", func(t *testing.T) { - sources, err := collectFilesystemSources(noopFS{}, false, nil, nil) + sources, err := collectFilesystemSources(noopFS{}, false, nil, nil, false) check.NoError(t, err) check.Bool(t, sources != nil, true) check.Number(t, len(sources.goSources), 0) check.Number(t, len(sources.sqlSources), 0) }) t.Run("empty_fsys", func(t *testing.T) { - sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil) + sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil, false) check.NoError(t, err) check.Number(t, len(sources.goSources), 0) check.Number(t, len(sources.sqlSources), 0) @@ -37,19 +37,19 @@ func TestCollectFileSources(t *testing.T) { "00000_foo.sql": sqlMapFile, } // strict disable - should not error - sources, err := collectFilesystemSources(mapFS, false, nil, nil) + sources, err := collectFilesystemSources(mapFS, false, nil, nil, false) check.NoError(t, err) check.Number(t, len(sources.goSources), 0) check.Number(t, len(sources.sqlSources), 0) // strict enabled - should error - _, err = collectFilesystemSources(mapFS, true, nil, nil) + _, err = collectFilesystemSources(mapFS, true, nil, nil, false) check.HasError(t, err) check.Contains(t, err.Error(), "migration version must be greater than zero") }) t.Run("collect", func(t *testing.T) { fsys, err := fs.Sub(newSQLOnlyFS(), "migrations") check.NoError(t, err) - sources, err := collectFilesystemSources(fsys, false, nil, nil) + sources, err := collectFilesystemSources(fsys, false, nil, nil, false) check.NoError(t, err) check.Number(t, len(sources.sqlSources), 4) check.Number(t, len(sources.goSources), 0) @@ -77,6 +77,7 @@ func TestCollectFileSources(t *testing.T) { "00110_qux.sql": true, }, nil, + false, ) check.NoError(t, err) check.Number(t, len(sources.sqlSources), 2) @@ -97,7 +98,7 @@ func TestCollectFileSources(t *testing.T) { mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")} fsys, err := fs.Sub(mapFS, "migrations") check.NoError(t, err) - _, err = collectFilesystemSources(fsys, true, nil, nil) + _, err = collectFilesystemSources(fsys, true, nil, nil, false) check.HasError(t, err) check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`) }) @@ -109,7 +110,7 @@ func TestCollectFileSources(t *testing.T) { "4_qux.sql": sqlMapFile, "5_foo_test.go": {Data: []byte(`package goose_test`)}, } - sources, err := collectFilesystemSources(mapFS, false, nil, nil) + sources, err := collectFilesystemSources(mapFS, false, nil, nil, false) check.NoError(t, err) check.Number(t, len(sources.sqlSources), 4) check.Number(t, len(sources.goSources), 0) @@ -124,7 +125,7 @@ func TestCollectFileSources(t *testing.T) { "no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)}, "some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)}, } - sources, err := collectFilesystemSources(mapFS, false, nil, nil) + sources, err := collectFilesystemSources(mapFS, false, nil, nil, false) check.NoError(t, err) check.Number(t, len(sources.sqlSources), 2) check.Number(t, len(sources.goSources), 1) @@ -143,7 +144,8 @@ func TestCollectFileSources(t *testing.T) { "001_foo.sql": sqlMapFile, "01_bar.sql": sqlMapFile, } - _, err := collectFilesystemSources(mapFS, false, nil, nil) + + _, err := collectFilesystemSources(mapFS, false, nil, nil, false) check.HasError(t, err) check.Contains(t, err.Error(), "found duplicate migration version 1") }) @@ -159,7 +161,7 @@ func TestCollectFileSources(t *testing.T) { t.Helper() f, err := fs.Sub(mapFS, dirpath) check.NoError(t, err) - got, err := collectFilesystemSources(f, false, nil, nil) + got, err := collectFilesystemSources(f, false, nil, nil, false) check.NoError(t, err) check.Number(t, len(got.sqlSources), len(sqlSources)) check.Number(t, len(got.goSources), 0) @@ -180,6 +182,25 @@ func TestCollectFileSources(t *testing.T) { }) assertDirpath("dir3", nil) }) + t.Run("recursive", func(t *testing.T) { + mapFS := fstest.MapFS{ + "876_a.sql": sqlMapFile, + "dir1/101_a.sql": sqlMapFile, + "dir1/102_b.sql": sqlMapFile, + "dir1/103_c.sql": sqlMapFile, + "dir2/201_a.sql": sqlMapFile, + "dir2/dir3/301_a.sql": sqlMapFile, + } + sources, err := collectFilesystemSources(mapFS, false, nil, nil, true) + check.NoError(t, err) + check.Equal(t, len(sources.sqlSources), 6) + check.Equal(t, sources.sqlSources[0].Path, "876_a.sql") + check.Equal(t, sources.sqlSources[1].Path, "dir1/101_a.sql") + check.Equal(t, sources.sqlSources[2].Path, "dir1/102_b.sql") + check.Equal(t, sources.sqlSources[3].Path, "dir1/103_c.sql") + check.Equal(t, sources.sqlSources[4].Path, "dir2/201_a.sql") + check.Equal(t, sources.sqlSources[5].Path, "dir2/dir3/301_a.sql") + }) } func TestMerge(t *testing.T) { @@ -195,7 +216,7 @@ func TestMerge(t *testing.T) { } fsys, err := fs.Sub(mapFS, "migrations") check.NoError(t, err) - sources, err := collectFilesystemSources(fsys, false, nil, nil) + sources, err := collectFilesystemSources(fsys, false, nil, nil, false) check.NoError(t, err) check.Equal(t, len(sources.sqlSources), 1) check.Equal(t, len(sources.goSources), 2) @@ -243,7 +264,7 @@ func TestMerge(t *testing.T) { } fsys, err := fs.Sub(mapFS, "migrations") check.NoError(t, err) - sources, err := collectFilesystemSources(fsys, false, nil, nil) + sources, err := collectFilesystemSources(fsys, false, nil, nil, false) check.NoError(t, err) t.Run("unregistered_all", func(t *testing.T) { migrations, err := merge(sources, map[int64]*Migration{ @@ -267,7 +288,7 @@ func TestMerge(t *testing.T) { } fsys, err := fs.Sub(mapFS, "migrations") check.NoError(t, err) - sources, err := collectFilesystemSources(fsys, false, nil, nil) + sources, err := collectFilesystemSources(fsys, false, nil, nil, false) check.NoError(t, err) t.Run("unregistered_all", func(t *testing.T) { migrations, err := merge(sources, map[int64]*Migration{ diff --git a/provider_options.go b/provider_options.go index 15ee99006..fdbfad269 100644 --- a/provider_options.go +++ b/provider_options.go @@ -165,6 +165,13 @@ func WithDisableVersioning(b bool) ProviderOption { }) } +func WithRecursive(b bool) ProviderOption { + return configFunc(func(c *config) error { + c.recursive = b + return nil + }) +} + type config struct { store database.Store @@ -184,6 +191,7 @@ type config struct { disableVersioning bool allowMissing bool disableGlobalRegistry bool + recursive bool // Let's not expose the Logger just yet. Ideally we consolidate on the std lib slog package // added in go1.21 and then expose that (if that's even necessary). For now, just use the std