Skip to content

Commit da25c8a

Browse files
Merge pull request #40 from andrewheberle/multi-sp
Multi sp
2 parents fbb7aee + 5a356fb commit da25c8a

File tree

9 files changed

+256
-141
lines changed

9 files changed

+256
-141
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ AUTH_IDP_METADATA=https://idp.example.net/metadata \
3535
```
3636
--cert string HTTPS Certificate
3737
--db-connection string Database connection string
38+
--db-prefix string Database table prefix
3839
--debug Enable debug logging
3940
-h, --help help for http-auth-server
4041
--idp-certificate string IdP Certificate/Public Key

config_multiple.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
service-providers:
2+
- sp-url: http://localhost:9091/a
3+
sp-cert: ./samlsp.crt
4+
sp-key: ./samlsp.key
5+
idp-metadata: https://mocksaml.com/api/saml/metadata
6+
- name: b
7+
sp-url: http://localhost:9091/b
8+
sp-cert: ./samlsp.crt
9+
sp-key: ./samlsp.key
10+
idp-metadata: https://mocksaml.com/api/saml/metadata

config_one.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
service-providers:
2+
- name: one
3+
sp-url: http://localhost:9091
4+
sp-cert: ./samlsp.crt
5+
sp-key: ./samlsp.key
6+
idp-metadata: https://mocksaml.com/api/saml/metadata

config_single.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
sp-cert: ./samlsp.crt
2+
sp-key: ./samlsp.key
3+
idp-metadata: https://mocksaml.com/api/saml/metadata

go.mod

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ require (
1414
github.com/spf13/cobra v1.8.1
1515
github.com/spf13/pflag v1.0.5
1616
github.com/spf13/viper v1.19.0
17-
gitlab.com/andrewheberle/routerswapper v1.2.0
1817
)
1918

2019
require (

go.sum

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
100100
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
101101
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
102102
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
103-
gitlab.com/andrewheberle/routerswapper v1.2.0 h1:43e23lnlcTI31DoI/4HP2aw27WCgsghLCcezgCCraz0=
104-
gitlab.com/andrewheberle/routerswapper v1.2.0/go.mod h1:olw/7+vGWD6II0k84qQuevoj46o5DIcG1OvM9MmyW5Q=
105103
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
106104
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
107105
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=

internal/cmd/root.go

Lines changed: 136 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import (
1717
"github.com/spf13/cobra"
1818
"github.com/spf13/pflag"
1919
"github.com/spf13/viper"
20-
"gitlab.com/andrewheberle/routerswapper"
2120
)
2221

2322
var rootCmd = &cobra.Command{
@@ -51,13 +50,14 @@ func init() {
5150
rootCmd.Flags().String("idp-certificate", "", "IdP Certificate/Public Key")
5251
rootCmd.Flags().String("db-connection", "", "Database connection string")
5352
rootCmd.Flags().String("db-prefix", "", "Database table prefix")
53+
rootCmd.Flags().StringP("config", "c", "", "Configuration file")
5454
rootCmd.Flags().Bool("debug", false, "Enable debug logging")
5555

5656
// flag requirements
57-
rootCmd.MarkFlagsRequiredTogether("cert", "key")
5857
rootCmd.MarkFlagsRequiredTogether("sp-cert", "sp-key")
5958
rootCmd.MarkFlagRequired("sp-cert")
6059
rootCmd.MarkFlagRequired("sp-key")
60+
rootCmd.MarkFlagsRequiredTogether("cert", "key")
6161
rootCmd.MarkFlagsRequiredTogether("idp-issuer", "idp-sso-endpoint", "idp-certificate")
6262
rootCmd.MarkFlagsMutuallyExclusive("idp-metadata", "idp-issuer")
6363
rootCmd.MarkFlagsMutuallyExclusive("idp-metadata", "idp-sso-endpoint")
@@ -74,14 +74,45 @@ func initConfig() {
7474
// bind flags to viper
7575
viper.BindPFlags(rootCmd.Flags())
7676

77-
// set any flags found in environment via viper
77+
// load config file if flag is set
78+
if config := viper.GetString("config"); config != "" {
79+
viper.SetConfigFile(config)
80+
if err := viper.ReadInConfig(); err != nil {
81+
slog.Error("problem loading configuration", "error", err)
82+
os.Exit(1)
83+
}
84+
85+
// set sp-cert and sp-key to something just to allow things to work when using multiple SP's
86+
for _, name := range []string{"sp-cert", "sp-key"} {
87+
if !viper.IsSet(name) {
88+
rootCmd.Flags().Set(name, "unused")
89+
}
90+
}
91+
}
92+
93+
// set any flags found in environment/config via viper
7894
rootCmd.Flags().VisitAll(func(f *pflag.Flag) {
7995
if viper.IsSet(f.Name) && viper.GetString(f.Name) != "" {
96+
slog.Info("setting flag", "name", f.Name, "value", viper.GetString(f.Name))
8097
rootCmd.Flags().Set(f.Name, viper.GetString(f.Name))
8198
}
8299
})
83100
}
84101

102+
type serviceProvider struct {
103+
Name string `mapstructure:"name"`
104+
ServiceProviderURL string `mapstructure:"sp-url"`
105+
ServiceProviderClaimMapping map[string]string `mapstructure:"sp-claim-mapping"`
106+
ServiceProviderCertificate string `mapstructure:"sp-cert"`
107+
ServiceProviderKey string `mapstructure:"sp-key"`
108+
IdPMetadata string `mapstructure:"idp-metadata"`
109+
IdPIssuer string `mapstructure:"idp-issuer"`
110+
IdPSSOEndpoint string `mapstructure:"idp-sso-endpoint"`
111+
IdPCertificate string `mapstructure:"idp-certificate"`
112+
DatabaseConnection string `mapstructure:"db-connection"`
113+
DatabaseTablePrefix string `mapstructure:"db-prefix"`
114+
}
115+
85116
func runRootCmd() error {
86117
// logging setup
87118
var logLevel = new(slog.LevelVar)
@@ -91,76 +122,131 @@ func runRootCmd() error {
91122
logLevel.Set(slog.LevelDebug)
92123
}
93124

94-
// validate service provider root url
95-
root, err := url.Parse(viper.GetString("sp-url"))
96-
if err != nil {
97-
return fmt.Errorf("problem with SP URL: %w", err)
98-
}
125+
// did we load in via a config file
126+
var serviceProviders []serviceProvider
127+
if viper.ConfigFileUsed() != "" {
128+
// has a list of service providers been provided?
129+
if viper.Get("service-providers") != nil {
130+
if err := viper.UnmarshalKey("service-providers", &serviceProviders); err != nil {
131+
return fmt.Errorf("error with service providers list: %w", err)
132+
}
133+
} else {
134+
var sp serviceProvider
135+
if err := viper.Unmarshal(&sp); err != nil {
136+
return fmt.Errorf("error with service provider: %w", err)
137+
}
99138

100-
// set up service provider options
101-
opts := []sp.ServiceProviderOption{
102-
sp.WithClaimMapping(viper.GetStringMapString("sp-claim-mapping")),
139+
serviceProviders = []serviceProvider{sp}
140+
}
103141
}
104142

105-
// handle metadata
106-
if m := viper.GetString("idp-metadata"); m != "" {
107-
metadata, err := url.Parse(m)
143+
// create run group
144+
g := run.Group{}
145+
146+
// new mux
147+
mux := http.NewServeMux()
148+
149+
// set up service provider(s)
150+
for _, spConfig := range serviceProviders {
151+
// validate service provider root url
152+
root, err := url.Parse(spConfig.ServiceProviderURL)
108153
if err != nil {
109-
return fmt.Errorf("problem parsing IdP metadata url: %w", err)
154+
return fmt.Errorf("problem with SP URL: %w", err)
110155
}
111156

112-
opts = append(opts, sp.WithMetadataURL(metadata))
113-
} else {
114-
metadata := sp.ServiceProviderMetadata{
115-
Issuer: viper.GetString("idp-issuer"),
116-
Endpoint: viper.GetString("idp-sso-endpoint"),
117-
NameId: "persistent",
118-
Certificate: viper.GetString("idp-certificate"),
157+
// set up service provider options
158+
opts := []sp.ServiceProviderOption{
159+
sp.WithClaimMapping(spConfig.ServiceProviderClaimMapping),
119160
}
120161

121-
opts = append(opts, sp.WithCustomMetadata(metadata))
122-
}
162+
// handle metadata
163+
if spConfig.IdPMetadata != "" {
164+
metadata, err := url.Parse(spConfig.IdPMetadata)
165+
if err != nil {
166+
return fmt.Errorf("problem parsing IdP metadata url: %w", err)
167+
}
123168

124-
// are we using a database for storing session attributes
125-
if dsn := viper.GetString("db-connection"); dsn != "" {
126-
store, err := sp.NewDbAttributeStore(viper.GetString("db-prefix"), dsn)
169+
opts = append(opts, sp.WithMetadataURL(metadata))
170+
} else {
171+
metadata := sp.ServiceProviderMetadata{
172+
Issuer: spConfig.IdPIssuer,
173+
Endpoint: spConfig.IdPSSOEndpoint,
174+
Certificate: spConfig.IdPCertificate,
175+
}
176+
177+
opts = append(opts, sp.WithCustomMetadata(metadata))
178+
}
179+
180+
// are we using a database for storing session attributes
181+
if dsn := spConfig.DatabaseConnection; dsn != "" {
182+
store, err := sp.NewDbAttributeStore(spConfig.DatabaseTablePrefix, dsn)
183+
if err != nil {
184+
return fmt.Errorf("problem setting up db attribute store: %w", err)
185+
}
186+
defer store.Close()
187+
188+
opts = append(opts, sp.WithAttributeStore(store))
189+
}
190+
191+
// set Service Provider name if provided
192+
if spConfig.Name != "" {
193+
opts = append(opts, sp.WithName(spConfig.Name))
194+
}
195+
196+
// set up auth provider
197+
provider, err := sp.NewServiceProvider(spConfig.ServiceProviderCertificate, spConfig.ServiceProviderKey, root, opts...)
127198
if err != nil {
128-
return fmt.Errorf("problem setting up db attribute store: %w", err)
199+
return fmt.Errorf("problem setting up SP: %w", err)
129200
}
130-
defer store.Close()
131201

132-
opts = append(opts, sp.WithAttributeStore(store))
133-
}
202+
// set up refresh/reload of service provider metdata
203+
if spConfig.IdPMetadata != "" {
204+
quit := make(chan struct{})
205+
g.Add(func() error {
206+
slog.Info("service provider refresh", "action", "started", "next", time.Now().Add(time.Hour*24))
207+
for {
208+
select {
209+
case <-quit:
210+
return nil
211+
default:
212+
if err := provider.RefreshMetadata(); err != nil {
213+
// not a fatal error
214+
slog.Error("saml service provider reload", "error", err)
215+
continue
216+
}
217+
}
134218

135-
// set up auth provider
136-
provider, err := sp.NewServiceProvider(viper.GetString("sp-cert"), viper.GetString("sp-key"), root, opts...)
137-
if err != nil {
138-
return fmt.Errorf("problem setting up SP: %w", err)
139-
}
219+
// some logging
220+
slog.Info("service provider refresh", "action", "refreshed", "next", time.Now().Add(time.Hour*24))
221+
}
222+
}, func(err error) {
223+
slog.Info("service provider refresh", "action", "shutting down")
224+
close(quit)
225+
})
226+
}
140227

141-
// new server mux
142-
mux := sp.NewMux(provider)
228+
// new server mux
229+
if err := provider.NewMux(mux); err != nil {
230+
return fmt.Errorf("error setting up mux: %w", err)
231+
}
143232

144-
// allow swapping of mux
145-
rs := routerswapper.New(mux)
233+
slog.Info("set up service provider",
234+
"acs-url", provider.AcsURL().String(),
235+
"metdata-url", provider.MetadataURL().String(),
236+
"logout-url", provider.LogoutUrl().String(),
237+
"name", spConfig.Name,
238+
)
239+
}
146240

147241
// set up server
148242
srv := &http.Server{
149243
Addr: viper.GetString("listen"),
150-
Handler: rs,
244+
Handler: mux,
151245
ReadTimeout: time.Second * 3,
152246
WriteTimeout: time.Second * 3,
153247
}
154248

155-
slog.Info("starting service",
156-
"listen", srv.Addr,
157-
"sp-acs-url", provider.AcsURL().String(),
158-
"sp-metdata-url", provider.MetadataURL().String(),
159-
"sp-logout-url", provider.LogoutUrl().String(),
160-
)
161-
162-
// create run group
163-
g := run.Group{}
249+
slog.Info("starting service", "listen", srv.Addr)
164250

165251
// add http server
166252
if viper.GetString("cert") == "" && viper.GetString("key") == "" {
@@ -213,54 +299,6 @@ func runRootCmd() error {
213299
})
214300
}
215301

216-
// set up refresh/reload of service provider metdata
217-
if viper.GetString("idp-metadata") != "" {
218-
quit := make(chan struct{})
219-
g.Add(func() error {
220-
slog.Info("service provider refresh", "action", "started", "next", time.Now().Add(time.Hour*24))
221-
for {
222-
select {
223-
case <-quit:
224-
return nil
225-
default:
226-
time.Sleep(time.Hour * 24)
227-
228-
// parse url
229-
metadata, _ := url.Parse(viper.GetString("idp-metadata"))
230-
if err != nil {
231-
return fmt.Errorf("problem parsing IdP metadata url: %w", err)
232-
}
233-
234-
// set up service provider options
235-
opts := []sp.ServiceProviderOption{
236-
sp.WithClaimMapping(viper.GetStringMapString("sp-claim-mapping")),
237-
sp.WithMetadataURL(metadata),
238-
}
239-
240-
// set up provider
241-
provider, err := sp.NewServiceProvider(viper.GetString("sp-cert"), viper.GetString("sp-key"), root, opts...)
242-
if err != nil {
243-
// not a fatal error
244-
slog.Error("saml service provider reload", "error", err)
245-
continue
246-
}
247-
248-
// new server mux
249-
mux := sp.NewMux(provider)
250-
251-
// swap to new mux
252-
rs.Swap(mux)
253-
}
254-
255-
// some logging
256-
slog.Info("service provider refresh", "action", "refreshed", "next", time.Now().Add(time.Hour*24))
257-
}
258-
}, func(err error) {
259-
slog.Info("service provider refresh", "action", "shutting down")
260-
close(quit)
261-
})
262-
}
263-
264302
if err := g.Run(); err != nil {
265303
return fmt.Errorf("problem while running: %w", err)
266304
}

pkg/sp/options.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ type ServiceProviderOption func(*ServiceProvider)
1414

1515
func WithMetadataURL(metadata *url.URL) ServiceProviderOption {
1616
return func(s *ServiceProvider) {
17-
// populate metadata either from a metadata URL or from custom values
17+
// populate metadata from a metadata URL
1818
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
1919
defer cancel()
2020

@@ -26,13 +26,14 @@ func WithMetadataURL(metadata *url.URL) ServiceProviderOption {
2626
}
2727

2828
s.idpMetadata = idpMetadata
29+
s.idpMetadataURL = metadata
2930
}
3031
}
3132

3233
func WithCustomMetadata(metadata ServiceProviderMetadata) ServiceProviderOption {
3334
return func(s *ServiceProvider) {
3435
// build metadata from provided values
35-
b, err := buildMetadata(metadata.Issuer, metadata.Endpoint, metadata.NameId, metadata.Certificate)
36+
b, err := buildMetadata(metadata.Issuer, metadata.Endpoint, metadata.Certificate)
3637
if err != nil {
3738
slog.Error("metadata build error", "error", err)
3839
return
@@ -59,3 +60,15 @@ func WithAttributeStore(store AttributeStore) ServiceProviderOption {
5960
s.store = store
6061
}
6162
}
63+
64+
func WithMetadataRefreshInterval(d time.Duration) ServiceProviderOption {
65+
return func(s *ServiceProvider) {
66+
s.idpMetadataRefreshInterval = d
67+
}
68+
}
69+
70+
func WithName(name string) ServiceProviderOption {
71+
return func(s *ServiceProvider) {
72+
s.name = name
73+
}
74+
}

0 commit comments

Comments
 (0)