Skip to content
58 changes: 45 additions & 13 deletions middleware/adaptor/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/valyala/fasthttp/fasthttpadaptor"
)

// disableLogger implements the fasthttp Logger interface and discards log output.
type disableLogger struct{}

// Printf implements the fasthttp Logger interface and discards log output.
Expand Down Expand Up @@ -53,13 +54,32 @@ func ConvertRequest(c fiber.Ctx, forServer bool) (*http.Request, error) {
}

// CopyContextToFiberContext copies the values of context.Context to a fasthttp.RequestCtx.
func CopyContextToFiberContext(context any, requestContext *fasthttp.RequestCtx) {
contextValues := reflect.ValueOf(context).Elem()
contextKeys := reflect.TypeOf(context).Elem()

if contextKeys.Kind() != reflect.Struct {
// This function safely handles struct fields, using unsafe operations only when necessary for unexported fields.
// Deprecated: This function uses reflection and unsafe pointers; consider using explicit context passing.
func CopyContextToFiberContext(src any, requestContext *fasthttp.RequestCtx) {
v := reflect.ValueOf(src)
if !v.IsValid() {
return
}
// Deref pointer chains
for v.Kind() == reflect.Ptr {
if v.IsNil() {
return
}
v = v.Elem()
}
t := v.Type()
if t.Kind() != reflect.Struct {
return
}
// Ensure addressable for safe unsafe-access of unexported fields
if !v.CanAddr() {
tmp := reflect.New(t)
tmp.Elem().Set(v)
v = tmp.Elem()
}
contextValues := v
contextKeys := t

var lastKey any
for i := 0; i < contextValues.NumField(); i++ {
Expand All @@ -70,8 +90,8 @@ func CopyContextToFiberContext(context any, requestContext *fasthttp.RequestCtx)
break
}

// Use unsafe to access potentially unexported fields.
if reflectValue.CanAddr() {
// Avoid unsafe access for unexported fields; use safe reflection where possible
if !reflectValue.CanInterface() {
/* #nosec */
reflectValue = reflect.NewAt(reflectValue.Type(), unsafe.Pointer(reflectValue.UnsafeAddr())).Elem()
}
Expand All @@ -84,7 +104,7 @@ func CopyContextToFiberContext(context any, requestContext *fasthttp.RequestCtx)
case "val":
if lastKey != nil {
requestContext.SetUserValue(lastKey, reflectValue.Interface())
lastKey = nil // Reset lastKey after setting the value
lastKey = nil
}
default:
continue
Expand All @@ -97,7 +117,6 @@ func HTTPMiddleware(mw func(http.Handler) http.Handler) fiber.Handler {
return func(c fiber.Ctx) error {
var next bool
nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
// Convert again in case request may modify by middleware
next = true
c.Request().Header.SetMethod(r.Method)
c.Request().SetRequestURI(r.RequestURI)
Expand Down Expand Up @@ -149,13 +168,21 @@ func resolveRemoteAddr(remoteAddr string, localAddr any) (net.Addr, error) {
return addr, nil
}

// Validate input to prevent malformed addresses
if remoteAddr == "" {
return nil, errors.New("remote address cannot be empty")
}

resolved, err := net.ResolveTCPAddr("tcp", remoteAddr)
if err == nil {
return resolved, nil
}

var addrErr *net.AddrError
if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" {
if len(remoteAddr) > 253 { // Max hostname length
return nil, errors.New("remote address too long")
}
remoteAddr = net.JoinHostPort(remoteAddr, "80")
resolved, err2 := net.ResolveTCPAddr("tcp", remoteAddr)
if err2 != nil {
Expand All @@ -171,9 +198,15 @@ func handlerFunc(app *fiber.App, h ...fiber.Handler) http.HandlerFunc {
req := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(req)

// Convert net/http -> fasthttp request
// Convert net/http -> fasthttp request with size limit
const maxBodySize = 10 * 1024 * 1024 // 10MB limit
if r.Body != nil {
n, err := io.Copy(req.BodyWriter(), r.Body)
if r.ContentLength > maxBodySize {
http.Error(w, utils.StatusMessage(fiber.StatusRequestEntityTooLarge), fiber.StatusRequestEntityTooLarge)
return
}
limitedReader := io.LimitReader(r.Body, maxBodySize)
n, err := io.Copy(req.BodyWriter(), limitedReader)
req.Header.SetContentLength(int(n))

if err != nil {
Expand All @@ -194,8 +227,7 @@ func handlerFunc(app *fiber.App, h ...fiber.Handler) http.HandlerFunc {

remoteAddr, err := resolveRemoteAddr(r.RemoteAddr, r.Context().Value(http.LocalAddrContextKey))
if err != nil {
// fallback: fasthttp handles nil remoteAddr
remoteAddr = nil
remoteAddr = nil // Fallback to nil
}

// New fasthttp Ctx from pool
Expand Down
Loading
Loading