diff --git a/internal/config/config.go b/internal/config/config.go index a76c1f957e12..56b095043c03 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -204,6 +204,7 @@ func (c *Config) RefreshPostgreSQLConfig() PostgreSQLConfig { // Map of environment variable suffix to config field pointer envVars := map[string]*string{ "HOST": &refreshed.Host, + "PORT": &refreshed.Port, "USER": &refreshed.User, "PASSWORD": &refreshed.Password, "NAME": &refreshed.Name, diff --git a/internal/config/struct.go b/internal/config/struct.go index d338492729d4..a984bf0fa34d 100644 --- a/internal/config/struct.go +++ b/internal/config/struct.go @@ -27,7 +27,7 @@ type Config struct { type PostgreSQLConfig struct { Host string `yaml:"host" env:"HOST, overwrite"` - Port int `yaml:"port" env:"PORT, overwrite"` + Port string `yaml:"port" env:"PORT, overwrite"` User string `yaml:"user" env:"USER, overwrite"` Password string `yaml:"password" env:"PASSWORD, overwrite"` Name string `yaml:"name" env:"NAME, overwrite"` diff --git a/internal/outpost/proxyv2/postgresstore/postgresstore.go b/internal/outpost/proxyv2/postgresstore/postgresstore.go index 893404fe214b..bdd1280b96e2 100644 --- a/internal/outpost/proxyv2/postgresstore/postgresstore.go +++ b/internal/outpost/proxyv2/postgresstore/postgresstore.go @@ -4,21 +4,23 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/base64" "encoding/json" "errors" "fmt" "net/http" "os" + "strconv" "strings" "time" "github.com/google/uuid" "github.com/gorilla/sessions" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/stdlib" "github.com/mitchellh/mapstructure" log "github.com/sirupsen/logrus" - _ "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -65,8 +67,8 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) { if cfg.Name == "" { return nil, fmt.Errorf("PostgreSQL database name is required") } - if cfg.Port <= 0 { - return nil, fmt.Errorf("PostgreSQL port must be positive") + if cfg.Port == "" { + return nil, fmt.Errorf("PostgreSQL port is required") } // Start with a default config @@ -75,9 +77,38 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) { return nil, fmt.Errorf("failed to create default config: %w", err) } - // Set connection parameters - connConfig.Host = cfg.Host - connConfig.Port = uint16(cfg.Port) + // Parse comma-separated hosts and create fallbacks + // cfg.Host can be a comma-separated list like "host1,host2,host3" + hosts := strings.Split(cfg.Host, ",") + for i, host := range hosts { + hosts[i] = strings.TrimSpace(host) + } + + // Parse and validate comma-separated ports + portStrs := strings.Split(cfg.Port, ",") + ports := make([]uint16, len(portStrs)) + for i, portStr := range portStrs { + portStr = strings.TrimSpace(portStr) + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, fmt.Errorf("invalid port value %q: %w", portStr, err) + } + if port <= 0 { + return nil, fmt.Errorf("PostgreSQL port %d must be positive", port) + } + if port > 65535 { + return nil, fmt.Errorf("PostgreSQL port %d is out of valid range", port) + } + ports[i] = uint16(port) + } + + // Get port for primary host + primaryHost := hosts[0] + primaryPort := ports[0] + + // Set connection parameters for primary host + connConfig.Host = primaryHost + connConfig.Port = primaryPort connConfig.User = cfg.User connConfig.Password = cfg.Password connConfig.Database = cfg.Name @@ -123,13 +154,35 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) { case "verify-full": // Verify the certificate and hostname tlsConfig.InsecureSkipVerify = false - tlsConfig.ServerName = cfg.Host + tlsConfig.ServerName = primaryHost } connConfig.TLSConfig = tlsConfig } } + // Create fallback configurations for additional hosts + if len(hosts) > 1 { + connConfig.Fallbacks = make([]*pgconn.FallbackConfig, 0, len(hosts)-1) + for i, host := range hosts[1:] { + port := getPortForIndex(ports, i+1) + fallback := &pgconn.FallbackConfig{ + Host: host, + Port: port, + } + // Copy TLS config to fallback if present + if connConfig.TLSConfig != nil { + fallbackTLS := connConfig.TLSConfig.Clone() + // Update ServerName for verify-full mode + if cfg.SSLMode == "verify-full" { + fallbackTLS.ServerName = host + } + fallback.TLSConfig = fallbackTLS + } + connConfig.Fallbacks = append(connConfig.Fallbacks, fallback) + } + } + // Set runtime params if connConfig.RuntimeParams == nil { connConfig.RuntimeParams = make(map[string]string) @@ -141,23 +194,106 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) { // Parse and apply connection options if specified if cfg.ConnOptions != "" { - // Parse key=value pairs from ConnOptions - // Format: "key1=value1 key2=value2" - pairs := strings.Split(cfg.ConnOptions, " ") - for _, pair := range pairs { - if pair == "" { - continue - } - kv := strings.SplitN(pair, "=", 2) - if len(kv) == 2 { - connConfig.RuntimeParams[kv[0]] = kv[1] - } + connOpts, err := parseConnOptions(cfg.ConnOptions) + if err != nil { + return nil, fmt.Errorf("failed to parse connection options: %w", err) + } + + if err := applyConnOptions(connConfig, connOpts); err != nil { + return nil, fmt.Errorf("failed to apply connection options: %w", err) } } return connConfig, nil } +// getPortForIndex returns the port for the given host index. +// If there are fewer ports than needed, returns the last port (libpq behavior). +func getPortForIndex(ports []uint16, i int) uint16 { + if i >= len(ports) { + return ports[len(ports)-1] + } + return ports[i] +} + +// parseConnOptions decodes a base64-encoded JSON string into a map of connection options. +// This matches the Python behavior in authentik/lib/config.py:get_dict_from_b64_json +func parseConnOptions(encoded string) (map[string]string, error) { + // Base64 decode + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return nil, fmt.Errorf("invalid base64 encoding: %w", err) + } + + // Parse JSON + var opts map[string]interface{} + if err := json.Unmarshal(decoded, &opts); err != nil { + return nil, fmt.Errorf("invalid JSON: %w", err) + } + + // Convert all values to strings + result := make(map[string]string) + for k, v := range opts { + switch val := v.(type) { + case string: + result[k] = val + case float64: + // JSON numbers are float64 + if val == float64(int(val)) { + result[k] = strconv.Itoa(int(val)) + } else { + result[k] = strconv.FormatFloat(val, 'f', -1, 64) + } + case bool: + result[k] = strconv.FormatBool(val) + default: + result[k] = fmt.Sprintf("%v", v) + } + } + + return result, nil +} + +// applyConnOptions applies parsed connection options to the pgx.ConnConfig. +func applyConnOptions(connConfig *pgx.ConnConfig, opts map[string]string) error { + for key, value := range opts { + // connect_timeout needs special handling as it's a connection-level timeout + if key == "connect_timeout" { + timeout, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("invalid connect_timeout value: %w", err) + } + connConfig.ConnectTimeout = time.Duration(timeout) * time.Second + continue + } + // target_session_attrs needs special handling to set ValidateConnect function + if key == "target_session_attrs" { + switch value { + case "read-write": + connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite + case "read-only": + connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadOnly + case "primary": + connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsPrimary + case "standby": + connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsStandby + case "prefer-standby": + connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsPreferStandby + case "any": + // "any" is the default (no validation needed) + connConfig.ValidateConnect = nil + default: + return fmt.Errorf("unknown target_session_attrs value: %s", value) + } + // Do not add target_session_attrs to RuntimeParams + continue + } + // All other options go to RuntimeParams + connConfig.RuntimeParams[key] = value + } + return nil +} + // BuildDSN constructs a PostgreSQL connection string from a ConnConfig. func BuildDSN(cfg config.PostgreSQLConfig) (string, error) { connConfig, err := BuildConnConfig(cfg) diff --git a/internal/outpost/proxyv2/postgresstore/postgresstore_test.go b/internal/outpost/proxyv2/postgresstore/postgresstore_test.go index 092803c7b2a3..5f212b7d7174 100644 --- a/internal/outpost/proxyv2/postgresstore/postgresstore_test.go +++ b/internal/outpost/proxyv2/postgresstore/postgresstore_test.go @@ -6,6 +6,7 @@ import ( "crypto/rsa" "crypto/x509" "crypto/x509/pkix" + "encoding/base64" "encoding/json" "encoding/pem" "fmt" @@ -13,12 +14,15 @@ import ( "net/http/httptest" "os" "path/filepath" + "reflect" + "runtime" "testing" "time" "github.com/google/uuid" "github.com/gorilla/sessions" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" @@ -33,7 +37,7 @@ import ( func SetupTestDB(t *testing.T) (*gorm.DB, *RefreshableConnPool) { cfg := config.Get().PostgreSQL - t.Logf("PostgreSQL config: Host=%s Port=%d User=%s DBName=%s SSLMode=%s", + t.Logf("PostgreSQL config: Host=%s Port=%s User=%s DBName=%s SSLMode=%s", cfg.Host, cfg.Port, cfg.User, cfg.Name, cfg.SSLMode) t.Logf("Password length: %d", len(cfg.Password)) if cfg.Password == "" { @@ -485,7 +489,7 @@ func TestBuildDSN_Validation(t *testing.T) { { name: "missing host", cfg: config.PostgreSQLConfig{ - Port: 5432, + Port: "5432", User: "testuser", Name: "testdb", }, @@ -496,7 +500,7 @@ func TestBuildDSN_Validation(t *testing.T) { name: "missing user", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", Name: "testdb", }, expectError: true, @@ -506,7 +510,7 @@ func TestBuildDSN_Validation(t *testing.T) { name: "missing database name", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", }, expectError: true, @@ -516,23 +520,23 @@ func TestBuildDSN_Validation(t *testing.T) { name: "invalid port (zero)", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 0, + Port: "0", User: "testuser", Name: "testdb", }, expectError: true, - errorMsg: "PostgreSQL port must be positive", + errorMsg: "PostgreSQL port 0 must be positive", }, { name: "invalid port (negative)", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: -1, + Port: "-1", User: "testuser", Name: "testdb", }, expectError: true, - errorMsg: "PostgreSQL port must be positive", + errorMsg: "PostgreSQL port -1 must be positive", }, } @@ -560,7 +564,7 @@ func TestBuildConnConfig(t *testing.T) { name: "basic configuration", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Name: "testdb", }, @@ -576,7 +580,7 @@ func TestBuildConnConfig(t *testing.T) { name: "with simple password", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Password: "testpass", Name: "testdb", @@ -589,7 +593,7 @@ func TestBuildConnConfig(t *testing.T) { name: "with password containing spaces", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Password: "my secure password", Name: "testdb", @@ -602,7 +606,7 @@ func TestBuildConnConfig(t *testing.T) { name: "with password containing single quotes", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Password: "pass'word", Name: "testdb", @@ -615,7 +619,7 @@ func TestBuildConnConfig(t *testing.T) { name: "with password containing backslashes", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Password: `pass\word`, Name: "testdb", @@ -628,7 +632,7 @@ func TestBuildConnConfig(t *testing.T) { name: "with password containing special characters", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Password: `p@ss w0rd!#$%^&*()`, Name: "testdb", @@ -641,7 +645,7 @@ func TestBuildConnConfig(t *testing.T) { name: "with password containing quotes and backslashes", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Password: `my'pass\word"here`, Name: "testdb", @@ -654,7 +658,7 @@ func TestBuildConnConfig(t *testing.T) { name: "with passphrase (multiple spaces)", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Password: "the quick brown fox jumps over", Name: "testdb", @@ -667,7 +671,7 @@ func TestBuildConnConfig(t *testing.T) { name: "with sslmode=disable", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Name: "testdb", SSLMode: "disable", @@ -680,7 +684,7 @@ func TestBuildConnConfig(t *testing.T) { name: "with sslmode=require (no certs)", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Name: "testdb", SSLMode: "require", @@ -694,7 +698,7 @@ func TestBuildConnConfig(t *testing.T) { name: "with custom schema", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Name: "testdb", DefaultSchema: "custom_schema", @@ -707,27 +711,48 @@ func TestBuildConnConfig(t *testing.T) { name: "with connection options", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Name: "testdb", - ConnOptions: "connect_timeout=10 application_name=authentik", + ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10","application_name":"authentik"}`)), }, validate: func(t *testing.T, cc *pgx.ConnConfig) { - assert.Equal(t, "10", cc.RuntimeParams["connect_timeout"]) + assert.Equal(t, 10*time.Second, cc.ConnectTimeout) assert.Equal(t, "authentik", cc.RuntimeParams["application_name"]) }, }, + { + name: "with target_session_attrs", + cfg: config.PostgreSQLConfig{ + Host: "localhost", + Port: "5432", + User: "testuser", + Name: "testdb", + ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)), + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + // target_session_attrs should NOT be in RuntimeParams + _, hasTargetSessionAttrs := cc.RuntimeParams["target_session_attrs"] + assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams") + // It should set ValidateConnect instead + assert.NotNil(t, cc.ValidateConnect, "ValidateConnect should be set for target_session_attrs") + // Verify it's the correct validator function + expectedValidator := pgconn.ValidateConnectTargetSessionAttrsReadWrite + assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(expectedValidator).Pointer()).Name(), + runtime.FuncForPC(reflect.ValueOf(cc.ValidateConnect).Pointer()).Name()) + }, + }, { name: "full configuration with special password", cfg: config.PostgreSQLConfig{ Host: "db.example.com", - Port: 5433, + Port: "5433", User: "admin", Password: "my super secret password!@#", Name: "production", SSLMode: "require", DefaultSchema: "app_schema", - ConnOptions: "application_name=authentik", + ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik"}`)), }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, "db.example.com", cc.Host) @@ -765,7 +790,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) { name: "verify-full with all certificates", cfg: config.PostgreSQLConfig{ Host: "db.example.com", - Port: 5432, + Port: "5432", User: "testuser", Password: "my secure password", Name: "testdb", @@ -786,7 +811,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) { name: "verify-ca with root cert only", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Name: "testdb", SSLMode: "verify-ca", @@ -803,7 +828,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) { name: "require with client cert", cfg: config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Name: "testdb", SSLMode: "require", @@ -820,7 +845,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) { name: "full configuration with SSL and special password", cfg: config.PostgreSQLConfig{ Host: "db.example.com", - Port: 5433, + Port: "5433", User: "admin", Password: "my super secret password!@#", Name: "production", @@ -829,7 +854,7 @@ func TestBuildConnConfig_WithSSLCertificates(t *testing.T) { SSLCert: clientCertPath, SSLKey: clientKeyPath, DefaultSchema: "app_schema", - ConnOptions: "application_name=authentik", + ConnOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik"}`)), }, validate: func(t *testing.T, cc *pgx.ConnConfig) { assert.Equal(t, "db.example.com", cc.Host) @@ -881,7 +906,7 @@ func TestBuildDSN_WithSpecialPasswords(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cfg := config.PostgreSQLConfig{ Host: "localhost", - Port: 5432, + Port: "5432", User: "testuser", Password: tt.password, Name: "testdb", @@ -941,6 +966,221 @@ func TestPostgresStore_ConnectionPoolSettings(t *testing.T) { } } +// TestParseConnOptions tests the base64 JSON parsing of connection options +func TestParseConnOptions(t *testing.T) { + tests := []struct { + name string + input string + expected map[string]string + expectError bool + errorMsg string + }{ + { + name: "simple key-value", + input: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)), + expected: map[string]string{"target_session_attrs": "read-write"}, + }, + { + name: "multiple options", + input: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10","application_name":"authentik"}`)), + expected: map[string]string{"connect_timeout": "10", "application_name": "authentik"}, + }, + { + name: "numeric value as number", + input: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":10}`)), + expected: map[string]string{"connect_timeout": "10"}, + }, + { + name: "boolean value", + input: base64.StdEncoding.EncodeToString([]byte(`{"default_transaction_read_only":true}`)), + expected: map[string]string{"default_transaction_read_only": "true"}, + }, + { + name: "empty object", + input: base64.StdEncoding.EncodeToString([]byte(`{}`)), + expected: map[string]string{}, + }, + { + name: "invalid base64", + input: "not-valid-base64!!!", + expectError: true, + errorMsg: "invalid base64 encoding", + }, + { + name: "invalid JSON", + input: base64.StdEncoding.EncodeToString([]byte(`not json`)), + expectError: true, + errorMsg: "invalid JSON", + }, + { + name: "JSON array instead of object", + input: base64.StdEncoding.EncodeToString([]byte(`["value1", "value2"]`)), + expectError: true, + errorMsg: "invalid JSON", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseConnOptions(tt.input) + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +// TestApplyConnOptions tests that connection options are applied correctly to pgx.ConnConfig +func TestApplyConnOptions(t *testing.T) { + tests := []struct { + name string + opts map[string]string + validate func(*testing.T, *pgx.ConnConfig) + expectError bool + errorMsg string + }{ + { + name: "connect_timeout sets ConnectTimeout", + opts: map[string]string{"connect_timeout": "30"}, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, 30*time.Second, cc.ConnectTimeout) + }, + }, + { + name: "target_session_attrs sets ValidateConnect", + opts: map[string]string{"target_session_attrs": "read-write"}, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + // target_session_attrs should NOT be in RuntimeParams + _, hasTargetSessionAttrs := cc.RuntimeParams["target_session_attrs"] + assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not be in RuntimeParams") + // It should set ValidateConnect instead + assert.NotNil(t, cc.ValidateConnect, "ValidateConnect should be set") + expectedValidator := pgconn.ValidateConnectTargetSessionAttrsReadWrite + assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(expectedValidator).Pointer()).Name(), + runtime.FuncForPC(reflect.ValueOf(cc.ValidateConnect).Pointer()).Name()) + }, + }, + { + name: "application_name goes to RuntimeParams", + opts: map[string]string{"application_name": "my-app"}, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "my-app", cc.RuntimeParams["application_name"]) + }, + }, + { + name: "statement_timeout goes to RuntimeParams", + opts: map[string]string{"statement_timeout": "5000"}, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "5000", cc.RuntimeParams["statement_timeout"]) + }, + }, + { + name: "unknown options go to RuntimeParams", + opts: map[string]string{"custom_param": "custom_value"}, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, "custom_value", cc.RuntimeParams["custom_param"]) + }, + }, + { + name: "multiple options", + opts: map[string]string{ + "connect_timeout": "10", + "target_session_attrs": "read-write", + "application_name": "authentik", + }, + validate: func(t *testing.T, cc *pgx.ConnConfig) { + assert.Equal(t, 10*time.Second, cc.ConnectTimeout) + // target_session_attrs should NOT be in RuntimeParams + _, hasTargetSessionAttrs := cc.RuntimeParams["target_session_attrs"] + assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not be in RuntimeParams") + // It should set ValidateConnect instead + assert.NotNil(t, cc.ValidateConnect, "ValidateConnect should be set") + assert.Equal(t, "authentik", cc.RuntimeParams["application_name"]) + }, + }, + { + name: "invalid connect_timeout", + opts: map[string]string{"connect_timeout": "not-a-number"}, + expectError: true, + errorMsg: "invalid connect_timeout value", + }, + { + name: "invalid target_session_attrs", + opts: map[string]string{"target_session_attrs": "invalid-mode"}, + expectError: true, + errorMsg: "unknown target_session_attrs value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a base config + connConfig, err := pgx.ParseConfig("") + require.NoError(t, err) + connConfig.RuntimeParams = make(map[string]string) + + err = applyConnOptions(connConfig, tt.opts) + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + require.NoError(t, err) + tt.validate(t, connConfig) + } + }) + } +} + +// TestBuildConnConfig_Base64JSONConnOptions tests the full integration of base64 JSON connection options +func TestBuildConnConfig_Base64JSONConnOptions(t *testing.T) { + t.Run("bug report scenario - target_session_attrs", func(t *testing.T) { + cfg := config.PostgreSQLConfig{ + Host: "localhost", + Port: "5432", + User: "authentik", + Name: "authentik", + ConnOptions: "eyJ0YXJnZXRfc2Vzc2lvbl9hdHRycyI6InJlYWQtd3JpdGUifQ==", + } + + connConfig, err := BuildConnConfig(cfg) + require.NoError(t, err) + // target_session_attrs should NOT be in RuntimeParams + _, hasTargetSessionAttrs := connConfig.RuntimeParams["target_session_attrs"] + assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams") + // It should set ValidateConnect instead + assert.NotNil(t, connConfig.ValidateConnect, "ValidateConnect should be set") + expectedValidator := pgconn.ValidateConnectTargetSessionAttrsReadWrite + assert.Equal(t, runtime.FuncForPC(reflect.ValueOf(expectedValidator).Pointer()).Name(), + runtime.FuncForPC(reflect.ValueOf(connConfig.ValidateConnect).Pointer()).Name()) + }) + + t.Run("complex connection options", func(t *testing.T) { + // {"connect_timeout":10,"target_session_attrs":"read-write","application_name":"authentik-proxy"} + connOpts := base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":10,"target_session_attrs":"read-write","application_name":"authentik-proxy"}`)) + cfg := config.PostgreSQLConfig{ + Host: "localhost", + Port: "5432", + User: "authentik", + Name: "authentik", + ConnOptions: connOpts, + } + + connConfig, err := BuildConnConfig(cfg) + require.NoError(t, err) + assert.Equal(t, 10*time.Second, connConfig.ConnectTimeout) + // target_session_attrs should NOT be in RuntimeParams + _, hasTargetSessionAttrs := connConfig.RuntimeParams["target_session_attrs"] + assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams") + // It should set ValidateConnect instead + assert.NotNil(t, connConfig.ValidateConnect, "ValidateConnect should be set") + assert.Equal(t, "authentik-proxy", connConfig.RuntimeParams["application_name"]) + }) +} + // Helper function to create session data JSON func createSessionData(t *testing.T, claims map[string]interface{}) string { sessionData := map[string]interface{}{ @@ -1036,3 +1276,495 @@ func generateTestCerts(t *testing.T) (rootCertPath, clientCertPath, clientKeyPat return rootCertPath, clientCertPath, clientKeyPath, cleanup } + +// TestBuildConnConfig_WithBase64EncodedConnOptions demonstrates that ConnOptions +// should be base64-encoded JSON but is currently being parsed as key=value pairs +func TestBuildConnConfig_WithBase64EncodedConnOptions(t *testing.T) { + tests := []struct { + name string + connOptions string + expected map[string]string + expectError bool + }{ + { + name: "base64 encoded JSON with single parameter", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10"}`)), + expected: map[string]string{ + // connect_timeout is handled specially and NOT added to RuntimeParams + }, + }, + { + name: "base64 encoded JSON with multiple parameters", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"connect_timeout":"10","application_name":"authentik","statement_timeout":"30000"}`)), + expected: map[string]string{ + // connect_timeout is handled specially and NOT added to RuntimeParams + "application_name": "authentik", + "statement_timeout": "30000", + }, + }, + { + name: "base64 encoded JSON with special characters in values", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik proxy v2"}`)), + expected: map[string]string{ + "application_name": "authentik proxy v2", + }, + }, + { + name: "base64 encoded JSON with target_session_attrs", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write","application_name":"authentik"}`)), + expected: map[string]string{ + "application_name": "authentik", + // target_session_attrs should NOT appear in RuntimeParams + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := config.PostgreSQLConfig{ + Host: "localhost", + Port: "5432", + User: "testuser", + Name: "testdb", + ConnOptions: tt.connOptions, + } + + result, err := BuildConnConfig(cfg) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + require.NotNil(t, result) + + // Verify that all expected parameters are present in RuntimeParams + for key, expectedValue := range tt.expected { + actualValue, exists := result.RuntimeParams[key] + assert.True(t, exists, "Expected runtime parameter %s to exist", key) + assert.Equal(t, expectedValue, actualValue, "Runtime parameter %s should have value %s", key, expectedValue) + } + + // Verify that connect_timeout is handled specially (sets ConnectTimeout field, not RuntimeParams) + if tt.name == "base64 encoded JSON with single parameter" || tt.name == "base64 encoded JSON with multiple parameters" { + _, hasConnectTimeout := result.RuntimeParams["connect_timeout"] + assert.False(t, hasConnectTimeout, "connect_timeout should not appear in RuntimeParams") + assert.Equal(t, 10*time.Second, result.ConnectTimeout, "connect_timeout should be set as ConnectTimeout duration") + } + + // Verify that target_session_attrs is NOT in RuntimeParams + // (it affects connection behavior, not a runtime param) + _, hasTargetSessionAttrs := result.RuntimeParams["target_session_attrs"] + assert.False(t, hasTargetSessionAttrs, "target_session_attrs should not appear in RuntimeParams") + }) + } +} + +// TestBuildConnConfig_TargetSessionAttrs demonstrates how target_session_attrs +// should be properly handled using pgx's ValidateConnect callback +func TestBuildConnConfig_TargetSessionAttrs(t *testing.T) { + tests := []struct { + name string + connOptions string + targetSessionAttrs string + expectedValidator pgconn.ValidateConnectFunc + validatorDescription string + }{ + { + name: "target_session_attrs=read-write", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)), + targetSessionAttrs: "read-write", + expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite, + validatorDescription: "should validate connection is read-write by checking transaction_read_only=off", + }, + { + name: "target_session_attrs=read-only", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-only"}`)), + targetSessionAttrs: "read-only", + expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadOnly, + validatorDescription: "should validate connection is read-only by checking transaction_read_only=on", + }, + { + name: "target_session_attrs=primary", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"primary"}`)), + targetSessionAttrs: "primary", + expectedValidator: pgconn.ValidateConnectTargetSessionAttrsPrimary, + validatorDescription: "should validate connection is to primary by checking in_hot_standby=off", + }, + { + name: "target_session_attrs=standby", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"standby"}`)), + targetSessionAttrs: "standby", + expectedValidator: pgconn.ValidateConnectTargetSessionAttrsStandby, + validatorDescription: "should validate connection is to standby by checking in_hot_standby=on", + }, + { + name: "target_session_attrs=prefer-standby", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"prefer-standby"}`)), + targetSessionAttrs: "prefer-standby", + expectedValidator: pgconn.ValidateConnectTargetSessionAttrsPreferStandby, + validatorDescription: "should prefer standby connections (affects fallback logic)", + }, + { + name: "target_session_attrs=any (default)", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"any"}`)), + targetSessionAttrs: "any", + expectedValidator: nil, + validatorDescription: "should not set validator as any connection is acceptable", + }, + { + name: "no target_session_attrs", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"application_name":"authentik"}`)), + targetSessionAttrs: "", + expectedValidator: nil, + validatorDescription: "should not set validator when target_session_attrs is not specified", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := config.PostgreSQLConfig{ + Host: "localhost", + Port: "5432", + User: "testuser", + Name: "testdb", + ConnOptions: tt.connOptions, + } + + result, err := BuildConnConfig(cfg) + require.NoError(t, err) + require.NotNil(t, result) + + // Verify target_session_attrs is NOT in RuntimeParams + _, hasTargetSessionAttrs := result.RuntimeParams["target_session_attrs"] + assert.False(t, hasTargetSessionAttrs, + "target_session_attrs should not appear in RuntimeParams") + + // Verify ValidateConnect callback is set to the correct standard pgx function + if tt.expectedValidator != nil { + require.NotNil(t, result.ValidateConnect, + "ValidateConnect should be set for target_session_attrs=%s: %s", + tt.targetSessionAttrs, tt.validatorDescription) + + // Compare function pointers using reflect to check if it's the same function + actualFuncPtr := runtime.FuncForPC(reflect.ValueOf(result.ValidateConnect).Pointer()) + expectedFuncPtr := runtime.FuncForPC(reflect.ValueOf(tt.expectedValidator).Pointer()) + + assert.Equal(t, expectedFuncPtr.Name(), actualFuncPtr.Name(), + "ValidateConnect should be set to %s for target_session_attrs=%s", + expectedFuncPtr.Name(), tt.targetSessionAttrs) + + t.Logf("Expected validator: %s", expectedFuncPtr.Name()) + t.Logf("Actual validator: %s", actualFuncPtr.Name()) + } else { + assert.Nil(t, result.ValidateConnect, + "ValidateConnect should not be set: %s", tt.validatorDescription) + } + }) + } +} + +// TestBuildConnConfig_TargetSessionAttrs_WithMultipleHosts tests that when multiple +// hosts are specified, fallbacks are properly configured along with the validator +func TestBuildConnConfig_TargetSessionAttrs_WithMultipleHosts(t *testing.T) { + tests := []struct { + name string + host string + port string + sslMode string + connOptions string + targetSessionAttrs string + expectedValidator pgconn.ValidateConnectFunc + expectedPrimaryHost string + expectedPrimaryPort uint16 + expectedFallbacks []*pgconn.FallbackConfig + expectTLS bool + validatorDescription string + }{ + { + name: "multiple hosts with read-write", + host: "db1.local,db2.local,db3.local", + port: "5432", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)), + targetSessionAttrs: "read-write", + expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite, + expectedPrimaryHost: "db1.local", + expectedPrimaryPort: 5432, + expectedFallbacks: []*pgconn.FallbackConfig{ + {Host: "db2.local", Port: 5432, TLSConfig: nil}, + {Host: "db3.local", Port: 5432, TLSConfig: nil}, + }, + expectTLS: false, + validatorDescription: "should set validator and create fallbacks for additional hosts", + }, + { + name: "multiple hosts with ports specified", + host: "db1.local,db2.local,db3.local", + port: "5432,5433,5434", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write"}`)), + targetSessionAttrs: "read-write", + expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite, + expectedPrimaryHost: "db1.local", + expectedPrimaryPort: 5432, + expectedFallbacks: []*pgconn.FallbackConfig{ + {Host: "db2.local", Port: 5433, TLSConfig: nil}, + {Host: "db3.local", Port: 5434, TLSConfig: nil}, + }, + expectTLS: false, + validatorDescription: "should handle hosts with explicit ports", + }, + { + name: "multiple hosts with TLS required", + host: "db1.local,db2.local,db3.local", + port: "5432", + sslMode: "require", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write", "sslmode":"require"}`)), + targetSessionAttrs: "read-write", + expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite, + expectedPrimaryHost: "db1.local", + expectedPrimaryPort: 5432, + expectedFallbacks: []*pgconn.FallbackConfig{ + {Host: "db2.local", Port: 5432}, // TLSConfig should be set (non-nil) + {Host: "db3.local", Port: 5432}, // TLSConfig should be set (non-nil) + }, + expectTLS: true, + validatorDescription: "should set TLS config for all hosts when sslmode=require", + }, + { + name: "multiple hosts with TLS verify-full", + host: "db1.local,db2.local,db3.local", + port: "5432", + sslMode: "require", + connOptions: base64.StdEncoding.EncodeToString([]byte(`{"target_session_attrs":"read-write", "sslmode":"verify-full"}`)), + targetSessionAttrs: "read-write", + expectedValidator: pgconn.ValidateConnectTargetSessionAttrsReadWrite, + expectedPrimaryHost: "db1.local", + expectedPrimaryPort: 5432, + expectedFallbacks: []*pgconn.FallbackConfig{ + {Host: "db2.local", Port: 5432}, // TLSConfig should be set (non-nil) + {Host: "db3.local", Port: 5432}, // TLSConfig should be set (non-nil) + }, + expectTLS: true, + validatorDescription: "should set TLS config host name for all hosts when sslmode=verify-full", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := config.PostgreSQLConfig{ + Host: tt.host, + Port: tt.port, + User: "testuser", + Name: "testdb", + SSLMode: tt.sslMode, + ConnOptions: tt.connOptions, + } + + result, err := BuildConnConfig(cfg) + require.NoError(t, err) + require.NotNil(t, result) + + // Verify target_session_attrs is NOT in RuntimeParams + _, hasTargetSessionAttrs := result.RuntimeParams["target_session_attrs"] + assert.False(t, hasTargetSessionAttrs, + "target_session_attrs should not appear in RuntimeParams") + + // Verify ValidateConnect is set to the correct function + require.NotNil(t, result.ValidateConnect, + "ValidateConnect should be set for target_session_attrs=%s with multiple hosts", + tt.targetSessionAttrs) + + actualFuncPtr := runtime.FuncForPC(reflect.ValueOf(result.ValidateConnect).Pointer()) + expectedFuncPtr := runtime.FuncForPC(reflect.ValueOf(tt.expectedValidator).Pointer()) + + assert.Equal(t, expectedFuncPtr.Name(), actualFuncPtr.Name(), + "ValidateConnect should be %s for target_session_attrs=%s", + expectedFuncPtr.Name(), tt.targetSessionAttrs) + + // Verify the primary host and port + assert.Equal(t, tt.expectedPrimaryHost, result.Host, + "Primary host should be %s", tt.expectedPrimaryHost) + assert.Equal(t, tt.expectedPrimaryPort, result.Port, + "Primary port should be %d", tt.expectedPrimaryPort) + + // Verify primary TLSConfig based on sslmode + if tt.expectTLS { + assert.NotNil(t, result.TLSConfig, + "Primary connection should have TLSConfig set when sslmode=%s", tt.sslMode) + } else { + assert.Nil(t, result.TLSConfig, + "Primary connection should not have TLSConfig when sslmode is not set") + } + + // Verify Fallbacks are configured for the additional hosts + require.Len(t, result.Fallbacks, len(tt.expectedFallbacks), + "Should have %d fallback configs for the additional hosts", len(tt.expectedFallbacks)) + + // Verify each fallback configuration + for i, expectedFb := range tt.expectedFallbacks { + actualFb := result.Fallbacks[i] + + assert.Equal(t, expectedFb.Host, actualFb.Host, + "Fallback %d host should be %s", i+1, expectedFb.Host) + assert.Equal(t, expectedFb.Port, actualFb.Port, + "Fallback %d port should be %d", i+1, expectedFb.Port) + + // Verify TLSConfig is set appropriately for fallbacks + if tt.expectTLS { + assert.NotNil(t, actualFb.TLSConfig, + "Fallback %d should have TLSConfig set when sslmode=%s", i+1, tt.sslMode) + // Verify InsecureSkipVerify for sslmode=require + switch tt.sslMode { + case "require": + assert.True(t, actualFb.TLSConfig.InsecureSkipVerify, + "Fallback %d TLSConfig should have InsecureSkipVerify=true for sslmode=require", i+1) + case "verify-full": + assert.False(t, actualFb.TLSConfig.InsecureSkipVerify, + "Fallback %d TLSConfig should have InsecureSkipVerify=false for sslmode=verify-full", i+1) + assert.Equal(t, actualFb.Host, actualFb.TLSConfig.ServerName, + "Fallback %d TLSConfig ServerName should match host for sslmode=verify-full", i+1) + } + } else { + assert.Nil(t, actualFb.TLSConfig, + "Fallback %d should not have TLSConfig when sslmode is not set", i+1) + } + } + + // Log the configuration for debugging + t.Logf("Primary host: %s:%d", result.Host, result.Port) + t.Logf("Validator: %s", actualFuncPtr.Name()) + for i, fb := range result.Fallbacks { + t.Logf("Fallback %d: %s:%d", i+1, fb.Host, fb.Port) + } + }) + } +} + +// TestBuildConnConfig_MultipleHosts_WithoutTargetSessionAttrs tests that multiple hosts +// create fallbacks even without target_session_attrs +func TestBuildConnConfig_MultipleHosts_WithoutTargetSessionAttrs(t *testing.T) { + cfg := config.PostgreSQLConfig{ + Host: "db1.local,db2.local,db3.local", + Port: "5432", + User: "testuser", + Name: "testdb", + } + + result, err := BuildConnConfig(cfg) + require.NoError(t, err) + require.NotNil(t, result) + + // Verify primary host + assert.Equal(t, "db1.local", result.Host) + assert.Equal(t, uint16(5432), result.Port) + + // Verify fallbacks are created + require.Len(t, result.Fallbacks, 2, "Should have 2 fallback configs") + assert.Equal(t, "db2.local", result.Fallbacks[0].Host) + assert.Equal(t, uint16(5432), result.Fallbacks[0].Port) + assert.Equal(t, "db3.local", result.Fallbacks[1].Host) + assert.Equal(t, uint16(5432), result.Fallbacks[1].Port) + + // Verify no ValidateConnect is set (no target_session_attrs) + assert.Nil(t, result.ValidateConnect) +} + +// TestBuildConnConfig_CommaSeparatedPorts_EdgeCases tests edge cases and error scenarios for comma-separated ports +func TestBuildConnConfig_CommaSeparatedPorts_EdgeCases(t *testing.T) { + tests := []struct { + name string + host string + port string + expectError bool + errorContains string + expectedHost string + expectedPort uint16 + expectedFallbacks []*pgconn.FallbackConfig + }{ + { + name: "invalid port in comma-separated list", + host: "db1.local,db2.local", + port: "5432,abc", + expectError: true, + errorContains: "invalid port value", + }, + { + name: "port out of range (too high)", + host: "db1.local,db2.local", + port: "5432,99999", + expectError: true, + errorContains: "PostgreSQL port 99999 is out of valid range", + }, + { + name: "port out of range (zero)", + host: "db1.local,db2.local", + port: "5432,0", + expectError: true, + errorContains: "PostgreSQL port 0 must be positive", + }, + { + name: "empty port string", + host: "db1.local", + port: "", + expectError: true, + errorContains: "PostgreSQL port is required", + }, + { + name: "port with only whitespace", + host: "db1.local", + port: " ", + expectError: true, + errorContains: "invalid port value", + }, + { + name: "mismatched number of hosts and ports", + host: "db1.local,db2.local", + port: "5432", + expectError: false, + expectedHost: "db1.local", + expectedPort: 5432, + expectedFallbacks: []*pgconn.FallbackConfig{ + {Host: "db2.local", Port: 5432}, + }, + }, + { + name: "extra ports than hosts", + host: "db1.local", + port: "5432,5433", + expectError: false, + expectedHost: "db1.local", + expectedPort: 5432, + expectedFallbacks: []*pgconn.FallbackConfig{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := config.PostgreSQLConfig{ + Host: tt.host, + Port: tt.port, + User: "testuser", + Name: "testdb", + } + + c, err := BuildConnConfig(cfg) + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + } else { + require.NoError(t, err) + require.NotNil(t, c) + + assert.Equal(t, tt.expectedHost, c.Host) + assert.Equal(t, tt.expectedPort, c.Port) + require.Len(t, c.Fallbacks, len(tt.expectedFallbacks)) + for i, expectedFb := range tt.expectedFallbacks { + actualFb := c.Fallbacks[i] + assert.Equal(t, expectedFb.Host, actualFb.Host) + assert.Equal(t, expectedFb.Port, actualFb.Port) + } + } + }) + } +}