Skip to content

Commit

Permalink
Don't poll for disconnects in BaseHTTPMiddleware via StreamingResponse (
Browse files Browse the repository at this point in the history
#2620)

* Don't poll for disconnects in BaseHTTPMiddleware via StreamingResponse

Fixes #2516

* add test

* fmt

* Update tests/middleware/test_base.py

Co-authored-by: Mikkel Duif <[email protected]>

* add test for line now missing coverage

* more coverage, fix test

* add comment

* fmt

* tweak test

* fix

* fix coverage

* relint

---------

Co-authored-by: Mikkel Duif <[email protected]>
  • Loading branch information
adriangb and mikkelduif authored Sep 1, 2024
1 parent c78c9aa commit d771bb7
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 36 deletions.
35 changes: 24 additions & 11 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette._utils import collapse_excgroups
from starlette.background import BackgroundTask
from starlette.requests import ClientDisconnect, Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.responses import AsyncContentStream, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
Expand Down Expand Up @@ -56,6 +55,7 @@ async def wrapped_receive(self) -> Message:
# at this point a disconnect is all that we should be receiving
# if we get something else, things went wrong somewhere
raise RuntimeError(f"Unexpected message received: {msg['type']}")
self._wrapped_rcv_disconnected = True
return msg

# wrapped_rcv state 3: not yet consumed
Expand Down Expand Up @@ -198,20 +198,33 @@ async def dispatch(
raise NotImplementedError() # pragma: no cover


class _StreamingResponse(StreamingResponse):
class _StreamingResponse(Response):

This comment has been minimized.

Copy link
@Yurzs

Yurzs Sep 17, 2024

Guys, really?
So now BaseHttpMiddleware returns a class that is not subclass of StreamingResponse when it's a streaming response.
Very nice change and obviously not breaking anything...

This comment has been minimized.

Copy link
@Kludex

Kludex Sep 17, 2024

Member

?

This comment has been minimized.

Copy link
@adriangb

adriangb Sep 17, 2024

Author Member

Can you give an example of something this broke that wasn't making assumptions about the class returned that were never actually promised anywhere?

This comment has been minimized.

Copy link
@ThomasChiroux

ThomasChiroux Sep 21, 2024

I was insvestigating why after updading to last fastapi / last starlette in my app all the tests where broken, so here is an example:
I've a middleware that log every call, request and response, working like this :

async def dispatch(self, request: Request, call_next):
        """Get info before the call and after and then send them to amqp"""
        start_time = time.perf_counter()

        request_body_bytes = await request.body()

        # use request values
        request_with_body = RequestWithBody(request.scope, request_body_bytes)
        response = await call_next(request_with_body)
        (
            response_content_bytes,
            response_headers,
            response_status,
        ) = await self._get_response_params(response)
...

and in _get_response_params:

async def _get_response_params(
        self, response: StreamingResponse
    ) -> Tuple[bytes, dict[str, str], int]:
        response_byte_chunks: list[bytes] = []
        response_status: list[int] = []
        response_headers: list[dict[str, str]] = []

        async def send(message: Message) -> None:
            if message["type"] == "http.response.start":
                response_status.append(message["status"])
                response_headers.append(
                    {k.decode("utf8"): v.decode("utf8") for k, v in message["headers"]}
                )
            else:
                response_byte_chunks.append(message["body"])

        await response.stream_response(send)
        content = b"".join(response_byte_chunks)
        return content, response_headers[0], response_status[0]

which now fails with:

>       await response.stream_response(send)
E       AttributeError: '_StreamingResponse' object has no attribute 'stream_response'

I find indeed at least strange / non intuitive that a _StreamingResponse object is not a herited from StreamingResponse but from Response..

Ok I undertand that _xxx objects are private and should not be used in external work, but I did not even decide that myself: for me the response from call_next was a SteamingResponse, not a _StreamingResponse..

btw, i do not now yet how to correct this, I will re-downgrade pydantic in the meantime.

This comment has been minimized.

Copy link
@adriangb

adriangb via email Sep 21, 2024

Author Member
def __init__(
self,
content: ContentStream,
content: AsyncContentStream,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
info: typing.Mapping[str, typing.Any] | None = None,
) -> None:
self._info = info
super().__init__(content, status_code, headers, media_type, background)
self.info = info
self.body_iterator = content
self.status_code = status_code
self.media_type = media_type
self.init_headers(headers)

async def stream_response(self, send: Send) -> None:
if self._info:
await send({"type": "http.response.debug", "info": self._info})
return await super().stream_response(send)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.info is not None:
await send({"type": "http.response.debug", "info": self.info})
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)

async for chunk in self.body_iterator:
await send({"type": "http.response.body", "body": chunk, "more_body": True})

await send({"type": "http.response.body", "body": b"", "more_body": False})
169 changes: 144 additions & 25 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Generator,
)

Expand All @@ -16,7 +17,7 @@
from starlette.background import BackgroundTask
from starlette.middleware import Middleware, _MiddlewareClass
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.requests import ClientDisconnect, Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.testclient import TestClient
Expand Down Expand Up @@ -260,7 +261,6 @@ async def homepage(request: Request) -> PlainTextResponse:
@pytest.mark.anyio
async def test_run_background_tasks_even_if_client_disconnects() -> None:
# test for https://github.com/encode/starlette/issues/1438
request_body_sent = False
response_complete = anyio.Event()
background_task_run = anyio.Event()

