Skip to content

Commit 52612cc

Browse files
committed
poc client disconnect
1 parent b0322d5 commit 52612cc

File tree

7 files changed

+129
-43
lines changed

7 files changed

+129
-43
lines changed

litestar/connection/request.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import math
44
import warnings
5-
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic, cast
5+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Final, Generic, cast
6+
7+
import anyio
68

79
from litestar._multipart import parse_content_header, parse_multipart_form
810
from litestar._parsers import parse_url_encoded_form_data
@@ -23,6 +25,7 @@
2325
LitestarException,
2426
LitestarWarning,
2527
)
28+
from litestar.exceptions.base_exceptions import ClientDisconnect
2629
from litestar.exceptions.http_exceptions import RequestEntityTooLarge
2730
from litestar.serialization import decode_json, decode_msgpack
2831
from litestar.types import Empty, HTTPReceiveMessage
@@ -45,6 +48,9 @@
4548
}
4649

4750

51+
METHODS_WITHOUT_BODY: Final = frozenset(["GET", "HEAD", "DELETE", "TRACE"])
52+
53+
4854
class Request(Generic[UserT, AuthT, StateT], ASGIConnection["HTTPRouteHandler", UserT, AuthT, StateT]):
4955
"""The Litestar Request class."""
5056

@@ -53,9 +59,11 @@ class Request(Generic[UserT, AuthT, StateT], ASGIConnection["HTTPRouteHandler",
5359
"_body",
5460
"_content_length",
5561
"_content_type",
62+
"_expect_body",
5663
"_form",
5764
"_json",
5865
"_msgpack",
66+
"_stream_consumed",
5967
"is_connected",
6068
"supports_push_promise",
6169
)
@@ -85,6 +93,8 @@ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send =
8593
self._accept: Accept | EmptyType = Empty
8694
self._content_length: int | None | EmptyType = Empty
8795
self.supports_push_promise = ASGIExtension.SERVER_PUSH in self._server_extensions
96+
self._stream_consumed = anyio.Event()
97+
self._expect_body = self.method not in METHODS_WITHOUT_BODY
8898

8999
@property
90100
def method(self) -> Method:
@@ -172,6 +182,30 @@ def content_length(self) -> int | None:
172182
raise ClientException(f"Invalid content-length: {content_length_header!r}") from None
173183
return content_length
174184

185+
async def _listen_for_disconnect(self) -> None:
186+
# we consider these methods empty and consume their - empty - body, since
187+
# handlers are unlikely to do this, in order to trigger the next phase and
188+
# start listening for 'http.disconnect'.
189+
if self.method in METHODS_WITHOUT_BODY:
190+
body = await self.body()
191+
if body and not self._expect_body:
192+
warnings.warn(
193+
f"Discarding unexpected request body for '{self.method} {self.url.path}' "
194+
"received while listening for disconnect",
195+
category=LitestarWarning,
196+
stacklevel=2,
197+
)
198+
199+
await self._stream_consumed.wait()
200+
201+
message = await self.receive()
202+
if message["type"] == "http.disconnect":
203+
raise ClientDisconnect
204+
205+
raise InternalServerException(
206+
f"Received unexpected {message['type']!r} message while listening for 'http.disconnect'"
207+
)
208+
175209
async def stream(self) -> AsyncGenerator[bytes, None]:
176210
"""Return an async generator that streams chunks of bytes.
177211
@@ -226,10 +260,11 @@ async def stream(self) -> AsyncGenerator[bytes, None]:
226260
yield body
227261

228262
if not event.get("more_body", False):
263+
self._stream_consumed.set()
229264
break
230265

231-
if event["type"] == "http.disconnect":
232-
raise InternalServerException("client disconnected prematurely")
266+
elif event["type"] == "http.disconnect":
267+
raise ClientDisconnect()
233268

234269
self.is_connected = False
235270
yield b""

litestar/exceptions/base_exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,7 @@ class SerializationException(LitestarException):
5555

5656
class LitestarWarning(UserWarning):
5757
"""Base class for Litestar warnings"""
58+
59+
60+
class ClientDisconnect(LitestarException):
61+
pass

litestar/response/streaming.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
from __future__ import annotations
22

33
import itertools
4-
from functools import partial
54
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Iterator, Union
65

7-
from anyio import CancelScope, create_task_group
8-
96
from litestar.enums import MediaType
107
from litestar.response.base import ASGIResponse, Response
118
from litestar.types.helper_types import StreamType
@@ -80,25 +77,6 @@ def __init__(
8077
iterator if isinstance(iterator, (AsyncIterable, AsyncIterator)) else AsyncIteratorWrapper(iterator)
8178
)
8279

83-
async def _listen_for_disconnect(self, cancel_scope: CancelScope, receive: Receive) -> None:
84-
"""Listen for a cancellation message, and if received - call cancel on the cancel scope.
85-
86-
Args:
87-
cancel_scope: A task group cancel scope instance.
88-
receive: The ASGI receive function.
89-
90-
Returns:
91-
None
92-
"""
93-
if not cancel_scope.cancel_called:
94-
message = await receive()
95-
if message["type"] == "http.disconnect":
96-
# despite the IDE warning, this is not a coroutine because anyio 3+ changed this.
97-
# therefore make sure not to await this.
98-
cancel_scope.cancel()
99-
else:
100-
await self._listen_for_disconnect(cancel_scope=cancel_scope, receive=receive)
101-
10280
async def _stream(self, send: Send) -> None:
10381
"""Send the chunks from the iterator as a stream of ASGI 'http.response.body' events.
10482
@@ -128,10 +106,7 @@ async def send_body(self, send: Send, receive: Receive) -> None:
128106
Returns:
129107
None
130108
"""
131-
132-
async with create_task_group() as task_group:
133-
task_group.start_soon(partial(self._stream, send))
134-
await self._listen_for_disconnect(cancel_scope=task_group.cancel_scope, receive=receive)
109+
await self._stream(send)
135110

136111

137112
class Stream(Response[StreamType[Union[str, bytes]]]):

litestar/routes/http.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,38 @@
11
from __future__ import annotations
22

3+
from functools import partial
34
from itertools import chain
45
from typing import TYPE_CHECKING, Any
56

7+
import anyio
68
from msgspec.msgpack import decode as _decode_msgpack_plain
79

810
from litestar.datastructures.multi_dicts import FormMultiDict
911
from litestar.enums import HttpMethod, MediaType, ScopeType
1012
from litestar.exceptions import ClientException, ImproperlyConfiguredException, SerializationException
13+
from litestar.exceptions.base_exceptions import ClientDisconnect
1114
from litestar.handlers.http_handlers import HTTPRouteHandler
1215
from litestar.response import Response
1316
from litestar.routes.base import BaseRoute
1417
from litestar.status_codes import HTTP_204_NO_CONTENT
1518
from litestar.types.empty import Empty
1619
from litestar.utils.scope.state import ScopeState
20+
from litestar.utils.sync import AsyncCallable
21+
22+
try:
23+
ExceptionGroupType: type[ExceptionGroup] | None = ExceptionGroup # type: ignore[name-defined]
24+
except NameError:
25+
ExceptionGroupType = None
26+
1727

1828
if TYPE_CHECKING:
1929
from litestar._kwargs import KwargsModel
2030
from litestar._kwargs.cleanup import DependencyCleanupGroup
2131
from litestar.connection import Request
2232
from litestar.types import ASGIApp, HTTPScope, Method, Receive, Scope, Send
2333

34+
ExceptionGroupType = ExceptionGroup # type: ignore[name-defined]
35+
2436

2537
class HTTPRoute(BaseRoute):
2638
"""An HTTP route, capable of handling multiple ``HTTPRouteHandler``\\ s.""" # noqa: D301
@@ -59,6 +71,36 @@ def __init__(
5971
handler_names=[route_handler.handler_name for route_handler in self.route_handlers],
6072
)
6173

74+
async def _handle_response_cycle(
75+
self,
76+
scope: HTTPScope,
77+
request: Request[Any, Any, Any],
78+
route_handler: HTTPRouteHandler,
79+
parameter_model: KwargsModel,
80+
receive: Receive,
81+
send: Send,
82+
cancel_scope: anyio.CancelScope | None = None,
83+
) -> None:
84+
try:
85+
response = await self._get_response_for_request(
86+
scope=scope, request=request, route_handler=route_handler, parameter_model=parameter_model
87+
)
88+
89+
await response(scope, receive, send)
90+
91+
if after_response_handler := route_handler.resolve_after_response():
92+
await after_response_handler(request)
93+
94+
finally:
95+
if cancel_scope is not None:
96+
cancel_scope.cancel()
97+
98+
async def _listen_for_disconnect(self, request: Request, cancel_scope: anyio.CancelScope) -> None:
99+
try:
100+
await request._listen_for_disconnect()
101+
except ClientDisconnect:
102+
cancel_scope.cancel()
103+
62104
async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None: # type: ignore[override]
63105
"""ASGI app that creates a Request from the passed in args, determines which handler function to call and then
64106
handles the call.
@@ -78,14 +120,37 @@ async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None:
78120
await route_handler.authorize_connection(connection=request)
79121

80122
try:
81-
response = await self._get_response_for_request(
82-
scope=scope, request=request, route_handler=route_handler, parameter_model=parameter_model
83-
)
84-
85-
await response(scope, receive, send)
123+
if route_handler.has_sync_callable or isinstance(route_handler.fn, AsyncCallable):
124+
# if it's a sync or to_thread function we can't actually cancel anything
125+
# so we just await it directly
126+
await self._handle_response_cycle(
127+
scope=scope,
128+
send=send,
129+
receive=receive,
130+
request=request,
131+
route_handler=route_handler,
132+
parameter_model=parameter_model,
133+
)
134+
else:
135+
async with anyio.create_task_group() as tg:
136+
tg.start_soon(
137+
partial(
138+
self._handle_response_cycle,
139+
scope=scope,
140+
send=send,
141+
receive=receive,
142+
request=request,
143+
route_handler=route_handler,
144+
parameter_model=parameter_model,
145+
cancel_scope=tg.cancel_scope,
146+
),
147+
)
148+
tg.start_soon(self._listen_for_disconnect, request, tg.cancel_scope)
149+
except Exception as exc:
150+
if isinstance(exc, ExceptionGroupType):
151+
raise exc.exceptions[0] from exc
152+
raise
86153

87-
if after_response_handler := route_handler.resolve_after_response():
88-
await after_response_handler(request)
89154
finally:
90155
if (form_data := ScopeState.from_scope(scope).form) is not Empty:
91156
await FormMultiDict.from_form_data(form_data).close()

tests/unit/test_asgi/test_asgi_router.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,19 @@
1919
if TYPE_CHECKING:
2020
from contextlib import AbstractAsyncContextManager
2121

22-
from litestar.types import Receive, Scope, Send
22+
from litestar.types import HTTPReceiveMessage, Message, Receive, Scope, Send
2323

2424
_ExceptionGroup = get_exception_group()
2525

2626

27+
async def no_op_send(message: Message) -> None:
28+
pass
29+
30+
31+
async def empty_http_receive() -> HTTPReceiveMessage:
32+
return {"type": "http.request", "body": b"", "more_body": False}
33+
34+
2735
def test_add_mount_route_disallow_path_parameter() -> None:
2836
async def handler(scope: Scope, receive: Receive, send: Send) -> None:
2937
return None
@@ -243,7 +251,7 @@ async def handler() -> None:
243251
app = Litestar(route_handlers=[handler], exception_handlers={RuntimeError: app_exception_handlers_mock})
244252
scope["path"] = "/nowhere-to-be-found"
245253
with pytest.raises(NotFoundException):
246-
await app.asgi_router(scope, AsyncMock(), AsyncMock())
254+
await app.asgi_router(scope, empty_http_receive, no_op_send)
247255

248256
state = ScopeState.from_scope(scope)
249257
assert state.exception_handlers is not Empty
@@ -266,7 +274,7 @@ async def handler() -> None:
266274
app = Litestar(route_handlers=[handler], exception_handlers={RuntimeError: app_exception_handlers_mock})
267275
router = app.asgi_router
268276
scope["path"] = "/"
269-
await router(scope, AsyncMock(), AsyncMock())
277+
await router(scope, empty_http_receive, no_op_send)
270278
state = ScopeState.from_scope(scope)
271279
assert state.exception_handlers is not Empty
272280
assert state.exception_handlers[RuntimeError] is app_exception_handlers_mock

tests/unit/test_connection/test_request.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
LitestarWarning,
2121
SerializationException,
2222
)
23+
from litestar.exceptions.base_exceptions import ClientDisconnect
2324
from litestar.middleware import MiddlewareProtocol
2425
from litestar.response.base import ASGIResponse
2526
from litestar.serialization import encode_json, encode_msgpack
@@ -382,7 +383,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
382383

