diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..f5bbbae577 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: patch + +This release fixes the issue that some coroutines in the WebSocket protocol handlers were never awaited if clients disconnected shortly after starting an operation. diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 5e9220b5af..76c30f4005 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -245,41 +245,42 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: elif hasattr(self.context, "connection_params"): self.context.connection_params = self.connection_params + operation = Operation( + self, + message.id, + operation_type, + message.payload.query, + message.payload.variables, + message.payload.operationName, + ) + + operation.task = asyncio.create_task(self.run_operation(operation)) + self.operations[message.id] = operation + + async def run_operation(self, operation: Operation) -> None: + """The operation task's top level method. Cleans-up and de-registers the operation once it is done.""" + # TODO: Handle errors in this method using self.handle_task_exception() + result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult] # Get an AsyncGenerator yielding the results - if operation_type == OperationType.SUBSCRIPTION: + if operation.operation_type == OperationType.SUBSCRIPTION: result_source = self.schema.subscribe( - query=message.payload.query, - variable_values=message.payload.variables, - operation_name=message.payload.operationName, + query=operation.query, + variable_values=operation.variables, + operation_name=operation.operation_name, context_value=self.context, root_value=self.root_value, ) else: result_source = self.schema.execute( - query=message.payload.query, - variable_values=message.payload.variables, + query=operation.query, + variable_values=operation.variables, context_value=self.context, root_value=self.root_value, - operation_name=message.payload.operationName, + operation_name=operation.operation_name, ) - operation = Operation(self, message.id, operation_type) - - # Create task to handle this subscription, reserve the operation ID - operation.task = asyncio.create_task( - self.operation_task(result_source, operation) - ) - self.operations[message.id] = operation - - async def operation_task( - self, - result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult], - operation: Operation, - ) -> None: - """The operation task's top level method. Cleans-up and de-registers the operation once it is done.""" - # TODO: Handle errors in this method using self.handle_task_exception() try: first_res_or_agen = await result_source # that's an immediate error we should end the operation @@ -340,17 +341,32 @@ async def reap_completed_tasks(self) -> None: class Operation: """A class encapsulating a single operation with its id. Helps enforce protocol state transition.""" - __slots__ = ["handler", "id", "operation_type", "completed", "task"] + __slots__ = [ + "handler", + "id", + "operation_type", + "query", + "variables", + "operation_name", + "completed", + "task", + ] def __init__( self, handler: BaseGraphQLTransportWSHandler, id: str, operation_type: OperationType, + query: str, + variables: Optional[Dict[str, Any]], + operation_name: Optional[str], ) -> None: self.handler = handler self.id = id self.operation_type = operation_type + self.query = query + self.variables = variables + self.operation_name = operation_name self.completed = False self.task: Optional[asyncio.Task] = None diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index af648380ff..3237bade18 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -5,7 +5,6 @@ from typing import ( TYPE_CHECKING, AsyncGenerator, - Awaitable, Dict, Optional, cast, @@ -37,7 +36,6 @@ if TYPE_CHECKING: from strawberry.http.async_base_view import AsyncWebSocketAdapter from strawberry.schema import BaseSchema - from strawberry.schema.subscribe import SubscriptionResult class BaseGraphQLWSHandler: @@ -136,15 +134,9 @@ async def handle_start(self, message: OperationMessage) -> None: if self.debug: pretty_print_graphql_operation(operation_name, query, variables) - result_source = self.schema.subscribe( - query=query, - variable_values=variables, - operation_name=operation_name, - context_value=self.context, - root_value=self.root_value, + result_handler = self.handle_async_results( + operation_id, query, operation_name, variables ) - - result_handler = self.handle_async_results(result_source, operation_id) self.tasks[operation_id] = asyncio.create_task(result_handler) async def handle_stop(self, message: OperationMessage) -> None: @@ -160,11 +152,19 @@ async def handle_keep_alive(self) -> None: async def handle_async_results( self, - result_source: Awaitable[SubscriptionResult], operation_id: str, + query: str, + operation_name: Optional[str], + variables: Optional[Dict[str, object]], ) -> None: try: - agen_or_err = await result_source + agen_or_err = await self.schema.subscribe( + query=query, + variable_values=variables, + operation_name=operation_name, + context_value=self.context, + root_value=self.root_value, + ) if isinstance(agen_or_err, PreExecutionError): assert agen_or_err.errors error_payload = agen_or_err.errors[0].formatted