Skip to content

Commit d301013

Browse files
dtomcejunrolled
authored andcommitted
Allow a custom context key (#65)
* custom context key * remove empty line
1 parent d03975e commit d301013

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

secure.go

+15-5
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ const (
2727
featurePolicyHeader = "Feature-Policy"
2828
expectCTHeader = "Expect-CT"
2929

30-
ctxSecureHeaderKey = secureCtxKey("SecureResponseHeader")
31-
cspNonceSize = 16
30+
ctxDefaultSecureHeaderKey = secureCtxKey("SecureResponseHeader")
31+
cspNonceSize = 16
3232
)
3333

3434
// SSLHostFunc a type whose pointer is the type of field `SSLHostFunc` of `Options` struct
@@ -97,6 +97,8 @@ type Options struct {
9797
STSSeconds int64
9898
// ExpectCTHeader allows the Expect-CT header value to be set with a custom value. Default is "".
9999
ExpectCTHeader string
100+
// SecureContextKey allows a custom key to be specified for context storage.
101+
SecureContextKey string
100102
}
101103

102104
// Secure is a middleware that helps setup a few basic security features. A single secure.Options struct can be
@@ -111,6 +113,9 @@ type Secure struct {
111113
// cRegexAllowedHosts saves the compiled regular expressions of the AllowedHosts
112114
// option for subsequent use in processRequest
113115
cRegexAllowedHosts []*regexp.Regexp
116+
117+
// ctxSecureHeaderKey is the key used for context storage for request modification.
118+
ctxSecureHeaderKey secureCtxKey
114119
}
115120

116121
// New constructs a new Secure instance with the supplied options.
@@ -143,6 +148,11 @@ func New(options ...Options) *Secure {
143148
}
144149
}
145150

151+
s.ctxSecureHeaderKey = ctxDefaultSecureHeaderKey
152+
if len(s.opt.SecureContextKey) > 0 {
153+
s.ctxSecureHeaderKey = secureCtxKey(s.opt.SecureContextKey)
154+
}
155+
146156
return s
147157
}
148158

@@ -182,7 +192,7 @@ func (s *Secure) HandlerForRequestOnly(h http.Handler) http.Handler {
182192
}
183193

184194
// Save response headers in the request context.
185-
ctx := context.WithValue(r.Context(), ctxSecureHeaderKey, responseHeader)
195+
ctx := context.WithValue(r.Context(), s.ctxSecureHeaderKey, responseHeader)
186196

187197
// No headers will be written to the ResponseWriter.
188198
h.ServeHTTP(w, r.WithContext(ctx))
@@ -212,7 +222,7 @@ func (s *Secure) HandlerFuncWithNextForRequestOnly(w http.ResponseWriter, r *htt
212222
// If there was an error, do not call next.
213223
if err == nil && next != nil {
214224
// Save response headers in the request context
215-
ctx := context.WithValue(r.Context(), ctxSecureHeaderKey, responseHeader)
225+
ctx := context.WithValue(r.Context(), s.ctxSecureHeaderKey, responseHeader)
216226

217227
// No headers will be written to the ResponseWriter.
218228
next(w, r.WithContext(ctx))
@@ -450,7 +460,7 @@ func (s *Secure) ModifyResponseHeaders(res *http.Response) error {
450460
res.Header.Set("Location", location)
451461
}
452462

453-
responseHeader := res.Request.Context().Value(ctxSecureHeaderKey)
463+
responseHeader := res.Request.Context().Value(s.ctxSecureHeaderKey)
454464
if responseHeader != nil {
455465
for header, values := range responseHeader.(http.Header) {
456466
if len(values) > 0 {

secure_test.go

+50
Original file line numberDiff line numberDiff line change
@@ -1370,6 +1370,56 @@ func TestModifyResponseHeadersWithSSLAndPathInLocationResponse(t *testing.T) {
13701370
expect(t, res.Header.Get("Location"), "https://secure.example.com/admin/login")
13711371
}
13721372

1373+
func TestCustomSecureContextKey(t *testing.T) {
1374+
s1 := New(Options{
1375+
BrowserXssFilter: true,
1376+
CustomBrowserXssValue: "0",
1377+
SecureContextKey: "totallySecureContextKey",
1378+
})
1379+
1380+
res := httptest.NewRecorder()
1381+
req, _ := http.NewRequest("GET", "/foo", nil)
1382+
1383+
var actual *http.Request
1384+
hf := func(w http.ResponseWriter, r *http.Request) {
1385+
actual = r
1386+
}
1387+
1388+
s1.HandlerFuncWithNextForRequestOnly(res, req, hf)
1389+
contextHeaders := actual.Context().Value(s1.ctxSecureHeaderKey).(http.Header)
1390+
expect(t, contextHeaders.Get(xssProtectionHeader), s1.opt.CustomBrowserXssValue)
1391+
}
1392+
1393+
func TestMultipleCustomSecureContextKeys(t *testing.T) {
1394+
s1 := New(Options{
1395+
BrowserXssFilter: true,
1396+
CustomBrowserXssValue: "0",
1397+
SecureContextKey: "totallySecureContextKey",
1398+
})
1399+
1400+
s2 := New(Options{
1401+
FeaturePolicy: "test",
1402+
SecureContextKey: "anotherSecureContextKey",
1403+
})
1404+
1405+
res := httptest.NewRecorder()
1406+
req, _ := http.NewRequest("GET", "/foo", nil)
1407+
1408+
var actual *http.Request
1409+
hf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1410+
actual = r
1411+
})
1412+
1413+
next := s1.HandlerForRequestOnly(hf)
1414+
s2.HandlerFuncWithNextForRequestOnly(res, req, next.ServeHTTP)
1415+
1416+
s1Headers := actual.Context().Value(s1.ctxSecureHeaderKey).(http.Header)
1417+
s2Headers := actual.Context().Value(s2.ctxSecureHeaderKey).(http.Header)
1418+
1419+
expect(t, s1Headers.Get(xssProtectionHeader), s1.opt.CustomBrowserXssValue)
1420+
expect(t, s2Headers.Get(featurePolicyHeader), s2.opt.FeaturePolicy)
1421+
}
1422+
13731423
/* Test Helpers */
13741424
func expect(t *testing.T, a interface{}, b interface{}) {
13751425
if a != b {

0 commit comments

Comments
 (0)