Skip to content

Commit

Permalink
Add option for middleware to set custom remote address (#1009)
Browse files Browse the repository at this point in the history
* Add option for middleware to set custom remote address

* Update Init2 to clear custom context remoteAddr
  • Loading branch information
lafriks authored Apr 23, 2021
1 parent 894272e commit 620f0c8
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
14 changes: 14 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ type RequestCtx struct {
connID uint64
connRequestNum uint64
connTime time.Time
remoteAddr net.Addr

time time.Time

Expand Down Expand Up @@ -1091,6 +1092,9 @@ func (ctx *RequestCtx) IsHead() bool {
//
// Always returns non-nil result.
func (ctx *RequestCtx) RemoteAddr() net.Addr {
if ctx.remoteAddr != nil {
return ctx.remoteAddr
}
if ctx.c == nil {
return zeroTCPAddr
}
Expand All @@ -1101,6 +1105,14 @@ func (ctx *RequestCtx) RemoteAddr() net.Addr {
return addr
}

// SetRemoteAddr sets remote address to the given value.
//
// Set nil value to resore default behaviour for using
// connection remote address.
func (ctx *RequestCtx) SetRemoteAddr(remoteAddr net.Addr) {
ctx.remoteAddr = remoteAddr
}

// LocalAddr returns server address for the given request.
//
// Always returns non-nil result.
Expand Down Expand Up @@ -2524,6 +2536,7 @@ func (s *Server) acquireCtx(c net.Conn) (ctx *RequestCtx) {
// See https://github.com/valyala/httpteleport for details.
func (ctx *RequestCtx) Init2(conn net.Conn, logger Logger, reduceMemoryUsage bool) {
ctx.c = conn
ctx.remoteAddr = nil
ctx.logger.logger = logger
ctx.connID = nextConnID()
ctx.s = fakeServer
Expand Down Expand Up @@ -2636,6 +2649,7 @@ func (s *Server) releaseCtx(ctx *RequestCtx) {
panic("BUG: cannot release timed out RequestCtx")
}
ctx.c = nil
ctx.remoteAddr = nil
ctx.fbr.c = nil
s.ctxPool.Put(ctx)
}
Expand Down
50 changes: 50 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2966,6 +2966,56 @@ func TestServerRemoteAddr(t *testing.T) {
verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.4:8765, remoteIP=1.2.3.4")
}

func TestServerCustomRemoteAddr(t *testing.T) {
t.Parallel()

customRemoteAddrHandler := func(h RequestHandler) RequestHandler {
return func(ctx *RequestCtx) {
ctx.SetRemoteAddr(&net.TCPAddr{
IP: []byte{1, 2, 3, 5},
Port: 0,
})
h(ctx)
}
}

s := &Server{
Handler: customRemoteAddrHandler(func(ctx *RequestCtx) {
h := &ctx.Request.Header
ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, remoteAddr=%s, remoteIP=%s",
h.RequestURI(), ctx.RemoteAddr(), ctx.RemoteIP())))
}),
}

rw := &readWriter{}
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")

rwx := &readWriterRemoteAddr{
rw: rw,
addr: &net.TCPAddr{
IP: []byte{1, 2, 3, 4},
Port: 8765,
},
}

ch := make(chan error)
go func() {
ch <- s.ServeConn(rwx)
}()

select {
case err := <-ch:
if err != nil {
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout")
}

br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.5:0, remoteIP=1.2.3.5")
}

type readWriterRemoteAddr struct {
net.Conn
rw io.ReadWriteCloser
Expand Down

0 comments on commit 620f0c8

Please sign in to comment.