diff --git a/middleware/adaptor/adaptor.go b/middleware/adaptor/adaptor.go index 2b2bc53faa..ea65dfc5d4 100644 --- a/middleware/adaptor/adaptor.go +++ b/middleware/adaptor/adaptor.go @@ -1,6 +1,8 @@ package adaptor import ( + "errors" + "fmt" "io" "net" "net/http" @@ -138,6 +140,32 @@ func FiberApp(app *fiber.App) http.HandlerFunc { return handlerFunc(app) } +func isUnixNetwork(network string) bool { + return network == "unix" || network == "unixgram" || network == "unixpacket" +} + +func resolveRemoteAddr(remoteAddr string, localAddr any) (net.Addr, error) { + if addr, ok := localAddr.(net.Addr); ok && isUnixNetwork(addr.Network()) { + return addr, nil + } + + resolved, err := net.ResolveTCPAddr("tcp", remoteAddr) + if err == nil { + return resolved, nil + } + + var addrErr *net.AddrError + if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" { + remoteAddr = net.JoinHostPort(remoteAddr, "80") + resolved, err2 := net.ResolveTCPAddr("tcp", remoteAddr) + if err2 != nil { + return nil, fmt.Errorf("failed to resolve TCP address after adding port: %w", err2) + } + return resolved, nil + } + return nil, fmt.Errorf("failed to resolve TCP address: %w", err) +} + func handlerFunc(app *fiber.App, h ...fiber.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { req := fasthttp.AcquireRequest() @@ -164,14 +192,10 @@ func handlerFunc(app *fiber.App, h ...fiber.Handler) http.HandlerFunc { } } - if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil && err.(*net.AddrError).Err == "missing port in address" { //nolint:errorlint,forcetypeassert,errcheck // overlinting - r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80") - } - - remoteAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr) + remoteAddr, err := resolveRemoteAddr(r.RemoteAddr, r.Context().Value(http.LocalAddrContextKey)) if err != nil { - http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError) - return + // fallback: fasthttp handles nil remoteAddr + remoteAddr = nil } // New fasthttp Ctx from pool diff --git a/middleware/adaptor/adaptor_test.go b/middleware/adaptor/adaptor_test.go index c3e65fd791..47782151dd 100644 --- a/middleware/adaptor/adaptor_test.go +++ b/middleware/adaptor/adaptor_test.go @@ -10,8 +10,11 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" + "path/filepath" "strings" "testing" + "time" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/assert" @@ -663,3 +666,116 @@ func Benchmark_HTTPHandler(b *testing.B) { require.NoError(b, err) } + +func TestUnixSocketAdaptor(t *testing.T) { + dir := t.TempDir() + socketPath := filepath.Join(dir, "test.sock") + defer func() { + if err := os.Remove(socketPath); err != nil { + t.Logf("cleanup failed: %v", err) + } + }() + + app := fiber.New() + app.Get("/hello", func(c fiber.Ctx) error { + return c.SendString("ok") + }) + handler := FiberApp(app) + + listener, err := net.Listen("unix", socketPath) + if err != nil { + // Skip on platforms where the "unix" network is unsupported + if strings.Contains(err.Error(), "unknown network") || + strings.Contains(err.Error(), "address family not supported") { + t.Skipf("Unix domain sockets not supported on this platform: %v", err) + } + t.Fatal(err) + } + defer func() { + if closeErr := listener.Close(); closeErr != nil { + t.Logf("listener close failed: %v", closeErr) + } + }() + + // start server with timeouts + srv := &http.Server{ + Handler: handler, + ReadTimeout: 5 * time.Second, + WriteTimeout: 10 * time.Second, + } + done := make(chan struct{}) + go func() { + if serveErr := srv.Serve(listener); serveErr != nil && serveErr != http.ErrServerClosed { + t.Errorf("http server failed: %v", serveErr) + } + close(done) + }() + + conn, err := net.Dial("unix", socketPath) + require.NoError(t, err) + defer func() { + if closeErr := conn.Close(); closeErr != nil { + t.Logf("conn close failed: %v", closeErr) + } + }() + + // set deadline for both write + read (2s) + require.NoError(t, conn.SetDeadline(time.Now().Add(2*time.Second))) + + // write request + _, err = conn.Write([]byte("GET /hello HTTP/1.1\r\nHost: localhost\r\n\r\n")) + require.NoError(t, err) + + // read response + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + + // clear deadline to avoid affecting further calls + require.NoError(t, conn.SetDeadline(time.Time{})) + + raw := string(buf[:n]) + t.Logf("Raw response:\n%s", raw) + require.Contains(t, raw, "HTTP/1.1 200 OK") + require.Contains(t, raw, "ok") + + // now shutdown the server explicitly before waiting for done + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + require.NoError(t, srv.Shutdown(ctx)) + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("server shutdown timed out") + } +} + +func TestHandlerFunc_FallbackRemoteAddr(t *testing.T) { + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("ok") + }) + + handler := handlerFunc(app) + + // Fake request with bad RemoteAddr + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil) + require.NoError(t, err) + req.RemoteAddr = "bad-addr" + + rr := httptest.NewRecorder() + handler(rr, req) + + res := rr.Result() + defer func() { + closeErr := res.Body.Close() + require.NoError(t, closeErr) + }() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Equal(t, http.StatusOK, res.StatusCode) + require.Contains(t, string(body), "ok") +}