From 228c00c0b031a52d9a9efe3921e2d0d3ba636200 Mon Sep 17 00:00:00 2001 From: majiayu000 <1835304752@qq.com> Date: Fri, 2 Jan 2026 00:55:29 +0800 Subject: [PATCH] feat: add -env flag to read database connection from environment variable Add support for reading the database connection string from an environment variable specified by the -env flag. This security enhancement allows users to avoid exposing connection strings directly on the command line. Usage: migrate -env DATABASE_URL -source file://path up Signed-off-by: majiayu000 <1835304752@qq.com> --- internal/cli/commands.go | 10 +++++++++ internal/cli/commands_test.go | 41 +++++++++++++++++++++++++++++++++++ internal/cli/main.go | 11 ++++++++++ 3 files changed, 62 insertions(+) diff --git a/internal/cli/commands.go b/internal/cli/commands.go index e37ca313e..425d7fe0c 100644 --- a/internal/cli/commands.go +++ b/internal/cli/commands.go @@ -20,6 +20,16 @@ var ( errInvalidTimeFormat = errors.New("time format may not be empty") ) +// databaseFromEnv reads the database connection string from the specified environment variable. +// Returns the value if set, or an error if the environment variable is not set or empty. +func databaseFromEnv(envName string) (string, error) { + val := os.Getenv(envName) + if val == "" { + return "", fmt.Errorf("environment variable %s is not set or empty", envName) + } + return val, nil +} + func nextSeqVersion(matches []string, seqDigits int) (string, error) { if seqDigits <= 0 { return "", errInvalidSequenceWidth diff --git a/internal/cli/commands_test.go b/internal/cli/commands_test.go index 798e2df77..a8edd0393 100644 --- a/internal/cli/commands_test.go +++ b/internal/cli/commands_test.go @@ -253,6 +253,47 @@ func (s *CreateCmdSuite) TestCreateCmd() { } } +func TestDatabaseFromEnv(t *testing.T) { + cases := []struct { + name string + envName string + envValue string + setEnv bool + expectedResult string + expectedErrStr string + }{ + {"valid env var", "TEST_DB_URL", "postgres://localhost:5432/test", true, "postgres://localhost:5432/test", ""}, + {"empty env var", "TEST_DB_URL_EMPTY", "", true, "", "environment variable TEST_DB_URL_EMPTY is not set or empty"}, + {"unset env var", "TEST_DB_URL_UNSET", "", false, "", "environment variable TEST_DB_URL_UNSET is not set or empty"}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if c.setEnv { + os.Setenv(c.envName, c.envValue) + defer os.Unsetenv(c.envName) + } + + result, err := databaseFromEnv(c.envName) + + if c.expectedErrStr != "" { + if err == nil { + t.Errorf("Expected error: %s but got nil", c.expectedErrStr) + } else if err.Error() != c.expectedErrStr { + t.Errorf("Expected error: %s but got: %s", c.expectedErrStr, err.Error()) + } + } else { + if err != nil { + t.Errorf("Unexpected error: %s", err.Error()) + } + if result != c.expectedResult { + t.Errorf("Expected result: %s but got: %s", c.expectedResult, result) + } + } + }) + } +} + func TestNumDownFromArgs(t *testing.T) { cases := []struct { name string diff --git a/internal/cli/main.go b/internal/cli/main.go index c7a3bd74a..052a2b135 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -68,6 +68,7 @@ func Main(version string) { pathPtr := flag.String("path", "", "") databasePtr := flag.String("database", "", "") sourcePtr := flag.String("source", "", "") + envPtr := flag.String("env", "", "") flag.Usage = func() { fmt.Fprintf(os.Stderr, @@ -78,6 +79,7 @@ Options: -source Location of the migrations (driver://url) -path Shorthand for -source=file://path -database Run migrations against this database (driver://url) + -env Read database connection string from environment variable -prefetch N Number of migrations to load in advance before executing (default 10) -lock-timeout N Allow N seconds to acquire database lock (default 15) -verbose Print verbose logging @@ -119,6 +121,15 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU *sourcePtr = fmt.Sprintf("file://%v", *pathPtr) } + // read database connection string from environment variable if -env is given + if *envPtr != "" { + dbFromEnv, err := databaseFromEnv(*envPtr) + if err != nil { + log.fatalErr(err) + } + *databasePtr = dbFromEnv + } + // initialize migrate // don't catch migraterErr here and let each command decide // how it wants to handle the error