-
Notifications
You must be signed in to change notification settings - Fork 1
/
csp.go
112 lines (93 loc) · 2.84 KB
/
csp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package bassinet
import (
"fmt"
"net/http"
"strings"
"unicode"
)
// CSPOptions takes directives as a map of strings and a flag to
// set the Content-Security-Policy-Report-Only header
type CSPOptions struct {
directives map[string][]string
reportOnly bool
}
var defaultDirectives map[string][]string = map[string][]string{
"default-src": {"'self'"},
"base-uri": {"'self'"},
"block-all-mixed-content": {},
"font-src": {"'self'", "https:", "data:"},
"frame-ancestors": {"'self'"},
"img-src": {"'self'", "data:"},
"object-src": {"'none'"},
"script-src": {"'self'"},
"script-src-attr": {"'none'"},
"style-src": {"'self'", "https:", "'unsafe-inline'"},
"upgrade-insecure-requests": {},
}
// CSP sets the Content-Security-Policy header
func CSP(o CSPOptions) (Middleware, error) {
directives := map[string][]string{}
if len(o.directives) < 1 {
directives = mergeDirectives(directives, defaultDirectives)
}
directives = mergeDirectives(directives, o.directives)
normedDirectives, err := normalizeDirectives(directives)
if err != nil {
return nil, err
}
if _, ok := directives["default-src"]; !ok {
return nil, fmt.Errorf("Content-Security-Policy needs a default-src but none was provided")
}
serializedDirectives := serializeDirectives(normedDirectives)
headerName := "Content-Security-Policy"
if o.reportOnly {
headerName = "Content-Security-Policy-Report-Only"
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(headerName, serializedDirectives)
next.ServeHTTP(w, r)
})
}, nil
}
func toKebab(s string) (string, error) {
var result string
for _, r := range s {
if r != rune('-') && (int(r) > 128 || !(unicode.IsLetter(r) || unicode.IsNumber(r))) {
return "", fmt.Errorf("Input string contains invalid character %q", r)
}
if unicode.IsUpper(r) {
result += "-" + string(unicode.ToLower(r))
continue
}
result += string(r)
}
return result, nil
}
func normalizeDirectives(directives map[string][]string) (map[string][]string, error) {
normedDirectives := map[string][]string{}
for k, d := range directives {
key, err := toKebab(k)
if err != nil {
return nil, err
}
normedDirectives[key] = append(normedDirectives[key], d...)
}
if len(normedDirectives["defaultSrc"]) == 1 {
return nil, fmt.Errorf("defaultSrc must be set")
}
return normedDirectives, nil
}
func serializeDirectives(directives map[string][]string) string {
var serialized string
for k, v := range directives {
serialized += fmt.Sprintf("%s %s; ", k, strings.Join(v, " "))
}
return serialized
}
func mergeDirectives(d1, d2 map[string][]string) map[string][]string {
for k := range d2 {
d1[k] = d2[k]
}
return d1
}