Skip to content

Commit 12e3de2

Browse files
dominic-rD-Tasker207
authored andcommitted
internal/outpost: improve PostgreSQL connection options parsing (#19118)
* internal: Outpost's conn options should be base64 json * correctly parse target_session_attrs + tests * fix port handling to use env provided port * add multiple port handling abilities to mirror the python config parser --------- Co-authored-by: Duncan Tasker <tasatree@gmail.com>
1 parent ad818a2 commit 12e3de2

File tree

4 files changed

+918
-49
lines changed

4 files changed

+918
-49
lines changed

internal/config/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ func (c *Config) RefreshPostgreSQLConfig() PostgreSQLConfig {
204204
// Map of environment variable suffix to config field pointer
205205
envVars := map[string]*string{
206206
"HOST": &refreshed.Host,
207+
"PORT": &refreshed.Port,
207208
"USER": &refreshed.User,
208209
"PASSWORD": &refreshed.Password,
209210
"NAME": &refreshed.Name,

internal/config/struct.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ type Config struct {
2727

2828
type PostgreSQLConfig struct {
2929
Host string `yaml:"host" env:"HOST, overwrite"`
30-
Port int `yaml:"port" env:"PORT, overwrite"`
30+
Port string `yaml:"port" env:"PORT, overwrite"`
3131
User string `yaml:"user" env:"USER, overwrite"`
3232
Password string `yaml:"password" env:"PASSWORD, overwrite"`
3333
Name string `yaml:"name" env:"NAME, overwrite"`

internal/outpost/proxyv2/postgresstore/postgresstore.go

Lines changed: 154 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,23 @@ import (
44
"context"
55
"crypto/tls"
66
"crypto/x509"
7+
"encoding/base64"
78
"encoding/json"
89
"errors"
910
"fmt"
1011
"net/http"
1112
"os"
13+
"strconv"
1214
"strings"
1315
"time"
1416

1517
"github.com/google/uuid"
1618
"github.com/gorilla/sessions"
1719
"github.com/jackc/pgx/v5"
20+
"github.com/jackc/pgx/v5/pgconn"
1821
"github.com/jackc/pgx/v5/stdlib"
1922
"github.com/mitchellh/mapstructure"
2023
log "github.com/sirupsen/logrus"
21-
_ "gorm.io/driver/postgres"
2224
"gorm.io/gorm"
2325
"gorm.io/gorm/clause"
2426

@@ -65,8 +67,8 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
6567
if cfg.Name == "" {
6668
return nil, fmt.Errorf("PostgreSQL database name is required")
6769
}
68-
if cfg.Port <= 0 {
69-
return nil, fmt.Errorf("PostgreSQL port must be positive")
70+
if cfg.Port == "" {
71+
return nil, fmt.Errorf("PostgreSQL port is required")
7072
}
7173

7274
// Start with a default config
@@ -75,9 +77,38 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
7577
return nil, fmt.Errorf("failed to create default config: %w", err)
7678
}
7779

78-
// Set connection parameters
79-
connConfig.Host = cfg.Host
80-
connConfig.Port = uint16(cfg.Port)
80+
// Parse comma-separated hosts and create fallbacks
81+
// cfg.Host can be a comma-separated list like "host1,host2,host3"
82+
hosts := strings.Split(cfg.Host, ",")
83+
for i, host := range hosts {
84+
hosts[i] = strings.TrimSpace(host)
85+
}
86+
87+
// Parse and validate comma-separated ports
88+
portStrs := strings.Split(cfg.Port, ",")
89+
ports := make([]uint16, len(portStrs))
90+
for i, portStr := range portStrs {
91+
portStr = strings.TrimSpace(portStr)
92+
port, err := strconv.Atoi(portStr)
93+
if err != nil {
94+
return nil, fmt.Errorf("invalid port value %q: %w", portStr, err)
95+
}
96+
if port <= 0 {
97+
return nil, fmt.Errorf("PostgreSQL port %d must be positive", port)
98+
}
99+
if port > 65535 {
100+
return nil, fmt.Errorf("PostgreSQL port %d is out of valid range", port)
101+
}
102+
ports[i] = uint16(port)
103+
}
104+
105+
// Get port for primary host
106+
primaryHost := hosts[0]
107+
primaryPort := ports[0]
108+
109+
// Set connection parameters for primary host
110+
connConfig.Host = primaryHost
111+
connConfig.Port = primaryPort
81112
connConfig.User = cfg.User
82113
connConfig.Password = cfg.Password
83114
connConfig.Database = cfg.Name
@@ -123,13 +154,35 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
123154
case "verify-full":
124155
// Verify the certificate and hostname
125156
tlsConfig.InsecureSkipVerify = false
126-
tlsConfig.ServerName = cfg.Host
157+
tlsConfig.ServerName = primaryHost
127158
}
128159

129160
connConfig.TLSConfig = tlsConfig
130161
}
131162
}
132163

164+
// Create fallback configurations for additional hosts
165+
if len(hosts) > 1 {
166+
connConfig.Fallbacks = make([]*pgconn.FallbackConfig, 0, len(hosts)-1)
167+
for i, host := range hosts[1:] {
168+
port := getPortForIndex(ports, i+1)
169+
fallback := &pgconn.FallbackConfig{
170+
Host: host,
171+
Port: port,
172+
}
173+
// Copy TLS config to fallback if present
174+
if connConfig.TLSConfig != nil {
175+
fallbackTLS := connConfig.TLSConfig.Clone()
176+
// Update ServerName for verify-full mode
177+
if cfg.SSLMode == "verify-full" {
178+
fallbackTLS.ServerName = host
179+
}
180+
fallback.TLSConfig = fallbackTLS
181+
}
182+
connConfig.Fallbacks = append(connConfig.Fallbacks, fallback)
183+
}
184+
}
185+
133186
// Set runtime params
134187
if connConfig.RuntimeParams == nil {
135188
connConfig.RuntimeParams = make(map[string]string)
@@ -141,23 +194,106 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
141194

142195
// Parse and apply connection options if specified
143196
if cfg.ConnOptions != "" {
144-
// Parse key=value pairs from ConnOptions
145-
// Format: "key1=value1 key2=value2"
146-
pairs := strings.Split(cfg.ConnOptions, " ")
147-
for _, pair := range pairs {
148-
if pair == "" {
149-
continue
150-
}
151-
kv := strings.SplitN(pair, "=", 2)
152-
if len(kv) == 2 {
153-
connConfig.RuntimeParams[kv[0]] = kv[1]
154-
}
197+
connOpts, err := parseConnOptions(cfg.ConnOptions)
198+
if err != nil {
199+
return nil, fmt.Errorf("failed to parse connection options: %w", err)
200+
}
201+
202+
if err := applyConnOptions(connConfig, connOpts); err != nil {
203+
return nil, fmt.Errorf("failed to apply connection options: %w", err)
155204
}
156205
}
157206

158207
return connConfig, nil
159208
}
160209

210+
// getPortForIndex returns the port for the given host index.
211+
// If there are fewer ports than needed, returns the last port (libpq behavior).
212+
func getPortForIndex(ports []uint16, i int) uint16 {
213+
if i >= len(ports) {
214+
return ports[len(ports)-1]
215+
}
216+
return ports[i]
217+
}
218+
219+
// parseConnOptions decodes a base64-encoded JSON string into a map of connection options.
220+
// This matches the Python behavior in authentik/lib/config.py:get_dict_from_b64_json
221+
func parseConnOptions(encoded string) (map[string]string, error) {
222+
// Base64 decode
223+
decoded, err := base64.StdEncoding.DecodeString(encoded)
224+
if err != nil {
225+
return nil, fmt.Errorf("invalid base64 encoding: %w", err)
226+
}
227+
228+
// Parse JSON
229+
var opts map[string]interface{}
230+
if err := json.Unmarshal(decoded, &opts); err != nil {
231+
return nil, fmt.Errorf("invalid JSON: %w", err)
232+
}
233+
234+
// Convert all values to strings
235+
result := make(map[string]string)
236+
for k, v := range opts {
237+
switch val := v.(type) {
238+
case string:
239+
result[k] = val
240+
case float64:
241+
// JSON numbers are float64
242+
if val == float64(int(val)) {
243+
result[k] = strconv.Itoa(int(val))
244+
} else {
245+
result[k] = strconv.FormatFloat(val, 'f', -1, 64)
246+
}
247+
case bool:
248+
result[k] = strconv.FormatBool(val)
249+
default:
250+
result[k] = fmt.Sprintf("%v", v)
251+
}
252+
}
253+
254+
return result, nil
255+
}
256+
257+
// applyConnOptions applies parsed connection options to the pgx.ConnConfig.
258+
func applyConnOptions(connConfig *pgx.ConnConfig, opts map[string]string) error {
259+
for key, value := range opts {
260+
// connect_timeout needs special handling as it's a connection-level timeout
261+
if key == "connect_timeout" {
262+
timeout, err := strconv.Atoi(value)
263+
if err != nil {
264+
return fmt.Errorf("invalid connect_timeout value: %w", err)
265+
}
266+
connConfig.ConnectTimeout = time.Duration(timeout) * time.Second
267+
continue
268+
}
269+
// target_session_attrs needs special handling to set ValidateConnect function
270+
if key == "target_session_attrs" {
271+
switch value {
272+
case "read-write":
273+
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite
274+
case "read-only":
275+
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadOnly
276+
case "primary":
277+
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsPrimary
278+
case "standby":
279+
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsStandby
280+
case "prefer-standby":
281+
connConfig.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsPreferStandby
282+
case "any":
283+
// "any" is the default (no validation needed)
284+
connConfig.ValidateConnect = nil
285+
default:
286+
return fmt.Errorf("unknown target_session_attrs value: %s", value)
287+
}
288+
// Do not add target_session_attrs to RuntimeParams
289+
continue
290+
}
291+
// All other options go to RuntimeParams
292+
connConfig.RuntimeParams[key] = value
293+
}
294+
return nil
295+
}
296+
161297
// BuildDSN constructs a PostgreSQL connection string from a ConnConfig.
162298
func BuildDSN(cfg config.PostgreSQLConfig) (string, error) {
163299
connConfig, err := BuildConnConfig(cfg)

0 commit comments

Comments
 (0)