Skip to content

Commit

Permalink
Fix result source was never awaited if clients disconnect fast (#3687)
Browse files Browse the repository at this point in the history
* Fix result source may never get awaited

* Add release file
  • Loading branch information
DoctorJohn authored Nov 5, 2024
1 parent 6839371 commit 3561948
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 35 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

This release fixes the issue that some coroutines in the WebSocket protocol handlers were never awaited if clients disconnected shortly after starting an operation.
62 changes: 39 additions & 23 deletions strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,41 +245,42 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:
elif hasattr(self.context, "connection_params"):
self.context.connection_params = self.connection_params

operation = Operation(
self,
message.id,
operation_type,
message.payload.query,
message.payload.variables,
message.payload.operationName,
)

operation.task = asyncio.create_task(self.run_operation(operation))
self.operations[message.id] = operation

async def run_operation(self, operation: Operation) -> None:
"""The operation task's top level method. Cleans-up and de-registers the operation once it is done."""
# TODO: Handle errors in this method using self.handle_task_exception()

result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult]

# Get an AsyncGenerator yielding the results
if operation_type == OperationType.SUBSCRIPTION:
if operation.operation_type == OperationType.SUBSCRIPTION:
result_source = self.schema.subscribe(
query=message.payload.query,
variable_values=message.payload.variables,
operation_name=message.payload.operationName,
query=operation.query,
variable_values=operation.variables,
operation_name=operation.operation_name,
context_value=self.context,
root_value=self.root_value,
)
else:
result_source = self.schema.execute(
query=message.payload.query,
variable_values=message.payload.variables,
query=operation.query,
variable_values=operation.variables,
context_value=self.context,
root_value=self.root_value,
operation_name=message.payload.operationName,
operation_name=operation.operation_name,
)

operation = Operation(self, message.id, operation_type)

# Create task to handle this subscription, reserve the operation ID
operation.task = asyncio.create_task(
self.operation_task(result_source, operation)
)
self.operations[message.id] = operation

async def operation_task(
self,
result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult],
operation: Operation,
) -> None:
"""The operation task's top level method. Cleans-up and de-registers the operation once it is done."""
# TODO: Handle errors in this method using self.handle_task_exception()
try:
first_res_or_agen = await result_source
# that's an immediate error we should end the operation
Expand Down Expand Up @@ -340,17 +341,32 @@ async def reap_completed_tasks(self) -> None:
class Operation:
"""A class encapsulating a single operation with its id. Helps enforce protocol state transition."""

__slots__ = ["handler", "id", "operation_type", "completed", "task"]
__slots__ = [
"handler",
"id",
"operation_type",
"query",
"variables",
"operation_name",
"completed",
"task",
]

def __init__(
self,
handler: BaseGraphQLTransportWSHandler,
id: str,
operation_type: OperationType,
query: str,
variables: Optional[Dict[str, Any]],
operation_name: Optional[str],
) -> None:
self.handler = handler
self.id = id
self.operation_type = operation_type
self.query = query
self.variables = variables
self.operation_name = operation_name
self.completed = False
self.task: Optional[asyncio.Task] = None

Expand Down
24 changes: 12 additions & 12 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import (
TYPE_CHECKING,
AsyncGenerator,
Awaitable,
Dict,
Optional,
cast,
Expand Down Expand Up @@ -37,7 +36,6 @@
if TYPE_CHECKING:
from strawberry.http.async_base_view import AsyncWebSocketAdapter
from strawberry.schema import BaseSchema
from strawberry.schema.subscribe import SubscriptionResult


class BaseGraphQLWSHandler:
Expand Down Expand Up @@ -136,15 +134,9 @@ async def handle_start(self, message: OperationMessage) -> None:
if self.debug:
pretty_print_graphql_operation(operation_name, query, variables)

result_source = self.schema.subscribe(
query=query,
variable_values=variables,
operation_name=operation_name,
context_value=self.context,
root_value=self.root_value,
result_handler = self.handle_async_results(
operation_id, query, operation_name, variables
)

result_handler = self.handle_async_results(result_source, operation_id)
self.tasks[operation_id] = asyncio.create_task(result_handler)

async def handle_stop(self, message: OperationMessage) -> None:
Expand All @@ -160,11 +152,19 @@ async def handle_keep_alive(self) -> None:

async def handle_async_results(
self,
result_source: Awaitable[SubscriptionResult],
operation_id: str,
query: str,
operation_name: Optional[str],
variables: Optional[Dict[str, object]],
) -> None:
try:
agen_or_err = await result_source
agen_or_err = await self.schema.subscribe(
query=query,
variable_values=variables,
operation_name=operation_name,
context_value=self.context,
root_value=self.root_value,
)
if isinstance(agen_or_err, PreExecutionError):
assert agen_or_err.errors
error_payload = agen_or_err.errors[0].formatted
Expand Down

0 comments on commit 3561948

Please sign in to comment.