Skip to content

Commit

Permalink
optimize header operations (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
acoshift committed May 20, 2023
1 parent 528300c commit 87f8ad7
Show file tree
Hide file tree
Showing 18 changed files with 221 additions and 66 deletions.
4 changes: 3 additions & 1 deletion pkg/authn/authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package authn

import (
"net/http"

"github.com/moonrhythm/parapet/pkg/internal/header"
)

// Authenticator middleware
Expand All @@ -27,7 +29,7 @@ func (m Authenticator) ServeHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := m.Authenticate(r); err != nil {
if m.Type != "" {
w.Header().Set("WWW-Authenticate", m.Type)
header.Set(w.Header(), header.WWWAuthenticate, m.Type)
}
m.Forbidden(w, r, err)
return
Expand Down
4 changes: 3 additions & 1 deletion pkg/authn/basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"errors"
"net/http"
"net/url"

"github.com/moonrhythm/parapet/pkg/internal/header"
)

var (
Expand Down Expand Up @@ -45,7 +47,7 @@ func (m BasicAuthenticator) ServeHandler(h http.Handler) http.Handler {
Type: t,
Authenticate: func(r *http.Request) error {
username, password, ok := r.BasicAuth()
r.Header.Del("Authorization")
header.Del(r.Header, header.Authorization)
if !ok {
return ErrMissingAuthorization
}
Expand Down
27 changes: 17 additions & 10 deletions pkg/authn/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"net/url"

"github.com/moonrhythm/parapet/pkg/internal/header"
"github.com/moonrhythm/parapet/pkg/internal/pool"
)

Expand Down Expand Up @@ -38,6 +39,12 @@ func (m ForwardAuthenticator) ServeHandler(h http.Handler) http.Handler {
if m.URL != nil {
urlStr = m.URL.String()
}
for i, h := range m.AuthRequestHeaders {
m.AuthRequestHeaders[i] = http.CanonicalHeaderKey(h)
}
for i, h := range m.AuthResponseHeaders {
m.AuthResponseHeaders[i] = http.CanonicalHeaderKey(h)
}
return Authenticator{
Authenticate: func(r *http.Request) error {
if urlStr == "" {
Expand All @@ -50,21 +57,21 @@ func (m ForwardAuthenticator) ServeHandler(h http.Handler) http.Handler {
}
if len(m.AuthRequestHeaders) == 0 {
req.Header = r.Header.Clone()
req.Header.Del("Content-Length")
header.Del(req.Header, header.ContentLength)
} else {
for _, h := range m.AuthRequestHeaders {
req.Header.Del(h)
header.Del(req.Header, h)
for _, v := range r.Header.Values(h) {
req.Header.Add(h, v)
header.Add(req.Header, h, v)
}
}
}

req.Header.Set("X-Forwarded-Method", r.Method)
req.Header.Set("X-Forwarded-Host", r.Host)
req.Header.Set("X-Forwarded-Uri", r.RequestURI)
req.Header.Set("X-Forwarded-Proto", r.Header.Get("X-Forwarded-Proto"))
req.Header.Set("X-Forwarded-For", r.Header.Get("X-Forwarded-For"))
header.Set(req.Header, header.XForwardedMethod, r.Method)
header.Set(req.Header, header.XForwardedHost, r.Host)
header.Set(req.Header, header.XForwardedURI, r.RequestURI)
header.Set(req.Header, header.XForwardedProto, header.Get(r.Header, header.XForwardedProto))
header.Set(req.Header, header.XForwardedFor, header.Get(r.Header, header.XForwardedFor))

resp, err := client.Do(req)
if err != nil {
Expand All @@ -86,9 +93,9 @@ func (m ForwardAuthenticator) ServeHandler(h http.Handler) http.Handler {
resp.Body.Close()

for _, h := range m.AuthResponseHeaders {
r.Header.Del(h)
header.Del(r.Header, h)
for _, v := range resp.Header.Values(h) {
r.Header.Add(h, v)
header.Add(r.Header, h, v)
}
}

Expand Down
31 changes: 11 additions & 20 deletions pkg/compress/compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ import (
"mime"
"net"
"net/http"
"net/textproto"
"strconv"
"strings"
"sync"

"github.com/moonrhythm/parapet/pkg/internal/header"
)

// Compress is the compress middleware
Expand Down Expand Up @@ -49,27 +50,27 @@ func (m Compress) ServeHandler(h http.Handler) http.Handler {

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// skip if client not support
if !strings.Contains(r.Header.Get("Accept-Encoding"), m.Encoding) {
if !strings.Contains(header.Get(r.Header, header.AcceptEncoding), m.Encoding) {
h.ServeHTTP(w, r)
return
}

// skip if web socket
if r.Header.Get("Sec-WebSocket-Key") != "" {
if header.Exists(r.Header, header.SecWebsocketKey) {
h.ServeHTTP(w, r)
return
}

hh := w.Header()

// skip if already encode
if hh.Get("Content-Encoding") != "" {
if header.Exists(hh, header.ContentEncoding) {
h.ServeHTTP(w, r)
return
}

if m.Vary {
addHeaderIfNotExists(hh, "Vary", "Accept-Encoding")
header.AddIfNotExists(hh, header.Vary, header.AcceptEncoding)
}

cw := &compressWriter{
Expand Down Expand Up @@ -108,13 +109,13 @@ func (w *compressWriter) init() {
h := w.Header()

// skip if already encode
if h.Get("Content-Encoding") != "" {
if header.Exists(h, header.ContentEncoding) {
return
}

// skip if length < min length
if w.minLength > 0 {
if sl := h.Get("Content-Length"); sl != "" {
if sl := header.Get(h, header.ContentLength); sl != "" {
l, _ := strconv.Atoi(sl)
if l > 0 && l < w.minLength {
return
Expand All @@ -124,7 +125,7 @@ func (w *compressWriter) init() {

// skip if no match type
if _, ok := w.types["*"]; !ok {
ct, _, err := mime.ParseMediaType(h.Get("Content-Type"))
ct, _, err := mime.ParseMediaType(header.Get(h, header.ContentType))
if err != nil {
ct = "application/octet-stream"
}
Expand All @@ -135,8 +136,8 @@ func (w *compressWriter) init() {

w.encoder = w.pool.Get().(Compressor)
w.encoder.Reset(w.ResponseWriter)
h.Del("Content-Length")
h.Set("Content-Encoding", w.encoding)
header.Del(h, header.ContentLength)
header.Set(h, header.ContentEncoding, w.encoding)
}

func (w *compressWriter) Write(b []byte) (int, error) {
Expand Down Expand Up @@ -195,13 +196,3 @@ func (w *compressWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
}
return nil, nil, http.ErrNotSupported
}

func addHeaderIfNotExists(h http.Header, key, value string) {
key = textproto.CanonicalMIMEHeaderKey(key)
for _, v := range h[key] {
if v == value {
return
}
}
h.Add(key, value)
}
30 changes: 16 additions & 14 deletions pkg/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"strconv"
"strings"
"time"

"github.com/moonrhythm/parapet/pkg/internal/header"
)

// New creates new default cors middleware for public api
Expand Down Expand Up @@ -50,37 +52,37 @@ func (m CORS) ServeHandler(h http.Handler) http.Handler {
headers := make(http.Header)

if m.AllowCredentials {
preflightHeaders.Set("Access-Control-Allow-Credentials", "true")
headers.Set("Access-Control-Allow-Credentials", "true")
header.Set(preflightHeaders, header.AccessControlAllowCredentials, "true")
header.Set(headers, header.AccessControlAllowCredentials, "true")
}
if len(m.AllowMethods) > 0 {
preflightHeaders.Set("Access-Control-Allow-Methods", strings.Join(m.AllowMethods, ","))
header.Set(preflightHeaders, header.AccessControlAllowMethods, strings.Join(m.AllowMethods, ","))
}
if len(m.AllowHeaders) > 0 {
preflightHeaders.Set("Access-Control-Allow-Headers", strings.Join(m.AllowHeaders, ","))
header.Set(preflightHeaders, header.AccessControlAllowHeaders, strings.Join(m.AllowHeaders, ","))
}
if len(m.ExposeHeaders) > 0 {
headers.Set("Access-Control-Expose-Headers", strings.Join(m.ExposeHeaders, ","))
header.Set(headers, header.AccessControlExposeHeaders, strings.Join(m.ExposeHeaders, ","))
}
if m.MaxAge > time.Duration(0) {
preflightHeaders.Set("Access-Control-Max-Age", strconv.FormatInt(int64(m.MaxAge/time.Second), 10))
header.Set(preflightHeaders, header.AccessControlMaxAge, strconv.FormatInt(int64(m.MaxAge/time.Second), 10))
}
if m.AllowAllOrigins {
preflightHeaders.Set("Access-Control-Allow-Origin", "*")
headers.Set("Access-Control-Allow-Origin", "*")
header.Set(preflightHeaders, header.AccessControlAllowOrigin, "*")
header.Set(headers, header.AccessControlAllowOrigin, "*")
} else {
preflightHeaders.Add("Vary", "Origin")
preflightHeaders.Add("Vary", "Access-Control-Request-Method")
preflightHeaders.Add("Vary", "Access-Control-Request-Headers")
headers.Set("Vary", "Origin")
header.Add(preflightHeaders, header.Vary, header.Origin)
header.Add(preflightHeaders, header.Vary, header.AccessControlRequestMethod)
header.Add(preflightHeaders, header.Vary, header.AccessControlRequestHeaders)
header.Set(headers, header.Vary, header.Origin)
}

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if origin := r.Header.Get("Origin"); len(origin) > 0 {
if origin := header.Get(r.Header, header.Origin); len(origin) > 0 {
h := w.Header()
if !m.AllowAllOrigins {
if m.AllowOrigins(origin) {
h.Set("Access-Control-Allow-Origin", origin)
header.Set(h, header.AccessControlAllowOrigin, origin)
} else {
w.WriteHeader(http.StatusForbidden)
return
Expand Down
5 changes: 2 additions & 3 deletions pkg/headers/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ package headers

import (
"net/http"
"net/textproto"
)

// MapRequest creates new request interceptor for map a header
func MapRequest(header string, mapper func(string) string) *RequestInterceptor {
header = textproto.CanonicalMIMEHeaderKey(header)
header = http.CanonicalHeaderKey(header)

return InterceptRequest(func(h http.Header) {
for i, v := range h[header] {
Expand All @@ -18,7 +17,7 @@ func MapRequest(header string, mapper func(string) string) *RequestInterceptor {

// MapResponse creates new response interceptor for map a header
func MapResponse(header string, mapper func(string) string) *ResponseInterceptor {
header = textproto.CanonicalMIMEHeaderKey(header)
header = http.CanonicalHeaderKey(header)

return InterceptResponse(func(w ResponseHeaderWriter) {
h := w.Header()
Expand Down
4 changes: 3 additions & 1 deletion pkg/hsts/hsts.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"net/http"
"strconv"
"time"

"github.com/moonrhythm/parapet/pkg/internal/header"
)

// HSTS middleware
Expand Down Expand Up @@ -42,7 +44,7 @@ func (m HSTS) ServeHandler(h http.Handler) http.Handler {
}

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Strict-Transport-Security", hs)
header.Set(w.Header(), header.StrictTransportSecurity, hs)
h.ServeHTTP(w, r)
})
}
88 changes: 88 additions & 0 deletions pkg/internal/header/header.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package header

import (
"net/http"
)

// Headers in canonical format
const (
AcceptEncoding = "Accept-Encoding"
AccessControlAllowCredentials = "Access-Control-Allow-Credentials"
AccessControlAllowHeaders = "Access-Control-Allow-Headers"
AccessControlAllowMethods = "Access-Control-Allow-Methods"
AccessControlAllowOrigin = "Access-Control-Allow-Origin"
AccessControlExposeHeaders = "Access-Control-Expose-Headers"
AccessControlMaxAge = "Access-Control-Max-Age"
AccessControlRequestHeaders = "Access-Control-Request-Headers"
AccessControlRequestMethod = "Access-Control-Request-Method"
Authorization = "Authorization"
ContentEncoding = "Content-Encoding"
ContentLength = "Content-Length"
ContentType = "Content-Type"
Origin = "Origin"
RetryAfter = "Retry-After"
SecWebsocketKey = "Sec-Websocket-Key"
StrictTransportSecurity = "Strict-Transport-Security"
Upgrade = "Upgrade"
Vary = "Vary"
WWWAuthenticate = "Www-Authenticate"
XForwardedFor = "X-Forwarded-For"
XForwardedHost = "X-Forwarded-Host"
XForwardedMethod = "X-Forwarded-Method"
XForwardedProto = "X-Forwarded-Proto"
XForwardedURI = "X-Forwarded-Uri"
XRealIP = "X-Real-Ip"
XRequestID = "X-Request-Id"
)

func AddIfNotExists(h http.Header, key, value string) {
for _, v := range h[key] {
if v == value {
return
}
}
h[key] = append(h[key], value)
}

func Get(h http.Header, key string) string {
if h == nil {
return ""
}
v := h[key]
if len(v) == 0 {
return ""
}
return v[0]
}

func Exists(h http.Header, key string) bool {
if h == nil {
return false
}
v := h[key]
if len(v) == 0 {
return false
}
return v[0] != ""
}

func Del(h http.Header, key string) {
if h == nil {
return
}
delete(h, key)
}

func Set(h http.Header, key, value string) {
if h == nil {
return
}
h[key] = []string{value}
}

func Add(h http.Header, key, value string) {
if h == nil {
return
}
h[key] = append(h[key], value)
}
Loading

0 comments on commit 87f8ad7

Please sign in to comment.