From 620f0c83adb3d72021545ecbb47ff1bef42fbd0c Mon Sep 17 00:00:00 2001 From: Lauris BH Date: Fri, 23 Apr 2021 14:25:02 +0300 Subject: [PATCH] Add option for middleware to set custom remote address (#1009) * Add option for middleware to set custom remote address * Update Init2 to clear custom context remoteAddr --- server.go | 14 ++++++++++++++ server_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/server.go b/server.go index 67f47d8638..9f92cce1b1 100644 --- a/server.go +++ b/server.go @@ -572,6 +572,7 @@ type RequestCtx struct { connID uint64 connRequestNum uint64 connTime time.Time + remoteAddr net.Addr time time.Time @@ -1091,6 +1092,9 @@ func (ctx *RequestCtx) IsHead() bool { // // Always returns non-nil result. func (ctx *RequestCtx) RemoteAddr() net.Addr { + if ctx.remoteAddr != nil { + return ctx.remoteAddr + } if ctx.c == nil { return zeroTCPAddr } @@ -1101,6 +1105,14 @@ func (ctx *RequestCtx) RemoteAddr() net.Addr { return addr } +// SetRemoteAddr sets remote address to the given value. +// +// Set nil value to resore default behaviour for using +// connection remote address. +func (ctx *RequestCtx) SetRemoteAddr(remoteAddr net.Addr) { + ctx.remoteAddr = remoteAddr +} + // LocalAddr returns server address for the given request. // // Always returns non-nil result. @@ -2524,6 +2536,7 @@ func (s *Server) acquireCtx(c net.Conn) (ctx *RequestCtx) { // See https://github.com/valyala/httpteleport for details. func (ctx *RequestCtx) Init2(conn net.Conn, logger Logger, reduceMemoryUsage bool) { ctx.c = conn + ctx.remoteAddr = nil ctx.logger.logger = logger ctx.connID = nextConnID() ctx.s = fakeServer @@ -2636,6 +2649,7 @@ func (s *Server) releaseCtx(ctx *RequestCtx) { panic("BUG: cannot release timed out RequestCtx") } ctx.c = nil + ctx.remoteAddr = nil ctx.fbr.c = nil s.ctxPool.Put(ctx) } diff --git a/server_test.go b/server_test.go index a0583f6408..b8cf7eebd4 100644 --- a/server_test.go +++ b/server_test.go @@ -2966,6 +2966,56 @@ func TestServerRemoteAddr(t *testing.T) { verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.4:8765, remoteIP=1.2.3.4") } +func TestServerCustomRemoteAddr(t *testing.T) { + t.Parallel() + + customRemoteAddrHandler := func(h RequestHandler) RequestHandler { + return func(ctx *RequestCtx) { + ctx.SetRemoteAddr(&net.TCPAddr{ + IP: []byte{1, 2, 3, 5}, + Port: 0, + }) + h(ctx) + } + } + + s := &Server{ + Handler: customRemoteAddrHandler(func(ctx *RequestCtx) { + h := &ctx.Request.Header + ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, remoteAddr=%s, remoteIP=%s", + h.RequestURI(), ctx.RemoteAddr(), ctx.RemoteIP()))) + }), + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") + + rwx := &readWriterRemoteAddr{ + rw: rw, + addr: &net.TCPAddr{ + IP: []byte{1, 2, 3, 4}, + Port: 8765, + }, + } + + ch := make(chan error) + go func() { + ch <- s.ServeConn(rwx) + }() + + select { + case err := <-ch: + if err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout") + } + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.5:0, remoteIP=1.2.3.5") +} + type readWriterRemoteAddr struct { net.Conn rw io.ReadWriteCloser