Skip to content

Commit 6bd69c1

Browse files
committed
Push more flag handling to main instead of referring to flags in other modules
1 parent 29b88ce commit 6bd69c1

File tree

5 files changed

+42
-44
lines changed

5 files changed

+42
-44
lines changed

main.go

+23-20
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,13 @@ var exitFunc = os.Exit
9898

9999
// Context groups listening context data together
100100
type Context struct {
101-
watcher chan bool
102-
status *statusHandler
103-
statusHTTP *http.Server
104-
dial func() (net.Conn, error)
105-
metrics *sqmetrics.SquareMetrics
106-
cert *certificate
101+
watcher chan bool
102+
status *statusHandler
103+
statusHTTP *http.Server
104+
shutdownTimeout time.Duration
105+
dial func() (net.Conn, error)
106+
metrics *sqmetrics.SquareMetrics
107+
cert *certificate
107108
}
108109

109110
// Dialer is an interface for dialers (either net.Dialer, or http_dialer.HttpTunnel)
@@ -280,7 +281,7 @@ func run(args []string) error {
280281
logger.Printf("using target address %s", *serverForwardAddress)
281282

282283
status := newStatusHandler(dial)
283-
context := &Context{watcher, status, nil, dial, metrics, cert}
284+
context := &Context{watcher, status, nil, *shutdownTimeout, dial, metrics, cert}
284285

285286
// Start listening
286287
err = serverListen(context)
@@ -309,7 +310,7 @@ func run(args []string) error {
309310
}
310311

311312
status := newStatusHandler(dial)
312-
context := &Context{watcher, status, nil, dial, metrics, cert}
313+
context := &Context{watcher, status, nil, *shutdownTimeout, dial, metrics, cert}
313314

314315
// Start listening
315316
err = clientListen(context)
@@ -328,7 +329,7 @@ func run(args []string) error {
328329
// connections. This is useful for the purpose of replacing certificates
329330
// in-place without having to take downtime, e.g. if a certificate is expiring.
330331
func serverListen(context *Context) error {
331-
config, err := buildConfig(*caBundlePath)
332+
config, err := buildConfig(*enabledCipherSuites, *caBundlePath)
332333
if err != nil {
333334
logger.Printf("error trying to read CA bundle: %s", err)
334335
return err
@@ -353,10 +354,11 @@ func serverListen(context *Context) error {
353354
}
354355

355356
proxy := &proxy{
356-
quit: 0,
357-
listener: tls.NewListener(listener, config),
358-
handlers: &sync.WaitGroup{},
359-
dial: context.dial,
357+
quit: 0,
358+
listener: tls.NewListener(listener, config),
359+
handlers: &sync.WaitGroup{},
360+
connectTimeout: *timeoutDuration,
361+
dial: context.dial,
360362
}
361363

362364
if *statusAddress != "" {
@@ -399,10 +401,11 @@ func clientListen(context *Context) error {
399401
}
400402

401403
proxy := &proxy{
402-
quit: 0,
403-
listener: listener,
404-
handlers: &sync.WaitGroup{},
405-
dial: context.dial,
404+
quit: 0,
405+
listener: listener,
406+
handlers: &sync.WaitGroup{},
407+
connectTimeout: *timeoutDuration,
408+
dial: context.dial,
406409
}
407410

408411
if *statusAddress != "" {
@@ -437,7 +440,7 @@ func (context *Context) serveStatus() error {
437440
mux.Handle("/debug/pprof/trace", http.HandlerFunc(pprof.Trace))
438441
}
439442

440-
config, err := buildConfig(*caBundlePath)
443+
config, err := buildConfig(*enabledCipherSuites, *caBundlePath)
441444
if err != nil {
442445
return err
443446
}
@@ -495,7 +498,7 @@ func serverBackendDialer() (func() (net.Conn, error), error) {
495498

496499
// Get backend dialer function in client mode (connecting to a TLS port)
497500
func clientBackendDialer(cert *certificate, network, address, host string) (func() (net.Conn, error), error) {
498-
config, err := buildConfig(*caBundlePath)
501+
config, err := buildConfig(*enabledCipherSuites, *caBundlePath)
499502
if err != nil {
500503
return nil, err
501504
}
@@ -522,7 +525,7 @@ func clientBackendDialer(cert *certificate, network, address, host string) (func
522525
logger.Printf("using HTTP(S) CONNECT proxy %s", (*clientConnectProxy).String())
523526

524527
// Use HTTP CONNECT proxy to connect to target.
525-
proxyConfig, err := buildConfig(*caBundlePath)
528+
proxyConfig, err := buildConfig(*enabledCipherSuites, *caBundlePath)
526529
if err != nil {
527530
return nil, err
528531
}

net.go

+8-7
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ import (
2929
)
3030

3131
type proxy struct {
32-
quit int32
33-
listener net.Listener
34-
handlers *sync.WaitGroup
35-
dial func() (net.Conn, error)
32+
quit int32
33+
listener net.Listener
34+
handlers *sync.WaitGroup
35+
connectTimeout time.Duration
36+
dial func() (net.Conn, error)
3637
}
3738

3839
var (
@@ -70,7 +71,7 @@ func (p *proxy) accept() {
7071
defer conn.Close()
7172
defer openCounter.Dec(1)
7273

73-
err := forceHandshake(conn)
74+
err := forceHandshake(p.connectTimeout, conn)
7475
if err != nil {
7576
errorCounter.Inc(1)
7677
logger.Printf("error on TLS handshake from %s: %s", conn.RemoteAddr(), err)
@@ -96,13 +97,13 @@ func (p *proxy) accept() {
9697
// Otherwise, unauthenticated clients would be able to open connections
9798
// and leave them hanging forever. Going through the handshake verifies
9899
// that clients have a valid client cert and are allowed to talk to us.
99-
func forceHandshake(conn net.Conn) error {
100+
func forceHandshake(timeout time.Duration, conn net.Conn) error {
100101
if tlsConn, ok := conn.(*tls.Conn); ok {
101102
startTime := time.Now()
102103
defer handshakeTimer.UpdateSince(startTime)
103104

104105
// Set deadline to avoid blocking forever
105-
err := tlsConn.SetDeadline(time.Now().Add(*timeoutDuration))
106+
err := tlsConn.SetDeadline(time.Now().Add(timeout))
106107
if err != nil {
107108
return err
108109
}

signals.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func (context *Context) signalHandler(proxy *proxy, closeables []io.Closer) {
5858
}
5959

6060
// Force-exit after timeout
61-
time.AfterFunc(*shutdownTimeout, func() {
61+
time.AfterFunc(context.shutdownTimeout, func() {
6262
// Graceful shutdown timeout reached. If we can't drain connections
6363
// to exit gracefully after this timeout, let's just exit.
6464
logger.Printf("graceful shutdown timeout: forcing exit")

tls.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ func dialWithDialer(dialer Dialer, timeout time.Duration, network, addr string,
225225
}
226226

227227
// buildConfig reads command-line options and builds a tls.Config
228-
func buildConfig(caBundlePath string) (*tls.Config, error) {
228+
func buildConfig(enabledCipherSuites string, caBundlePath string) (*tls.Config, error) {
229229
ca, err := caBundle(caBundlePath)
230230
if err != nil {
231231
return nil, err
@@ -236,7 +236,7 @@ func buildConfig(caBundlePath string) (*tls.Config, error) {
236236
// * We list AES-128 ahead of AES-256 for performance reasons.
237237

238238
suites := []uint16{}
239-
for _, suite := range strings.Split(*enabledCipherSuites, ",") {
239+
for _, suite := range strings.Split(enabledCipherSuites, ",") {
240240
ciphers, ok := cipherSuites[strings.TrimSpace(suite)]
241241
if !ok {
242242
return nil, fmt.Errorf("invalid cipher suite '%s' selected", suite)

tls_test.go

+8-14
Original file line numberDiff line numberDiff line change
@@ -180,18 +180,16 @@ func TestBuildConfig(t *testing.T) {
180180
defer os.Remove(tmpCaBundle.Name())
181181
defer os.Remove(tmpKeystoreNoPrivKey.Name())
182182

183-
*enabledCipherSuites = ""
184-
conf, err := buildConfig(tmpCaBundle.Name())
183+
conf, err := buildConfig("", tmpCaBundle.Name())
185184
assert.NotNil(t, err, "should fail to build config with no cipher suites")
186185

187-
*enabledCipherSuites = "AES,CHACHA"
188-
conf, err = buildConfig(tmpCaBundle.Name())
186+
conf, err = buildConfig("AES,CHACHA", tmpCaBundle.Name())
189187
assert.Nil(t, err, "should be able to build TLS config")
190188
assert.NotNil(t, conf.RootCAs, "config must have CA certs")
191189
assert.NotNil(t, conf.ClientCAs, "config must have CA certs")
192190
assert.True(t, conf.MinVersion == tls.VersionTLS12, "must have correct TLS min version")
193191

194-
conf, err = buildConfig("does-not-exist")
192+
conf, err = buildConfig("AES", "does-not-exist")
195193
assert.Nil(t, conf, "conf with invalid params should be nil")
196194
assert.NotNil(t, err, "should reject invalid CA cert bundle")
197195

@@ -222,21 +220,17 @@ func TestCipherSuitePreference(t *testing.T) {
222220
tmpCaBundle.Sync()
223221
defer os.Remove(tmpCaBundle.Name())
224222

225-
*enabledCipherSuites = "XYZ"
226-
conf, err := buildConfig(tmpCaBundle.Name())
223+
conf, err := buildConfig("XYZ", tmpCaBundle.Name())
227224
assert.NotNil(t, err, "should not be able to build TLS config with invalid cipher suite option")
228225

229-
*enabledCipherSuites = ""
230-
conf, err = buildConfig(tmpCaBundle.Name())
226+
conf, err = buildConfig("", tmpCaBundle.Name())
231227
assert.NotNil(t, err, "should not be able to build TLS config wihout cipher suite selection")
232228

233-
*enabledCipherSuites = "CHACHA,AES"
234-
conf, err = buildConfig(tmpCaBundle.Name())
229+
conf, err = buildConfig("CHACHA,AES", tmpCaBundle.Name())
235230
assert.Nil(t, err, "should be able to build TLS config")
236231
assert.True(t, conf.CipherSuites[0] == tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, "expecting ChaCha20")
237232

238-
*enabledCipherSuites = "AES,CHACHA"
239-
conf, err = buildConfig(tmpCaBundle.Name())
233+
conf, err = buildConfig("AES,CHACHA", tmpCaBundle.Name())
240234
assert.Nil(t, err, "should be able to build TLS config")
241235
assert.True(t, conf.CipherSuites[0] == tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, "expecting AES")
242236
}
@@ -262,7 +256,7 @@ func TestBuildConfigSystemRoots(t *testing.T) {
262256
t.SkipNow()
263257
return
264258
}
265-
conf, err := buildConfig("")
259+
conf, err := buildConfig("AES", "")
266260
assert.Nil(t, err, "should be able to build TLS config")
267261
assert.NotNil(t, conf.RootCAs, "config must have CA certs")
268262
assert.NotNil(t, conf.ClientCAs, "config must have CA certs")

0 commit comments

Comments
 (0)