11from __future__ import annotations
22
3+ from functools import partial
34from itertools import chain
45from typing import TYPE_CHECKING , Any
56
7+ import anyio
68from msgspec .msgpack import decode as _decode_msgpack_plain
79
810from litestar .datastructures .multi_dicts import FormMultiDict
911from litestar .enums import HttpMethod , MediaType , ScopeType
1012from litestar .exceptions import ClientException , ImproperlyConfiguredException , SerializationException
13+ from litestar .exceptions .base_exceptions import ClientDisconnect
1114from litestar .handlers .http_handlers import HTTPRouteHandler
1215from litestar .response import Response
1316from litestar .routes .base import BaseRoute
1417from litestar .status_codes import HTTP_204_NO_CONTENT
1518from litestar .types .empty import Empty
1619from litestar .utils .scope .state import ScopeState
20+ from litestar .utils .sync import AsyncCallable
21+
22+ try :
23+ ExceptionGroupType : type [ExceptionGroup ] | None = ExceptionGroup # type: ignore[name-defined]
24+ except NameError :
25+ ExceptionGroupType = None
26+
1727
1828if TYPE_CHECKING :
1929 from litestar ._kwargs import KwargsModel
2030 from litestar ._kwargs .cleanup import DependencyCleanupGroup
2131 from litestar .connection import Request
2232 from litestar .types import ASGIApp , HTTPScope , Method , Receive , Scope , Send
2333
34+ ExceptionGroupType = ExceptionGroup # type: ignore[name-defined]
35+
2436
2537class HTTPRoute (BaseRoute ):
2638 """An HTTP route, capable of handling multiple ``HTTPRouteHandler``\\ s.""" # noqa: D301
@@ -59,6 +71,36 @@ def __init__(
5971 handler_names = [route_handler .handler_name for route_handler in self .route_handlers ],
6072 )
6173
74+ async def _handle_response_cycle (
75+ self ,
76+ scope : HTTPScope ,
77+ request : Request [Any , Any , Any ],
78+ route_handler : HTTPRouteHandler ,
79+ parameter_model : KwargsModel ,
80+ receive : Receive ,
81+ send : Send ,
82+ cancel_scope : anyio .CancelScope | None = None ,
83+ ) -> None :
84+ try :
85+ response = await self ._get_response_for_request (
86+ scope = scope , request = request , route_handler = route_handler , parameter_model = parameter_model
87+ )
88+
89+ await response (scope , receive , send )
90+
91+ if after_response_handler := route_handler .resolve_after_response ():
92+ await after_response_handler (request )
93+
94+ finally :
95+ if cancel_scope is not None :
96+ cancel_scope .cancel ()
97+
98+ async def _listen_for_disconnect (self , request : Request , cancel_scope : anyio .CancelScope ) -> None :
99+ try :
100+ await request ._listen_for_disconnect ()
101+ except ClientDisconnect :
102+ cancel_scope .cancel ()
103+
62104 async def handle (self , scope : HTTPScope , receive : Receive , send : Send ) -> None : # type: ignore[override]
63105 """ASGI app that creates a Request from the passed in args, determines which handler function to call and then
64106 handles the call.
@@ -78,14 +120,37 @@ async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None:
78120 await route_handler .authorize_connection (connection = request )
79121
80122 try :
81- response = await self ._get_response_for_request (
82- scope = scope , request = request , route_handler = route_handler , parameter_model = parameter_model
83- )
84-
85- await response (scope , receive , send )
123+ if route_handler .has_sync_callable or isinstance (route_handler .fn , AsyncCallable ):
124+ # if it's a sync or to_thread function we can't actually cancel anything
125+ # so we just await it directly
126+ await self ._handle_response_cycle (
127+ scope = scope ,
128+ send = send ,
129+ receive = receive ,
130+ request = request ,
131+ route_handler = route_handler ,
132+ parameter_model = parameter_model ,
133+ )
134+ else :
135+ async with anyio .create_task_group () as tg :
136+ tg .start_soon (
137+ partial (
138+ self ._handle_response_cycle ,
139+ scope = scope ,
140+ send = send ,
141+ receive = receive ,
142+ request = request ,
143+ route_handler = route_handler ,
144+ parameter_model = parameter_model ,
145+ cancel_scope = tg .cancel_scope ,
146+ ),
147+ )
148+ tg .start_soon (self ._listen_for_disconnect , request , tg .cancel_scope )
149+ except Exception as exc :
150+ if isinstance (exc , ExceptionGroupType ):
151+ raise exc .exceptions [0 ] from exc
152+ raise
86153
87- if after_response_handler := route_handler .resolve_after_response ():
88- await after_response_handler (request )
89154 finally :
90155 if (form_data := ScopeState .from_scope (scope ).form ) is not Empty :
91156 await FormMultiDict .from_form_data (form_data ).close ()
0 commit comments