Skip to content

Commit

Permalink
cmd/atlascmd: add migrate apply command (#849)
Browse files Browse the repository at this point in the history
* cmd/atlascmd: add migrate apply command

* make linter happy

* go generate

* go mod tidy

* reword description

* get rid of dialect specific logic in cmd

* compact pretty output
  • Loading branch information
masseelch authored Jun 8, 2022
1 parent 39ab904 commit 82a0084
Show file tree
Hide file tree
Showing 17 changed files with 477 additions and 68 deletions.
203 changes: 184 additions & 19 deletions cmd/atlascmd/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,45 @@ import (
"io"
"io/ioutil"
"path/filepath"
"strconv"
"strings"
"time"

entmigrate "ariga.io/atlas/cmd/atlascmd/migrate"
"ariga.io/atlas/sql/migrate"
"ariga.io/atlas/sql/schema"
"ariga.io/atlas/sql/sqlclient"
"ariga.io/atlas/sql/sqltool"
"github.com/fatih/color"
"github.com/spf13/pflag"

"github.com/spf13/cobra"
)

const (
migrateFlagDevURL = "dev-url"
migrateFlagDir = "dir"
migrateFlagForce = "force"
migrateFlagFormat = "format"
migrateFlagSchema = "schema"
migrateDiffFlagTo = "to"
migrateDiffFlagVerbose = "verbose"
migrateFlagDevURL = "dev-url"
migrateFlagDir = "dir"
migrateFlagForce = "force"
migrateFlagFormat = "format"
migrateFlagLog = "log"
migrateFlagRevisionsSchema = "revisions-schema"
migrateFlagTo = "to"
migrateFlagSchema = "schema"
migrateDiffFlagVerbose = "verbose"
)

var (
// MigrateFlags are the flags used in MigrateCmd (and sub-commands).
MigrateFlags struct {
DirURL string
DevURL string
ToURL string
Schemas []string
Format string
Force bool
Verbose bool
DirURL string
DevURL string
ToURL string
Schemas []string
Format string
LogFormat string
RevisionSchema string
Force bool
Verbose bool
}
// MigrateCmd represents the migrate command. It wraps several other sub-commands.
MigrateCmd = &cobra.Command{
Expand Down Expand Up @@ -75,6 +83,21 @@ to re-hash the contents and resolve the error
return nil
},
}
// MigrateApplyCmd represents the 'atlas migrate apply' subcommand.
MigrateApplyCmd = &cobra.Command{
Use: "apply",
Short: "Applies pending migration files on the connected database.",
Long: `'atlas migrate apply' reads the migration state of the connected database and computes what migrations are pending.
It then attempts to apply the pending migration files in the correct order onto the database.
The first argument denotes the maximum number of migration files to apply.
As a safety measure 'atlas migrate apply' will abort with an error, if:
- the migration directory is not in sync with the 'atlas.sum' file
- the migration and database history do not match each other`,
Example: ` atlas migrate apply --to mysql://user:pass@localhost:3306/dbname
atlas migrate apply 1 --dir file:///path/to/migration/directory --to mysql://user:pass@localhost:3306/dbname`,
Args: cobra.MaximumNArgs(1),
RunE: CmdMigrateApplyRun,
}
// MigrateDiffCmd represents the 'atlas migrate diff' subcommand.
MigrateDiffCmd = &cobra.Command{
Use: "diff",
Expand Down Expand Up @@ -116,8 +139,8 @@ This command should be used whenever a manual change in the migration directory
the atlas.sum file. If there is a mismatch it will be reported. If the --dev-url flag is given, the migration files are
executed on the connected database in order to validate SQL semantics.`,
Example: ` atlas migrate validate
atlas migrate validate --dir /path/to/migration/directory
atlas migrate validate --dir /path/to/migration/directory --dev-url mysql://user:pass@localhost:3306/dev`,
atlas migrate validate --dir file:///path/to/migration/directory
atlas migrate validate --dir file:///path/to/migration/directory --dev-url mysql://user:pass@localhost:3306/dev`,
PreRunE: migrateFlagsFromEnv,
RunE: CmdMigrateValidateRun,
}
Expand All @@ -126,6 +149,7 @@ executed on the connected database in order to validate SQL semantics.`,
func init() {
// Add sub-commands.
Root.AddCommand(MigrateCmd)
MigrateCmd.AddCommand(MigrateApplyCmd)
MigrateCmd.AddCommand(MigrateDiffCmd)
MigrateCmd.AddCommand(MigrateHashCmd)
MigrateCmd.AddCommand(MigrateNewCmd)
Expand All @@ -134,24 +158,98 @@ func init() {
devURL := func(set *pflag.FlagSet) {
set.StringVarP(&MigrateFlags.DevURL, migrateFlagDevURL, "", "", "[driver://username:password@address/dbname?param=value] select a data source using the URL format")
}
toURL := func(set *pflag.FlagSet) {
set.StringVarP(&MigrateFlags.ToURL, migrateFlagTo, "", "", "[driver://username:password@address/dbname?param=value] select a data source using the URL format")
}
// Global flags.
MigrateCmd.PersistentFlags().StringVarP(&MigrateFlags.DirURL, migrateFlagDir, "", "file://migrations", "select migration directory using URL format")
MigrateCmd.PersistentFlags().StringSliceVarP(&MigrateFlags.Schemas, migrateFlagSchema, "", nil, "set schema names")
MigrateCmd.PersistentFlags().StringVarP(&MigrateFlags.Format, migrateFlagFormat, "", formatAtlas, "set migration file format")
MigrateCmd.PersistentFlags().BoolVarP(&MigrateFlags.Force, migrateFlagForce, "", false, "force a command to run on a broken migration directory state")
MigrateCmd.PersistentFlags().SortFlags = false
// Apply flags.
MigrateApplyCmd.Flags().StringVarP(&MigrateFlags.LogFormat, migrateFlagLog, "", logFormatTTY, "log format to use")
MigrateApplyCmd.Flags().StringVarP(&MigrateFlags.RevisionSchema, migrateFlagRevisionsSchema, "", "", "schema name where the revisions table is to be created")
toURL(MigrateApplyCmd.Flags())
// Diff flags.
devURL(MigrateDiffCmd.Flags())
MigrateDiffCmd.Flags().StringVarP(&MigrateFlags.ToURL, migrateDiffFlagTo, "", "", "[driver://username:password@address/dbname?param=value] select a data source using the URL format")
toURL(MigrateDiffCmd.Flags())
MigrateDiffCmd.Flags().BoolVarP(&MigrateFlags.Verbose, migrateDiffFlagVerbose, "", false, "enable verbose logging")
MigrateDiffCmd.Flags().SortFlags = false
cobra.CheckErr(MigrateDiffCmd.MarkFlagRequired(migrateFlagDevURL))
cobra.CheckErr(MigrateDiffCmd.MarkFlagRequired(migrateDiffFlagTo))
cobra.CheckErr(MigrateDiffCmd.MarkFlagRequired(migrateFlagTo))
// Validate flags.
devURL(MigrateValidateCmd.Flags())
receivesEnv(MigrateCmd)
}

// CmdMigrateApplyRun is the command executed when running the CLI with 'migrate apply' args.
func CmdMigrateApplyRun(cmd *cobra.Command, args []string) error {
var (
n int
err error
)
if len(args) > 1 {
n, err = strconv.Atoi(args[1])
if err != nil {
return err
}
}
// Open the migration directory.
dir, err := dir()
if err != nil {
return err
}
// Open a client to the database.
target, err := sqlclient.Open(cmd.Context(), MigrateFlags.ToURL)
if err != nil {
return err
}
// Get the correct log format and destination. Currently, only os.Stdout is supported.
l, err := logFormat(cmd.OutOrStdout())
if err != nil {
return err
}
// Currently, only in DB revisions are supported.
opts := []entmigrate.Option{entmigrate.WithSchema(MigrateFlags.RevisionSchema)}
rrw, err := entmigrate.NewEntRevisions(target, opts...)
if err != nil {
return err
}
if err := rrw.Init(cmd.Context()); err != nil {
return err
}
// Wrap the whole execution in one transaction. This behaviour will change once
// there are insights about migration files available.
tx, err := target.Tx(cmd.Context(), nil)
if err != nil {
return err
}
// Get the executor.
ex, err := migrate.NewExecutor(tx.Driver, dir, rrw, migrate.WithLogger(l))
if err != nil {
return err
}
defer func(rrw *entmigrate.EntRevisions, ctx context.Context) {
if err2 := rrw.Flush(ctx); err2 != nil {
if err != nil {
err = fmt.Errorf("%v: %w", err2, err)
} else {
err = err2
}
}
}(rrw, cmd.Context())
err = ex.ExecuteN(cmd.Context(), n)
if err != nil {
if err2 := tx.Rollback(); err2 != nil {
err = fmt.Errorf("%v: %w", err2, err)
}
return err
}
err = tx.Commit()
return err
}

// CmdMigrateDiffRun is the command executed when running the CLI with 'migrate diff' args.
func CmdMigrateDiffRun(cmd *cobra.Command, args []string) error {
// Open a dev driver.
Expand Down Expand Up @@ -351,6 +449,73 @@ func formatter() (migrate.Formatter, error) {
}
}

const (
logFormatTTY = "tty"
)

// LogTTY is a migrate.Logger that pretty prints execution progress.
// If the connected out is not a tty, it will fall back to a non-colorful output.
type LogTTY struct {
out io.Writer
start time.Time
fileStart time.Time
fileCounter int
stmtCounter int
}

var (
cyan = color.CyanString
yellow = color.YellowString
dash = yellow("--")
arr = cyan("->")
indent2 = strings.Repeat(" ", 2)
indent4 = strings.Repeat(indent2, 2)
)

// Log implements the migrate.Logger interface.
func (l *LogTTY) Log(e migrate.LogEntry) {
switch e := e.(type) {
case migrate.LogExecution:
l.start = time.Now()
fmt.Fprintf(l.out, "Migrating to version %v", cyan(e.To))
if e.From != "" {
fmt.Fprintf(l.out, " from %v", cyan(e.From))
}
fmt.Fprintf(l.out, " (%d migrations in total):\n", len(e.Files))
case migrate.LogFile:
l.fileCounter++
if !l.fileStart.IsZero() {
l.reportFileEnd()
}
l.fileStart = time.Now()
fmt.Fprintf(l.out, "\n%s%v migrating version %v\n", indent2, dash, cyan(e.Version))
case migrate.LogStmt:
l.stmtCounter++
fmt.Fprintf(l.out, "%s%v %s\n", indent4, arr, e.SQL)
case migrate.LogDone:
l.reportFileEnd()
fmt.Fprintf(l.out, "\n%s%v\n", indent2, cyan(strings.Repeat("-", 25)))
fmt.Fprintf(l.out, "%s%v %v\n", indent2, dash, time.Since(l.start))
fmt.Fprintf(l.out, "%s%v %v migrations\n", indent2, dash, l.fileCounter)
fmt.Fprintf(l.out, "%s%v %v sql statements\n", indent2, dash, l.stmtCounter)
default:
fmt.Fprintf(l.out, "%v", e)
}
}

func (l *LogTTY) reportFileEnd() {
fmt.Fprintf(l.out, "%s%v ok (%v)\n", indent2, dash, yellow("%s", time.Since(l.fileStart)))
}

func logFormat(out io.Writer) (migrate.Logger, error) {
switch MigrateFlags.LogFormat {
case logFormatTTY:
return &LogTTY{out: out}, nil
default:
return nil, fmt.Errorf("unknown log-format %q", MigrateFlags.LogFormat)
}
}

func migrateFlagsFromEnv(cmd *cobra.Command, _ []string) error {
activeEnv, err := selectEnv(GlobalFlags.SelectedEnv)
if err != nil {
Expand All @@ -373,7 +538,7 @@ func migrateFlagsFromEnv(cmd *cobra.Command, _ []string) error {
}
toURL = "file://" + toURL
}
if err := maySetFlag(cmd, migrateDiffFlagTo, toURL); err != nil {
if err := maySetFlag(cmd, migrateFlagTo, toURL); err != nil {
return err
}
if s := "[" + strings.Join(activeEnv.Schemas, "") + "]"; len(activeEnv.Schemas) > 0 {
Expand Down
44 changes: 40 additions & 4 deletions cmd/atlascmd/migrate/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@ import (
"ariga.io/atlas/sql/migrate"
"ariga.io/atlas/sql/schema"
"ariga.io/atlas/sql/sqlclient"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
entschema "entgo.io/ent/dialect/sql/schema"
)

type (
// A EntRevisions provides implementation for the migrate.RevisionReadWriter interface.
EntRevisions struct {
ac *sqlclient.Client // underlying Atlas client
sc *sqlclient.Client // underlying Atlas client connected to the named schema
ec *ent.Client // underlying Ent client
schema string // name of the schema the revision table resides in
ac *sqlclient.Client // underlying Atlas client
sc *sqlclient.Client // underlying Atlas client connected to the named schema
ec *ent.Client // underlying Ent client
schema string // name of the schema the revision table resides in
cache []migrate.Revision // cache stores writes to Ent for blocked connections (like in SQLite).
}

// Option allows to configure EntRevisions by using functional arguments.
Expand Down Expand Up @@ -88,6 +90,8 @@ func (r *EntRevisions) Init(ctx context.Context) error {
}

// ReadRevisions reads the revisions from the revisions table.
//
// ReadRevisions will not return results only saved to cache.
func (r *EntRevisions) ReadRevisions(ctx context.Context) (migrate.Revisions, error) {
revs, err := r.ec.Revision.Query().Order(ent.Asc(revision.FieldID)).All(ctx)
if err != nil {
Expand All @@ -111,6 +115,32 @@ func (r *EntRevisions) ReadRevisions(ctx context.Context) (migrate.Revisions, er

// WriteRevision writes a revision to the revisions table.
func (r *EntRevisions) WriteRevision(ctx context.Context, rev *migrate.Revision) error {
if r.useCache() {
// Do not store the pointer since we want to maintain the order for writes to the database.
r.cache = append(r.cache, *rev)
return nil
}
return r.write(ctx, rev)
}

// Flush writes the changes saved in memory to the database.
//
// This method exists to support both execution of migration in a transaction and saving revision for SQLite flavors,
// since attempting to write to the database while in a transaction will fail there.
func (r *EntRevisions) Flush(ctx context.Context) error {
if !r.useCache() {
return nil
}
for i := range r.cache {
if err := r.write(ctx, &r.cache[i]); err != nil {
return err
}
}
return nil
}

// write attempts to write the given revision to the database.
func (r *EntRevisions) write(ctx context.Context, rev *migrate.Revision) error {
return r.ec.Revision.Create().
SetID(rev.Version).
SetDescription(rev.Description).
Expand All @@ -125,4 +155,10 @@ func (r *EntRevisions) WriteRevision(ctx context.Context, rev *migrate.Revision)
Exec(ctx)
}

func (r *EntRevisions) useCache() bool {
// For SQLite dialect and flavors we have to enable the revision write cache to postpone writing to
// the database until the transaction wrapping the migration execution has been committed.
return r.ac.Name == dialect.SQLite
}

var _ migrate.RevisionReadWriter = (*EntRevisions)(nil)
Loading

0 comments on commit 82a0084

Please sign in to comment.