From a8f064c39408b53e3712740b9a7d4f1fbde7d5b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 9 Jun 2023 10:28:27 +0000 Subject: [PATCH] Add a test for the task error handler --- .../graphql_transport_ws/handlers.py | 5 ++- tests/websockets/test_graphql_transport_ws.py | 39 ++++++++++++++++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index ceec00b4eb..596404d0eb 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -288,13 +288,14 @@ async def operation_task(self, operation: Operation) -> None: except asyncio.CancelledError: raise except Exception as error: + # Log any unhandled exceptions in the operation task await self.handle_task_exception(error) - # cleanup in case of something really unexpected finally: - # add this task to a list to be reaped later + # Clenaup. Remove the operation from the list of active operations if operation.id in self.operations: del self.operations[operation.id] # TODO: Stop collecting background tasks, not necessary. + # Add this task to a list to be reaped later self.completed_tasks.append(task) async def handle_operation( diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 21f55bb7e1..8fd21f1738 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -3,7 +3,7 @@ import sys import time from datetime import timedelta -from typing import AsyncGenerator, Type +from typing import Any, AsyncGenerator, Type from unittest.mock import patch try: @@ -16,6 +16,9 @@ from pytest_mock import MockerFixture from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL +from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( + BaseGraphQLTransportWSHandler, +) from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( CompleteMessage, ConnectionAckMessage, @@ -1068,3 +1071,37 @@ async def slow_get_context(ctxt): id="sub1", payload={"data": {"valueFromContext": "slow"}} ).as_dict() ) + + +async def test_task_error_handler(ws: WebSocketClient): + """ + Test that error handling works + """ + # can't use a simple Event here, because the handler may run + # on a different thread + wakeup = False + + # a replacement method which causes an error in th eTask + async def op(*args: Any, **kwargs: Any): + nonlocal wakeup + wakeup = True + raise ZeroDivisionError("test") + + with patch.object(BaseGraphQLTransportWSHandler, "task_logger") as logger: + with patch.object(BaseGraphQLTransportWSHandler, "handle_operation", op): + # send any old subscription request. It will raise an error + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(sleep:0) }" + ), + ).as_dict() + ) + + # wait for the error to be logged + while not wakeup: + await asyncio.sleep(0.01) + # and another little bit, for the thread to finish + await asyncio.sleep(0.01) + assert logger.exception.called