-
-
Notifications
You must be signed in to change notification settings - Fork 953
Commit
#2620) * Don't poll for disconnects in BaseHTTPMiddleware via StreamingResponse Fixes #2516 * add test * fmt * Update tests/middleware/test_base.py Co-authored-by: Mikkel Duif <[email protected]> * add test for line now missing coverage * more coverage, fix test * add comment * fmt * tweak test * fix * fix coverage * relint --------- Co-authored-by: Mikkel Duif <[email protected]>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,9 +6,8 @@ | |
from anyio.abc import ObjectReceiveStream, ObjectSendStream | ||
|
||
from starlette._utils import collapse_excgroups | ||
from starlette.background import BackgroundTask | ||
from starlette.requests import ClientDisconnect, Request | ||
from starlette.responses import ContentStream, Response, StreamingResponse | ||
from starlette.responses import AsyncContentStream, Response | ||
from starlette.types import ASGIApp, Message, Receive, Scope, Send | ||
|
||
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] | ||
|
@@ -56,6 +55,7 @@ async def wrapped_receive(self) -> Message: | |
# at this point a disconnect is all that we should be receiving | ||
# if we get something else, things went wrong somewhere | ||
raise RuntimeError(f"Unexpected message received: {msg['type']}") | ||
self._wrapped_rcv_disconnected = True | ||
return msg | ||
|
||
# wrapped_rcv state 3: not yet consumed | ||
|
@@ -198,20 +198,33 @@ async def dispatch( | |
raise NotImplementedError() # pragma: no cover | ||
|
||
|
||
class _StreamingResponse(StreamingResponse): | ||
class _StreamingResponse(Response): | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
adriangb
Author
Member
|
||
def __init__( | ||
self, | ||
content: ContentStream, | ||
content: AsyncContentStream, | ||
status_code: int = 200, | ||
headers: typing.Mapping[str, str] | None = None, | ||
media_type: str | None = None, | ||
background: BackgroundTask | None = None, | ||
info: typing.Mapping[str, typing.Any] | None = None, | ||
) -> None: | ||
self._info = info | ||
super().__init__(content, status_code, headers, media_type, background) | ||
self.info = info | ||
self.body_iterator = content | ||
self.status_code = status_code | ||
self.media_type = media_type | ||
self.init_headers(headers) | ||
|
||
async def stream_response(self, send: Send) -> None: | ||
if self._info: | ||
await send({"type": "http.response.debug", "info": self._info}) | ||
return await super().stream_response(send) | ||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
if self.info is not None: | ||
await send({"type": "http.response.debug", "info": self.info}) | ||
await send( | ||
{ | ||
"type": "http.response.start", | ||
"status": self.status_code, | ||
"headers": self.raw_headers, | ||
} | ||
) | ||
|
||
async for chunk in self.body_iterator: | ||
await send({"type": "http.response.body", "body": chunk, "more_body": True}) | ||
|
||
await send({"type": "http.response.body", "body": b"", "more_body": False}) |
Guys, really?
So now BaseHttpMiddleware returns a class that is not subclass of StreamingResponse when it's a streaming response.
Very nice change and obviously not breaking anything...