Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions litestar/connection/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import math
import warnings
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic, cast
from typing import TYPE_CHECKING, Any, AsyncGenerator, Final, Generic, cast

import anyio

from litestar._multipart import parse_content_header, parse_multipart_form
from litestar._parsers import parse_url_encoded_form_data
Expand All @@ -23,6 +25,7 @@
LitestarException,
LitestarWarning,
)
from litestar.exceptions.base_exceptions import ClientDisconnect
from litestar.exceptions.http_exceptions import RequestEntityTooLarge
from litestar.serialization import decode_json, decode_msgpack
from litestar.types import Empty, HTTPReceiveMessage
Expand All @@ -45,6 +48,9 @@
}


METHODS_WITHOUT_BODY: Final = frozenset(["GET", "HEAD", "DELETE", "TRACE"])


class Request(Generic[UserT, AuthT, StateT], ASGIConnection["HTTPRouteHandler", UserT, AuthT, StateT]):
"""The Litestar Request class."""

Expand All @@ -53,9 +59,11 @@ class Request(Generic[UserT, AuthT, StateT], ASGIConnection["HTTPRouteHandler",
"_body",
"_content_length",
"_content_type",
"_expect_body",
"_form",
"_json",
"_msgpack",
"_stream_consumed",
"is_connected",
"supports_push_promise",
)
Expand Down Expand Up @@ -85,6 +93,8 @@ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send =
self._accept: Accept | EmptyType = Empty
self._content_length: int | None | EmptyType = Empty
self.supports_push_promise = ASGIExtension.SERVER_PUSH in self._server_extensions
self._stream_consumed = anyio.Event()
self._expect_body = self.method not in METHODS_WITHOUT_BODY

@property
def method(self) -> Method:
Expand Down Expand Up @@ -172,6 +182,30 @@ def content_length(self) -> int | None:
raise ClientException(f"Invalid content-length: {content_length_header!r}") from None
return content_length

async def _listen_for_disconnect(self) -> None:
# we consider these methods empty and consume their - empty - body, since
# handlers are unlikely to do this, in order to trigger the next phase and
# start listening for 'http.disconnect'.
if self.method in METHODS_WITHOUT_BODY:
body = await self.body()
if body and not self._expect_body:
warnings.warn(
f"Discarding unexpected request body for '{self.method} {self.url.path}' "
"received while listening for disconnect",
category=LitestarWarning,
stacklevel=2,
)

await self._stream_consumed.wait()

message = await self.receive()
if message["type"] == "http.disconnect":
raise ClientDisconnect

raise InternalServerException(
f"Received unexpected {message['type']!r} message while listening for 'http.disconnect'"
)

async def stream(self) -> AsyncGenerator[bytes, None]:
"""Return an async generator that streams chunks of bytes.

Expand Down Expand Up @@ -226,10 +260,11 @@ async def stream(self) -> AsyncGenerator[bytes, None]:
yield body

if not event.get("more_body", False):
self._stream_consumed.set()
break

if event["type"] == "http.disconnect":
raise InternalServerException("client disconnected prematurely")
elif event["type"] == "http.disconnect":
raise ClientDisconnect()

self.is_connected = False
yield b""
Expand Down
4 changes: 4 additions & 0 deletions litestar/exceptions/base_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,7 @@ class SerializationException(LitestarException):

class LitestarWarning(UserWarning):
"""Base class for Litestar warnings"""


class ClientDisconnect(LitestarException):
pass
27 changes: 1 addition & 26 deletions litestar/response/streaming.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from __future__ import annotations

import itertools
from functools import partial
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Iterable, Iterator, Union

from anyio import CancelScope, create_task_group

from litestar.enums import MediaType
from litestar.response.base import ASGIResponse, Response
from litestar.types.helper_types import StreamType
Expand Down Expand Up @@ -80,25 +77,6 @@ def __init__(
iterator if isinstance(iterator, (AsyncIterable, AsyncIterator)) else AsyncIteratorWrapper(iterator)
)

async def _listen_for_disconnect(self, cancel_scope: CancelScope, receive: Receive) -> None:
"""Listen for a cancellation message, and if received - call cancel on the cancel scope.

Args:
cancel_scope: A task group cancel scope instance.
receive: The ASGI receive function.

Returns:
None
"""
if not cancel_scope.cancel_called:
message = await receive()
if message["type"] == "http.disconnect":
# despite the IDE warning, this is not a coroutine because anyio 3+ changed this.
# therefore make sure not to await this.
cancel_scope.cancel()
else:
await self._listen_for_disconnect(cancel_scope=cancel_scope, receive=receive)

async def _stream(self, send: Send) -> None:
"""Send the chunks from the iterator as a stream of ASGI 'http.response.body' events.

Expand Down Expand Up @@ -128,10 +106,7 @@ async def send_body(self, send: Send, receive: Receive) -> None:
Returns:
None
"""

async with create_task_group() as task_group:
task_group.start_soon(partial(self._stream, send))
await self._listen_for_disconnect(cancel_scope=task_group.cancel_scope, receive=receive)
await self._stream(send)


class Stream(Response[StreamType[Union[str, bytes]]]):
Expand Down
72 changes: 65 additions & 7 deletions litestar/routes/http.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from __future__ import annotations

from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any

import anyio
from msgspec.msgpack import decode as _decode_msgpack_plain

from litestar.datastructures.multi_dicts import FormMultiDict
from litestar.enums import HttpMethod, MediaType, ScopeType
from litestar.exceptions import ClientException, ImproperlyConfiguredException, SerializationException
from litestar.exceptions.base_exceptions import ClientDisconnect
from litestar.handlers.http_handlers import HTTPRouteHandler
from litestar.response import Response
from litestar.routes.base import BaseRoute
from litestar.status_codes import HTTP_204_NO_CONTENT
from litestar.types.empty import Empty
from litestar.utils.scope.state import ScopeState
from litestar.utils.sync import AsyncCallable
from exceptiongroup import ExceptionGroup

if TYPE_CHECKING:
from litestar._kwargs import KwargsModel
Expand Down Expand Up @@ -59,6 +64,36 @@ def __init__(
handler_names=[route_handler.handler_name for route_handler in self.route_handlers],
)

async def _handle_response_cycle(
self,
scope: HTTPScope,
request: Request[Any, Any, Any],
route_handler: HTTPRouteHandler,
parameter_model: KwargsModel,
receive: Receive,
send: Send,
cancel_scope: anyio.CancelScope | None = None,
) -> None:
try:
response = await self._get_response_for_request(
scope=scope, request=request, route_handler=route_handler, parameter_model=parameter_model
)

await response(scope, receive, send)

if after_response_handler := route_handler.resolve_after_response():
await after_response_handler(request)

finally:
if cancel_scope is not None:
cancel_scope.cancel()

async def _listen_for_disconnect(self, request: Request, cancel_scope: anyio.CancelScope) -> None:
try:
await request._listen_for_disconnect()
except ClientDisconnect:
cancel_scope.cancel()

async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None: # type: ignore[override]
"""ASGI app that creates a Request from the passed in args, determines which handler function to call and then
handles the call.
Expand All @@ -78,14 +113,37 @@ async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None:
await route_handler.authorize_connection(connection=request)

try:
response = await self._get_response_for_request(
scope=scope, request=request, route_handler=route_handler, parameter_model=parameter_model
)

await response(scope, receive, send)
if route_handler.has_sync_callable or isinstance(route_handler.fn, AsyncCallable):
# if it's a sync or to_thread function we can't actually cancel anything
# so we just await it directly
await self._handle_response_cycle(
scope=scope,
send=send,
receive=receive,
request=request,
route_handler=route_handler,
parameter_model=parameter_model,
)
else:
async with anyio.create_task_group() as tg:
tg.start_soon(
partial(
self._handle_response_cycle,
scope=scope,
send=send,
receive=receive,
request=request,
route_handler=route_handler,
parameter_model=parameter_model,
cancel_scope=tg.cancel_scope,
),
)
tg.start_soon(self._listen_for_disconnect, request, tg.cancel_scope)
Comment on lines +128 to +141
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a more performant, but we can't really use it for reasons of anyio/trio compatibility.

Suggested change
async with anyio.create_task_group() as tg:
tg.start_soon(
partial(
self._handle_response_cycle,
scope=scope,
send=send,
receive=receive,
request=request,
route_handler=route_handler,
parameter_model=parameter_model,
cancel_scope=tg.cancel_scope,
),
)
tg.start_soon(self._listen_for_disconnect, request, tg.cancel_scope)
tasks = [
asyncio.create_task(
self._handle_response_cycle(
scope=scope,
send=send,
receive=receive,
request=request,
route_handler=route_handler,
parameter_model=parameter_model,
)
),
asyncio.create_task(request._listen_for_disconnect()),
]
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
done.pop().result()
if done:
done.pop().result()
if pending:
pending = pending.pop()
pending.cancel()
await pending
pending.result()

except Exception as exc:
if isinstance(exc, ExceptionGroup):
raise exc.exceptions[0] from exc
raise

if after_response_handler := route_handler.resolve_after_response():
await after_response_handler(request)
finally:
if (form_data := ScopeState.from_scope(scope).form) is not Empty:
await FormMultiDict.from_form_data(form_data).close()
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ dependencies = [
"rich-click",
"multipart>=1.2.0",
# default litestar plugins
"litestar-htmx>=0.4.0"
"litestar-htmx>=0.4.0",
"exceptiongroup>=1.2.2",
]
description = "Litestar - A production-ready, highly performant, extensible ASGI API Framework"
keywords = ["api", "rest", "asgi", "litestar", "starlite"]
Expand Down Expand Up @@ -413,8 +414,6 @@ lint.select = [
"W", # pycodestyle - warning
"YTT", # flake8-2020
]

line-length = 120
lint.ignore = [
"A003", # flake8-builtins - class attribute {name} is shadowing a python builtin
"B010", # flake8-bugbear - do not call setattr with a constant attribute value
Expand All @@ -435,6 +434,8 @@ lint.ignore = [
"ISC001", # Ruff formatter incompatible
"CPY001", # ruff - copyright notice at the top of the file
]

line-length = 120
src = ["litestar", "tests", "docs/examples"]
target-version = "py38"

Expand Down
14 changes: 11 additions & 3 deletions tests/unit/test_asgi/test_asgi_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,19 @@
if TYPE_CHECKING:
from contextlib import AbstractAsyncContextManager

from litestar.types import Receive, Scope, Send
from litestar.types import HTTPReceiveMessage, Message, Receive, Scope, Send

_ExceptionGroup = get_exception_group()


async def no_op_send(message: Message) -> None:
pass


async def empty_http_receive() -> HTTPReceiveMessage:
return {"type": "http.request", "body": b"", "more_body": False}


def test_add_mount_route_disallow_path_parameter() -> None:
async def handler(scope: Scope, receive: Receive, send: Send) -> None:
return None
Expand Down Expand Up @@ -243,7 +251,7 @@ async def handler() -> None:
app = Litestar(route_handlers=[handler], exception_handlers={RuntimeError: app_exception_handlers_mock})
scope["path"] = "/nowhere-to-be-found"
with pytest.raises(NotFoundException):
await app.asgi_router(scope, AsyncMock(), AsyncMock())
await app.asgi_router(scope, empty_http_receive, no_op_send)

state = ScopeState.from_scope(scope)
assert state.exception_handlers is not Empty
Expand All @@ -266,7 +274,7 @@ async def handler() -> None:
app = Litestar(route_handlers=[handler], exception_handlers={RuntimeError: app_exception_handlers_mock})
router = app.asgi_router
scope["path"] = "/"
await router(scope, AsyncMock(), AsyncMock())
await router(scope, empty_http_receive, no_op_send)
state = ScopeState.from_scope(scope)
assert state.exception_handlers is not Empty
assert state.exception_handlers[RuntimeError] is app_exception_handlers_mock
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/test_connection/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
LitestarWarning,
SerializationException,
)
from litestar.exceptions.base_exceptions import ClientDisconnect
from litestar.middleware import MiddlewareProtocol
from litestar.response.base import ASGIResponse
from litestar.serialization import encode_json, encode_msgpack
Expand Down Expand Up @@ -382,7 +383,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:


async def test_request_disconnect(create_scope: Callable[..., Scope]) -> None:
"""If a client disconnect occurs while reading request body then InternalServerException should be raised."""
"""If a client disconnect occurs while reading request body then ClientDisconnect should be raised."""

async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request[Any, Any, State](scope, receive)
Expand All @@ -391,7 +392,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
async def receiver() -> dict[str, str]:
return {"type": "http.disconnect"}

with pytest.raises(InternalServerException):
with pytest.raises(ClientDisconnect):
await app(
create_scope(type="http", route_handler=_route_handler, method="POST", path="/"),
receiver, # type: ignore[arg-type]
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/test_kwargs/test_reserved_kwargs_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def route_handler(state: state_typing) -> str: # type: ignore[valid-type]
(post, HttpMethod.POST, HTTP_201_CREATED),
(put, HttpMethod.PUT, HTTP_200_OK),
(patch, HttpMethod.PATCH, HTTP_200_OK),
(delete, HttpMethod.DELETE, HTTP_204_NO_CONTENT),
],
)
def test_data_using_model(decorator: Any, http_method: Any, expected_status_code: Any) -> None:
Expand All @@ -96,7 +95,6 @@ def test_method(self, data: DataclassPerson) -> None:
(post, HttpMethod.POST, HTTP_201_CREATED),
(put, HttpMethod.PUT, HTTP_200_OK),
(patch, HttpMethod.PATCH, HTTP_200_OK),
(delete, HttpMethod.DELETE, HTTP_204_NO_CONTENT),
],
)
def test_data_using_list_of_models(decorator: Any, http_method: Any, expected_status_code: Any) -> None:
Expand Down
Loading
Loading