Skip to content

Commit 2b97d4e

Browse files
author
Andrew Heberle
committed
Working version using simplecommand
1 parent d7ac703 commit 2b97d4e

File tree

4 files changed

+78
-60
lines changed

4 files changed

+78
-60
lines changed

config_one_nocert.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
service-providers:
2+
- name: nocert
3+
sp-url: http://localhost:9091
4+
idp-metadata: https://mocksaml.com/api/saml/metadata

go.mod

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,23 @@ toolchain go1.24.2
77
replace github.com/crewjam/saml v0.4.14 => github.com/rancher/saml v0.4.14-rancher3
88

99
require (
10-
github.com/andrewheberle/simplecommand v0.2.0
10+
github.com/andrewheberle/simplecommand v0.3.0
1111
github.com/bep/simplecobra v0.6.0
1212
github.com/cloudflare/certinel v0.4.1
1313
github.com/crewjam/saml v0.4.14
14-
github.com/go-viper/mapstructure/v2 v2.2.1
1514
github.com/golang-jwt/jwt/v4 v4.5.2
1615
github.com/jackc/pgx/v5 v5.6.0
1716
github.com/karlseguin/ccache/v3 v3.0.6
1817
github.com/oklog/run v1.1.0
1918
github.com/russellhaering/goxmldsig v1.4.0
20-
gopkg.in/yaml.v3 v3.0.1
2119
)
2220

