-
Notifications
You must be signed in to change notification settings - Fork 1
/
cors.go
151 lines (128 loc) · 4.25 KB
/
cors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
// Copyright 2021 Teal.Finance/Garcon contributors
// This file is part of Teal.Finance/Garcon,
// an API and website server under the MIT License.
// SPDX-License-Identifier: MIT
package garcon
import (
"net/http"
"net/url"
"strings"
"github.com/rs/cors"
"github.com/teal-finance/garcon/gg"
)
// MiddlewareCORS is a middleware to handle Cross-Origin Resource Sharing (CORS).
func (g *Garcon) MiddlewareCORS() gg.Middleware {
return g.MiddlewareCORSWithMethodsHeaders(nil, nil)
}
// MiddlewareCORSWithMethodsHeaders is a middleware to handle Cross-Origin Resource Sharing (CORS).
func (g *Garcon) MiddlewareCORSWithMethodsHeaders(methods, headers []string) gg.Middleware {
return MiddlewareCORS(g.allowedOrigins, methods, headers, g.devMode)
}
// MiddlewareCORS uses restrictive CORS values.
func MiddlewareCORS(allowedOrigins, methods, headers []string, debug bool) func(next http.Handler) http.Handler {
c := newCORS(allowedOrigins, methods, headers, debug)
if c.Log != nil {
c.Log = corsLogger{}
}
return c.Handler
}
type corsLogger struct{}
func (corsLogger) Printf(fmt string, a ...any) {
if strings.Contains(fmt, "Actual request") {
return
}
log.Securityf("CORS "+fmt, a...)
}
// DevOrigins provides the development origins:
// - yarn run vite --port 3000
// - yarn run vite preview --port 5000
// - localhost:8085 on multi devices: web auto-reload using https://github.com/synw/fwr
// - flutter run --web-port=8080
// - 192.168.1.x + any port on tablet: mobile app using fast builtin auto-reload.
func DevOrigins() []*url.URL {
return []*url.URL{
{Scheme: "http", Host: "localhost:"},
{Scheme: "http", Host: "192.168.1."},
}
}
func newCORS(allowedOrigins, methods, headers []string, debug bool) *cors.Cors {
if len(methods) == 0 {
// original default: http.MethodGet, http.MethodPost, http.MethodHead
methods = []string{http.MethodGet, http.MethodPost, http.MethodDelete}
}
if len(headers) == 0 {
// original default: "Origin", "Accept", "Content-Type", "X-Requested-With"
headers = []string{"Origin", "Content-Type", "Authorization"}
}
options := cors.Options{
AllowedOrigins: nil,
AllowOriginFunc: allowOriginFunc(allowedOrigins),
AllowOriginRequestFunc: nil,
AllowOriginVaryRequestFunc: nil,
AllowedMethods: methods,
AllowedHeaders: headers,
ExposedHeaders: nil,
MaxAge: 3600 * 24, // https://developer.mozilla.org/docs/Web/HTTP/Headers/Access-Control-Max-Age
AllowCredentials: true,
AllowPrivateNetwork: false,
OptionsPassthrough: false,
OptionsSuccessStatus: http.StatusNoContent,
Debug: debug, // verbose logs
Logger: nil,
}
log.Security("CORS Methods:", options.AllowedMethods)
log.Security("CORS Headers:", options.AllowedHeaders)
log.Securityf("CORS Credentials=%v MaxAge=%v", options.AllowCredentials, options.MaxAge)
if debug {
log.Ok("CORS Debug mode")
}
return cors.New(options)
}
func allowOriginFunc(allowedOrigins []string) func(string) bool {
InsertSchema(allowedOrigins)
switch len(allowedOrigins) {
case 0:
return allOrigins()
case 1:
return oneOrigin(allowedOrigins[0])
default:
return multipleOriginPrefixes(allowedOrigins)
}
}
// InsertSchema inserts "http://" when HTTP schema is missing.
func InsertSchema(urls []string) {
for i, u := range urls {
if !strings.HasPrefix(u, "https://") &&
!strings.HasPrefix(u, "http://") {
urls[i] = "http://" + u
}
}
}
func allOrigins() func(string) bool {
log.Security("CORS Allow all origins")
return func(origin string) bool {
return true
}
}
func oneOrigin(allowedOrigin string) func(string) bool {
log.Security("CORS Allow one origin:", allowedOrigin)
return func(origin string) bool {
if origin == allowedOrigin {
return true
}
log.Security("CORS Refuse", origin, "is not "+allowedOrigin)
return false
}
}
func multipleOriginPrefixes(allowedPrefixes []string) func(origin string) bool {
log.Security("CORS Allow origin prefixes:", allowedPrefixes)
return func(origin string) bool {
for _, prefix := range allowedPrefixes {
if strings.HasPrefix(origin, prefix) {
return true
}
}
log.Security("CORS Refuse", origin, "without prefixes", allowedPrefixes)
return false
}
}