@@ -4,21 +4,23 @@ import (
44 "context"
55 "crypto/tls"
66 "crypto/x509"
7+ "encoding/base64"
78 "encoding/json"
89 "errors"
910 "fmt"
1011 "net/http"
1112 "os"
13+ "strconv"
1214 "strings"
1315 "time"
1416
1517 "github.com/google/uuid"
1618 "github.com/gorilla/sessions"
1719 "github.com/jackc/pgx/v5"
20+ "github.com/jackc/pgx/v5/pgconn"
1821 "github.com/jackc/pgx/v5/stdlib"
1922 "github.com/mitchellh/mapstructure"
2023 log "github.com/sirupsen/logrus"
21- _ "gorm.io/driver/postgres"
2224 "gorm.io/gorm"
2325 "gorm.io/gorm/clause"
2426
@@ -65,8 +67,8 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
6567 if cfg .Name == "" {
6668 return nil , fmt .Errorf ("PostgreSQL database name is required" )
6769 }
68- if cfg .Port <= 0 {
69- return nil , fmt .Errorf ("PostgreSQL port must be positive " )
70+ if cfg .Port == "" {
71+ return nil , fmt .Errorf ("PostgreSQL port is required " )
7072 }
7173
7274 // Start with a default config
@@ -75,9 +77,38 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
7577 return nil , fmt .Errorf ("failed to create default config: %w" , err )
7678 }
7779
78- // Set connection parameters
79- connConfig .Host = cfg .Host
80- connConfig .Port = uint16 (cfg .Port )
80+ // Parse comma-separated hosts and create fallbacks
81+ // cfg.Host can be a comma-separated list like "host1,host2,host3"
82+ hosts := strings .Split (cfg .Host , "," )
83+ for i , host := range hosts {
84+ hosts [i ] = strings .TrimSpace (host )
85+ }
86+
87+ // Parse and validate comma-separated ports
88+ portStrs := strings .Split (cfg .Port , "," )
89+ ports := make ([]uint16 , len (portStrs ))
90+ for i , portStr := range portStrs {
91+ portStr = strings .TrimSpace (portStr )
92+ port , err := strconv .Atoi (portStr )
93+ if err != nil {
94+ return nil , fmt .Errorf ("invalid port value %q: %w" , portStr , err )
95+ }
96+ if port <= 0 {
97+ return nil , fmt .Errorf ("PostgreSQL port %d must be positive" , port )
98+ }
99+ if port > 65535 {
100+ return nil , fmt .Errorf ("PostgreSQL port %d is out of valid range" , port )
101+ }
102+ ports [i ] = uint16 (port )
103+ }
104+
105+ // Get port for primary host
106+ primaryHost := hosts [0 ]
107+ primaryPort := ports [0 ]
108+
109+ // Set connection parameters for primary host
110+ connConfig .Host = primaryHost
111+ connConfig .Port = primaryPort
81112 connConfig .User = cfg .User
82113 connConfig .Password = cfg .Password
83114 connConfig .Database = cfg .Name
@@ -123,13 +154,35 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
123154 case "verify-full" :
124155 // Verify the certificate and hostname
125156 tlsConfig .InsecureSkipVerify = false
126- tlsConfig .ServerName = cfg . Host
157+ tlsConfig .ServerName = primaryHost
127158 }
128159
129160 connConfig .TLSConfig = tlsConfig
130161 }
131162 }
132163
164+ // Create fallback configurations for additional hosts
165+ if len (hosts ) > 1 {
166+ connConfig .Fallbacks = make ([]* pgconn.FallbackConfig , 0 , len (hosts )- 1 )
167+ for i , host := range hosts [1 :] {
168+ port := getPortForIndex (ports , i + 1 )
169+ fallback := & pgconn.FallbackConfig {
170+ Host : host ,
171+ Port : port ,
172+ }
173+ // Copy TLS config to fallback if present
174+ if connConfig .TLSConfig != nil {
175+ fallbackTLS := connConfig .TLSConfig .Clone ()
176+ // Update ServerName for verify-full mode
177+ if cfg .SSLMode == "verify-full" {
178+ fallbackTLS .ServerName = host
179+ }
180+ fallback .TLSConfig = fallbackTLS
181+ }
182+ connConfig .Fallbacks = append (connConfig .Fallbacks , fallback )
183+ }
184+ }
185+
133186 // Set runtime params
134187 if connConfig .RuntimeParams == nil {
135188 connConfig .RuntimeParams = make (map [string ]string )
@@ -141,23 +194,106 @@ func BuildConnConfig(cfg config.PostgreSQLConfig) (*pgx.ConnConfig, error) {
141194
142195 // Parse and apply connection options if specified
143196 if cfg .ConnOptions != "" {
144- // Parse key=value pairs from ConnOptions
145- // Format: "key1=value1 key2=value2"
146- pairs := strings .Split (cfg .ConnOptions , " " )
147- for _ , pair := range pairs {
148- if pair == "" {
149- continue
150- }
151- kv := strings .SplitN (pair , "=" , 2 )
152- if len (kv ) == 2 {
153- connConfig .RuntimeParams [kv [0 ]] = kv [1 ]
154- }
197+ connOpts , err := parseConnOptions (cfg .ConnOptions )
198+ if err != nil {
199+ return nil , fmt .Errorf ("failed to parse connection options: %w" , err )
200+ }
201+
202+ if err := applyConnOptions (connConfig , connOpts ); err != nil {
203+ return nil , fmt .Errorf ("failed to apply connection options: %w" , err )
155204 }
156205 }
157206
158207 return connConfig , nil
159208}
160209
210+ // getPortForIndex returns the port for the given host index.
211+ // If there are fewer ports than needed, returns the last port (libpq behavior).
212+ func getPortForIndex (ports []uint16 , i int ) uint16 {
213+ if i >= len (ports ) {
214+ return ports [len (ports )- 1 ]
215+ }
216+ return ports [i ]
217+ }
218+
219+ // parseConnOptions decodes a base64-encoded JSON string into a map of connection options.
220+ // This matches the Python behavior in authentik/lib/config.py:get_dict_from_b64_json
221+ func parseConnOptions (encoded string ) (map [string ]string , error ) {
222+ // Base64 decode
223+ decoded , err := base64 .StdEncoding .DecodeString (encoded )
224+ if err != nil {
225+ return nil , fmt .Errorf ("invalid base64 encoding: %w" , err )
226+ }
227+
228+ // Parse JSON
229+ var opts map [string ]interface {}
230+ if err := json .Unmarshal (decoded , & opts ); err != nil {
231+ return nil , fmt .Errorf ("invalid JSON: %w" , err )
232+ }
233+
234+ // Convert all values to strings
235+ result := make (map [string ]string )
236+ for k , v := range opts {
237+ switch val := v .(type ) {
238+ case string :
239+ result [k ] = val
240+ case float64 :
241+ // JSON numbers are float64
242+ if val == float64 (int (val )) {
243+ result [k ] = strconv .Itoa (int (val ))
244+ } else {
245+ result [k ] = strconv .FormatFloat (val , 'f' , - 1 , 64 )
246+ }
247+ case bool :
248+ result [k ] = strconv .FormatBool (val )
249+ default :
250+ result [k ] = fmt .Sprintf ("%v" , v )
251+ }
252+ }
253+
254+ return result , nil
255+ }
256+
257+ // applyConnOptions applies parsed connection options to the pgx.ConnConfig.
258+ func applyConnOptions (connConfig * pgx.ConnConfig , opts map [string ]string ) error {
259+ for key , value := range opts {
260+ // connect_timeout needs special handling as it's a connection-level timeout
261+ if key == "connect_timeout" {
262+ timeout , err := strconv .Atoi (value )
263+ if err != nil {
264+ return fmt .Errorf ("invalid connect_timeout value: %w" , err )
265+ }
266+ connConfig .ConnectTimeout = time .Duration (timeout ) * time .Second
267+ continue
268+ }
269+ // target_session_attrs needs special handling to set ValidateConnect function
270+ if key == "target_session_attrs" {
271+ switch value {
272+ case "read-write" :
273+ connConfig .ValidateConnect = pgconn .ValidateConnectTargetSessionAttrsReadWrite
274+ case "read-only" :
275+ connConfig .ValidateConnect = pgconn .ValidateConnectTargetSessionAttrsReadOnly
276+ case "primary" :
277+ connConfig .ValidateConnect = pgconn .ValidateConnectTargetSessionAttrsPrimary
278+ case "standby" :
279+ connConfig .ValidateConnect = pgconn .ValidateConnectTargetSessionAttrsStandby
280+ case "prefer-standby" :
281+ connConfig .ValidateConnect = pgconn .ValidateConnectTargetSessionAttrsPreferStandby
282+ case "any" :
283+ // "any" is the default (no validation needed)
284+ connConfig .ValidateConnect = nil
285+ default :
286+ return fmt .Errorf ("unknown target_session_attrs value: %s" , value )
287+ }
288+ // Do not add target_session_attrs to RuntimeParams
289+ continue
290+ }
291+ // All other options go to RuntimeParams
292+ connConfig .RuntimeParams [key ] = value
293+ }
294+ return nil
295+ }
296+
161297// BuildDSN constructs a PostgreSQL connection string from a ConnConfig.
162298func BuildDSN (cfg config.PostgreSQLConfig ) (string , error ) {
163299 connConfig , err := BuildConnConfig (cfg )
0 commit comments