2321
require (
2422
github.com/andrewheberle/simpleviper v1.1.1 // indirect
2523
github.com/beevik/etree v1.2.0 // indirect
2624
github.com/crewjam/httperr v0.2.0 // indirect
2725
github.com/fsnotify/fsnotify v1.8.0 // indirect
26+
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
2827
github.com/inconshreveable/mousetrap v1.1.0 // indirect
2928
github.com/jackc/pgpassfile v1.0.0 // indirect
3029
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
@@ -47,4 +46,5 @@ require (
4746
golang.org/x/sync v0.10.0 // indirect
4847
golang.org/x/sys v0.29.0 // indirect
4948
golang.org/x/text v0.21.0 // indirect
49+
gopkg.in/yaml.v3 v3.0.1 // indirect
5050
)

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
github.com/andrewheberle/simplecommand v0.2.0 h1:ZJRESmjd8zsHGTK1EY5ickZOtWNlQLCh4qSUiSPVIsE=
2-
github.com/andrewheberle/simplecommand v0.2.0/go.mod h1:mcCWB3Ano9gsv7FW288hXDZr4H9XPYIgXxDJP1Z1cYY=
1+
github.com/andrewheberle/simplecommand v0.3.0 h1:pjTQae9YwajvSEFjdaaczAd/7i/LSic6yuxUQENYTlU=
2+
github.com/andrewheberle/simplecommand v0.3.0/go.mod h1:D9L/jnIotmn3rxyAYIKAAd9rSA+QEHsbfU1UUAo1Upg=
33
github.com/andrewheberle/simpleviper v1.1.1 h1:9cgJDjcQZoQD1OrgjdMgWP4oFVlFGaHXzxVOsJz0abE=
44
github.com/andrewheberle/simpleviper v1.1.1/go.mod h1:xMIWZmEaiCzd86Pq1YNb0PQ/4Fz5thKInTscmfUvUmw=
55
github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A=

internal/pkg/cmd/cmd.go

Lines changed: 69 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ import (
1515
"github.com/andrewheberle/simplecommand"
1616
"github.com/bep/simplecobra"
1717
"github.com/cloudflare/certinel/fswatcher"
18-
"github.com/go-viper/mapstructure/v2"
1918
"github.com/oklog/run"
20-
"gopkg.in/yaml.v3"
2119
)
2220

2321
type rootCommand struct {
@@ -27,7 +25,6 @@ type rootCommand struct {
2725
cert string
2826
key string
2927
listen string
30-
config string
3128
debug bool
3229

3330
// sp flags
@@ -50,15 +47,17 @@ type rootCommand struct {
5047
}
5148

5249
func (c *rootCommand) Init(cd *simplecobra.Commandeer) error {
53-
c.Command.Init(cd)
50+
if err := c.Command.Init(cd); err != nil {
51+
return err
52+
}
5453

5554
cmd := cd.CobraCommand
5655
// general command line flags
5756
cmd.Flags().StringVar(&c.cert, "cert", "", "HTTPS Certificate")
5857
cmd.Flags().StringVar(&c.key, "key", "", "HTTPS Key")
5958
cmd.MarkFlagsRequiredTogether("cert", "key")
6059
cmd.Flags().StringVar(&c.listen, "listen", "127.0.0.1:9091", "Listen address")
61-
cmd.Flags().StringVarP(&c.config, "config", "c", "", "Configuration file")
60+
cmd.Flags().StringVarP(&c.Config, "config", "c", "", "Configuration file")
6261
cmd.Flags().BoolVar(&c.debug, "debug", false, "Enable debug logging")
6362

6463
// sp command line flags
@@ -97,6 +96,8 @@ func (c *rootCommand) PreRun(this, runner *simplecobra.Commandeer) error {
9796
logLevel.Set(slog.LevelDebug)
9897
}
9998

99+
c.logger.Debug("service provider list", "list", c.serviceProviders())
100+
100101
return nil
101102
}
102103

@@ -115,67 +116,23 @@ type serviceProvider struct {
115116
}
116117

117118
func (c *rootCommand) Run(ctx context.Context, cd *simplecobra.Commandeer, args []string) error {
118-
var serviceProviders []serviceProvider
119-
120-
// did we load in via a config file
121-
if c.config != "" {
122-
// read in file
123-
b, err := os.ReadFile(c.config)
124-
if err == nil {
125-
var y map[string]any
126-
127-
// unmarshal to map
128-
if err := yaml.Unmarshal(b, &y); err == nil {
129-
// was there a service_providers key
130-
if splist, ok := y["service_providers"]; ok {
131-
// multiple sp
132-
if err := mapstructure.Decode(splist, &serviceProviders); err != nil {
133-
return fmt.Errorf("error with service providers list: %w", err)
134-
}
135-
} else {
136-
// single sp
137-
var sp serviceProvider
138-
139-
// try to unmarshal as single
140-
if err := yaml.Unmarshal(b, &sp); err != nil {
141-
return fmt.Errorf("error with service provider: %w", err)
142-
}
143-
144-
serviceProviders = []serviceProvider{sp}
145-
}
146-
}
147-
}
148-
} else {
149-
// create sp conf from flags directly
150-
serviceProviders = []serviceProvider{
151-
{
152-
ServiceProviderURL: c.spUrl,
153-
ServiceProviderClaimMapping: c.spClaimMapping,
154-
ServiceProviderCertificate: c.spCert,
155-
ServiceProviderKey: c.spKey,
156-
IdPMetadata: c.idpMetadata,
157-
IdPIssuer: c.idpIssuer,
158-
IdPSSOEndpoint: c.idpSSOEndpoint,
159-
IdPCertificate: c.idpCertificate,
160-
DatabaseConnection: c.dbConnection,
161-
DatabaseTablePrefix: c.dbPrefix,
162-
},
163-
}
164-
}
165-
166119
// create run group
167120
g := run.Group{}
168121

169122
// new mux
170123
mux := http.NewServeMux()
171124

172125
// set up service provider(s)
173-
for _, spConfig := range serviceProviders {
126+
for _, spConfig := range c.serviceProviders() {
127+
// use global values as a fallback if some values are not set
128+
spConfig.ServiceProviderCertificate = fallback(spConfig.ServiceProviderCertificate, c.spCert)
129+
spConfig.ServiceProviderKey = fallback(spConfig.ServiceProviderKey, c.spKey)
130+
174131
// show config in debug mode
175132
c.logger.Debug("setting up service provider",
176133
"name", spConfig.Name,
177134
"url", spConfig.ServiceProviderURL,
178-
"metdata", spConfig.IdPMetadata,
135+
"metadata", spConfig.IdPMetadata,
179136
"cert", spConfig.ServiceProviderCertificate,
180137
"key", spConfig.ServiceProviderKey,
181138
)
@@ -334,6 +291,53 @@ func (c *rootCommand) Run(ctx context.Context, cd *simplecobra.Commandeer, args
334291
return g.Run()
335292
}
336293

294+
func (c *rootCommand) serviceProviders() []serviceProvider {
295+
var serviceProviders []serviceProvider
296+
297+
// no config file or no viper
298+
if c.Config == "" || c.Viper() == nil {
299+
return []serviceProvider{
300+
{
301+
ServiceProviderURL: c.spUrl,
302+
ServiceProviderClaimMapping: c.spClaimMapping,
303+
ServiceProviderCertificate: c.spCert,
304+
ServiceProviderKey: c.spKey,
305+
IdPMetadata: c.idpMetadata,
306+
IdPIssuer: c.idpIssuer,
307+
IdPSSOEndpoint: c.idpSSOEndpoint,
308+
IdPCertificate: c.idpCertificate,
309+
DatabaseConnection: c.dbConnection,
310+
DatabaseTablePrefix: c.dbPrefix,
311+
},
312+
}
313+
}
314+
315+
// plain config file (not a list)
316+
if c.Viper().Get("service-providers") == nil {
317+
return []serviceProvider{
318+
{
319+
ServiceProviderURL: c.spUrl,
320+
ServiceProviderClaimMapping: c.spClaimMapping,
321+
ServiceProviderCertificate: c.spCert,
322+
ServiceProviderKey: c.spKey,
323+
IdPMetadata: c.idpMetadata,
324+
IdPIssuer: c.idpIssuer,
325+
IdPSSOEndpoint: c.idpSSOEndpoint,
326+
IdPCertificate: c.idpCertificate,
327+
DatabaseConnection: c.dbConnection,
328+
DatabaseTablePrefix: c.dbPrefix,
329+
},
330+
}
331+
}
332+
333+
// try to unmarshal from list
334+
if err := c.Viper().UnmarshalKey("service-providers", &serviceProviders); err != nil {
335+
return []serviceProvider{}
336+
}
337+
338+
return serviceProviders
339+
}
340+
337341
func Execute(args []string) error {
338342
rootCmd := &rootCommand{
339343
Command: simplecommand.New(
@@ -356,3 +360,13 @@ a SAML IdP in order to provide SSO to a proxied service.`),
356360

357361
return nil
358362
}
363+
364+
func fallback[T comparable](a, b T) T {
365+
var zero T
366+
367+
if a == zero {
368+
return b
369+
}
370+
371+
return a
372+
}

0 commit comments

Comments
 (0)