Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ func (c *Config) RefreshPostgreSQLConfig() PostgreSQLConfig {
// Map of environment variable suffix to config field pointer
envVars := map[string]*string{
"HOST": &refreshed.Host,
"PORT": &refreshed.Port,
"USER": &refreshed.User,
"PASSWORD": &refreshed.Password,
"NAME": &refreshed.Name,
Expand Down
2 changes: 1 addition & 1 deletion internal/config/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type Config struct {

type PostgreSQLConfig struct {
Host string `yaml:"host" env:"HOST, overwrite"`
Port int `yaml:"port" env:"PORT, overwrite"`
Port string `yaml:"port" env:"PORT, overwrite"`
User string `yaml:"user" env:"USER, overwrite"`
Password string `yaml:"password" env:"PASSWORD, overwrite"`
Name string `yaml:"name" env:"NAME, overwrite"`
Expand Down
172 changes: 154 additions & 18 deletions internal/outpost/proxyv2/postgresstore/postgresstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,23 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"time"

"github.com/google/uuid"
"github.com/gorilla/sessions"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/stdlib"
"github.com/mitchellh/mapstructure"
log "github.com/sirupsen/logrus"
_ "gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/clause"

Expand Down Expand Up @@ -65,8 +67,8 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
if cfg.Name == "" {
return nil, fmt.Errorf("PostgreSQL database name is required")
}
if cfg.Port <= 0 {
return nil, fmt.Errorf("PostgreSQL port must be positive")
if cfg.Port == "" {
return nil, fmt.Errorf("PostgreSQL port is required")
}

// Start with a default config
Expand All @@ -75,9 +77,38 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
return nil, fmt.Errorf("failed to create default config: %w", err)
}

// Set connection parameters
connConfig.Host = cfg.Host
connConfig.Port = uint16(cfg.Port)
// Parse comma-separated hosts and create fallbacks
// cfg.Host can be a comma-separated list like "host1,host2,host3"
hosts := strings.Split(cfg.Host, ",")
for i, host := range hosts {
hosts[i] = strings.TrimSpace(host)
}

// Parse and validate comma-separated ports
portStrs := strings.Split(cfg.Port, ",")
ports := make([]uint16, len(portStrs))
for i, portStr := range portStrs {
portStr = strings.TrimSpace(portStr)
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("invalid port value %q: %w", portStr, err)
}
if port <= 0 {
return nil, fmt.Errorf("PostgreSQL port %d must be positive", port)
}
if port > 65535 {
return nil, fmt.Errorf("PostgreSQL port %d is out of valid range", port)
}
ports[i] = uint16(port)
}

// Get port for primary host
primaryHost := hosts[0]
primaryPort := ports[0]

// Set connection parameters for primary host
connConfig.Host = primaryHost
connConfig.Port = primaryPort
connConfig.User = cfg.User
connConfig.Password = cfg.Password
connConfig.Database = cfg.Name
Expand Down Expand Up @@ -123,13 +154,35 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
case "verify-full":
// Verify the certificate and hostname
tlsConfig.InsecureSkipVerify = false
tlsConfig.ServerName = cfg.Host
tlsConfig.ServerName = primaryHost
}

connConfig.TLSConfig = tlsConfig
}
}

// Create fallback configurations for additional hosts
if len(hosts) > 1 {
connConfig.Fallbacks = make([]*pgconn.FallbackConfig, 0, len(hosts)-1)
for i, host := range hosts[1:] {
port := getPortForIndex(ports, i+1)
fallback := &pgconn.FallbackConfig{
Host: host,
Port: port,
}
// Copy TLS config to fallback if present
if connConfig.TLSConfig != nil {
fallbackTLS := connConfig.TLSConfig.Clone()
// Update ServerName for verify-full mode
if cfg.SSLMode == "verify-full" {
fallbackTLS.ServerName = host
}
fallback.TLSConfig = fallbackTLS
}
connConfig.Fallbacks = append(connConfig.Fallbacks, fallback)
}
}

// Set runtime params
if connConfig.RuntimeParams == nil {
connConfig.RuntimeParams = make(map[string]string)
Expand All @@ -141,23 +194,106 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {

// Parse and apply connection options if specified
if cfg.ConnOptions != "" {
// Parse key=value pairs from ConnOptions
// Format: "key1=value1 key2=value2"
pairs := strings.Split(cfg.ConnOptions, " ")
for _, pair := range pairs {
if pair == "" {
continue
}
kv := strings.SplitN(pair, "=", 2)
if len(kv) == 2 {
connConfig.RuntimeParams[kv[0]] = kv[1]
}
connOpts, err := parseConnOptions(cfg.ConnOptions)
if err != nil {
return nil, fmt.Errorf("failed to parse connection options: %w", err)
}

if err := applyConnOptions(connConfig, connOpts); err != nil {
return nil, fmt.Errorf("failed to apply connection options: %w", err)
}
}

return connConfig, nil
}

// getPortForIndex returns the port for the given host index.
// If there are fewer ports than needed, returns the last port (libpq behavior).
func getPortForIndex(ports []uint16, i int) uint16 {
if i >= len(ports) {
return ports[len(ports)-1]
}
return ports[i]
}

// parseConnOptions decodes a base64-encoded JSON string into a map of connection options.
// This matches the Python behavior in authentik/lib/config.py:get_dict_from_b64_json
func parseConnOptions(encoded string) (map[string]string, error) {
// Base64 decode
decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return nil, fmt.Errorf("invalid base64 encoding: %w", err)
}

// Parse JSON
var opts map[string]interface{}
if err := json.Unmarshal(decoded, &opts); err != nil {
return nil, fmt.Errorf("invalid JSON: %w", err)
}

// Convert all values to strings
result := make(map[string]string)
for k, v := range opts {
switch val := v.(type) {
case string:
result[k] = val
case float64:
// JSON numbers are float64
if val == float64(int(val)) {
result[k] = strconv.Itoa(int(val))
} else {
result[k] = strconv.FormatFloat(val, 'f', -1, 64)
}
case bool:
result[k] = strconv.FormatBool(val)
default:
result[k] = fmt.Sprintf("%v", v)
}
}

return result, nil
}

// applyConnOptions applies parsed connection options to the pgx.ConnConfig.
func applyConnOptions(connConfig *pgx.ConnConfig, opts map[string]string) error {
for key, value := range opts {
// connect_timeout needs special handling as it's a connection-level timeout
if key == "connect_timeout" {
timeout, err := strconv.Atoi(value)
if err != nil {
return fmt.Errorf("invalid connect_timeout value: %w", err)
}
connConfig.ConnectTimeout = time.Duration(timeout) * time.Second
continue
}
// target_session_attrs needs special handling to set ValidateConnect function
if key == "target_session_attrs" {
switch value {
case "read-write":
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite
case "read-only":
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadOnly
case "primary":
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsPrimary
case "standby":
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsStandby
case "prefer-standby":
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsPreferStandby
case "any":
// "any" is the default (no validation needed)
connConfig.ValidateConnect = nil
default:
return fmt.Errorf("unknown target_session_attrs value: %s", value)
}
// Do not add target_session_attrs to RuntimeParams
continue
}
// All other options go to RuntimeParams
connConfig.RuntimeParams[key] = value
}
return nil
}

// BuildDSN constructs a PostgreSQL connection string from a ConnConfig.
func BuildDSN(cfg config.PostgreSQLConfig) (string, error) {
connConfig, err := BuildConnConfig(cfg)
Expand Down
Loading
Loading