From 002ceb496616a402069908b60074381f72ccd693 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= <25355197+provinzkraut@users.noreply.github.com> Date: Sun, 2 Feb 2025 15:21:04 +0100 Subject: [PATCH] poc client disconnect --- litestar/connection/request.py | 41 ++++++++++- litestar/exceptions/base_exceptions.py | 4 ++ litestar/response/streaming.py | 27 +------ litestar/routes/http.py | 72 +++++++++++++++++-- pyproject.toml | 7 +- tests/unit/test_asgi/test_asgi_router.py | 14 +++- tests/unit/test_connection/test_request.py | 5 +- .../test_reserved_kwargs_injection.py | 2 - uv.lock | 29 +++++--- 9 files changed, 145 insertions(+), 56 deletions(-) diff --git a/litestar/connection/request.py b/litestar/connection/request.py index 065312d7ed..e6e087ec26 100644 --- a/litestar/connection/request.py +++ b/litestar/connection/request.py @@ -2,7 +2,9 @@ import math import warnings -from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic, cast +from typing import TYPE_CHECKING, Any, AsyncGenerator, Final, Generic, cast + +import anyio from litestar._multipart import parse_content_header, parse_multipart_form from litestar._parsers import parse_url_encoded_form_data @@ -23,6 +25,7 @@ LitestarException, LitestarWarning, ) +from litestar.exceptions.base_exceptions import ClientDisconnect from litestar.exceptions.http_exceptions import RequestEntityTooLarge from litestar.serialization import decode_json, decode_msgpack from litestar.types import Empty, HTTPReceiveMessage @@ -45,6 +48,9 @@ } +METHODS_WITHOUT_BODY: Final = frozenset(["GET", "HEAD", "DELETE", "TRACE"]) + + class Request(Generic[UserT, AuthT, StateT], ASGIConnection["HTTPRouteHandler", UserT, AuthT, StateT]): """The Litestar Request class.""" @@ -53,9 +59,11 @@ class Request(Generic[UserT, AuthT, StateT], ASGIConnection["HTTPRouteHandler", "_body", "_content_length", "_content_type", + "_expect_body", "_form", "_json", "_msgpack", + "_stream_consumed", "is_connected", "supports_push_promise", ) @@ -85,6 +93,8 @@ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = self._accept: Accept | EmptyType = Empty self._content_length: int | None | EmptyType = Empty self.supports_push_promise = ASGIExtension.SERVER_PUSH in self._server_extensions + self._stream_consumed = anyio.Event() + self._expect_body = self.method not in METHODS_WITHOUT_BODY @property def method(self) -> Method: @@ -172,6 +182,30 @@ def content_length(self) -> int | None: raise ClientException(f"Invalid content-length: {content_length_header!r}") from None return content_length + async def _listen_for_disconnect(self) -> None: + # we consider these methods empty and consume their - empty - body, since + # handlers are unlikely to do this, in order to trigger the next phase and + # start listening for 'http.disconnect'. + if self.method in METHODS_WITHOUT_BODY: + body = await self.body() + if body and not self._expect_body: + warnings.warn( + f"Discarding unexpected request body for '{self.method} {self.url.path}' " + "received while listening for disconnect", + category=LitestarWarning, + stacklevel=2, + ) + + await self._stream_consumed.wait() + + message = await self.receive() + if message["type"] == "http.disconnect": + raise ClientDisconnect + + raise InternalServerException( + f"Received unexpected {message['type']!r} message while listening for 'http.disconnect'" + ) + async def stream(self) -> AsyncGenerator[bytes, None]: """Return an async generator that streams chunks of bytes. @@ -226,10 +260,11 @@ async def stream(self) -> AsyncGenerator[bytes, None]: yield body if not event.get("more_body", False): + self._stream_consumed.set() break - if event["type"] == "http.disconnect": - raise InternalServerException("client disconnected prematurely") + elif event["type"] == "http.disconnect": + raise ClientDisconnect() self.is_connected = False yield b"" diff --git a/litestar/exceptions/base_exceptions.py b/litestar/exceptions/base_exceptions.py index d23d3b5957..eb87bd1ba0 100644 --- a/litestar/exceptions/base_exceptions.py +++ b/litestar/exceptions/base_exceptions.py @@ -55,3 +55,7 @@ class SerializationException(LitestarException): class LitestarWarning(UserWarning): """Base class for Litestar warnings""" + + +class ClientDisconnect(LitestarException): + pass diff --git a/litestar/response/streaming.py b/litestar/response/streaming.py index fc76522416..09d2eeb844 100644 --- a/litestar/response/streaming.py +++ b/litestar/response/streaming.py @@ -1,11 +1,8 @@ from __future__ import annotations import itertools -from functools import partial from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Iterator, Union -from anyio import CancelScope, create_task_group - from litestar.enums import MediaType from litestar.response.base import ASGIResponse, Response from litestar.types.helper_types import StreamType @@ -80,25 +77,6 @@ def __init__( iterator if isinstance(iterator, (AsyncIterable, AsyncIterator)) else AsyncIteratorWrapper(iterator) ) - async def _listen_for_disconnect(self, cancel_scope: CancelScope, receive: Receive) -> None: - """Listen for a cancellation message, and if received - call cancel on the cancel scope. - - Args: - cancel_scope: A task group cancel scope instance. - receive: The ASGI receive function. - - Returns: - None - """ - if not cancel_scope.cancel_called: - message = await receive() - if message["type"] == "http.disconnect": - # despite the IDE warning, this is not a coroutine because anyio 3+ changed this. - # therefore make sure not to await this. - cancel_scope.cancel() - else: - await self._listen_for_disconnect(cancel_scope=cancel_scope, receive=receive) - async def _stream(self, send: Send) -> None: """Send the chunks from the iterator as a stream of ASGI 'http.response.body' events. @@ -128,10 +106,7 @@ async def send_body(self, send: Send, receive: Receive) -> None: Returns: None """ - - async with create_task_group() as task_group: - task_group.start_soon(partial(self._stream, send)) - await self._listen_for_disconnect(cancel_scope=task_group.cancel_scope, receive=receive) + await self._stream(send) class Stream(Response[StreamType[Union[str, bytes]]]): diff --git a/litestar/routes/http.py b/litestar/routes/http.py index 80368e852e..c4e9d5901b 100644 --- a/litestar/routes/http.py +++ b/litestar/routes/http.py @@ -1,19 +1,24 @@ from __future__ import annotations +from functools import partial from itertools import chain from typing import TYPE_CHECKING, Any +import anyio from msgspec.msgpack import decode as _decode_msgpack_plain from litestar.datastructures.multi_dicts import FormMultiDict from litestar.enums import HttpMethod, MediaType, ScopeType from litestar.exceptions import ClientException, ImproperlyConfiguredException, SerializationException +from litestar.exceptions.base_exceptions import ClientDisconnect from litestar.handlers.http_handlers import HTTPRouteHandler from litestar.response import Response from litestar.routes.base import BaseRoute from litestar.status_codes import HTTP_204_NO_CONTENT from litestar.types.empty import Empty from litestar.utils.scope.state import ScopeState +from litestar.utils.sync import AsyncCallable +from exceptiongroup import ExceptionGroup if TYPE_CHECKING: from litestar._kwargs import KwargsModel @@ -59,6 +64,36 @@ def __init__( handler_names=[route_handler.handler_name for route_handler in self.route_handlers], ) + async def _handle_response_cycle( + self, + scope: HTTPScope, + request: Request[Any, Any, Any], + route_handler: HTTPRouteHandler, + parameter_model: KwargsModel, + receive: Receive, + send: Send, + cancel_scope: anyio.CancelScope | None = None, + ) -> None: + try: + response = await self._get_response_for_request( + scope=scope, request=request, route_handler=route_handler, parameter_model=parameter_model + ) + + await response(scope, receive, send) + + if after_response_handler := route_handler.resolve_after_response(): + await after_response_handler(request) + + finally: + if cancel_scope is not None: + cancel_scope.cancel() + + async def _listen_for_disconnect(self, request: Request, cancel_scope: anyio.CancelScope) -> None: + try: + await request._listen_for_disconnect() + except ClientDisconnect: + cancel_scope.cancel() + async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None: # type: ignore[override] """ASGI app that creates a Request from the passed in args, determines which handler function to call and then handles the call. @@ -78,14 +113,37 @@ async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None: await route_handler.authorize_connection(connection=request) try: - response = await self._get_response_for_request( - scope=scope, request=request, route_handler=route_handler, parameter_model=parameter_model - ) - - await response(scope, receive, send) + if route_handler.has_sync_callable or isinstance(route_handler.fn, AsyncCallable): + # if it's a sync or to_thread function we can't actually cancel anything + # so we just await it directly + await self._handle_response_cycle( + scope=scope, + send=send, + receive=receive, + request=request, + route_handler=route_handler, + parameter_model=parameter_model, + ) + else: + async with anyio.create_task_group() as tg: + tg.start_soon( + partial( + self._handle_response_cycle, + scope=scope, + send=send, + receive=receive, + request=request, + route_handler=route_handler, + parameter_model=parameter_model, + cancel_scope=tg.cancel_scope, + ), + ) + tg.start_soon(self._listen_for_disconnect, request, tg.cancel_scope) + except Exception as exc: + if isinstance(exc, ExceptionGroup): + raise exc.exceptions[0] from exc + raise - if after_response_handler := route_handler.resolve_after_response(): - await after_response_handler(request) finally: if (form_data := ScopeState.from_scope(scope).form) is not Empty: await FormMultiDict.from_form_data(form_data).close() diff --git a/pyproject.toml b/pyproject.toml index 6e33df1749..873caab7c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,8 @@ dependencies = [ "rich-click", "multipart>=1.2.0", # default litestar plugins - "litestar-htmx>=0.4.0" + "litestar-htmx>=0.4.0", + "exceptiongroup>=1.2.2", ] description = "Litestar - A production-ready, highly performant, extensible ASGI API Framework" keywords = ["api", "rest", "asgi", "litestar", "starlite"] @@ -413,8 +414,6 @@ lint.select = [ "W", # pycodestyle - warning "YTT", # flake8-2020 ] - -line-length = 120 lint.ignore = [ "A003", # flake8-builtins - class attribute {name} is shadowing a python builtin "B010", # flake8-bugbear - do not call setattr with a constant attribute value @@ -435,6 +434,8 @@ lint.ignore = [ "ISC001", # Ruff formatter incompatible "CPY001", # ruff - copyright notice at the top of the file ] + +line-length = 120 src = ["litestar", "tests", "docs/examples"] target-version = "py38" diff --git a/tests/unit/test_asgi/test_asgi_router.py b/tests/unit/test_asgi/test_asgi_router.py index dc5c9f053e..4a34c1596f 100644 --- a/tests/unit/test_asgi/test_asgi_router.py +++ b/tests/unit/test_asgi/test_asgi_router.py @@ -19,11 +19,19 @@ if TYPE_CHECKING: from contextlib import AbstractAsyncContextManager - from litestar.types import Receive, Scope, Send + from litestar.types import HTTPReceiveMessage, Message, Receive, Scope, Send _ExceptionGroup = get_exception_group() +async def no_op_send(message: Message) -> None: + pass + + +async def empty_http_receive() -> HTTPReceiveMessage: + return {"type": "http.request", "body": b"", "more_body": False} + + def test_add_mount_route_disallow_path_parameter() -> None: async def handler(scope: Scope, receive: Receive, send: Send) -> None: return None @@ -243,7 +251,7 @@ async def handler() -> None: app = Litestar(route_handlers=[handler], exception_handlers={RuntimeError: app_exception_handlers_mock}) scope["path"] = "/nowhere-to-be-found" with pytest.raises(NotFoundException): - await app.asgi_router(scope, AsyncMock(), AsyncMock()) + await app.asgi_router(scope, empty_http_receive, no_op_send) state = ScopeState.from_scope(scope) assert state.exception_handlers is not Empty @@ -266,7 +274,7 @@ async def handler() -> None: app = Litestar(route_handlers=[handler], exception_handlers={RuntimeError: app_exception_handlers_mock}) router = app.asgi_router scope["path"] = "/" - await router(scope, AsyncMock(), AsyncMock()) + await router(scope, empty_http_receive, no_op_send) state = ScopeState.from_scope(scope) assert state.exception_handlers is not Empty assert state.exception_handlers[RuntimeError] is app_exception_handlers_mock diff --git a/tests/unit/test_connection/test_request.py b/tests/unit/test_connection/test_request.py index 688f5250a7..7c8fd1b24d 100644 --- a/tests/unit/test_connection/test_request.py +++ b/tests/unit/test_connection/test_request.py @@ -20,6 +20,7 @@ LitestarWarning, SerializationException, ) +from litestar.exceptions.base_exceptions import ClientDisconnect from litestar.middleware import MiddlewareProtocol from litestar.response.base import ASGIResponse from litestar.serialization import encode_json, encode_msgpack @@ -382,7 +383,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: async def test_request_disconnect(create_scope: Callable[..., Scope]) -> None: - """If a client disconnect occurs while reading request body then InternalServerException should be raised.""" + """If a client disconnect occurs while reading request body then ClientDisconnect should be raised.""" async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request[Any, Any, State](scope, receive) @@ -391,7 +392,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: async def receiver() -> dict[str, str]: return {"type": "http.disconnect"} - with pytest.raises(InternalServerException): + with pytest.raises(ClientDisconnect): await app( create_scope(type="http", route_handler=_route_handler, method="POST", path="/"), receiver, # type: ignore[arg-type] diff --git a/tests/unit/test_kwargs/test_reserved_kwargs_injection.py b/tests/unit/test_kwargs/test_reserved_kwargs_injection.py index 4fdc12205f..1b91aa7675 100644 --- a/tests/unit/test_kwargs/test_reserved_kwargs_injection.py +++ b/tests/unit/test_kwargs/test_reserved_kwargs_injection.py @@ -72,7 +72,6 @@ def route_handler(state: state_typing) -> str: # type: ignore[valid-type] (post, HttpMethod.POST, HTTP_201_CREATED), (put, HttpMethod.PUT, HTTP_200_OK), (patch, HttpMethod.PATCH, HTTP_200_OK), - (delete, HttpMethod.DELETE, HTTP_204_NO_CONTENT), ], ) def test_data_using_model(decorator: Any, http_method: Any, expected_status_code: Any) -> None: @@ -96,7 +95,6 @@ def test_method(self, data: DataclassPerson) -> None: (post, HttpMethod.POST, HTTP_201_CREATED), (put, HttpMethod.PUT, HTTP_200_OK), (patch, HttpMethod.PATCH, HTTP_200_OK), - (delete, HttpMethod.DELETE, HTTP_204_NO_CONTENT), ], ) def test_data_using_list_of_models(decorator: Any, http_method: Any, expected_status_code: Any) -> None: diff --git a/uv.lock b/uv.lock index 0e4f1daa92..20710aad7f 100644 --- a/uv.lock +++ b/uv.lock @@ -1,11 +1,13 @@ version = 1 requires-python = ">=3.8, <4.0" resolution-markers = [ - "python_full_version >= '3.9' and python_full_version < '3.13' and sys_platform != 'win32'", + "python_full_version >= '3.11' and python_full_version < '3.13' and sys_platform != 'win32'", + "python_full_version >= '3.9' and python_full_version < '3.11' and sys_platform != 'win32'", "python_full_version < '3.9' and sys_platform != 'win32'", "python_full_version >= '3.13' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.9' and python_full_version < '3.13' and sys_platform == 'win32'", + "python_full_version >= '3.11' and python_full_version < '3.13' and sys_platform == 'win32'", + "python_full_version >= '3.9' and python_full_version < '3.11' and sys_platform == 'win32'", "python_full_version < '3.9' and sys_platform == 'win32'", "python_full_version >= '3.13' and sys_platform == 'win32'", ] @@ -111,10 +113,12 @@ name = "anyio" version = "4.8.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.9' and python_full_version < '3.13' and sys_platform != 'win32'", + "python_full_version >= '3.11' and python_full_version < '3.13' and sys_platform != 'win32'", + "python_full_version >= '3.9' and python_full_version < '3.11' and sys_platform != 'win32'", "python_full_version >= '3.13' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.9' and python_full_version < '3.13' and sys_platform == 'win32'", + "python_full_version >= '3.11' and python_full_version < '3.13' and sys_platform == 'win32'", + "python_full_version >= '3.9' and python_full_version < '3.11' and sys_platform == 'win32'", "python_full_version >= '3.13' and sys_platform == 'win32'", ] dependencies = [ @@ -1780,7 +1784,7 @@ dependencies = [ { name = "anyio", version = "4.5.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, { name = "anyio", version = "4.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, { name = "click" }, - { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "exceptiongroup" }, { name = "httpx" }, { name = "importlib-metadata", marker = "python_full_version < '3.10'" }, { name = "importlib-resources", marker = "python_full_version < '3.9'" }, @@ -1962,6 +1966,7 @@ requires-dist = [ { name = "cryptography", marker = "extra == 'cryptography'" }, { name = "cryptography", marker = "extra == 'jwt'" }, { name = "email-validator", marker = "extra == 'pydantic'" }, + { name = "exceptiongroup", specifier = ">=1.2.2" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "fast-query-parsers", marker = "extra == 'standard'", specifier = ">=1.0.2" }, { name = "httpx", specifier = ">=0.22" }, @@ -2343,10 +2348,12 @@ name = "msgspec" version = "0.19.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.9' and python_full_version < '3.13' and sys_platform != 'win32'", + "python_full_version >= '3.11' and python_full_version < '3.13' and sys_platform != 'win32'", + "python_full_version >= '3.9' and python_full_version < '3.11' and sys_platform != 'win32'", "python_full_version >= '3.13' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.9' and python_full_version < '3.13' and sys_platform == 'win32'", + "python_full_version >= '3.11' and python_full_version < '3.13' and sys_platform == 'win32'", + "python_full_version >= '3.9' and python_full_version < '3.11' and sys_platform == 'win32'", "python_full_version >= '3.13' and sys_platform == 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/cf/9b/95d8ce458462b8b71b8a70fa94563b2498b89933689f3a7b8911edfae3d7/msgspec-0.19.0.tar.gz", hash = "sha256:604037e7cd475345848116e89c553aa9a233259733ab51986ac924ab1b976f8e", size = 216934 } @@ -3339,10 +3346,12 @@ name = "pytest-asyncio" version = "0.25.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.9' and python_full_version < '3.13' and sys_platform != 'win32'", + "python_full_version >= '3.11' and python_full_version < '3.13' and sys_platform != 'win32'", + "python_full_version >= '3.9' and python_full_version < '3.11' and sys_platform != 'win32'", "python_full_version >= '3.13' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.9' and python_full_version < '3.13' and sys_platform == 'win32'", + "python_full_version >= '3.11' and python_full_version < '3.13' and sys_platform == 'win32'", + "python_full_version >= '3.9' and python_full_version < '3.11' and sys_platform == 'win32'", "python_full_version >= '3.13' and sys_platform == 'win32'", ] dependencies = [ @@ -4117,7 +4126,7 @@ name = "taskgroup" version = "0.0.0a4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "exceptiongroup", marker = "python_full_version < '3.13'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0a/40/02753c40fa30dfdde7567c1daeefbf957dcf8c99e6534a80afb438adf07e/taskgroup-0.0.0a4.tar.gz", hash = "sha256:eb08902d221e27661950f2a0320ddf3f939f579279996f81fe30779bca3a159c", size = 8553 } wheels = [