Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support recursive collection of source files for provider using fs.FS #706

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func newProvider(
// feat(mf): we could add a flag to parse SQL migrations eagerly. This would allow us to return
// an error if there are any SQL parsing errors. This adds a bit overhead to startup though, so
// we should make it optional.
filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions)
filesystemSources, err := collectFilesystemSources(fsys, false, cfg.excludePaths, cfg.excludeVersions, cfg.recursive)
if err != nil {
return nil, err
}
Expand Down
147 changes: 100 additions & 47 deletions provider_collect.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,54 @@ type fileSources struct {
goSources []Source
}

func checkFile(fullpath string, strict bool, excludePaths map[string]bool, excludeVersions map[int64]bool, versionToBaseLookup map[int64]string) (Source, bool, error) {
base := filepath.Base(fullpath)
if strings.HasSuffix(base, "_test.go") {
return Source{}, false, nil
}
if excludePaths[base] {
// TODO(mf): log this?
return Source{}, false, nil
}
// If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use
// that as the version. Otherwise, ignore it. This allows users to have arbitrary
// filenames, but still have versioned migrations within the same directory. For
// example, a user could have a helpers.go file which contains unexported helper
// functions for migrations.
version, err := NumericComponent(base)
if err != nil {
if strict {
return Source{}, false, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
}
return Source{}, false, nil
}
if excludeVersions[version] {
// TODO: log this?
return Source{}, false, nil
}
// Ensure there are no duplicate versions.
if existing, ok := versionToBaseLookup[version]; ok {
return Source{}, false, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
version,
existing,
base,
)
}
source := Source{Path: fullpath, Version: version}
switch filepath.Ext(base) {
case ".sql":
source.Type = TypeSQL
case ".go":
source.Type = TypeGo
default:
// Should never happen since we already filtered out all other file types.
return Source{}, false, fmt.Errorf("invalid file extension: %q", base)
}
// Add the version to the lookup map.
versionToBaseLookup[version] = base
return source, true, nil
}

// collectFilesystemSources scans the file system for migration files that have a numeric prefix
// (greater than one) followed by an underscore and a file extension of either .go or .sql. fsys may
// be nil, in which case an empty fileSources is returned.
Expand All @@ -29,6 +77,7 @@ func collectFilesystemSources(
strict bool,
excludePaths map[string]bool,
excludeVersions map[int64]bool,
recursive bool,
) (*fileSources, error) {
if fsys == nil {
return new(fileSources), nil
Expand All @@ -39,65 +88,69 @@ func collectFilesystemSources(
"*.sql",
"*.go",
} {
files, err := fs.Glob(fsys, pattern)
files, err := func() ([]string, error) {
if recursive {
var files []string
err := fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
subFs, err := fs.Sub(fsys, path)
if err != nil {
return err
}
dirFiles, err := fs.Glob(subFs, pattern)
if err != nil {
return err
}
for _, file := range dirFiles {
files = append(files, filepath.Join(path, file))
}
}
return nil
})
if err != nil {
return nil, err
}
return files, nil
} else {
files, err := fs.Glob(fsys, pattern)
if err != nil {
return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err)
}
return files, nil
}
}()
if err != nil {
return nil, fmt.Errorf("failed to glob pattern %q: %w", pattern, err)
return nil, err
}
for _, fullpath := range files {
base := filepath.Base(fullpath)
if strings.HasSuffix(base, "_test.go") {
continue
}
if excludePaths[base] {
// TODO(mf): log this?
continue
}
// If the filename has a valid looking version of the form: NUMBER_.{sql,go}, then use
// that as the version. Otherwise, ignore it. This allows users to have arbitrary
// filenames, but still have versioned migrations within the same directory. For
// example, a user could have a helpers.go file which contains unexported helper
// functions for migrations.
version, err := NumericComponent(base)
source, isValid, err := checkFile(
fullpath,
strict,
excludePaths,
excludeVersions,
versionToBaseLookup,
)
if err != nil {
if strict {
return nil, fmt.Errorf("failed to parse numeric component from %q: %w", base, err)
}
continue
return nil, err
}
if excludeVersions[version] {
// TODO: log this?
if !isValid {
continue
}
// Ensure there are no duplicate versions.
if existing, ok := versionToBaseLookup[version]; ok {
return nil, fmt.Errorf("found duplicate migration version %d:\n\texisting:%v\n\tcurrent:%v",
version,
existing,
base,
)
}
switch filepath.Ext(base) {
case ".sql":
sources.sqlSources = append(sources.sqlSources, Source{
Type: TypeSQL,
Path: fullpath,
Version: version,
})
case ".go":
sources.goSources = append(sources.goSources, Source{
Type: TypeGo,
Path: fullpath,
Version: version,
})
switch source.Type {
case TypeSQL:
sources.sqlSources = append(sources.sqlSources, source)
case TypeGo:
sources.goSources = append(sources.goSources, source)
default:
// Should never happen since we already filtered out all other file types.
return nil, fmt.Errorf("invalid file extension: %q", base)
return nil, errors.New("unreachable")
}
// Add the version to the lookup map.
versionToBaseLookup[version] = base
}
}
return sources, nil

}

