Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add provider HasPending method #751

Merged
merged 12 commits into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ test-packages:
test-packages-short:
go test -test.short $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples)

coverage-short:
go test ./ -test.short $(GO_TEST_FLAGS) -cover -coverprofile=coverage.out
go tool cover -html=coverage.out

coverage:
go test ./ $(GO_TEST_FLAGS) -cover -coverprofile=coverage.out
go tool cover -html=coverage.out

#
# Integration-related targets
#
Expand Down
118 changes: 118 additions & 0 deletions internal/testing/integration/postgres_locking_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"hash/crc64"
"math/rand"
"os"
"sort"
"sync"
"testing"
"testing/fstest"
"time"

"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/testing/testdb"
"github.com/pressly/goose/v3/lock"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -406,6 +410,120 @@ func TestPostgresProviderLocking(t *testing.T) {
})
}

func TestPostgresHasPending(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip("skipping test in short mode.")
}

db, cleanup, err := testdb.NewPostgres()
require.NoError(t, err)
t.Cleanup(cleanup)

workers := 15

run := func(want bool) {
var g errgroup.Group
boolCh := make(chan bool, workers)
for i := 0; i < workers; i++ {
g.Go(func() error {
p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS("testdata/migrations/postgres"))
check.NoError(t, err)
hasPending, err := p.HasPending(context.Background())
if err != nil {
return err
}
boolCh <- hasPending
return nil

})
}
check.NoError(t, g.Wait())
close(boolCh)
// expect all values to be true
for hasPending := range boolCh {
check.Bool(t, hasPending, want)
}
}
t.Run("concurrent_has_pending", func(t *testing.T) {
run(true)
})

// apply all migrations
p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS("testdata/migrations/postgres"))
check.NoError(t, err)
_, err = p.Up(context.Background())
check.NoError(t, err)

t.Run("concurrent_no_pending", func(t *testing.T) {
run(false)
})

// Add a new migration file
last := p.ListSources()[len(p.ListSources())-1]
newVersion := fmt.Sprintf("%d_new_migration.sql", last.Version+1)
fsys := fstest.MapFS{
newVersion: &fstest.MapFile{Data: []byte(`
-- +goose Up
SELECT pg_sleep_for('4 seconds');
`)},
}
lockID := int64(crc64.Checksum([]byte(t.Name()), crc64.MakeTable(crc64.ECMA)))
// Create a new provider with the new migration file
sessionLocker, err := lock.NewPostgresSessionLocker(lock.WithLockTimeout(1, 10), lock.WithLockID(lockID)) // Timeout 5min. Try every 1s up to 10 times.
require.NoError(t, err)
newProvider, err := goose.NewProvider(goose.DialectPostgres, db, fsys, goose.WithSessionLocker(sessionLocker))
check.NoError(t, err)
check.Number(t, len(newProvider.ListSources()), 1)
oldProvider := p
check.Number(t, len(oldProvider.ListSources()), 6)

var g errgroup.Group
g.Go(func() error {
hasPending, err := newProvider.HasPending(context.Background())
if err != nil {
return err
}
check.Bool(t, hasPending, true)
return nil
})
g.Go(func() error {
hasPending, err := oldProvider.HasPending(context.Background())
if err != nil {
return err
}
check.Bool(t, hasPending, false)
return nil
})
check.NoError(t, g.Wait())

// A new provider is running in the background with a session lock to simulate a long running
// migration. If older instances come up, they should not have any pending migrations and not be
// affected by the long running migration. Test the following scenario:
// https://github.com/pressly/goose/pull/507#discussion_r1266498077
g.Go(func() error {
_, err := newProvider.Up(context.Background())
return err
})
time.Sleep(1 * time.Second)
isLocked, err := existsPgLock(context.Background(), db, lockID)
check.NoError(t, err)
check.Bool(t, isLocked, true)
hasPending, err := oldProvider.HasPending(context.Background())
check.NoError(t, err)
check.Bool(t, hasPending, false)
// Wait for the long running migration to finish
check.NoError(t, g.Wait())
// Check that the new migration was applied
hasPending, err = newProvider.HasPending(context.Background())
check.NoError(t, err)
check.Bool(t, hasPending, false)
// The max version should be the new migration
currentVersion, err := newProvider.GetDBVersion(context.Background())
check.NoError(t, err)
check.Number(t, currentVersion, last.Version+1)
}

