Skip to content

Commit 1fe34cd

Browse files
authored
🐛 bug: Handle Unix sockets in adaptor middleware (#3760)
* Fix Fiber v3 adapter for Unix socket testing * test(adaptor): use temp dir for Unix socket and skip on non-Unix * formatted test file and handle error check * resolves issues in adapter test * *net.TCPaddr to net.addr * fixes in adaptor test * fixes in adaptor test * fixes in adaptor test * checked error return in adataptor test * resolved variable shadowing in adaptor test * resolved variable shadowing in adaptor test * resolved variable shadowing in adaptor test * added resolveRemoteAddr function for resolving addr properly * changes in resolveRemoteAddr function * resolved lint error * added test for bad remote address * resolve lint issue in adaptor test * resolve lint issue in adaptor test
1 parent cd273d2 commit 1fe34cd

File tree

2 files changed

+147
-7
lines changed

2 files changed

+147
-7
lines changed

middleware/adaptor/adaptor.go

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package adaptor
22

33
import (
4+
"errors"
5+
"fmt"
46
"io"
57
"net"
68
"net/http"
@@ -138,6 +140,32 @@ func FiberApp(app *fiber.App) http.HandlerFunc {
138140
return handlerFunc(app)
139141
}
140142

143+
func isUnixNetwork(network string) bool {
144+
return network == "unix" || network == "unixgram" || network == "unixpacket"
145+
}
146+
147+
func resolveRemoteAddr(remoteAddr string, localAddr any) (net.Addr, error) {
148+
if addr, ok := localAddr.(net.Addr); ok && isUnixNetwork(addr.Network()) {
149+
return addr, nil
150+
}
151+
152+
resolved, err := net.ResolveTCPAddr("tcp", remoteAddr)
153+
if err == nil {
154+
return resolved, nil
155+
}
156+
157+
var addrErr *net.AddrError
158+
if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" {
159+
remoteAddr = net.JoinHostPort(remoteAddr, "80")
160+
resolved, err2 := net.ResolveTCPAddr("tcp", remoteAddr)
161+
if err2 != nil {
162+
return nil, fmt.Errorf("failed to resolve TCP address after adding port: %w", err2)
163+
}
164+
return resolved, nil
165+
}
166+
return nil, fmt.Errorf("failed to resolve TCP address: %w", err)
167+
}
168+
141169
func handlerFunc(app *fiber.App, h ...fiber.Handler) http.HandlerFunc {
142170
return func(w http.ResponseWriter, r *http.Request) {
143171
req := fasthttp.AcquireRequest()
@@ -164,14 +192,10 @@ func handlerFunc(app *fiber.App, h ...fiber.Handler) http.HandlerFunc {
164192
}
165193
}
166194

167-
if _, _, err := net.SplitHostPort(r.RemoteAddr); err != nil && err.(*net.AddrError).Err == "missing port in address" { //nolint:errorlint,forcetypeassert,errcheck // overlinting
168-
r.RemoteAddr = net.JoinHostPort(r.RemoteAddr, "80")
169-
}
170-
171-
remoteAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
195+
remoteAddr, err := resolveRemoteAddr(r.RemoteAddr, r.Context().Value(http.LocalAddrContextKey))
172196
if err != nil {
173-
http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError)
174-
return
197+
// fallback: fasthttp handles nil remoteAddr
198+
remoteAddr = nil
175199
}
176200

177201
// New fasthttp Ctx from pool

middleware/adaptor/adaptor_test.go

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ import (
1010
"net/http"
1111
"net/http/httptest"
1212
"net/url"
13+
"os"
14+
"path/filepath"
1315
"strings"
1416
"testing"
17+
"time"
1518

1619
"github.com/gofiber/fiber/v3"
1720
"github.com/stretchr/testify/assert"
@@ -663,3 +666,116 @@ func Benchmark_HTTPHandler(b *testing.B) {
663666

664667
require.NoError(b, err)
665668
}
669+
670+
func TestUnixSocketAdaptor(t *testing.T) {
671+
dir := t.TempDir()
672+
socketPath := filepath.Join(dir, "test.sock")
673+
defer func() {
674+
if err := os.Remove(socketPath); err != nil {
675+
t.Logf("cleanup failed: %v", err)
676+
}
677+
}()
678+
679+
app := fiber.New()
680+
app.Get("/hello", func(c fiber.Ctx) error {
681+
return c.SendString("ok")
682+
})
683+
handler := FiberApp(app)
684+
685+
listener, err := net.Listen("unix", socketPath)
686+
if err != nil {
687+
// Skip on platforms where the "unix" network is unsupported
688+
if strings.Contains(err.Error(), "unknown network") ||
689+
strings.Contains(err.Error(), "address family not supported") {
690+
t.Skipf("Unix domain sockets not supported on this platform: %v", err)
691+
}
692+
t.Fatal(err)
693+
}
694+
defer func() {
695+
if closeErr := listener.Close(); closeErr != nil {
696+
t.Logf("listener close failed: %v", closeErr)
697+
}
698+
}()
699+
700+
// start server with timeouts
701+
srv := &http.Server{
702+
Handler: handler,
703+
ReadTimeout: 5 * time.Second,
704+
WriteTimeout: 10 * time.Second,
705+
}
706+
done := make(chan struct{})
707+
go func() {
708+
if serveErr := srv.Serve(listener); serveErr != nil && serveErr != http.ErrServerClosed {
709+
t.Errorf("http server failed: %v", serveErr)
710+
}
711+
close(done)
712+
}()
713+
714+
conn, err := net.Dial("unix", socketPath)
715+
require.NoError(t, err)
716+
defer func() {
717+
if closeErr := conn.Close(); closeErr != nil {
718+
t.Logf("conn close failed: %v", closeErr)
719+
}
720+
}()
721+
722+
// set deadline for both write + read (2s)
723+
require.NoError(t, conn.SetDeadline(time.Now().Add(2*time.Second)))
724+
725+
// write request
726+
_, err = conn.Write([]byte("GET /hello HTTP/1.1\r\nHost: localhost\r\n\r\n"))
727+
require.NoError(t, err)
728+
729+
// read response
730+
buf := make([]byte, 1024)
731+
n, err := conn.Read(buf)
732+
require.NoError(t, err)
733+
734+
// clear deadline to avoid affecting further calls
735+
require.NoError(t, conn.SetDeadline(time.Time{}))
736+
737+
raw := string(buf[:n])
738+
t.Logf("Raw response:\n%s", raw)
739+
require.Contains(t, raw, "HTTP/1.1 200 OK")
740+
require.Contains(t, raw, "ok")
741+
742+
// now shutdown the server explicitly before waiting for done
743+
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
744+
defer cancel()
745+
require.NoError(t, srv.Shutdown(ctx))
746+
747+
select {
748+
case <-done:
749+
case <-time.After(5 * time.Second):
750+
t.Fatal("server shutdown timed out")
751+
}
752+
}
753+
754+
func TestHandlerFunc_FallbackRemoteAddr(t *testing.T) {
755+
app := fiber.New()
756+
app.Get("/", func(c fiber.Ctx) error {
757+
return c.SendString("ok")
758+
})
759+
760+
handler := handlerFunc(app)
761+
762+
// Fake request with bad RemoteAddr
763+
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil)
764+
require.NoError(t, err)
765+
req.RemoteAddr = "bad-addr"
766+
767+
rr := httptest.NewRecorder()
768+
handler(rr, req)
769+
770+
res := rr.Result()
771+
defer func() {
772+
closeErr := res.Body.Close()
773+
require.NoError(t, closeErr)
774+
}()
775+
776+
body, err := io.ReadAll(res.Body)
777+
require.NoError(t, err)
778+
779+
require.Equal(t, http.StatusOK, res.StatusCode)
780+
require.Contains(t, string(body), "ok")
781+
}

0 commit comments

Comments
 (0)