diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 4e5054d7a..87c0f51f8 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -6,9 +6,8 @@ from anyio.abc import ObjectReceiveStream, ObjectSendStream from starlette._utils import collapse_excgroups -from starlette.background import BackgroundTask from starlette.requests import ClientDisconnect, Request -from starlette.responses import ContentStream, Response, StreamingResponse +from starlette.responses import AsyncContentStream, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] @@ -56,6 +55,7 @@ async def wrapped_receive(self) -> Message: # at this point a disconnect is all that we should be receiving # if we get something else, things went wrong somewhere raise RuntimeError(f"Unexpected message received: {msg['type']}") + self._wrapped_rcv_disconnected = True return msg # wrapped_rcv state 3: not yet consumed @@ -198,20 +198,33 @@ async def dispatch( raise NotImplementedError() # pragma: no cover -class _StreamingResponse(StreamingResponse): +class _StreamingResponse(Response): def __init__( self, - content: ContentStream, + content: AsyncContentStream, status_code: int = 200, headers: typing.Mapping[str, str] | None = None, media_type: str | None = None, - background: BackgroundTask | None = None, info: typing.Mapping[str, typing.Any] | None = None, ) -> None: - self._info = info - super().__init__(content, status_code, headers, media_type, background) + self.info = info + self.body_iterator = content + self.status_code = status_code + self.media_type = media_type + self.init_headers(headers) - async def stream_response(self, send: Send) -> None: - if self._info: - await send({"type": "http.response.debug", "info": self._info}) - return await super().stream_response(send) + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if self.info is not None: + await send({"type": "http.response.debug", "info": self.info}) + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + + async for chunk in self.body_iterator: + await send({"type": "http.response.body", "body": chunk, "more_body": True}) + + await send({"type": "http.response.body", "body": b"", "more_body": False}) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 3ad1751a2..8e410cb15 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -5,6 +5,7 @@ from typing import ( Any, AsyncGenerator, + AsyncIterator, Generator, ) @@ -16,7 +17,7 @@ from starlette.background import BackgroundTask from starlette.middleware import Middleware, _MiddlewareClass from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.requests import Request +from starlette.requests import ClientDisconnect, Request from starlette.responses import PlainTextResponse, Response, StreamingResponse from starlette.routing import Route, WebSocketRoute from starlette.testclient import TestClient @@ -260,7 +261,6 @@ async def homepage(request: Request) -> PlainTextResponse: @pytest.mark.anyio async def test_run_background_tasks_even_if_client_disconnects() -> None: # test for https://github.com/encode/starlette/issues/1438 - request_body_sent = False response_complete = anyio.Event() background_task_run = anyio.Event() @@ -293,13 +293,7 @@ async def passthrough( } async def receive() -> Message: - nonlocal request_body_sent - if not request_body_sent: - request_body_sent = True - return {"type": "http.request", "body": b"", "more_body": False} - # We simulate a client that disconnects immediately after receiving the response - await response_complete.wait() - return {"type": "http.disconnect"} + raise NotImplementedError("Should not be called!") # pragma: no cover async def send(message: Message) -> None: if message["type"] == "http.response.body": @@ -313,7 +307,6 @@ async def send(message: Message) -> None: @pytest.mark.anyio async def test_do_not_block_on_background_tasks() -> None: - request_body_sent = False response_complete = anyio.Event() events: list[str | Message] = [] @@ -345,12 +338,7 @@ async def passthrough( } async def receive() -> Message: - nonlocal request_body_sent - if not request_body_sent: - request_body_sent = True - return {"type": "http.request", "body": b"", "more_body": False} - await response_complete.wait() - return {"type": "http.disconnect"} + raise NotImplementedError("Should not be called!") # pragma: no cover async def send(message: Message) -> None: if message["type"] == "http.response.body": @@ -379,7 +367,6 @@ async def send(message: Message) -> None: @pytest.mark.anyio async def test_run_context_manager_exit_even_if_client_disconnects() -> None: # test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042 - request_body_sent = False response_complete = anyio.Event() context_manager_exited = anyio.Event() @@ -424,13 +411,7 @@ async def passthrough( } async def receive() -> Message: - nonlocal request_body_sent - if not request_body_sent: - request_body_sent = True - return {"type": "http.request", "body": b"", "more_body": False} - # We simulate a client that disconnects immediately after receiving the response - await response_complete.wait() - return {"type": "http.disconnect"} + raise NotImplementedError("Should not be called!") # pragma: no cover async def send(message: Message) -> None: if message["type"] == "http.response.body": @@ -778,7 +759,9 @@ async def rcv() -> AsyncGenerator[Message, None]: yield {"type": "http.request", "body": b"1", "more_body": True} yield {"type": "http.request", "body": b"2", "more_body": True} yield {"type": "http.request", "body": b"3"} - await anyio.sleep(float("inf")) + raise AssertionError( # pragma: no cover + "Should not be called, no need to poll for disconnect" + ) sent: list[Message] = [] @@ -1033,3 +1016,139 @@ async def endpoint(request: Request) -> Response: resp.raise_for_status() assert bodies == [b"Hello, World!-foo"] + + +@pytest.mark.anyio +async def test_multiple_middlewares_stacked_client_disconnected() -> None: + class MyMiddleware(BaseHTTPMiddleware): + def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None: + self.version = version + self.events = events + super().__init__(app) + + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + self.events.append(f"{self.version}:STARTED") + res = await call_next(request) + self.events.append(f"{self.version}:COMPLETED") + return res + + async def sleepy(request: Request) -> Response: + try: + await request.body() + except ClientDisconnect: + pass + else: # pragma: no cover + raise AssertionError("Should have raised ClientDisconnect") + return Response(b"") + + events: list[str] = [] + + app = Starlette( + routes=[Route("/", sleepy)], + middleware=[ + Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10) + ], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + } + + async def receive() -> AsyncIterator[Message]: + yield {"type": "http.disconnect"} + + sent: list[Message] = [] + + async def send(message: Message) -> None: + sent.append(message) + + await app(scope, receive().__anext__, send) + + assert events == [ + "1:STARTED", + "2:STARTED", + "3:STARTED", + "4:STARTED", + "5:STARTED", + "6:STARTED", + "7:STARTED", + "8:STARTED", + "9:STARTED", + "10:STARTED", + "10:COMPLETED", + "9:COMPLETED", + "8:COMPLETED", + "7:COMPLETED", + "6:COMPLETED", + "5:COMPLETED", + "4:COMPLETED", + "3:COMPLETED", + "2:COMPLETED", + "1:COMPLETED", + ] + + assert sent == [ + { + "type": "http.response.start", + "status": 200, + "headers": [(b"content-length", b"0")], + }, + {"type": "http.response.body", "body": b"", "more_body": False}, + ] + + +@pytest.mark.anyio +@pytest.mark.parametrize("send_body", [True, False]) +async def test_poll_for_disconnect_repeated(send_body: bool) -> None: + async def app_poll_disconnect(scope: Scope, receive: Receive, send: Send) -> None: + for _ in range(2): + msg = await receive() + while msg["type"] == "http.request": + msg = await receive() + assert msg["type"] == "http.disconnect" + await Response(b"good!")(scope, receive, send) + + class MyMiddleware(BaseHTTPMiddleware): + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + return await call_next(request) + + app = MyMiddleware(app_poll_disconnect) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + } + + async def receive() -> AsyncIterator[Message]: + # the key here is that we only ever send 1 htt.disconnect message + if send_body: + yield {"type": "http.request", "body": b"hello", "more_body": True} + yield {"type": "http.request", "body": b"", "more_body": False} + yield {"type": "http.disconnect"} + raise AssertionError("Should not be called, would hang") # pragma: no cover + + sent: list[Message] = [] + + async def send(message: Message) -> None: + sent.append(message) + + await app(scope, receive().__anext__, send) + + assert sent == [ + { + "type": "http.response.start", + "status": 200, + "headers": [(b"content-length", b"5")], + }, + {"type": "http.response.body", "body": b"good!", "more_body": True}, + {"type": "http.response.body", "body": b"", "more_body": False}, + ]