diff --git a/header_test.go b/header_test.go index 0fcca30806..f82d221a0e 100644 --- a/header_test.go +++ b/header_test.go @@ -2479,6 +2479,12 @@ func verifyResponseHeader(t *testing.T, h *ResponseHeader, expectedStatusCode, e } } +func verifyResponseHeaderConnection(t *testing.T, h *ResponseHeader, expectConnection string) { + if string(h.Peek(HeaderConnection)) != expectConnection { + t.Fatalf("Unexpected Connection %q. Expected %q", h.Peek(HeaderConnection), expectConnection) + } +} + func verifyRequestHeader(t *testing.T, h *RequestHeader, expectedContentLength int, expectedRequestURI, expectedHost, expectedReferer, expectedContentType string) { if h.ContentLength() != expectedContentLength { diff --git a/server.go b/server.go index ec83471a18..0cda0f83de 100644 --- a/server.go +++ b/server.go @@ -375,11 +375,16 @@ type Server struct { // which will close it when needed. KeepHijackedConns bool + + // CloseOnShutdown when true adds a `Connection: close` header when when the server is shutting down. + CloseOnShutdown bool + // StreamRequestBody enables request body streaming, // and calls the handler sooner when given body is // larger then the current limit. StreamRequestBody bool + tlsConfig *tls.Config nextProtos map[string]ServeHandler @@ -2221,6 +2226,7 @@ func (s *Server) serveConn(c net.Conn) (err error) { } connectionClose = connectionClose || ctx.Response.ConnectionClose() + connectionClose = connectionClose || ctx.Response.ConnectionClose() || (s.CloseOnShutdown && atomic.LoadInt32(&s.stop) == 1) if connectionClose { ctx.Response.Header.SetCanonical(strConnection, strClose) } else if !isHTTP11 { diff --git a/server_test.go b/server_test.go index 91bc6a8a79..c7baa3e080 100644 --- a/server_test.go +++ b/server_test.go @@ -3117,7 +3117,70 @@ func TestShutdown(t *testing.T) { t.Errorf("unexpected error: %s", err) } br := bufio.NewReader(conn) - verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") + resp := verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") + verifyResponseHeaderConnection(t, &resp.Header, "") + clientCh <- struct{}{} + }() + time.Sleep(time.Millisecond * 100) + shutdownCh := make(chan struct{}) + go func() { + if err := s.Shutdown(); err != nil { + t.Errorf("unexepcted error: %s", err) + } + shutdownCh <- struct{}{} + }() + done := 0 + for { + select { + case <-time.After(time.Second): + t.Fatal("shutdown took too long") + case <-serveCh: + done++ + case <-clientCh: + done++ + case <-shutdownCh: + done++ + } + if done == 3 { + return + } + } +} + +func TestCloseOnShutdown(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + s := &Server{ + Handler: func(ctx *RequestCtx) { + time.Sleep(time.Millisecond * 500) + ctx.Success("aaa/bbb", []byte("real response")) + }, + CloseOnShutdown: true, + } + serveCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexepcted error: %s", err) + } + _, err := ln.Dial() + if err == nil { + t.Error("server is still listening") + } + serveCh <- struct{}{} + }() + clientCh := make(chan struct{}) + go func() { + conn, err := ln.Dial() + if err != nil { + t.Errorf("unexepcted error: %s", err) + } + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %s", err) + } + br := bufio.NewReader(conn) + resp := verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") + verifyResponseHeaderConnection(t, &resp.Header, "close") clientCh <- struct{}{} }() time.Sleep(time.Millisecond * 100) @@ -3580,7 +3643,7 @@ func TestIncompleteBodyReturnsUnexpectedEOF(t *testing.T) { } } -func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) { +func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) *Response { var resp Response if err := resp.Read(r); err != nil { t.Fatalf("Unexpected error when parsing response: %s", err) @@ -3590,6 +3653,7 @@ func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expec t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), []byte(expectedBody)) } verifyResponseHeader(t, &resp.Header, expectedStatusCode, len(resp.Body()), expectedContentType) + return &resp } type readWriter struct {