@@ -15,9 +15,7 @@ import (
1515 "github.com/andrewheberle/simplecommand"
1616 "github.com/bep/simplecobra"
1717 "github.com/cloudflare/certinel/fswatcher"
18- "github.com/go-viper/mapstructure/v2"
1918 "github.com/oklog/run"
20- "gopkg.in/yaml.v3"
2119)
2220
2321type rootCommand struct {
@@ -27,7 +25,6 @@ type rootCommand struct {
2725 cert string
2826 key string
2927 listen string
30- config string
3128 debug bool
3229
3330 // sp flags
@@ -50,15 +47,17 @@ type rootCommand struct {
5047}
5148
5249func (c * rootCommand ) Init (cd * simplecobra.Commandeer ) error {
53- c .Command .Init (cd )
50+ if err := c .Command .Init (cd ); err != nil {
51+ return err
52+ }
5453
5554 cmd := cd .CobraCommand
5655 // general command line flags
5756 cmd .Flags ().StringVar (& c .cert , "cert" , "" , "HTTPS Certificate" )
5857 cmd .Flags ().StringVar (& c .key , "key" , "" , "HTTPS Key" )
5958 cmd .MarkFlagsRequiredTogether ("cert" , "key" )
6059 cmd .Flags ().StringVar (& c .listen , "listen" , "127.0.0.1:9091" , "Listen address" )
61- cmd .Flags ().StringVarP (& c .config , "config" , "c" , "" , "Configuration file" )
60+ cmd .Flags ().StringVarP (& c .Config , "config" , "c" , "" , "Configuration file" )
6261 cmd .Flags ().BoolVar (& c .debug , "debug" , false , "Enable debug logging" )
6362
6463 // sp command line flags
@@ -97,6 +96,8 @@ func (c *rootCommand) PreRun(this, runner *simplecobra.Commandeer) error {
9796 logLevel .Set (slog .LevelDebug )
9897 }
9998
99+ c .logger .Debug ("service provider list" , "list" , c .serviceProviders ())
100+
100101 return nil
101102}
102103
@@ -115,67 +116,23 @@ type serviceProvider struct {
115116}
116117
117118func (c * rootCommand ) Run (ctx context.Context , cd * simplecobra.Commandeer , args []string ) error {
118- var serviceProviders []serviceProvider
119-
120- // did we load in via a config file
121- if c .config != "" {
122- // read in file
123- b , err := os .ReadFile (c .config )
124- if err == nil {
125- var y map [string ]any
126-
127- // unmarshal to map
128- if err := yaml .Unmarshal (b , & y ); err == nil {
129- // was there a service_providers key
130- if splist , ok := y ["service_providers" ]; ok {
131- // multiple sp
132- if err := mapstructure .Decode (splist , & serviceProviders ); err != nil {
133- return fmt .Errorf ("error with service providers list: %w" , err )
134- }
135- } else {
136- // single sp
137- var sp serviceProvider
138-
139- // try to unmarshal as single
140- if err := yaml .Unmarshal (b , & sp ); err != nil {
141- return fmt .Errorf ("error with service provider: %w" , err )
142- }
143-
144- serviceProviders = []serviceProvider {sp }
145- }
146- }
147- }
148- } else {
149- // create sp conf from flags directly
150- serviceProviders = []serviceProvider {
151- {
152- ServiceProviderURL : c .spUrl ,
153- ServiceProviderClaimMapping : c .spClaimMapping ,
154- ServiceProviderCertificate : c .spCert ,
155- ServiceProviderKey : c .spKey ,
156- IdPMetadata : c .idpMetadata ,
157- IdPIssuer : c .idpIssuer ,
158- IdPSSOEndpoint : c .idpSSOEndpoint ,
159- IdPCertificate : c .idpCertificate ,
160- DatabaseConnection : c .dbConnection ,
161- DatabaseTablePrefix : c .dbPrefix ,
162- },
163- }
164- }
165-
166119 // create run group
167120 g := run.Group {}
168121
169122 // new mux
170123 mux := http .NewServeMux ()
171124
172125 // set up service provider(s)
173- for _ , spConfig := range serviceProviders {
126+ for _ , spConfig := range c .serviceProviders () {
127+ // use global values as a fallback if some values are not set
128+ spConfig .ServiceProviderCertificate = fallback (spConfig .ServiceProviderCertificate , c .spCert )
129+ spConfig .ServiceProviderKey = fallback (spConfig .ServiceProviderKey , c .spKey )
130+
174131 // show config in debug mode
175132 c .logger .Debug ("setting up service provider" ,
176133 "name" , spConfig .Name ,
177134 "url" , spConfig .ServiceProviderURL ,
178- "metdata " , spConfig .IdPMetadata ,
135+ "metadata " , spConfig .IdPMetadata ,
179136 "cert" , spConfig .ServiceProviderCertificate ,
180137 "key" , spConfig .ServiceProviderKey ,
181138 )
@@ -334,6 +291,53 @@ func (c *rootCommand) Run(ctx context.Context, cd *simplecobra.Commandeer, args
334291 return g .Run ()
335292}
336293
294+ func (c * rootCommand ) serviceProviders () []serviceProvider {
295+ var serviceProviders []serviceProvider
296+
297+ // no config file or no viper
298+ if c .Config == "" || c .Viper () == nil {
299+ return []serviceProvider {
300+ {
301+ ServiceProviderURL : c .spUrl ,
302+ ServiceProviderClaimMapping : c .spClaimMapping ,
303+ ServiceProviderCertificate : c .spCert ,
304+ ServiceProviderKey : c .spKey ,
305+ IdPMetadata : c .idpMetadata ,
306+ IdPIssuer : c .idpIssuer ,
307+ IdPSSOEndpoint : c .idpSSOEndpoint ,
308+ IdPCertificate : c .idpCertificate ,
309+ DatabaseConnection : c .dbConnection ,
310+ DatabaseTablePrefix : c .dbPrefix ,
311+ },
312+ }
313+ }
314+
315+ // plain config file (not a list)
316+ if c .Viper ().Get ("service-providers" ) == nil {
317+ return []serviceProvider {
318+ {
319+ ServiceProviderURL : c .spUrl ,
320+ ServiceProviderClaimMapping : c .spClaimMapping ,
321+ ServiceProviderCertificate : c .spCert ,
322+ ServiceProviderKey : c .spKey ,
323+ IdPMetadata : c .idpMetadata ,
324+ IdPIssuer : c .idpIssuer ,
325+ IdPSSOEndpoint : c .idpSSOEndpoint ,
326+ IdPCertificate : c .idpCertificate ,
327+ DatabaseConnection : c .dbConnection ,
328+ DatabaseTablePrefix : c .dbPrefix ,
329+ },
330+ }
331+ }
332+
333+ // try to unmarshal from list
334+ if err := c .Viper ().UnmarshalKey ("service-providers" , & serviceProviders ); err != nil {
335+ return []serviceProvider {}
336+ }
337+
338+ return serviceProviders
339+ }
340+
337341func Execute (args []string ) error {
338342 rootCmd := & rootCommand {
339343 Command : simplecommand .New (
@@ -356,3 +360,13 @@ a SAML IdP in order to provide SSO to a proxied service.`),
356360
357361 return nil
358362}
363+
364+ func fallback [T comparable ](a , b T ) T {
365+ var zero T
366+
367+ if a == zero {
368+ return b
369+ }
370+
371+ return a
372+ }
0 commit comments