func newSQLMigration(source Source) *Migration {
Expand Down
49 changes: 35 additions & 14 deletions provider_collect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@ import (
func TestCollectFileSources(t *testing.T) {
t.Parallel()
t.Run("nil_fsys", func(t *testing.T) {
sources, err := collectFilesystemSources(nil, false, nil, nil)
sources, err := collectFilesystemSources(nil, false, nil, nil, false)
check.NoError(t, err)
check.Bool(t, sources != nil, true)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
})
t.Run("noop_fsys", func(t *testing.T) {
sources, err := collectFilesystemSources(noopFS{}, false, nil, nil)
sources, err := collectFilesystemSources(noopFS{}, false, nil, nil, false)
check.NoError(t, err)
check.Bool(t, sources != nil, true)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
})
t.Run("empty_fsys", func(t *testing.T) {
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil)
sources, err := collectFilesystemSources(fstest.MapFS{}, false, nil, nil, false)
check.NoError(t, err)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
Expand All @@ -37,19 +37,19 @@ func TestCollectFileSources(t *testing.T) {
"00000_foo.sql": sqlMapFile,
}
// strict disable - should not error
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
sources, err := collectFilesystemSources(mapFS, false, nil, nil, false)
check.NoError(t, err)
check.Number(t, len(sources.goSources), 0)
check.Number(t, len(sources.sqlSources), 0)
// strict enabled - should error
_, err = collectFilesystemSources(mapFS, true, nil, nil)
_, err = collectFilesystemSources(mapFS, true, nil, nil, false)
check.HasError(t, err)
check.Contains(t, err.Error(), "migration version must be greater than zero")
})
t.Run("collect", func(t *testing.T) {
fsys, err := fs.Sub(newSQLOnlyFS(), "migrations")
check.NoError(t, err)
sources, err := collectFilesystemSources(fsys, false, nil, nil)
sources, err := collectFilesystemSources(fsys, false, nil, nil, false)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 4)
check.Number(t, len(sources.goSources), 0)
Expand Down Expand Up @@ -77,6 +77,7 @@ func TestCollectFileSources(t *testing.T) {
"00110_qux.sql": true,
},
nil,
false,
)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 2)
Expand All @@ -97,7 +98,7 @@ func TestCollectFileSources(t *testing.T) {
mapFS["migrations/not_valid.sql"] = &fstest.MapFile{Data: []byte("invalid")}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
_, err = collectFilesystemSources(fsys, true, nil, nil)
_, err = collectFilesystemSources(fsys, true, nil, nil, false)
check.HasError(t, err)
check.Contains(t, err.Error(), `failed to parse numeric component from "not_valid.sql"`)
})
Expand All @@ -109,7 +110,7 @@ func TestCollectFileSources(t *testing.T) {
"4_qux.sql": sqlMapFile,
"5_foo_test.go": {Data: []byte(`package goose_test`)},
}
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
sources, err := collectFilesystemSources(mapFS, false, nil, nil, false)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 4)
check.Number(t, len(sources.goSources), 0)
Expand All @@ -124,7 +125,7 @@ func TestCollectFileSources(t *testing.T) {
"no_a_real_migration.sql": {Data: []byte(`SELECT 1;`)},
"some/other/dir/2_foo.sql": {Data: []byte(`SELECT 1;`)},
}
sources, err := collectFilesystemSources(mapFS, false, nil, nil)
sources, err := collectFilesystemSources(mapFS, false, nil, nil, false)
check.NoError(t, err)
check.Number(t, len(sources.sqlSources), 2)
check.Number(t, len(sources.goSources), 1)
Expand All @@ -143,7 +144,8 @@ func TestCollectFileSources(t *testing.T) {
"001_foo.sql": sqlMapFile,
"01_bar.sql": sqlMapFile,
}
_, err := collectFilesystemSources(mapFS, false, nil, nil)