func existsPgLock(ctx context.Context, db *sql.DB, lockID int64) (bool, error) {
q := `SELECT EXISTS(SELECT 1 FROM pg_locks WHERE locktype='advisory' AND ((classid::bigint<<32)|objid::bigint)=$1)`
row := db.QueryRowContext(ctx, q, lockID)
Expand Down
98 changes: 87 additions & 11 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ type Provider struct {
// database.
mu sync.Mutex

db *sql.DB
store database.Store
db *sql.DB
store database.Store
versionTableOnce sync.Once

fsys fs.FS
cfg config

// migrations are ordered by version in ascending order.
// migrations are ordered by version in ascending order. This list will never be empty and
// contains all migrations known to the provider.
migrations []*Migration
}

Expand All @@ -49,8 +51,6 @@ type Provider struct {
// See [ProviderOption] for more information on configuring the provider.
//
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
//
// Experimental: This API is experimental and may change in the future.
func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
if db == nil {
return nil, errors.New("db must not be nil")
Expand Down Expand Up @@ -154,6 +154,14 @@ func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
return p.status(ctx)
}

// HasPending returns true if there are pending migrations to apply, otherwise, it returns false.
//
// Note, this method will not use a SessionLocker if one is configured. This allows callers to check
// for pending migrations without blocking or being blocked by other operations.
func (p *Provider) HasPending(ctx context.Context) (bool, error) {
return p.hasPending(ctx)
}

// GetDBVersion returns the highest version recorded in the database, regardless of the order in
// which migrations were applied. For example, if migrations were applied out of order (1,4,2,3),
// this method returns 4. If no migrations have been applied, it returns 0.
Expand Down Expand Up @@ -214,12 +222,26 @@ func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bo
// Up applies all pending migrations. If there are no new migrations to apply, this method returns
// empty list and nil error.
func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
hasPending, err := p.HasPending(ctx)
if err != nil {
return nil, err
}
if !hasPending {
return nil, nil
}
return p.up(ctx, false, math.MaxInt64)
}

// UpByOne applies the next pending migration. If there is no next migration to apply, this method
// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result.
// returns [ErrNoNextVersion].
func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
hasPending, err := p.HasPending(ctx)
if err != nil {
return nil, err
}
if !hasPending {
return nil, ErrNoNextVersion
}
res, err := p.up(ctx, true, math.MaxInt64)
if err != nil {
return nil, err
Expand Down Expand Up @@ -247,6 +269,13 @@ func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
// For example, if there are three new migrations (9,10,11) and the current database version is 8
// with a requested version of 10, only versions 9,10 will be applied.
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
hasPending, err := p.HasPending(ctx)
if err != nil {
return nil, err
}
if !hasPending {
return nil, nil
}
return p.up(ctx, false, version)
}

Expand Down Expand Up @@ -303,7 +332,7 @@ func (p *Provider) up(
if version < 1 {
return nil, errInvalidVersion
}
conn, cleanup, err := p.initialize(ctx)
conn, cleanup, err := p.initialize(ctx, true)
if err != nil {
return nil, fmt.Errorf("failed to initialize: %w", err)
}
Expand Down Expand Up @@ -345,7 +374,7 @@ func (p *Provider) down(
byOne bool,
version int64,
) (_ []*MigrationResult, retErr error) {
conn, cleanup, err := p.initialize(ctx)
conn, cleanup, err := p.initialize(ctx, true)
if err != nil {
return nil, fmt.Errorf("failed to initialize: %w", err)
}
Expand Down Expand Up @@ -404,7 +433,7 @@ func (p *Provider) apply(
if err != nil {
return nil, err
}
conn, cleanup, err := p.initialize(ctx)
conn, cleanup, err := p.initialize(ctx, true)
if err != nil {
return nil, fmt.Errorf("failed to initialize: %w", err)
}
Expand Down Expand Up @@ -436,8 +465,55 @@ func (p *Provider) apply(
return p.runMigrations(ctx, conn, []*Migration{m}, d, true)
}

func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) {
conn, cleanup, err := p.initialize(ctx, false)
if err != nil {
return false, fmt.Errorf("failed to initialize: %w", err)
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()

// If versioning is disabled, we always have pending migrations.
if p.cfg.disableVersioning {
return true, nil
}
if p.cfg.allowMissing {
// List all migrations from the database.
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return false, err
}
// If there are no migrations in the database, we have pending migrations.
if len(dbMigrations) == 0 {
return true, nil
}
applied := make(map[int64]bool, len(dbMigrations))
for _, m := range dbMigrations {
applied[m.Version] = true
}
// Iterate over all migrations and check if any are missing.
for _, m := range p.migrations {
if !applied[m.Version] {
return true, nil
}
}
return false, nil
}
// If out-of-order migrations are not allowed, we can optimize this by only checking whether the
// last migration the provider knows about is applied.
last := p.migrations[len(p.migrations)-1]
if _, err := p.store.GetMigration(ctx, conn, last.Version); err != nil {
if errors.Is(err, database.ErrVersionNotFound) {
return true, nil
}
return false, err
}
return false, nil
}

func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
conn, cleanup, err := p.initialize(ctx)
conn, cleanup, err := p.initialize(ctx, true)
if err != nil {
return nil, fmt.Errorf("failed to initialize: %w", err)
}
Expand Down Expand Up @@ -478,7 +554,7 @@ func (p *Provider) getDBMaxVersion(ctx context.Context, conn *sql.Conn) (_ int64
if conn == nil {
var cleanup func() error
var err error
conn, cleanup, err = p.initialize(ctx)
conn, cleanup, err = p.initialize(ctx, true)
if err != nil {
return 0, err
}
Expand Down
Loading
Loading