Skip to content

Commit d985e4b

Browse files
committed
refactor TLS config
TLS config code causes security linters to report false positive about TLS versions that can be configured. This is porting over CAPO PR 2037. Signed-off-by: Tuomo Tanskanen <[email protected]>
1 parent 3dc25fd commit d985e4b

File tree

2 files changed

+58
-33
lines changed

2 files changed

+58
-33
lines changed

Diff for: main.go

+13-23
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,19 @@ func main() {
366366
func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error) {
367367
var tlsOptions []func(config *tls.Config)
368368

369-
tlsMinVersion, err := GetTLSVersion(options.TLSMinVersion)
370-
if err != nil {
371-
return nil, err
372-
}
373-
374-
tlsMaxVersion, err := GetTLSVersion(options.TLSMaxVersion)
375-
if err != nil {
376-
return nil, err
369+
// To make a static analyzer happy, this block ensures there is no code
370+
// path that sets a TLS version outside the acceptable values, even in
371+
// case of unexpected user input.
372+
var tlsMinVersion, tlsMaxVersion uint16
373+
for version, option := range map[*uint16]string{&tlsMinVersion: options.TLSMinVersion, &tlsMaxVersion: options.TLSMaxVersion} {
374+
switch option {
375+
case TLSVersion12:
376+
*version = tls.VersionTLS12
377+
case TLSVersion13:
378+
*version = tls.VersionTLS13
379+
default:
380+
return nil, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", option, strings.Join(tlsSupportedVersions, ", "))
381+
}
377382
}
378383

379384
if tlsMaxVersion != 0 && tlsMinVersion > tlsMaxVersion {
@@ -418,21 +423,6 @@ func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error)
418423
return tlsOptions, nil
419424
}
420425

421-
// GetTLSVersion returns the corresponding tls.Version or error.
422-
func GetTLSVersion(version string) (uint16, error) {
423-
var v uint16
424-
425-
switch version {
426-
case TLSVersion12:
427-
v = tls.VersionTLS12
428-
case TLSVersion13:
429-
v = tls.VersionTLS13
430-
default:
431-
return 0, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", version, strings.Join(tlsSupportedVersions, ", "))
432-
}
433-
return v, nil
434-
}
435-
436426
func getMaxConcurrentReconciles(controllerConcurrency int) (int, error) {
437427
if controllerConcurrency > 0 {
438428
ctrl.Log.Info(fmt.Sprintf("controller concurrency will be set to %d according to command line flag", controllerConcurrency))

Diff for: main_test.go

+45-10
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package main
1818

1919
import (
2020
"bytes"
21+
"crypto/tls"
2122
"testing"
2223

2324
. "github.com/onsi/gomega"
@@ -75,25 +76,59 @@ func Test13CipherSuite(t *testing.T) {
7576
klog.SetOutput(bufWriter)
7677
klog.LogToStderr(false) // this is important, because klog by default logs to stderr only
7778
_, err := GetTLSOptionOverrideFuncs(tlsMockOptions)
78-
g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers"))
7979
g.Expect(err).ShouldNot(HaveOccurred())
80+
g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers"))
8081
})
8182
}
8283

83-
func TestGetTLSVersion(t *testing.T) {
84-
t.Run("should error out when incorrect tls version passed", func(t *testing.T) {
84+
func TestGetTLSOverrideFuncs(t *testing.T) {
85+
t.Run("should error out when incorrect min tls version passed", func(t *testing.T) {
86+
g := NewWithT(t)
87+
_, err := GetTLSOptionOverrideFuncs(TLSOptions{
88+
TLSMinVersion: "TLS11",
89+
TLSMaxVersion: "TLS12",
90+
})
91+
g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)"))
92+
})
93+
t.Run("should error out when incorrect max tls version passed", func(t *testing.T) {
8594
g := NewWithT(t)
86-
tlsVersion := "TLS11"
87-
_, err := GetTLSVersion(tlsVersion)
95+
_, err := GetTLSOptionOverrideFuncs(TLSOptions{
96+
TLSMinVersion: "TLS12",
97+
TLSMaxVersion: "TLS11",
98+
})
8899
g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)"))
89100
})
90-
t.Run("should pass and output correct tls version", func(t *testing.T) {
91-
const VersionTLS12 uint16 = 771
101+
t.Run("should apply the requested TLS versions", func(t *testing.T) {
102+
g := NewWithT(t)
103+
tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{
104+
TLSMinVersion: "TLS12",
105+
TLSMaxVersion: "TLS13",
106+
})
107+
108+
var tlsConfig tls.Config
109+
for _, apply := range tlsOptionOverrides {
110+
apply(&tlsConfig)
111+
}
112+
113+
g.Expect(err).ShouldNot(HaveOccurred())
114+
g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS12)))
115+
g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13)))
116+
})
117+
t.Run("should apply the requested non-default TLS versions", func(t *testing.T) {
92118
g := NewWithT(t)
93-
tlsVersion := "TLS12"
94-
version, err := GetTLSVersion(tlsVersion)
95-
g.Expect(version).To(Equal(VersionTLS12))
119+
tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{
120+
TLSMinVersion: "TLS13",
121+
TLSMaxVersion: "TLS13",
122+
})
123+
124+
var tlsConfig tls.Config
125+
for _, apply := range tlsOptionOverrides {
126+
apply(&tlsConfig)
127+
}
128+
96129
g.Expect(err).ShouldNot(HaveOccurred())
130+
g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS13)))
131+
g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13)))
97132
})
98133
}
99134

0 commit comments

Comments
 (0)