_, err := collectFilesystemSources(mapFS, false, nil, nil, false)
check.HasError(t, err)
check.Contains(t, err.Error(), "found duplicate migration version 1")
})
Expand All @@ -159,7 +161,7 @@ func TestCollectFileSources(t *testing.T) {
t.Helper()
f, err := fs.Sub(mapFS, dirpath)
check.NoError(t, err)
got, err := collectFilesystemSources(f, false, nil, nil)
got, err := collectFilesystemSources(f, false, nil, nil, false)
check.NoError(t, err)
check.Number(t, len(got.sqlSources), len(sqlSources))
check.Number(t, len(got.goSources), 0)
Expand All @@ -180,6 +182,25 @@ func TestCollectFileSources(t *testing.T) {
})
assertDirpath("dir3", nil)
})
t.Run("recursive", func(t *testing.T) {
mapFS := fstest.MapFS{
"876_a.sql": sqlMapFile,
"dir1/101_a.sql": sqlMapFile,
"dir1/102_b.sql": sqlMapFile,
"dir1/103_c.sql": sqlMapFile,
"dir2/201_a.sql": sqlMapFile,
"dir2/dir3/301_a.sql": sqlMapFile,
}
sources, err := collectFilesystemSources(mapFS, false, nil, nil, true)
check.NoError(t, err)
check.Equal(t, len(sources.sqlSources), 6)
check.Equal(t, sources.sqlSources[0].Path, "876_a.sql")
check.Equal(t, sources.sqlSources[1].Path, "dir1/101_a.sql")
check.Equal(t, sources.sqlSources[2].Path, "dir1/102_b.sql")
check.Equal(t, sources.sqlSources[3].Path, "dir1/103_c.sql")
check.Equal(t, sources.sqlSources[4].Path, "dir2/201_a.sql")
check.Equal(t, sources.sqlSources[5].Path, "dir2/dir3/301_a.sql")
})
}

func TestMerge(t *testing.T) {
Expand All @@ -195,7 +216,7 @@ func TestMerge(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFilesystemSources(fsys, false, nil, nil)
sources, err := collectFilesystemSources(fsys, false, nil, nil, false)
check.NoError(t, err)
check.Equal(t, len(sources.sqlSources), 1)
check.Equal(t, len(sources.goSources), 2)
Expand Down Expand Up @@ -243,7 +264,7 @@ func TestMerge(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFilesystemSources(fsys, false, nil, nil)
sources, err := collectFilesystemSources(fsys, false, nil, nil, false)
check.NoError(t, err)
t.Run("unregistered_all", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*Migration{
Expand All @@ -267,7 +288,7 @@ func TestMerge(t *testing.T) {
}
fsys, err := fs.Sub(mapFS, "migrations")
check.NoError(t, err)
sources, err := collectFilesystemSources(fsys, false, nil, nil)
sources, err := collectFilesystemSources(fsys, false, nil, nil, false)
check.NoError(t, err)
t.Run("unregistered_all", func(t *testing.T) {
migrations, err := merge(sources, map[int64]*Migration{
Expand Down
8 changes: 8 additions & 0 deletions provider_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ func WithDisableVersioning(b bool) ProviderOption {
})
}

func WithRecursive(b bool) ProviderOption {
return configFunc(func(c *config) error {
c.recursive = b
return nil
})
}

type config struct {
store database.Store

Expand All @@ -184,6 +191,7 @@ type config struct {
disableVersioning bool
allowMissing bool
disableGlobalRegistry bool
recursive bool

// Let's not expose the Logger just yet. Ideally we consolidate on the std lib slog package
// added in go1.21 and then expose that (if that's even necessary). For now, just use the std
Expand Down