diff --git a/README.md b/README.md index c6979c1..12f1004 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,14 @@ connections: params: autocommit: "true" tls: skip-verify + - alias: managed_mysql + driver: mysql + proto: tcp + user: root + passwdCmd: ["echo", "-n", "super_secure"] + host: 127.0.0.1 + port: 13306 + dbName: world - alias: mysql_via_ssh driver: mysql proto: tcp @@ -232,7 +240,9 @@ The first setting in `connections` is the default connection. ### connections -`dataSourceName` takes precedence over the value set in `proto`, `user`, `passwd`, `host`, `port`, `dbName`, `params`. +`dataSourceName` takes precedence over the value set in `proto`, `user`, `passwd`, `passwdCmd`, `host`, `port`, `dbName`, `params`. + +`passwdCmd` takes precedence over the value set in `passwd`. | Key | Description | | -------------- | ------------------------------------------- | @@ -242,6 +252,7 @@ The first setting in `connections` is the default connection. | proto | `tcp`, `udp`, `unix`. | | user | User name | | passwd | Password | +| passwdCmd | Command to be executed to get password (Array) | | host | Host | | port | Port | | path | unix socket path | diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f0cead3..ac11a30 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -39,6 +39,16 @@ func TestGetConfig(t *testing.T) { DBName: "world", Params: map[string]string{"autocommit": "true", "tls": "skip-verify"}, }, + { + Alias: "sqls_managed", + Driver: "mysql", + Proto: "tcp", + User: "root", + PasswdCmd: []string{"echo", "topsecret"}, + Host: "127.0.0.1", + Port: 13306, + DBName: "world", + }, { Alias: "sqls_sqlite3", Driver: "sqlite3", diff --git a/internal/config/testdata/basic.yml b/internal/config/testdata/basic.yml index bdfc6e2..95cd304 100644 --- a/internal/config/testdata/basic.yml +++ b/internal/config/testdata/basic.yml @@ -13,6 +13,16 @@ connections: params: autocommit: "true" tls: skip-verify + - alias: sqls_managed + driver: mysql + dataSourceName: "" + proto: tcp + user: root + passwdCmd: ["echo", "topsecret"] + host: 127.0.0.1 + port: 13306 + path: "" + dbName: world - alias: sqls_sqlite3 driver: sqlite3 dataSourceName: "file:/home/sqls-server/chinook.db" diff --git a/internal/database/clickhouse.go b/internal/database/clickhouse.go index 40e2fe9..1b29e46 100644 --- a/internal/database/clickhouse.go +++ b/internal/database/clickhouse.go @@ -101,10 +101,15 @@ func genClickhouseDsn(dbConfig *DBConfig) (string, error) { return "", fmt.Errorf("unsupported protocol %s", dbConfig.Proto) } - if dbConfig.Passwd == "" { + passwd, err := dbConfig.ResolvePassword() + if err != nil { + return "", err + } + + if passwd == "" { u.User = url.User(dbConfig.User) } else { - u.User = url.UserPassword(dbConfig.User, dbConfig.Passwd) + u.User = url.UserPassword(dbConfig.User, passwd) } u.Host = fmt.Sprintf("%s:%d", dbConfig.Host, dbConfig.Port) @@ -131,8 +136,13 @@ func genClickhouseConfig(dbConfig *DBConfig) (*clickhouse.Options, error) { cfg := &clickhouse.Options{} + passwd, err := dbConfig.ResolvePassword() + if err != nil { + return nil, err + } + cfg.Auth.Username = dbConfig.User - cfg.Auth.Password = dbConfig.Passwd + cfg.Auth.Password = passwd cfg.Auth.Database = dbConfig.DBName switch dbConfig.Proto { diff --git a/internal/database/config.go b/internal/database/config.go index a95306e..4837786 100644 --- a/internal/database/config.go +++ b/internal/database/config.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "os" + "os/exec" "github.com/sqls-server/sqls/dialect" "golang.org/x/crypto/ssh" @@ -25,6 +26,7 @@ type DBConfig struct { Proto Proto `json:"proto" yaml:"proto"` User string `json:"user" yaml:"user"` Passwd string `json:"passwd" yaml:"passwd"` + PasswdCmd []string `json:"passwdCmd" yaml:"passwdCmd"` Host string `json:"host" yaml:"host"` Port int `json:"port" yaml:"port"` Path string `json:"path" yaml:"path"` @@ -33,6 +35,20 @@ type DBConfig struct { SSHCfg *SSHConfig `json:"sshConfig" yaml:"sshConfig"` } +func (c *DBConfig) ResolvePassword() (string, error) { + if len(c.PasswdCmd) == 0 { + return c.Passwd, nil + } + + cmd := exec.Command(c.PasswdCmd[0], c.PasswdCmd[1:]...) // nolint:gosec // The whole feature is allowing the user to run a provided command. + data, err := cmd.Output() + if err != nil { + return "", err + } + + return string(data), nil +} + func (c *DBConfig) Validate() error { if c.Driver == "" { return errors.New("required: connections[].driver") @@ -101,8 +117,8 @@ func (c *DBConfig) Validate() error { if c.User == "" { return errors.New("required: connections[].user") } - if c.Passwd == "" { - return errors.New("required: connections[].Passwd") + if len(c.PasswdCmd) == 0 && c.Passwd == "" { + return errors.New("required: connections[].PasswdCmd or connections[].Passwd") } if c.Host == "" { return errors.New("required: connections[].Host") @@ -129,7 +145,7 @@ func (c *DBConfig) Validate() error { return errors.New("required: connections[].host") } case ProtoUDP, ProtoUnix: - default: + default: return errors.New("invalid: connections[].proto") } if c.SSHCfg != nil { diff --git a/internal/database/config_test.go b/internal/database/config_test.go new file mode 100644 index 0000000..e3924af --- /dev/null +++ b/internal/database/config_test.go @@ -0,0 +1,64 @@ +package database + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestResolvePassword(t *testing.T) { + type testCase struct { + title string + dbConfig DBConfig + want string + wantErr bool + } + + testCases := []testCase{ + { + title: "only password", + dbConfig: DBConfig{ + Passwd: "test", + }, + want: "test", + }, + { + title: "only command", + dbConfig: DBConfig{ + PasswdCmd: []string{"echo", "-n", "secure"}, + }, + want: "secure", + }, + { + title: "password and command", + dbConfig: DBConfig{ + Passwd: "test", + PasswdCmd: []string{"echo", "-n", "secure"}, + }, + want: "secure", + }, + { + title: "failing command", + dbConfig: DBConfig{ + Passwd: "test", + PasswdCmd: []string{"false"}, + }, + wantErr: true, + }, + } + + for _, tt := range testCases { + t.Run(tt.title, func(t *testing.T) { + got, err := tt.dbConfig.ResolvePassword() + if err != nil { + if !tt.wantErr { + t.Errorf("ResolvePassword() error = %v, wantErr %v", err, tt.wantErr) + return + } + } + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("unmatch (- want, + got):\n%s", diff) + } + }) + } +} diff --git a/internal/database/mssql.go b/internal/database/mssql.go index dd795ab..7939a7b 100644 --- a/internal/database/mssql.go +++ b/internal/database/mssql.go @@ -1,18 +1,18 @@ package database import ( - "os" "context" "database/sql" "fmt" "log" "net/url" + "os" "strconv" _ "github.com/denisenkom/go-mssqldb" - "github.com/sqls-server/sqls/dialect" "github.com/jfcote87/sshdb" "github.com/jfcote87/sshdb/mssql" + "github.com/sqls-server/sqls/dialect" "golang.org/x/crypto/ssh" ) @@ -23,7 +23,7 @@ func init() { func mssqlOpen(dbConnCfg *DBConfig) (*DBConnection, error) { var ( - conn *sql.DB + conn *sql.DB ) dsn, err := genMssqlConfig(dbConnCfg) if err != nil { @@ -41,9 +41,9 @@ func mssqlOpen(dbConnCfg *DBConfig) (*DBConnection, error) { return nil, fmt.Errorf("unable to decrypt private key") } - cfg := &ssh.ClientConfig { + cfg := &ssh.ClientConfig{ User: dbConnCfg.SSHCfg.User, - Auth: []ssh.AuthMethod { + Auth: []ssh.AuthMethod{ ssh.PublicKeys(signer), }, HostKeyCallback: ssh.InsecureIgnoreHostKey(), @@ -77,7 +77,7 @@ func mssqlOpen(dbConnCfg *DBConfig) (*DBConnection, error) { conn.SetMaxOpenConns(DefaultMaxOpenConns) return &DBConnection{ - Conn: conn, + Conn: conn, }, nil } @@ -376,9 +376,14 @@ func genMssqlConfig(connCfg *DBConfig) (string, error) { return connCfg.DataSourceName, nil } + passwd, err := connCfg.ResolvePassword() + if err != nil { + return "", err + } + q := url.Values{} q.Set("user", connCfg.User) - q.Set("password", connCfg.Passwd) + q.Set("password", passwd) q.Set("database", connCfg.DBName) switch connCfg.Proto { diff --git a/internal/database/mysql.go b/internal/database/mysql.go index c8fb57b..8651ec7 100644 --- a/internal/database/mysql.go +++ b/internal/database/mysql.go @@ -94,8 +94,13 @@ func genMysqlConfig(connCfg *DBConfig) (*mysql.Config, error) { return mysql.ParseDSN(connCfg.DataSourceName) } + passwd, err := connCfg.ResolvePassword() + if err != nil { + return nil, err + } + cfg.User = connCfg.User - cfg.Passwd = connCfg.Passwd + cfg.Passwd = passwd cfg.DBName = connCfg.DBName switch connCfg.Proto { @@ -116,7 +121,7 @@ func genMysqlConfig(connCfg *DBConfig) (*mysql.Config, error) { } cfg.Addr = connCfg.Path cfg.Net = string(connCfg.Proto) - case ProtoHTTP: + case ProtoHTTP: default: return nil, fmt.Errorf("default addr for network %s unknown", connCfg.Proto) } diff --git a/internal/database/oracle.go b/internal/database/oracle.go index 122537f..ba90f9a 100644 --- a/internal/database/oracle.go +++ b/internal/database/oracle.go @@ -50,7 +50,12 @@ func genOracleConfig(connCfg *DBConfig) (string, error) { if port == 0 { port = 1521 } - DSName := connCfg.User + "/" + connCfg.Passwd + "@" + host + ":" + strconv.Itoa(port) + "/" + connCfg.DBName + passwd, err := connCfg.ResolvePassword() + if err != nil { + return "", err + } + + DSName := connCfg.User + "/" + passwd + "@" + host + ":" + strconv.Itoa(port) + "/" + connCfg.DBName return DSName, nil } @@ -106,7 +111,7 @@ func (db *OracleDBRepository) SchemaTables(ctx context.Context) (map[string][]st ctx, ` SELECT OWNER, TABLE_NAME - FROM SYS.ALL_TABLES + FROM SYS.ALL_TABLES ORDER BY OWNER, TABLE_NAME `) if err != nil { diff --git a/internal/database/postgresql.go b/internal/database/postgresql.go index 2d08c2f..2503f45 100644 --- a/internal/database/postgresql.go +++ b/internal/database/postgresql.go @@ -396,9 +396,14 @@ func genPostgresConfig(connCfg *DBConfig) (string, error) { return connCfg.DataSourceName, nil } + passwd, err := connCfg.ResolvePassword() + if err != nil { + return "", err + } + q := url.Values{} q.Set("user", connCfg.User) - q.Set("password", connCfg.Passwd) + q.Set("password", passwd) q.Set("dbname", connCfg.DBName) switch connCfg.Proto { diff --git a/internal/database/vertica.go b/internal/database/vertica.go index 10bde70..71e1d4f 100644 --- a/internal/database/vertica.go +++ b/internal/database/vertica.go @@ -51,7 +51,12 @@ func genVerticaConfig(connCfg *DBConfig) (string, error) { port = 5433 } - DSName := connCfg.User + "/" + connCfg.Passwd + "@" + host + ":" + strconv.Itoa(port) + "/" + connCfg.DBName + passwd, err := connCfg.ResolvePassword() + if err != nil { + return "", err + } + + DSName := connCfg.User + "/" + passwd + "@" + host + ":" + strconv.Itoa(port) + "/" + connCfg.DBName return DSName, nil }