Skip to content

Commit

Permalink
Add passwdCmd config option
Browse files Browse the repository at this point in the history
It iso now possible to define a command that is executed in order to
optain the password. This is useful in situations, where your database
is manged by a cloud provider and you are using the cloud providers IAM
solution to gain the password.

Fixes #154
  • Loading branch information
patrickpichler committed May 6, 2024
1 parent eb695ac commit c9c5e05
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 20 deletions.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ connections:
params:
autocommit: "true"
tls: skip-verify
- alias: managed_mysql
driver: mysql
proto: tcp
user: root
passwdCmd: ["echo", "-n", "super_secure"]
host: 127.0.0.1
port: 13306
dbName: world
- alias: mysql_via_ssh
driver: mysql
proto: tcp
Expand Down Expand Up @@ -232,7 +240,9 @@ The first setting in `connections` is the default connection.

### connections

`dataSourceName` takes precedence over the value set in `proto`, `user`, `passwd`, `host`, `port`, `dbName`, `params`.
`dataSourceName` takes precedence over the value set in `proto`, `user`, `passwd`, `passwdCmd`, `host`, `port`, `dbName`, `params`.

`passwdCmd` takes precedence over the value set in `passwd`.

| Key | Description |
| -------------- | ------------------------------------------- |
Expand All @@ -242,6 +252,7 @@ The first setting in `connections` is the default connection.
| proto | `tcp`, `udp`, `unix`. |
| user | User name |
| passwd | Password |
| passwdCmd | Command to be executed to get password (Array) |
| host | Host |
| port | Port |
| path | unix socket path |
Expand Down
11 changes: 11 additions & 0 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ func TestGetConfig(t *testing.T) {
DBName: "world",
Params: map[string]string{"autocommit": "true", "tls": "skip-verify"},
},
{
Alias: "sqls_mysql",
Driver: "mysql",
Proto: "tcp",
User: "root",
PasswdCmd: []string{"echo", "topsecret"},
Host: "127.0.0.1",
Port: 13306,
DBName: "world",
Params: map[string]string{"autocommit": "true", "tls": "skip-verify"},
},
{
Alias: "sqls_sqlite3",
Driver: "sqlite3",
Expand Down
16 changes: 13 additions & 3 deletions internal/database/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,15 @@ func genClickhouseDsn(dbConfig *DBConfig) (string, error) {
return "", fmt.Errorf("unsupported protocol %s", dbConfig.Proto)
}

if dbConfig.Passwd == "" {
passwd, err := dbConfig.ResolvePassword()
if err != nil {
return "", err
}

if passwd == "" {
u.User = url.User(dbConfig.User)
} else {
u.User = url.UserPassword(dbConfig.User, dbConfig.Passwd)
u.User = url.UserPassword(dbConfig.User, passwd)
}

u.Host = fmt.Sprintf("%s:%d", dbConfig.Host, dbConfig.Port)
Expand All @@ -131,8 +136,13 @@ func genClickhouseConfig(dbConfig *DBConfig) (*clickhouse.Options, error) {

cfg := &clickhouse.Options{}

passwd, err := dbConfig.ResolvePassword()
if err != nil {
return nil, err
}

cfg.Auth.Username = dbConfig.User
cfg.Auth.Password = dbConfig.Passwd
cfg.Auth.Password = passwd
cfg.Auth.Database = dbConfig.DBName

switch dbConfig.Proto {
Expand Down
22 changes: 19 additions & 3 deletions internal/database/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"os"
"os/exec"

"github.com/sqls-server/sqls/dialect"
"golang.org/x/crypto/ssh"
Expand All @@ -25,6 +26,7 @@ type DBConfig struct {
Proto Proto `json:"proto" yaml:"proto"`
User string `json:"user" yaml:"user"`
Passwd string `json:"passwd" yaml:"passwd"`
PasswdCmd []string `json:"passwdCmd" yaml:"passwdCmd"`
Host string `json:"host" yaml:"host"`
Port int `json:"port" yaml:"port"`
Path string `json:"path" yaml:"path"`
Expand All @@ -33,6 +35,20 @@ type DBConfig struct {
SSHCfg *SSHConfig `json:"sshConfig" yaml:"sshConfig"`
}

func (c *DBConfig) ResolvePassword() (string, error) {
if len(c.PasswdCmd) == 0 {
return c.Passwd, nil
}

cmd := exec.Command(c.PasswdCmd[0], c.PasswdCmd[1:]...)

Check failure on line 43 in internal/database/config.go

View workflow job for this annotation

GitHub Actions / test

G204: Subprocess launched with a potential tainted input or cmd arguments (gosec)
data, err := cmd.Output()
if err != nil {
return "", err
}

return string(data), nil
}

func (c *DBConfig) Validate() error {
if c.Driver == "" {
return errors.New("required: connections[].driver")
Expand Down Expand Up @@ -101,8 +117,8 @@ func (c *DBConfig) Validate() error {
if c.User == "" {
return errors.New("required: connections[].user")
}
if c.Passwd == "" {
return errors.New("required: connections[].Passwd")
if len(c.PasswdCmd) == 0 && c.Passwd == "" {
return errors.New("required: connections[].PasswdCmd or connections[].Passwd")
}
if c.Host == "" {
return errors.New("required: connections[].Host")
Expand All @@ -129,7 +145,7 @@ func (c *DBConfig) Validate() error {
return errors.New("required: connections[].host")
}
case ProtoUDP, ProtoUnix:
default:
default:
return errors.New("invalid: connections[].proto")
}
if c.SSHCfg != nil {
Expand Down
64 changes: 64 additions & 0 deletions internal/database/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package database

import (
"testing"

"github.com/google/go-cmp/cmp"
)

func TestResolvePassword(t *testing.T) {
type testCase struct {
title string
dbConfig DBConfig
want string
wantErr bool
}

testCases := []testCase{
{
title: "only password",
dbConfig: DBConfig{
Passwd: "test",
},
want: "test",
},
{
title: "only command",
dbConfig: DBConfig{
PasswdCmd: []string{"echo", "-n", "secure"},
},
want: "secure",
},
{
title: "password and command",
dbConfig: DBConfig{
Passwd: "test",
PasswdCmd: []string{"echo", "-n", "secure"},
},
want: "secure",
},
{
title: "failing command",
dbConfig: DBConfig{
Passwd: "test",
PasswdCmd: []string{"false"},
},
wantErr: true,
},
}

for _, tt := range testCases {
t.Run(tt.title, func(t *testing.T) {
got, err := tt.dbConfig.ResolvePassword()
if err != nil {
if !tt.wantErr {
t.Errorf("ResolvePassword() error = %v, wantErr %v", err, tt.wantErr)
return
}
}
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("unmatch (- want, + got):\n%s", diff)
}
})
}
}
19 changes: 12 additions & 7 deletions internal/database/mssql.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
package database

import (
"os"
"context"
"database/sql"
"fmt"
"log"
"net/url"
"os"
"strconv"

_ "github.com/denisenkom/go-mssqldb"
"github.com/sqls-server/sqls/dialect"
"github.com/jfcote87/sshdb"
"github.com/jfcote87/sshdb/mssql"
"github.com/sqls-server/sqls/dialect"
"golang.org/x/crypto/ssh"
)

Expand All @@ -23,7 +23,7 @@ func init() {

func mssqlOpen(dbConnCfg *DBConfig) (*DBConnection, error) {
var (
conn *sql.DB
conn *sql.DB
)
dsn, err := genMssqlConfig(dbConnCfg)
if err != nil {
Expand All @@ -41,9 +41,9 @@ func mssqlOpen(dbConnCfg *DBConfig) (*DBConnection, error) {
return nil, fmt.Errorf("unable to decrypt private key")
}

cfg := &ssh.ClientConfig {
cfg := &ssh.ClientConfig{
User: dbConnCfg.SSHCfg.User,
Auth: []ssh.AuthMethod {
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Expand Down Expand Up @@ -77,7 +77,7 @@ func mssqlOpen(dbConnCfg *DBConfig) (*DBConnection, error) {
conn.SetMaxOpenConns(DefaultMaxOpenConns)

return &DBConnection{
Conn: conn,
Conn: conn,
}, nil
}

Expand Down Expand Up @@ -376,9 +376,14 @@ func genMssqlConfig(connCfg *DBConfig) (string, error) {
return connCfg.DataSourceName, nil
}

passwd, err := connCfg.ResolvePassword()
if err != nil {
return "", err
}

q := url.Values{}
q.Set("user", connCfg.User)
q.Set("password", connCfg.Passwd)
q.Set("password", passwd)
q.Set("database", connCfg.DBName)

switch connCfg.Proto {
Expand Down
9 changes: 7 additions & 2 deletions internal/database/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,13 @@ func genMysqlConfig(connCfg *DBConfig) (*mysql.Config, error) {
return mysql.ParseDSN(connCfg.DataSourceName)
}

passwd, err := connCfg.ResolvePassword()
if err != nil {
return nil, err
}

cfg.User = connCfg.User
cfg.Passwd = connCfg.Passwd
cfg.Passwd = passwd
cfg.DBName = connCfg.DBName

switch connCfg.Proto {
Expand All @@ -116,7 +121,7 @@ func genMysqlConfig(connCfg *DBConfig) (*mysql.Config, error) {
}
cfg.Addr = connCfg.Path
cfg.Net = string(connCfg.Proto)
case ProtoHTTP:
case ProtoHTTP:
default:
return nil, fmt.Errorf("default addr for network %s unknown", connCfg.Proto)
}
Expand Down
9 changes: 7 additions & 2 deletions internal/database/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ func genOracleConfig(connCfg *DBConfig) (string, error) {
if port == 0 {
port = 1521
}
DSName := connCfg.User + "/" + connCfg.Passwd + "@" + host + ":" + strconv.Itoa(port) + "/" + connCfg.DBName
passwd, err := connCfg.ResolvePassword()
if err != nil {
return "", err
}

DSName := connCfg.User + "/" + passwd + "@" + host + ":" + strconv.Itoa(port) + "/" + connCfg.DBName
return DSName, nil
}

Expand Down Expand Up @@ -106,7 +111,7 @@ func (db *OracleDBRepository) SchemaTables(ctx context.Context) (map[string][]st
ctx,
`
SELECT OWNER, TABLE_NAME
FROM SYS.ALL_TABLES
FROM SYS.ALL_TABLES
ORDER BY OWNER, TABLE_NAME
`)
if err != nil {
Expand Down
7 changes: 6 additions & 1 deletion internal/database/postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,14 @@ func genPostgresConfig(connCfg *DBConfig) (string, error) {
return connCfg.DataSourceName, nil
}

passwd, err := connCfg.ResolvePassword()
if err != nil {
return "", err
}

q := url.Values{}
q.Set("user", connCfg.User)
q.Set("password", connCfg.Passwd)
q.Set("password", passwd)
q.Set("dbname", connCfg.DBName)

switch connCfg.Proto {
Expand Down
7 changes: 6 additions & 1 deletion internal/database/vertica.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ func genVerticaConfig(connCfg *DBConfig) (string, error) {
port = 5433
}

DSName := connCfg.User + "/" + connCfg.Passwd + "@" + host + ":" + strconv.Itoa(port) + "/" + connCfg.DBName
passwd, err := connCfg.ResolvePassword()
if err != nil {
return "", err
}

DSName := connCfg.User + "/" + passwd + "@" + host + ":" + strconv.Itoa(port) + "/" + connCfg.DBName
return DSName, nil
}

Expand Down

0 comments on commit c9c5e05

Please sign in to comment.