Skip to content

Commit 2588a36

Browse files
authored
feat: support rest.WithCorsHeaders to customize cors headers (zeromicro#4284)
1 parent c2421be commit 2588a36

File tree

4 files changed

+144
-0
lines changed

4 files changed

+144
-0
lines changed

rest/internal/cors/handlers.go

+5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ const (
2626
originHeader = "Origin"
2727
)
2828

29+
// AddAllowHeaders sets the allowed headers.
30+
func AddAllowHeaders(header http.Header, headers ...string) {
31+
header.Add(allowHeaders, strings.Join(headers, ", "))
32+
}
33+
2934
// NotAllowedHandler handles cross domain not allowed requests.
3035
// At most one origin can be specified, other origins are ignored if given, default to be *.
3136
func NotAllowedHandler(fn func(w http.ResponseWriter), origins ...string) http.Handler {

rest/internal/cors/handlers_test.go

+69
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,80 @@ package cors
33
import (
44
"net/http"
55
"net/http/httptest"
6+
"strings"
67
"testing"
78

89
"github.com/stretchr/testify/assert"
910
)
1011

12+
func TestAddAllowHeaders(t *testing.T) {
13+
tests := []struct {
14+
name string
15+
initial string
16+
headers []string
17+
expected string
18+
}{
19+
{
20+
name: "single header",
21+
initial: "",
22+
headers: []string{"Content-Type"},
23+
expected: "Content-Type",
24+
},
25+
{
26+
name: "multiple headers",
27+
initial: "",
28+
headers: []string{"Content-Type", "Authorization", "X-Requested-With"},
29+
expected: "Content-Type, Authorization, X-Requested-With",
30+
},
31+
{
32+
name: "add to existing headers",
33+
initial: "Origin, Accept",
34+
headers: []string{"Content-Type"},
35+
expected: "Origin, Accept, Content-Type",
36+
},
37+
{
38+
name: "no headers",
39+
initial: "",
40+
headers: []string{},
41+
expected: "",
42+
},
43+
}
44+
45+
for _, tt := range tests {
46+
t.Run(tt.name, func(t *testing.T) {
47+
header := http.Header{}
48+
headers := make(map[string]struct{})
49+
if tt.initial != "" {
50+
header.Set(allowHeaders, tt.initial)
51+
vals := strings.Split(tt.initial, ", ")
52+
for _, v := range vals {
53+
headers[v] = struct{}{}
54+
}
55+
}
56+
for _, h := range tt.headers {
57+
headers[h] = struct{}{}
58+
}
59+
AddAllowHeaders(header, tt.headers...)
60+
var actual []string
61+
vals := header.Values(allowHeaders)
62+
for _, v := range vals {
63+
bunch := strings.Split(v, ", ")
64+
for _, b := range bunch {
65+
if len(b) > 0 {
66+
actual = append(actual, b)
67+
}
68+
}
69+
}
70+
71+
var expect []string
72+
for k := range headers {
73+
expect = append(expect, k)
74+
}
75+
assert.ElementsMatch(t, expect, actual)
76+
})
77+
}
78+
}
79+
1180
func TestCorsHandlerWithOrigins(t *testing.T) {
1281
tests := []struct {
1382
name string

rest/server.go

+12
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,18 @@ func WithCors(origin ...string) RunOption {
161161
}
162162
}
163163

164+
// WithCorsHeaders returns a RunOption to enable CORS with given headers.
165+
func WithCorsHeaders(headers ...string) RunOption {
166+
const allDomains = "*"
167+
168+
return func(server *Server) {
169+
server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, allDomains))
170+
server.router = newCorsRouter(server.router, func(header http.Header) {
171+
cors.AddAllowHeaders(header, headers...)
172+
}, allDomains)
173+
}
174+
}
175+
164176
// WithCustomCors returns a func to enable CORS for given origin, or default to all origins (*),
165177
// fn lets caller customizing the response.
166178
func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(http.ResponseWriter),

rest/server_test.go

+58
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,64 @@ Port: 54321
420420
opt(svr)
421421
}
422422

423+
func TestWithCorsHeaders(t *testing.T) {
424+
tests := []struct {
425+
name string
426+
headers []string
427+
}{
428+
{
429+
name: "single header",
430+
headers: []string{"UserHeader"},
431+
},
432+
{
433+
name: "multiple headers",
434+
headers: []string{"UserHeader", "X-Requested-With"},
435+
},
436+
{
437+
name: "no headers",
438+
headers: []string{},
439+
},
440+
}
441+
442+
for _, tt := range tests {
443+
t.Run(tt.name, func(t *testing.T) {
444+
const configYaml = `
445+
Name: foo
446+
Port: 54321
447+
`
448+
var cnf RestConf
449+
assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf))
450+
rt := router.NewRouter()
451+
svr, err := NewServer(cnf, WithRouter(rt))
452+
assert.Nil(t, err)
453+
defer svr.Stop()
454+
option := WithCorsHeaders(tt.headers...)
455+
option(svr)
456+
457+
// Assuming newCorsRouter sets headers correctly,
458+
// we would need to verify the behavior here. Since we don't have
459+
// direct access to headers, we'll mock newCorsRouter to capture it.
460+
w := httptest.NewRecorder()
461+
svr.ServeHTTP(w, httptest.NewRequest(http.MethodOptions, "/", nil))
462+
463+
vals := w.Header().Values("Access-Control-Allow-Headers")
464+
respHeaders := make(map[string]struct{})
465+
for _, header := range vals {
466+
headers := strings.Split(header, ", ")
467+
for _, h := range headers {
468+
if len(h) > 0 {
469+
respHeaders[h] = struct{}{}
470+
}
471+
}
472+
}
473+
for _, h := range tt.headers {
474+
_, ok := respHeaders[h]
475+
assert.Truef(t, ok, "expected header %s not found", h)
476+
}
477+
})
478+
}
479+
}
480+
423481
func TestServer_PrintRoutes(t *testing.T) {
424482
const (
425483
configYaml = `

0 commit comments

Comments
 (0)