@@ -17,7 +17,6 @@ import (
1717 "github.com/spf13/cobra"
1818 "github.com/spf13/pflag"
1919 "github.com/spf13/viper"
20- "gitlab.com/andrewheberle/routerswapper"
2120)
2221
2322var rootCmd = & cobra.Command {
@@ -51,13 +50,14 @@ func init() {
5150 rootCmd .Flags ().String ("idp-certificate" , "" , "IdP Certificate/Public Key" )
5251 rootCmd .Flags ().String ("db-connection" , "" , "Database connection string" )
5352 rootCmd .Flags ().String ("db-prefix" , "" , "Database table prefix" )
53+ rootCmd .Flags ().StringP ("config" , "c" , "" , "Configuration file" )
5454 rootCmd .Flags ().Bool ("debug" , false , "Enable debug logging" )
5555
5656 // flag requirements
57- rootCmd .MarkFlagsRequiredTogether ("cert" , "key" )
5857 rootCmd .MarkFlagsRequiredTogether ("sp-cert" , "sp-key" )
5958 rootCmd .MarkFlagRequired ("sp-cert" )
6059 rootCmd .MarkFlagRequired ("sp-key" )
60+ rootCmd .MarkFlagsRequiredTogether ("cert" , "key" )
6161 rootCmd .MarkFlagsRequiredTogether ("idp-issuer" , "idp-sso-endpoint" , "idp-certificate" )
6262 rootCmd .MarkFlagsMutuallyExclusive ("idp-metadata" , "idp-issuer" )
6363 rootCmd .MarkFlagsMutuallyExclusive ("idp-metadata" , "idp-sso-endpoint" )
@@ -74,14 +74,45 @@ func initConfig() {
7474 // bind flags to viper
7575 viper .BindPFlags (rootCmd .Flags ())
7676
77- // set any flags found in environment via viper
77+ // load config file if flag is set
78+ if config := viper .GetString ("config" ); config != "" {
79+ viper .SetConfigFile (config )
80+ if err := viper .ReadInConfig (); err != nil {
81+ slog .Error ("problem loading configuration" , "error" , err )
82+ os .Exit (1 )
83+ }
84+
85+ // set sp-cert and sp-key to something just to allow things to work when using multiple SP's
86+ for _ , name := range []string {"sp-cert" , "sp-key" } {
87+ if ! viper .IsSet (name ) {
88+ rootCmd .Flags ().Set (name , "unused" )
89+ }
90+ }
91+ }
92+
93+ // set any flags found in environment/config via viper
7894 rootCmd .Flags ().VisitAll (func (f * pflag.Flag ) {
7995 if viper .IsSet (f .Name ) && viper .GetString (f .Name ) != "" {
96+ slog .Info ("setting flag" , "name" , f .Name , "value" , viper .GetString (f .Name ))
8097 rootCmd .Flags ().Set (f .Name , viper .GetString (f .Name ))
8198 }
8299 })
83100}
84101
102+ type serviceProvider struct {
103+ Name string `mapstructure:"name"`
104+ ServiceProviderURL string `mapstructure:"sp-url"`
105+ ServiceProviderClaimMapping map [string ]string `mapstructure:"sp-claim-mapping"`
106+ ServiceProviderCertificate string `mapstructure:"sp-cert"`
107+ ServiceProviderKey string `mapstructure:"sp-key"`
108+ IdPMetadata string `mapstructure:"idp-metadata"`
109+ IdPIssuer string `mapstructure:"idp-issuer"`
110+ IdPSSOEndpoint string `mapstructure:"idp-sso-endpoint"`
111+ IdPCertificate string `mapstructure:"idp-certificate"`
112+ DatabaseConnection string `mapstructure:"db-connection"`
113+ DatabaseTablePrefix string `mapstructure:"db-prefix"`
114+ }
115+
85116func runRootCmd () error {
86117 // logging setup
87118 var logLevel = new (slog.LevelVar )
@@ -91,76 +122,131 @@ func runRootCmd() error {
91122 logLevel .Set (slog .LevelDebug )
92123 }
93124
94- // validate service provider root url
95- root , err := url .Parse (viper .GetString ("sp-url" ))
96- if err != nil {
97- return fmt .Errorf ("problem with SP URL: %w" , err )
98- }
125+ // did we load in via a config file
126+ var serviceProviders []serviceProvider
127+ if viper .ConfigFileUsed () != "" {
128+ // has a list of service providers been provided?
129+ if viper .Get ("service-providers" ) != nil {
130+ if err := viper .UnmarshalKey ("service-providers" , & serviceProviders ); err != nil {
131+ return fmt .Errorf ("error with service providers list: %w" , err )
132+ }
133+ } else {
134+ var sp serviceProvider
135+ if err := viper .Unmarshal (& sp ); err != nil {
136+ return fmt .Errorf ("error with service provider: %w" , err )
137+ }
99138
100- // set up service provider options
101- opts := []sp.ServiceProviderOption {
102- sp .WithClaimMapping (viper .GetStringMapString ("sp-claim-mapping" )),
139+ serviceProviders = []serviceProvider {sp }
140+ }
103141 }
104142
105- // handle metadata
106- if m := viper .GetString ("idp-metadata" ); m != "" {
107- metadata , err := url .Parse (m )
143+ // create run group
144+ g := run.Group {}
145+
146+ // new mux
147+ mux := http .NewServeMux ()
148+
149+ // set up service provider(s)
150+ for _ , spConfig := range serviceProviders {
151+ // validate service provider root url
152+ root , err := url .Parse (spConfig .ServiceProviderURL )
108153 if err != nil {
109- return fmt .Errorf ("problem parsing IdP metadata url : %w" , err )
154+ return fmt .Errorf ("problem with SP URL : %w" , err )
110155 }
111156
112- opts = append (opts , sp .WithMetadataURL (metadata ))
113- } else {
114- metadata := sp.ServiceProviderMetadata {
115- Issuer : viper .GetString ("idp-issuer" ),
116- Endpoint : viper .GetString ("idp-sso-endpoint" ),
117- NameId : "persistent" ,
118- Certificate : viper .GetString ("idp-certificate" ),
157+ // set up service provider options
158+ opts := []sp.ServiceProviderOption {
159+ sp .WithClaimMapping (spConfig .ServiceProviderClaimMapping ),
119160 }
120161
121- opts = append (opts , sp .WithCustomMetadata (metadata ))
122- }
162+ // handle metadata
163+ if spConfig .IdPMetadata != "" {
164+ metadata , err := url .Parse (spConfig .IdPMetadata )
165+ if err != nil {
166+ return fmt .Errorf ("problem parsing IdP metadata url: %w" , err )
167+ }
123168
124- // are we using a database for storing session attributes
125- if dsn := viper .GetString ("db-connection" ); dsn != "" {
126- store , err := sp .NewDbAttributeStore (viper .GetString ("db-prefix" ), dsn )
169+ opts = append (opts , sp .WithMetadataURL (metadata ))
170+ } else {
171+ metadata := sp.ServiceProviderMetadata {
172+ Issuer : spConfig .IdPIssuer ,
173+ Endpoint : spConfig .IdPSSOEndpoint ,
174+ Certificate : spConfig .IdPCertificate ,
175+ }
176+
177+ opts = append (opts , sp .WithCustomMetadata (metadata ))
178+ }
179+
180+ // are we using a database for storing session attributes
181+ if dsn := spConfig .DatabaseConnection ; dsn != "" {
182+ store , err := sp .NewDbAttributeStore (spConfig .DatabaseTablePrefix , dsn )
183+ if err != nil {
184+ return fmt .Errorf ("problem setting up db attribute store: %w" , err )
185+ }
186+ defer store .Close ()
187+
188+ opts = append (opts , sp .WithAttributeStore (store ))
189+ }
190+
191+ // set Service Provider name if provided
192+ if spConfig .Name != "" {
193+ opts = append (opts , sp .WithName (spConfig .Name ))
194+ }
195+
196+ // set up auth provider
197+ provider , err := sp .NewServiceProvider (spConfig .ServiceProviderCertificate , spConfig .ServiceProviderKey , root , opts ... )
127198 if err != nil {
128- return fmt .Errorf ("problem setting up db attribute store : %w" , err )
199+ return fmt .Errorf ("problem setting up SP : %w" , err )
129200 }
130- defer store .Close ()
131201
132- opts = append (opts , sp .WithAttributeStore (store ))
133- }
202+ // set up refresh/reload of service provider metdata
203+ if spConfig .IdPMetadata != "" {
204+ quit := make (chan struct {})
205+ g .Add (func () error {
206+ slog .Info ("service provider refresh" , "action" , "started" , "next" , time .Now ().Add (time .Hour * 24 ))
207+ for {
208+ select {
209+ case <- quit :
210+ return nil
211+ default :
212+ if err := provider .RefreshMetadata (); err != nil {
213+ // not a fatal error
214+ slog .Error ("saml service provider reload" , "error" , err )
215+ continue
216+ }
217+ }
134218
135- // set up auth provider
136- provider , err := sp .NewServiceProvider (viper .GetString ("sp-cert" ), viper .GetString ("sp-key" ), root , opts ... )
137- if err != nil {
138- return fmt .Errorf ("problem setting up SP: %w" , err )
139- }
219+ // some logging
220+ slog .Info ("service provider refresh" , "action" , "refreshed" , "next" , time .Now ().Add (time .Hour * 24 ))
221+ }
222+ }, func (err error ) {
223+ slog .Info ("service provider refresh" , "action" , "shutting down" )
224+ close (quit )
225+ })
226+ }
140227
141- // new server mux
142- mux := sp .NewMux (provider )
228+ // new server mux
229+ if err := provider .NewMux (mux ); err != nil {
230+ return fmt .Errorf ("error setting up mux: %w" , err )
231+ }
143232
144- // allow swapping of mux
145- rs := routerswapper .New (mux )
233+ slog .Info ("set up service provider" ,
234+ "acs-url" , provider .AcsURL ().String (),
235+ "metdata-url" , provider .MetadataURL ().String (),
236+ "logout-url" , provider .LogoutUrl ().String (),
237+ "name" , spConfig .Name ,
238+ )
239+ }
146240
147241 // set up server
148242 srv := & http.Server {
149243 Addr : viper .GetString ("listen" ),
150- Handler : rs ,
244+ Handler : mux ,
151245 ReadTimeout : time .Second * 3 ,
152246 WriteTimeout : time .Second * 3 ,
153247 }
154248
155- slog .Info ("starting service" ,
156- "listen" , srv .Addr ,
157- "sp-acs-url" , provider .AcsURL ().String (),
158- "sp-metdata-url" , provider .MetadataURL ().String (),
159- "sp-logout-url" , provider .LogoutUrl ().String (),
160- )
161-
162- // create run group
163- g := run.Group {}
249+ slog .Info ("starting service" , "listen" , srv .Addr )
164250
165251 // add http server
166252 if viper .GetString ("cert" ) == "" && viper .GetString ("key" ) == "" {
@@ -213,54 +299,6 @@ func runRootCmd() error {
213299 })
214300 }
215301
216- // set up refresh/reload of service provider metdata
217- if viper .GetString ("idp-metadata" ) != "" {
218- quit := make (chan struct {})
219- g .Add (func () error {
220- slog .Info ("service provider refresh" , "action" , "started" , "next" , time .Now ().Add (time .Hour * 24 ))
221- for {
222- select {
223- case <- quit :
224- return nil
225- default :
226- time .Sleep (time .Hour * 24 )
227-
228- // parse url
229- metadata , _ := url .Parse (viper .GetString ("idp-metadata" ))
230- if err != nil {
231- return fmt .Errorf ("problem parsing IdP metadata url: %w" , err )
232- }
233-
234- // set up service provider options
235- opts := []sp.ServiceProviderOption {
236- sp .WithClaimMapping (viper .GetStringMapString ("sp-claim-mapping" )),
237- sp .WithMetadataURL (metadata ),
238- }
239-
240- // set up provider
241- provider , err := sp .NewServiceProvider (viper .GetString ("sp-cert" ), viper .GetString ("sp-key" ), root , opts ... )
242- if err != nil {
243- // not a fatal error
244- slog .Error ("saml service provider reload" , "error" , err )
245- continue
246- }
247-
248- // new server mux
249- mux := sp .NewMux (provider )
250-
251- // swap to new mux
252- rs .Swap (mux )
253- }
254-
255- // some logging
256- slog .Info ("service provider refresh" , "action" , "refreshed" , "next" , time .Now ().Add (time .Hour * 24 ))
257- }
258- }, func (err error ) {
259- slog .Info ("service provider refresh" , "action" , "shutting down" )
260- close (quit )
261- })
262- }
263-
264302 if err := g .Run (); err != nil {
265303 return fmt .Errorf ("problem while running: %w" , err )
266304 }
0 commit comments