Expand Down Expand Up @@ -293,13 +293,7 @@ async def passthrough(
}

async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
# We simulate a client that disconnects immediately after receiving the response
await response_complete.wait()
return {"type": "http.disconnect"}
raise NotImplementedError("Should not be called!") # pragma: no cover

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand All @@ -313,7 +307,6 @@ async def send(message: Message) -> None:

@pytest.mark.anyio
async def test_do_not_block_on_background_tasks() -> None:
request_body_sent = False
response_complete = anyio.Event()
events: list[str | Message] = []

Expand Down Expand Up @@ -345,12 +338,7 @@ async def passthrough(
}

async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
await response_complete.wait()
return {"type": "http.disconnect"}
raise NotImplementedError("Should not be called!") # pragma: no cover

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand Down Expand Up @@ -379,7 +367,6 @@ async def send(message: Message) -> None:
@pytest.mark.anyio
async def test_run_context_manager_exit_even_if_client_disconnects() -> None:
# test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
request_body_sent = False
response_complete = anyio.Event()
context_manager_exited = anyio.Event()

Expand Down Expand Up @@ -424,13 +411,7 @@ async def passthrough(
}

async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
# We simulate a client that disconnects immediately after receiving the response
await response_complete.wait()
return {"type": "http.disconnect"}
raise NotImplementedError("Should not be called!") # pragma: no cover

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand Down Expand Up @@ -778,7 +759,9 @@ async def rcv() -> AsyncGenerator[Message, None]:
yield {"type": "http.request", "body": b"1", "more_body": True}
yield {"type": "http.request", "body": b"2", "more_body": True}
yield {"type": "http.request", "body": b"3"}
await anyio.sleep(float("inf"))
raise AssertionError( # pragma: no cover
"Should not be called, no need to poll for disconnect"
)

sent: list[Message] = []

Expand Down Expand Up @@ -1033,3 +1016,139 @@ async def endpoint(request: Request) -> Response:
resp.raise_for_status()

assert bodies == [b"Hello, World!-foo"]


@pytest.mark.anyio
async def test_multiple_middlewares_stacked_client_disconnected() -> None:
class MyMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None:
self.version = version
self.events = events
super().__init__(app)

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
self.events.append(f"{self.version}:STARTED")
res = await call_next(request)
self.events.append(f"{self.version}:COMPLETED")
return res

async def sleepy(request: Request) -> Response:
try:
await request.body()
except ClientDisconnect:
pass
else: # pragma: no cover
raise AssertionError("Should have raised ClientDisconnect")
return Response(b"")

events: list[str] = []

app = Starlette(
routes=[Route("/", sleepy)],
middleware=[
Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)
],
)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}

async def receive() -> AsyncIterator[Message]:
yield {"type": "http.disconnect"}

sent: list[Message] = []

async def send(message: Message) -> None:
sent.append(message)

await app(scope, receive().__anext__, send)

assert events == [
"1:STARTED",
"2:STARTED",
"3:STARTED",
"4:STARTED",
"5:STARTED",
"6:STARTED",
"7:STARTED",
"8:STARTED",
"9:STARTED",
"10:STARTED",
"10:COMPLETED",
"9:COMPLETED",
"8:COMPLETED",
"7:COMPLETED",
"6:COMPLETED",
"5:COMPLETED",
"4:COMPLETED",
"3:COMPLETED",
"2:COMPLETED",
"1:COMPLETED",
]

assert sent == [
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-length", b"0")],
},
{"type": "http.response.body", "body": b"", "more_body": False},
]


@pytest.mark.anyio
@pytest.mark.parametrize("send_body", [True, False])
async def test_poll_for_disconnect_repeated(send_body: bool) -> None:
async def app_poll_disconnect(scope: Scope, receive: Receive, send: Send) -> None:
for _ in range(2):
msg = await receive()
while msg["type"] == "http.request":
msg = await receive()
assert msg["type"] == "http.disconnect"
await Response(b"good!")(scope, receive, send)

class MyMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
return await call_next(request)

app = MyMiddleware(app_poll_disconnect)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}

async def receive() -> AsyncIterator[Message]:
# the key here is that we only ever send 1 htt.disconnect message
if send_body:
yield {"type": "http.request", "body": b"hello", "more_body": True}
yield {"type": "http.request", "body": b"", "more_body": False}
yield {"type": "http.disconnect"}
raise AssertionError("Should not be called, would hang") # pragma: no cover

sent: list[Message] = []

async def send(message: Message) -> None:
sent.append(message)

await app(scope, receive().__anext__, send)

assert sent == [
{
"type": "http.response.start",
"status": 200,
"headers": [(b"content-length", b"5")],
},
{"type": "http.response.body", "body": b"good!", "more_body": True},
{"type": "http.response.body", "body": b"", "more_body": False},
]

0 comments on commit d771bb7

Please sign in to comment.