Skip to content

Commit 0ce3852

Browse files
authored
Adding AllowRequestFunc (#85)
1 parent 56ae1bd commit 0ce3852

File tree

3 files changed

+60
-49
lines changed

3 files changed

+60
-49
lines changed

README.md

+14-5
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ Secure comes with a variety of configuration options (Note: these are not the de
6262
// ...
6363
s := secure.New(secure.Options{
6464
AllowedHosts: []string{"ssl.example.com"}, // AllowedHosts is a list of fully qualified domain names that are allowed. Default is empty list, which allows any and all host names.
65-
AllowedHostsFunc: func() []string { return []string{"example.com", "www.example.com" } // AllowedHostsFunc is a custom function that returns a list of fully qualified domain names that are allowed. This can be used in combination with the above AllowedHosts.
66-
AllowedHostsAreRegex: false, // AllowedHostsAreRegex determines, if the provided AllowedHosts slice contains valid regular expressions. This does not apply to the `AllowedHostsFunc` values! Default is false.
65+
AllowedHostsAreRegex: false, // AllowedHostsAreRegex determines, if the provided AllowedHosts slice contains valid regular expressions. Default is false.
66+
AllowRequestFunc: nil, // AllowRequestFunc is a custom function type that allows you to determine if the request should proceed or not based on your own custom logic. Default is nil.
6767
HostsProxyHeaders: []string{"X-Forwarded-Hosts"}, // HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request.
6868
SSLRedirect: true, // If SSLRedirect is set to true, then only allow HTTPS requests. Default is false.
6969
SSLTemporaryRedirect: false, // If SSLTemporaryRedirect is true, the a 302 will be used while redirecting. Default is false (301).
@@ -102,8 +102,8 @@ s := secure.New()
102102

103103
l := secure.New(secure.Options{
104104
AllowedHosts: []string,
105-
AllowedHostsFunc: nil,
106105
AllowedHostsAreRegex: false,
106+
AllowRequestFunc: nil,
107107
HostsProxyHeaders: []string,
108108
SSLRedirect: false,
109109
SSLTemporaryRedirect: false,
@@ -127,11 +127,20 @@ l := secure.New(secure.Options{
127127
IsDevelopment: false,
128128
})
129129
~~~
130-
Also note the default bad host handler returns an error:
130+
The default bad host handler returns the following error:
131131
~~~ go
132132
http.Error(w, "Bad Host", http.StatusInternalServerError)
133133
~~~
134-
Call `secure.SetBadHostHandler` to change the bad host handler.
134+
Call `secure.SetBadHostHandler` to set your own custom handler.
135+
136+
The default bad request handler returns the following error:
137+
~~~ go
138+
http.Error(w, "Bad Request", http.StatusBadRequest)
139+
~~~
140+
Call `secure.SetBadRequestHandler` to set your own custom handler.
141+
142+
### Allow Request Function
143+
Secure allows you to set a custom function (`func(r *http.Request) bool`) for the `AllowRequestFunc` option. You can use this function as a custom filter to allow the request to continue or simply reject it. This can be handy if you need to do any dynamic filtering on any of the request properties. It should be noted that this function will be called on every request, so be sure to make your checks quick and not relying on time consuming external calls (or you will be slowing down all requests). See above on how to set a custom handler for the rejected requests.
135144

136145
### Redirecting HTTP to HTTPS
137146
If you want to redirect all HTTP requests to HTTPS, you can use the following example.

secure.go

+28-22
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,17 @@ const (
3636
// SSLHostFunc is a custom function type that can be used to dynamically set the SSL host of a request.
3737
type SSLHostFunc func(host string) (newHost string)
3838

39-
// AllowedHostsFunc is a custom function type that can be used to dynamically return a slice of strings that will be used in the `AllowHosts` check.
40-
type AllowedHostsFunc func() []string
39+
// AllowRequestFunc is a custom function type that can be used to dynamically determine if a request should proceed or not.
40+
type AllowRequestFunc func(r *http.Request) bool
4141

4242
func defaultBadHostHandler(w http.ResponseWriter, r *http.Request) {
4343
http.Error(w, "Bad Host", http.StatusInternalServerError)
4444
}
4545

46+
func defaultBadRequestHandler(w http.ResponseWriter, r *http.Request) {
47+
http.Error(w, "Bad Request", http.StatusBadRequest)
48+
}
49+
4650
// Options is a struct for specifying configuration options for the secure.Secure middleware.
4751
type Options struct {
4852
// If BrowserXssFilter is true, adds the X-XSS-Protection header with the value `1; mode=block`. Default is false.
@@ -95,10 +99,10 @@ type Options struct {
9599
SSLHost string
96100
// AllowedHosts is a slice of fully qualified domain names that are allowed. Default is an empty slice, which allows any and all host names.
97101
AllowedHosts []string
98-
// AllowedHostsFunc is a custom function that returns a slice of fully qualified domain names that are allowed. If set, values will be used in combination with the above AllowedHosts. Default is nil.
99-
AllowedHostsFunc AllowedHostsFunc
100-
// AllowedHostsAreRegex determines, if the provided `AllowedHosts` slice contains valid regular expressions. This does not apply to `AllowedHostsFunc`! If this flag is set to true, every request's host will be checked against these expressions. Default is false.
102+
// AllowedHostsAreRegex determines, if the provided `AllowedHosts` slice contains valid regular expressions. If this flag is set to true, every request's host will be checked against these expressions. Default is false.
101103
AllowedHostsAreRegex bool
104+
// AllowRequestFunc is a custom function that allows you to determine if the request should proceed or not based on your own custom logic. Default is nil.
105+
AllowRequestFunc AllowRequestFunc
102106
// HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request.
103107
HostsProxyHeaders []string
104108
// SSLHostFunc is a function pointer, the return value of the function is the host name that has same functionality as `SSHost`. Default is nil.
@@ -123,6 +127,9 @@ type Secure struct {
123127
// badHostHandler is the handler used when an incorrect host is passed in.
124128
badHostHandler http.Handler
125129

130+
// badRequestHandler is the handler used when the AllowRequestFunc rejects a request.
131+
badRequestHandler http.Handler
132+
126133
// cRegexAllowedHosts saves the compiled regular expressions of the AllowedHosts
127134
// option for subsequent use in processRequest
128135
cRegexAllowedHosts []*regexp.Regexp
@@ -146,8 +153,9 @@ func New(options ...Options) *Secure {
146153
o.nonceEnabled = strings.Contains(o.ContentSecurityPolicy, "%[1]s") || strings.Contains(o.ContentSecurityPolicyReportOnly, "%[1]s")
147154

148155
s := &Secure{
149-
opt: o,
150-
badHostHandler: http.HandlerFunc(defaultBadHostHandler),
156+
opt: o,
157+
badHostHandler: http.HandlerFunc(defaultBadHostHandler),
158+
badRequestHandler: http.HandlerFunc(defaultBadRequestHandler),
151159
}
152160

153161
if s.opt.AllowedHostsAreRegex {
@@ -174,6 +182,11 @@ func (s *Secure) SetBadHostHandler(handler http.Handler) {
174182
s.badHostHandler = handler
175183
}
176184

185+
// SetBadRequestHandler sets the handler to call when the AllowRequestFunc rejects a request.
186+
func (s *Secure) SetBadRequestHandler(handler http.Handler) {
187+
s.badRequestHandler = handler
188+
}
189+
177190
// Handler implements the http.HandlerFunc for integration with the standard net/http lib.
178191
func (s *Secure) Handler(h http.Handler) http.Handler {
179192
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -294,14 +307,7 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He
294307
}
295308

296309
// Allowed hosts check.
297-
combinedAllowedHosts := s.opt.AllowedHosts
298-
var allowedFuncHosts []string
299-
if s.opt.AllowedHostsFunc != nil {
300-
allowedFuncHosts = s.opt.AllowedHostsFunc()
301-
combinedAllowedHosts = append(combinedAllowedHosts, allowedFuncHosts...)
302-
}
303-
304-
if len(combinedAllowedHosts) > 0 && !s.opt.IsDevelopment {
310+
if len(s.opt.AllowedHosts) > 0 && !s.opt.IsDevelopment {
305311
isGoodHost := false
306312
if s.opt.AllowedHostsAreRegex {
307313
for _, allowedHost := range s.cRegexAllowedHosts {
@@ -310,14 +316,8 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He
310316
break
311317
}
312318
}
313-
for _, allowedHost := range allowedFuncHosts {
314-
if strings.EqualFold(allowedHost, host) {
315-
isGoodHost = true
316-
break
317-
}
318-
}
319319
} else {
320-
for _, allowedHost := range combinedAllowedHosts {
320+
for _, allowedHost := range s.opt.AllowedHosts {
321321
if strings.EqualFold(allowedHost, host) {
322322
isGoodHost = true
323323
break
@@ -380,6 +380,12 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He
380380
}
381381
}
382382

383+
// If the AllowRequestFunc is set, call it and exit early if needed.
384+
if s.opt.AllowRequestFunc != nil && !s.opt.AllowRequestFunc(r) {
385+
s.badRequestHandler.ServeHTTP(w, r)
386+
return nil, nil, fmt.Errorf("request not allowed")
387+
}
388+
383389
// Create our header container.
384390
responseHeader := make(http.Header)
385391

secure_test.go

+18-22
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"net/http"
66
"net/http/httptest"
77
"reflect"
8+
"strings"
89
"testing"
910
)
1011

@@ -1448,57 +1449,52 @@ func TestMultipleCustomSecureContextKeys(t *testing.T) {
14481449
expect(t, s2Headers.Get(featurePolicyHeader), s2.opt.FeaturePolicy)
14491450
}
14501451

1451-
func TestAllowHostsFunc(t *testing.T) {
1452+
func TestAllowRequestFuncTrue(t *testing.T) {
14521453
s := New(Options{
1453-
AllowedHostsFunc: func() []string { return []string{"www.allow-func.com"} },
1454+
AllowRequestFunc: func(r *http.Request) bool { return true },
14541455
})
14551456

14561457
res := httptest.NewRecorder()
14571458
req, _ := http.NewRequest("GET", "/foo", nil)
1458-
req.Host = "www.allow-func.com"
1459+
req.Host = "www.allow-request.com"
14591460

14601461
s.Handler(myHandler).ServeHTTP(res, req)
14611462

14621463
expect(t, res.Code, http.StatusOK)
14631464
expect(t, res.Body.String(), `bar`)
14641465
}
14651466

1466-
func TestAllowHostsFuncWithAllowedHostsList(t *testing.T) {
1467+
func TestAllowRequestFuncFalse(t *testing.T) {
14671468
s := New(Options{
1468-
AllowedHosts: []string{"www.allow.com"},
1469-
AllowedHostsFunc: func() []string { return []string{"www.allow-func.com"} },
1469+
AllowRequestFunc: func(r *http.Request) bool { return false },
14701470
})
14711471

14721472
res := httptest.NewRecorder()
14731473
req, _ := http.NewRequest("GET", "/foo", nil)
1474-
req.Host = "www.allow.com"
1474+
req.Host = "www.deny-request.com"
14751475

14761476
s.Handler(myHandler).ServeHTTP(res, req)
14771477

1478-
expect(t, res.Code, http.StatusOK)
1479-
expect(t, res.Body.String(), `bar`)
1478+
expect(t, res.Code, http.StatusBadRequest)
14801479
}
14811480

1482-
func TestAllowHostsFuncWithAllowedHostsListWithRegex(t *testing.T) {
1481+
func TestBadRequestHandler(t *testing.T) {
14831482
s := New(Options{
1484-
AllowedHosts: []string{"*\\.allow\\.com"},
1485-
AllowedHostsFunc: func() []string { return []string{"foo.bar.allow.com"} },
1486-
AllowedHostsAreRegex: true,
1483+
AllowRequestFunc: func(r *http.Request) bool { return false },
14871484
})
1485+
badRequestFunc := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1486+
http.Error(w, "custom error", http.StatusConflict)
1487+
})
1488+
s.SetBadRequestHandler(badRequestFunc)
14881489

14891490
res := httptest.NewRecorder()
14901491
req, _ := http.NewRequest("GET", "/foo", nil)
1491-
req.Host = "foo.bar.allow.com"
1492-
s.Handler(myHandler).ServeHTTP(res, req)
1493-
expect(t, res.Code, http.StatusOK)
1494-
expect(t, res.Body.String(), `bar`)
1492+
req.Host = "www.deny-request.com"
14951493

1496-
res = httptest.NewRecorder()
1497-
req, _ = http.NewRequest("GET", "/foo", nil)
1498-
req.Host = "bar.allow.com"
14991494
s.Handler(myHandler).ServeHTTP(res, req)
1500-
expect(t, res.Code, http.StatusOK)
1501-
expect(t, res.Body.String(), `bar`)
1495+
1496+
expect(t, res.Code, http.StatusConflict)
1497+
expect(t, strings.TrimSpace(res.Body.String()), `custom error`)
15021498
}
15031499

15041500
/* Test Helpers */

0 commit comments

Comments
 (0)