383384

384385
async def test_request_disconnect(create_scope: Callable[..., Scope]) -> None:
385-
"""If a client disconnect occurs while reading request body then InternalServerException should be raised."""
386+
"""If a client disconnect occurs while reading request body then ClientDisconnect should be raised."""
386387

387388
async def app(scope: Scope, receive: Receive, send: Send) -> None:
388389
request = Request[Any, Any, State](scope, receive)
@@ -391,7 +392,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
391392
async def receiver() -> dict[str, str]:
392393
return {"type": "http.disconnect"}
393394

394-
with pytest.raises(InternalServerException):
395+
with pytest.raises(ClientDisconnect):
395396
await app(
396397
create_scope(type="http", route_handler=_route_handler, method="POST", path="/"),
397398
receiver, # type: ignore[arg-type]

tests/unit/test_kwargs/test_reserved_kwargs_injection.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def route_handler(state: state_typing) -> str: # type: ignore[valid-type]
7272
(post, HttpMethod.POST, HTTP_201_CREATED),
7373
(put, HttpMethod.PUT, HTTP_200_OK),
7474
(patch, HttpMethod.PATCH, HTTP_200_OK),
75-
(delete, HttpMethod.DELETE, HTTP_204_NO_CONTENT),
7675
],
7776
)
7877
def test_data_using_model(decorator: Any, http_method: Any, expected_status_code: Any) -> None:
@@ -96,7 +95,6 @@ def test_method(self, data: DataclassPerson) -> None:
9695
(post, HttpMethod.POST, HTTP_201_CREATED),
9796
(put, HttpMethod.PUT, HTTP_200_OK),
9897
(patch, HttpMethod.PATCH, HTTP_200_OK),
99-
(delete, HttpMethod.DELETE, HTTP_204_NO_CONTENT),
10098
],
10199
)
102100
def test_data_using_list_of_models(decorator: Any, http_method: Any, expected_status_code: Any) -> None:

0 commit comments

Comments
 (0)