Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 145 additions & 15 deletions internal/outpost/proxyv2/postgresstore/postgresstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,23 @@
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"time"

"github.com/google/uuid"
"github.com/gorilla/sessions"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/stdlib"
"github.com/mitchellh/mapstructure"
log "github.com/sirupsen/logrus"
_ "gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/clause"

Expand Down Expand Up @@ -75,9 +77,23 @@
return nil, fmt.Errorf("failed to create default config: %w", err)
}

// Set connection parameters
connConfig.Host = cfg.Host
connConfig.Port = uint16(cfg.Port)
// Parse comma-separated hosts and create fallbacks
// cfg.Host can be a comma-separated list like "host1:5433,host2,host3:5434"
hosts := strings.Split(cfg.Host, ",")
for i, host := range hosts {
hosts[i] = strings.TrimSpace(host)
}

if len(hosts) == 0 {
return nil, fmt.Errorf("no hosts specified")
}

// Parse first host (primary)
primaryHost, primaryPort := parseHostPort(hosts[0], cfg.Port)

// Set connection parameters for primary host
connConfig.Host = primaryHost
connConfig.Port = primaryPort
connConfig.User = cfg.User
connConfig.Password = cfg.Password
connConfig.Database = cfg.Name
Expand Down Expand Up @@ -123,13 +139,36 @@
case "verify-full":
// Verify the certificate and hostname
tlsConfig.InsecureSkipVerify = false
tlsConfig.ServerName = cfg.Host
tlsConfig.ServerName = primaryHost
}

connConfig.TLSConfig = tlsConfig
}
}

// Create fallback configurations for additional hosts
if len(hosts) > 1 {
connConfig.Fallbacks = make([]*pgconn.FallbackConfig, 0, len(hosts)-1)
for _, host := range hosts[1:] {
fallbackHost, fallbackPort := parseHostPort(host, cfg.Port)
fallback := &pgconn.FallbackConfig{
Host: fallbackHost,
Port: fallbackPort,
}
// Copy TLS config to fallback if present
if connConfig.TLSConfig != nil {
// Create a copy of the TLS config for the fallback
fallbackTLS := connConfig.TLSConfig.Clone()
// Update ServerName for verify-full mode
if cfg.SSLMode == "verify-full" {
fallbackTLS.ServerName = fallbackHost
}
fallback.TLSConfig = fallbackTLS
}
connConfig.Fallbacks = append(connConfig.Fallbacks, fallback)
}
}

// Set runtime params
if connConfig.RuntimeParams == nil {
connConfig.RuntimeParams = make(map[string]string)
Expand All @@ -141,21 +180,112 @@

// Parse and apply connection options if specified
if cfg.ConnOptions != "" {
// Parse key=value pairs from ConnOptions
// Format: "key1=value1 key2=value2"
pairs := strings.Split(cfg.ConnOptions, " ")
for _, pair := range pairs {
if pair == "" {
continue
connOpts, err := parseConnOptions(cfg.ConnOptions)
if err != nil {
return nil, fmt.Errorf("failed to parse connection options: %w", err)
}

// Apply each connection option to the appropriate config field
if err := applyConnOptions(connConfig, connOpts); err != nil {
return nil, fmt.Errorf("failed to apply connection options: %w", err)
}
}

return connConfig, nil
}

// parseHostPort parses a host string that may contain a port ("host:port")
// If no port is specified, returns the default port
func parseHostPort(hostStr string, defaultPort int) (string, uint16) {
if strings.Contains(hostStr, ":") {
// Host has explicit port
parts := strings.Split(hostStr, ":")
if len(parts) == 2 {
if port, err := strconv.Atoi(parts[1]); err == nil && port > 0 {
return parts[0], uint16(port)
}
kv := strings.SplitN(pair, "=", 2)
if len(kv) == 2 {
connConfig.RuntimeParams[kv[0]] = kv[1]
}
}
// Use default port
return hostStr, uint16(defaultPort)
}

// parseConnOptions decodes a base64-encoded JSON string into a map of connection options.
// This matches the Python behavior in authentik/lib/config.py:get_dict_from_b64_json
func parseConnOptions(encoded string) (map[string]string, error) {
// Base64 decode
decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return nil, fmt.Errorf("invalid base64 encoding: %w", err)
}

// Parse JSON
var opts map[string]interface{}
if err := json.Unmarshal(decoded, &opts); err != nil {
return nil, fmt.Errorf("invalid JSON: %w", err)
}

// Convert all values to strings
result := make(map[string]string)
for k, v := range opts {
switch val := v.(type) {
case string:
result[k] = val
case float64:
// JSON numbers are float64
if val == float64(int(val)) {
result[k] = strconv.Itoa(int(val))
} else {
result[k] = strconv.FormatFloat(val, 'f', -1, 64)
}
case bool:
result[k] = strconv.FormatBool(val)
default:
result[k] = fmt.Sprintf("%v", v)
}
}

return connConfig, nil
return result, nil
}

// applyConnOptions applies parsed connection options to the pgx.ConnConfig.
func applyConnOptions(connConfig *pgx.ConnConfig, opts map[string]string) error {
for key, value := range opts {
// connect_timeout needs special handling as it's a connection-level timeout
if key == "connect_timeout" {
timeout, err := strconv.Atoi(value)
if err != nil {
return fmt.Errorf("invalid connect_timeout value: %w", err)
}
connConfig.ConnectTimeout = time.Duration(timeout) * time.Second
continue
}
// target_session_attrs needs special handling to set ValidateConnect function
if key == "target_session_attrs" {
switch value {
case "read-write":
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite
case "read-only":
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadOnly
case "primary":
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsPrimary
case "standby":
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsStandby
case "prefer-standby":
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsPreferStandby
case "any":
// "any" is the default (no validation needed)
connConfig.ValidateConnect = nil
default:
return fmt.Errorf("unknown target_session_attrs value: %s", value)
}
// Do not add target_session_attrs to RuntimeParams
continue
}
// All other options go to RuntimeParams
connConfig.RuntimeParams[key] = value
}
return nil
}

// BuildDSN constructs a PostgreSQL connection string from a ConnConfig.
Expand Down
Loading
Loading