From e4d90046558491a1210b08fcabf4fcf4690191c6 Mon Sep 17 00:00:00 2001 From: Thanatat Tamtan Date: Sat, 20 May 2023 21:31:07 +0700 Subject: [PATCH] optimize header operations --- pkg/authn/authn.go | 4 +- pkg/authn/basic.go | 4 +- pkg/authn/forward.go | 27 +++++---- pkg/compress/compress.go | 31 ++++------- pkg/cors/cors.go | 30 +++++----- pkg/headers/map.go | 5 +- pkg/hsts/hsts.go | 4 +- pkg/internal/header/header.go | 88 ++++++++++++++++++++++++++++++ pkg/internal/header/header_test.go | 46 ++++++++++++++++ pkg/logger/logger.go | 8 ++- pkg/logger/record.go | 2 +- pkg/ratelimit/ratelimit.go | 6 +- pkg/redirect/https.go | 4 +- pkg/redirect/nonwww.go | 4 +- pkg/redirect/www.go | 4 +- pkg/requestid/requestid.go | 10 ++-- pkg/stackdriver/trace.go | 4 +- pkg/upstream/transport.go | 6 +- 18 files changed, 221 insertions(+), 66 deletions(-) create mode 100644 pkg/internal/header/header.go create mode 100644 pkg/internal/header/header_test.go diff --git a/pkg/authn/authn.go b/pkg/authn/authn.go index fda2e0e..89c1521 100644 --- a/pkg/authn/authn.go +++ b/pkg/authn/authn.go @@ -2,6 +2,8 @@ package authn import ( "net/http" + + "github.com/moonrhythm/parapet/pkg/internal/header" ) // Authenticator middleware @@ -27,7 +29,7 @@ func (m Authenticator) ServeHandler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := m.Authenticate(r); err != nil { if m.Type != "" { - w.Header().Set("WWW-Authenticate", m.Type) + header.Set(w.Header(), header.WWWAuthenticate, m.Type) } m.Forbidden(w, r, err) return diff --git a/pkg/authn/basic.go b/pkg/authn/basic.go index c928a25..b812a01 100644 --- a/pkg/authn/basic.go +++ b/pkg/authn/basic.go @@ -5,6 +5,8 @@ import ( "errors" "net/http" "net/url" + + "github.com/moonrhythm/parapet/pkg/internal/header" ) var ( @@ -45,7 +47,7 @@ func (m BasicAuthenticator) ServeHandler(h http.Handler) http.Handler { Type: t, Authenticate: func(r *http.Request) error { username, password, ok := r.BasicAuth() - r.Header.Del("Authorization") + header.Del(r.Header, header.Authorization) if !ok { return ErrMissingAuthorization } diff --git a/pkg/authn/forward.go b/pkg/authn/forward.go index aee5c6b..a8cfee7 100644 --- a/pkg/authn/forward.go +++ b/pkg/authn/forward.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" + "github.com/moonrhythm/parapet/pkg/internal/header" "github.com/moonrhythm/parapet/pkg/internal/pool" ) @@ -38,6 +39,12 @@ func (m ForwardAuthenticator) ServeHandler(h http.Handler) http.Handler { if m.URL != nil { urlStr = m.URL.String() } + for i, h := range m.AuthRequestHeaders { + m.AuthRequestHeaders[i] = http.CanonicalHeaderKey(h) + } + for i, h := range m.AuthResponseHeaders { + m.AuthResponseHeaders[i] = http.CanonicalHeaderKey(h) + } return Authenticator{ Authenticate: func(r *http.Request) error { if urlStr == "" { @@ -50,21 +57,21 @@ func (m ForwardAuthenticator) ServeHandler(h http.Handler) http.Handler { } if len(m.AuthRequestHeaders) == 0 { req.Header = r.Header.Clone() - req.Header.Del("Content-Length") + header.Del(req.Header, header.ContentLength) } else { for _, h := range m.AuthRequestHeaders { - req.Header.Del(h) + header.Del(req.Header, h) for _, v := range r.Header.Values(h) { - req.Header.Add(h, v) + header.Add(req.Header, h, v) } } } - req.Header.Set("X-Forwarded-Method", r.Method) - req.Header.Set("X-Forwarded-Host", r.Host) - req.Header.Set("X-Forwarded-Uri", r.RequestURI) - req.Header.Set("X-Forwarded-Proto", r.Header.Get("X-Forwarded-Proto")) - req.Header.Set("X-Forwarded-For", r.Header.Get("X-Forwarded-For")) + header.Set(req.Header, header.XForwardedMethod, r.Method) + header.Set(req.Header, header.XForwardedHost, r.Host) + header.Set(req.Header, header.XForwardedURI, r.RequestURI) + header.Set(req.Header, header.XForwardedProto, header.Get(r.Header, header.XForwardedProto)) + header.Set(req.Header, header.XForwardedFor, header.Get(r.Header, header.XForwardedFor)) resp, err := client.Do(req) if err != nil { @@ -86,9 +93,9 @@ func (m ForwardAuthenticator) ServeHandler(h http.Handler) http.Handler { resp.Body.Close() for _, h := range m.AuthResponseHeaders { - r.Header.Del(h) + header.Del(r.Header, h) for _, v := range resp.Header.Values(h) { - r.Header.Add(h, v) + header.Add(r.Header, h, v) } } diff --git a/pkg/compress/compress.go b/pkg/compress/compress.go index 73f9608..10107a6 100644 --- a/pkg/compress/compress.go +++ b/pkg/compress/compress.go @@ -6,10 +6,11 @@ import ( "mime" "net" "net/http" - "net/textproto" "strconv" "strings" "sync" + + "github.com/moonrhythm/parapet/pkg/internal/header" ) // Compress is the compress middleware @@ -49,13 +50,13 @@ func (m Compress) ServeHandler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // skip if client not support - if !strings.Contains(r.Header.Get("Accept-Encoding"), m.Encoding) { + if !strings.Contains(header.Get(r.Header, header.AcceptEncoding), m.Encoding) { h.ServeHTTP(w, r) return } // skip if web socket - if r.Header.Get("Sec-WebSocket-Key") != "" { + if header.Exists(r.Header, header.SecWebsocketKey) { h.ServeHTTP(w, r) return } @@ -63,13 +64,13 @@ func (m Compress) ServeHandler(h http.Handler) http.Handler { hh := w.Header() // skip if already encode - if hh.Get("Content-Encoding") != "" { + if header.Exists(hh, header.ContentEncoding) { h.ServeHTTP(w, r) return } if m.Vary { - addHeaderIfNotExists(hh, "Vary", "Accept-Encoding") + header.AddIfNotExists(hh, header.Vary, header.AcceptEncoding) } cw := &compressWriter{ @@ -108,13 +109,13 @@ func (w *compressWriter) init() { h := w.Header() // skip if already encode - if h.Get("Content-Encoding") != "" { + if header.Exists(h, header.ContentEncoding) { return } // skip if length < min length if w.minLength > 0 { - if sl := h.Get("Content-Length"); sl != "" { + if sl := header.Get(h, header.ContentLength); sl != "" { l, _ := strconv.Atoi(sl) if l > 0 && l < w.minLength { return @@ -124,7 +125,7 @@ func (w *compressWriter) init() { // skip if no match type if _, ok := w.types["*"]; !ok { - ct, _, err := mime.ParseMediaType(h.Get("Content-Type")) + ct, _, err := mime.ParseMediaType(header.Get(h, header.ContentType)) if err != nil { ct = "application/octet-stream" } @@ -135,8 +136,8 @@ func (w *compressWriter) init() { w.encoder = w.pool.Get().(Compressor) w.encoder.Reset(w.ResponseWriter) - h.Del("Content-Length") - h.Set("Content-Encoding", w.encoding) + header.Del(h, header.ContentLength) + header.Set(h, header.ContentEncoding, w.encoding) } func (w *compressWriter) Write(b []byte) (int, error) { @@ -195,13 +196,3 @@ func (w *compressWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { } return nil, nil, http.ErrNotSupported } - -func addHeaderIfNotExists(h http.Header, key, value string) { - key = textproto.CanonicalMIMEHeaderKey(key) - for _, v := range h[key] { - if v == value { - return - } - } - h.Add(key, value) -} diff --git a/pkg/cors/cors.go b/pkg/cors/cors.go index 576a961..b3339df 100644 --- a/pkg/cors/cors.go +++ b/pkg/cors/cors.go @@ -5,6 +5,8 @@ import ( "strconv" "strings" "time" + + "github.com/moonrhythm/parapet/pkg/internal/header" ) // New creates new default cors middleware for public api @@ -50,37 +52,37 @@ func (m CORS) ServeHandler(h http.Handler) http.Handler { headers := make(http.Header) if m.AllowCredentials { - preflightHeaders.Set("Access-Control-Allow-Credentials", "true") - headers.Set("Access-Control-Allow-Credentials", "true") + header.Set(preflightHeaders, header.AccessControlAllowCredentials, "true") + header.Set(headers, header.AccessControlAllowCredentials, "true") } if len(m.AllowMethods) > 0 { - preflightHeaders.Set("Access-Control-Allow-Methods", strings.Join(m.AllowMethods, ",")) + header.Set(preflightHeaders, header.AccessControlAllowMethods, strings.Join(m.AllowMethods, ",")) } if len(m.AllowHeaders) > 0 { - preflightHeaders.Set("Access-Control-Allow-Headers", strings.Join(m.AllowHeaders, ",")) + header.Set(preflightHeaders, header.AccessControlAllowHeaders, strings.Join(m.AllowHeaders, ",")) } if len(m.ExposeHeaders) > 0 { - headers.Set("Access-Control-Expose-Headers", strings.Join(m.ExposeHeaders, ",")) + header.Set(headers, header.AccessControlExposeHeaders, strings.Join(m.ExposeHeaders, ",")) } if m.MaxAge > time.Duration(0) { - preflightHeaders.Set("Access-Control-Max-Age", strconv.FormatInt(int64(m.MaxAge/time.Second), 10)) + header.Set(preflightHeaders, header.AccessControlMaxAge, strconv.FormatInt(int64(m.MaxAge/time.Second), 10)) } if m.AllowAllOrigins { - preflightHeaders.Set("Access-Control-Allow-Origin", "*") - headers.Set("Access-Control-Allow-Origin", "*") + header.Set(preflightHeaders, header.AccessControlAllowOrigin, "*") + header.Set(headers, header.AccessControlAllowOrigin, "*") } else { - preflightHeaders.Add("Vary", "Origin") - preflightHeaders.Add("Vary", "Access-Control-Request-Method") - preflightHeaders.Add("Vary", "Access-Control-Request-Headers") - headers.Set("Vary", "Origin") + header.Add(preflightHeaders, header.Vary, header.Origin) + header.Add(preflightHeaders, header.Vary, header.AccessControlRequestMethod) + header.Add(preflightHeaders, header.Vary, header.AccessControlRequestHeaders) + header.Set(headers, header.Vary, header.Origin) } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if origin := r.Header.Get("Origin"); len(origin) > 0 { + if origin := header.Get(r.Header, header.Origin); len(origin) > 0 { h := w.Header() if !m.AllowAllOrigins { if m.AllowOrigins(origin) { - h.Set("Access-Control-Allow-Origin", origin) + header.Set(h, header.AccessControlAllowOrigin, origin) } else { w.WriteHeader(http.StatusForbidden) return diff --git a/pkg/headers/map.go b/pkg/headers/map.go index 3412f40..4a63e66 100644 --- a/pkg/headers/map.go +++ b/pkg/headers/map.go @@ -2,12 +2,11 @@ package headers import ( "net/http" - "net/textproto" ) // MapRequest creates new request interceptor for map a header func MapRequest(header string, mapper func(string) string) *RequestInterceptor { - header = textproto.CanonicalMIMEHeaderKey(header) + header = http.CanonicalHeaderKey(header) return InterceptRequest(func(h http.Header) { for i, v := range h[header] { @@ -18,7 +17,7 @@ func MapRequest(header string, mapper func(string) string) *RequestInterceptor { // MapResponse creates new response interceptor for map a header func MapResponse(header string, mapper func(string) string) *ResponseInterceptor { - header = textproto.CanonicalMIMEHeaderKey(header) + header = http.CanonicalHeaderKey(header) return InterceptResponse(func(w ResponseHeaderWriter) { h := w.Header() diff --git a/pkg/hsts/hsts.go b/pkg/hsts/hsts.go index 9c5cc69..3293ed0 100644 --- a/pkg/hsts/hsts.go +++ b/pkg/hsts/hsts.go @@ -4,6 +4,8 @@ import ( "net/http" "strconv" "time" + + "github.com/moonrhythm/parapet/pkg/internal/header" ) // HSTS middleware @@ -42,7 +44,7 @@ func (m HSTS) ServeHandler(h http.Handler) http.Handler { } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Strict-Transport-Security", hs) + header.Set(w.Header(), header.StrictTransportSecurity, hs) h.ServeHTTP(w, r) }) } diff --git a/pkg/internal/header/header.go b/pkg/internal/header/header.go new file mode 100644 index 0000000..6a89f70 --- /dev/null +++ b/pkg/internal/header/header.go @@ -0,0 +1,88 @@ +package header + +import ( + "net/http" +) + +// Headers in canonical format +const ( + AcceptEncoding = "Accept-Encoding" + AccessControlAllowCredentials = "Access-Control-Allow-Credentials" + AccessControlAllowHeaders = "Access-Control-Allow-Headers" + AccessControlAllowMethods = "Access-Control-Allow-Methods" + AccessControlAllowOrigin = "Access-Control-Allow-Origin" + AccessControlExposeHeaders = "Access-Control-Expose-Headers" + AccessControlMaxAge = "Access-Control-Max-Age" + AccessControlRequestHeaders = "Access-Control-Request-Headers" + AccessControlRequestMethod = "Access-Control-Request-Method" + Authorization = "Authorization" + ContentEncoding = "Content-Encoding" + ContentLength = "Content-Length" + ContentType = "Content-Type" + Origin = "Origin" + RetryAfter = "Retry-After" + SecWebsocketKey = "Sec-Websocket-Key" + StrictTransportSecurity = "Strict-Transport-Security" + Upgrade = "Upgrade" + Vary = "Vary" + WWWAuthenticate = "Www-Authenticate" + XForwardedFor = "X-Forwarded-For" + XForwardedHost = "X-Forwarded-Host" + XForwardedMethod = "X-Forwarded-Method" + XForwardedProto = "X-Forwarded-Proto" + XForwardedURI = "X-Forwarded-Uri" + XRealIP = "X-Real-Ip" + XRequestID = "X-Request-Id" +) + +func AddIfNotExists(h http.Header, key, value string) { + for _, v := range h[key] { + if v == value { + return + } + } + h[key] = append(h[key], value) +} + +func Get(h http.Header, key string) string { + if h == nil { + return "" + } + v := h[key] + if len(v) == 0 { + return "" + } + return v[0] +} + +func Exists(h http.Header, key string) bool { + if h == nil { + return false + } + v := h[key] + if len(v) == 0 { + return false + } + return v[0] != "" +} + +func Del(h http.Header, key string) { + if h == nil { + return + } + delete(h, key) +} + +func Set(h http.Header, key, value string) { + if h == nil { + return + } + h[key] = []string{value} +} + +func Add(h http.Header, key, value string) { + if h == nil { + return + } + h[key] = append(h[key], value) +} diff --git a/pkg/internal/header/header_test.go b/pkg/internal/header/header_test.go new file mode 100644 index 0000000..b3d8f94 --- /dev/null +++ b/pkg/internal/header/header_test.go @@ -0,0 +1,46 @@ +package header_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/moonrhythm/parapet/pkg/internal/header" +) + +func TestHeaders(t *testing.T) { + list := []string{ + header.AcceptEncoding, + header.AccessControlAllowCredentials, + header.AccessControlAllowHeaders, + header.AccessControlAllowMethods, + header.AccessControlAllowOrigin, + header.AccessControlExposeHeaders, + header.AccessControlMaxAge, + header.AccessControlRequestHeaders, + header.AccessControlRequestMethod, + header.Authorization, + header.ContentEncoding, + header.ContentLength, + header.ContentType, + header.Origin, + header.RetryAfter, + header.SecWebsocketKey, + header.StrictTransportSecurity, + header.Upgrade, + header.Vary, + header.WWWAuthenticate, + header.XForwardedFor, + header.XForwardedHost, + header.XForwardedMethod, + header.XForwardedProto, + header.XForwardedURI, + header.XRealIP, + header.XRequestID, + } + + for _, x := range list { + assert.Equal(t, http.CanonicalHeaderKey(x), x) + } +} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 77f613f..254ff5b 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -9,6 +9,8 @@ import ( "net/http" "os" "time" + + "github.com/moonrhythm/parapet/pkg/internal/header" ) // Logger middleware @@ -41,9 +43,9 @@ func (m Logger) ServeHandler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() - proto := r.Header.Get("X-Forwarded-Proto") - realIP := r.Header.Get("X-Real-Ip") - xff := r.Header.Get("X-Forwarded-For") + proto := header.Get(r.Header, header.XForwardedProto) + realIP := header.Get(r.Header, header.XRealIP) + xff := header.Get(r.Header, header.XForwardedFor) remoteIP, _, _ := net.SplitHostPort(r.RemoteAddr) d := newRecord() diff --git a/pkg/logger/record.go b/pkg/logger/record.go index 4b2f32c..e5e71d3 100644 --- a/pkg/logger/record.go +++ b/pkg/logger/record.go @@ -15,7 +15,7 @@ type record struct { } func newRecord() *record { - return &record{data: make(map[string]interface{})} + return &record{data: make(map[string]interface{}, 18)} } func (r *record) Set(name string, value interface{}) { diff --git a/pkg/ratelimit/ratelimit.go b/pkg/ratelimit/ratelimit.go index dc6f029..64c77a7 100644 --- a/pkg/ratelimit/ratelimit.go +++ b/pkg/ratelimit/ratelimit.go @@ -6,6 +6,8 @@ import ( "net/http" "strconv" "time" + + "github.com/moonrhythm/parapet/pkg/internal/header" ) // New creates new rate limiter @@ -42,7 +44,7 @@ type ExceededHandler func(w http.ResponseWriter, r *http.Request, after time.Dur func defaultExceededHandler(w http.ResponseWriter, _ *http.Request, after time.Duration) { if after > 0 { - w.Header().Set("Retry-After", strconv.FormatInt(int64(after/time.Second), 10)) + header.Set(w.Header(), header.RetryAfter, strconv.FormatInt(int64(after/time.Second), 10)) } http.Error(w, "Too Many Requests", http.StatusTooManyRequests) } @@ -53,7 +55,7 @@ func defaultKey(_ *http.Request) string { // ClientIP returns client ip from request func ClientIP(r *http.Request) string { - ipStr := r.Header.Get("X-Real-Ip") + ipStr := header.Get(r.Header, header.XRealIP) ip := net.ParseIP(ipStr) if ip == nil { return ipStr diff --git a/pkg/redirect/https.go b/pkg/redirect/https.go index 55eb311..5dd4df5 100644 --- a/pkg/redirect/https.go +++ b/pkg/redirect/https.go @@ -2,6 +2,8 @@ package redirect import ( "net/http" + + "github.com/moonrhythm/parapet/pkg/internal/header" ) // HTTPS creates new https redirector @@ -21,7 +23,7 @@ func (m HTTPSRedirector) ServeHandler(h http.Handler) http.Handler { } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - proto := r.Header.Get("X-Forwarded-Proto") + proto := header.Get(r.Header, header.XForwardedProto) if proto == "http" { http.Redirect(w, r, "https://"+r.Host+r.RequestURI, m.StatusCode) return diff --git a/pkg/redirect/nonwww.go b/pkg/redirect/nonwww.go index 1e6315b..6177908 100644 --- a/pkg/redirect/nonwww.go +++ b/pkg/redirect/nonwww.go @@ -3,6 +3,8 @@ package redirect import ( "net/http" "strings" + + "github.com/moonrhythm/parapet/pkg/internal/header" ) // NonWWW creates new non www redirector @@ -24,7 +26,7 @@ func (m NonWWWRedirector) ServeHandler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { host := strings.TrimPrefix(r.Host, "www.") if len(host) < len(r.Host) { - proto := r.Header.Get("X-Forwarded-Proto") + proto := header.Get(r.Header, header.XForwardedProto) http.Redirect(w, r, proto+"://"+host+r.RequestURI, m.StatusCode) return } diff --git a/pkg/redirect/www.go b/pkg/redirect/www.go index 7003f24..1780dcd 100644 --- a/pkg/redirect/www.go +++ b/pkg/redirect/www.go @@ -3,6 +3,8 @@ package redirect import ( "net/http" "strings" + + "github.com/moonrhythm/parapet/pkg/internal/header" ) // WWW creates new www redirector @@ -23,7 +25,7 @@ func (m WWWRedirector) ServeHandler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !strings.HasPrefix(r.Host, "www.") { - proto := r.Header.Get("X-Forwarded-Proto") + proto := header.Get(r.Header, header.XForwardedProto) http.Redirect(w, r, proto+"://www."+r.Host+r.RequestURI, m.StatusCode) return } diff --git a/pkg/requestid/requestid.go b/pkg/requestid/requestid.go index fd8be9d..b5f8286 100644 --- a/pkg/requestid/requestid.go +++ b/pkg/requestid/requestid.go @@ -5,6 +5,7 @@ import ( "github.com/gofrs/uuid" + "github.com/moonrhythm/parapet/pkg/internal/header" "github.com/moonrhythm/parapet/pkg/logger" ) @@ -28,21 +29,22 @@ func New() *RequestID { } // DefaultHeader is the default request, response header -const DefaultHeader = "X-Request-Id" +const DefaultHeader = header.XRequestID // ServeHandler implements middleware interface func (m RequestID) ServeHandler(h http.Handler) http.Handler { if m.Header == "" { m.Header = DefaultHeader } + m.Header = http.CanonicalHeaderKey(m.Header) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - id := r.Header.Get(m.Header) + id := header.Get(r.Header, m.Header) if id == "" || !m.TrustProxy { id = uuid.Must(uuid.NewV4()).String() - r.Header.Set(m.Header, id) + header.Set(r.Header, m.Header, id) } - w.Header().Set(m.Header, id) + header.Set(w.Header(), m.Header, id) logger.Set(r.Context(), "requestId", id) h.ServeHTTP(w, r) diff --git a/pkg/stackdriver/trace.go b/pkg/stackdriver/trace.go index aed6e30..78666f1 100644 --- a/pkg/stackdriver/trace.go +++ b/pkg/stackdriver/trace.go @@ -11,6 +11,8 @@ import ( "go.opencensus.io/trace" "go.opencensus.io/trace/propagation" "google.golang.org/api/option" + + "github.com/moonrhythm/parapet/pkg/internal/header" ) // NewTrace creates new stack driver trace middleware @@ -40,7 +42,7 @@ func (m Trace) ServeHandler(h http.Handler) http.Handler { } if m.FormatSpanName == nil { m.FormatSpanName = func(r *http.Request) string { - proto := r.Header.Get("X-Forwarded-Proto") + proto := header.Get(r.Header, header.XForwardedProto) return proto + "://" + r.Host + r.RequestURI } } diff --git a/pkg/upstream/transport.go b/pkg/upstream/transport.go index 8eff296..58efb3a 100644 --- a/pkg/upstream/transport.go +++ b/pkg/upstream/transport.go @@ -10,6 +10,8 @@ import ( "time" "golang.org/x/net/http2" + + "github.com/moonrhythm/parapet/pkg/internal/header" ) const ( @@ -69,7 +71,7 @@ func (t *H2CTransport) RoundTrip(r *http.Request) (*http.Response, error) { r.URL.Scheme = "http" // Currently Go does not support RFC 8441, downgrade to http1 - if r.Header.Get("Upgrade") != "" { + if header.Exists(r.Header, header.Upgrade) { return t.h1.RoundTrip(r) } @@ -325,7 +327,7 @@ func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) { r.URL.Scheme = "http" // Currently Go does not support RFC 8441, downgrade to http1 - if r.Header.Get("Upgrade") != "" { + if header.Exists(r.Header, header.Upgrade) { tr = t.httpTr } case "unix":