From b466ecf8675d02b6250f94dbd0da059445c82ccc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Tue, 11 Jul 2023 14:40:12 +0000 Subject: [PATCH 01/22] schema.subscribe is now an async generator. Apply extension contexts while executing. --- strawberry/schema/base.py | 16 +- strawberry/schema/execute.py | 139 +++++++++++++++++- strawberry/schema/schema.py | 33 +++-- .../graphql_transport_ws/handlers.py | 26 +--- .../protocols/graphql_ws/handlers.py | 11 +- 5 files changed, 181 insertions(+), 44 deletions(-) diff --git a/strawberry/schema/base.py b/strawberry/schema/base.py index a1c286c6d0..cb6aafefbc 100644 --- a/strawberry/schema/base.py +++ b/strawberry/schema/base.py @@ -2,7 +2,17 @@ from abc import abstractmethod from functools import lru_cache -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Dict, + Iterable, + List, + Optional, + Type, + Union, +) from typing_extensions import Protocol from strawberry.utils.logging import StrawberryLogger @@ -55,14 +65,14 @@ def execute_sync( raise NotImplementedError @abstractmethod - async def subscribe( + def subscribe( self, query: str, variable_values: Optional[Dict[str, Any]] = None, context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, - ) -> Any: + ) -> AsyncGenerator[Any, None]: raise NotImplementedError @abstractmethod diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index af0bd07a7f..f8cd0b2974 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -4,6 +4,8 @@ from inspect import isawaitable from typing import ( TYPE_CHECKING, + AsyncGenerator, + AsyncIterable, Awaitable, Callable, Iterable, @@ -17,8 +19,10 @@ cast, ) +from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLError, parse from graphql import execute as original_execute +from graphql import subscribe as original_subscribe from graphql.validation import validate from strawberry.exceptions import MissingQueryError @@ -31,7 +35,6 @@ from typing_extensions import NotRequired, Unpack from graphql import ExecutionContext as GraphQLExecutionContext - from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLSchema from graphql.language import DocumentNode from graphql.validation import ASTValidationRule @@ -272,3 +275,137 @@ def execute_sync( errors=execution_context.result.errors, extensions=extensions_runner.get_extensions_results_sync(), ) + + +async def subscribe( + schema: GraphQLSchema, + *, + extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]], + execution_context: ExecutionContext, + execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, + process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], +) -> AsyncGenerator[Tuple[bool, ExecutionResult], None]: + """ + The graphql-core subscribe function returns either an ExecutionResult or an + AsyncGenerator[ExecutionResult, None]. The former is returned in case of an error + during parsing or validation. + Because we need to maintain execution context, we cannot return an + async generator, we must _be_ an async generator. So we yield a + (bool, ExecutionResult) tuple, where the bool indicates whether the result is an + potentially multiple execution result or a single result. + A False value indicates an single result, most likely an intial + failure (and no more values will be yielded) whereas a True value indicates a + successful subscription, and more values may be yielded. + """ + + extensions_runner = SchemaExtensionsRunner( + execution_context=execution_context, + extensions=list(extensions), + ) + + # unlike execute(), the entire operation, including the results hooks, + # is run within the operation() hook. + async with extensions_runner.operation(): + # Note: In graphql-core the schema would be validated here but in + # Strawberry we are validating it at initialisation time instead + if not execution_context.query: + raise MissingQueryError() + + async with extensions_runner.parsing(): + try: + if not execution_context.graphql_document: + execution_context.graphql_document = parse_document( + execution_context.query, **execution_context.parse_options + ) + + except GraphQLError as error: + execution_context.errors = [error] + process_errors([error], execution_context) + yield False, ExecutionResult( + data=None, + errors=[error], + extensions=await extensions_runner.get_extensions_results(), + ) + # the generator is usually closed here, so the following is not + # reached + return # pragma: no cover + + except Exception as error: # pragma: no cover + error = GraphQLError(str(error), original_error=error) + execution_context.errors = [error] + process_errors([error], execution_context) + yield False, ExecutionResult( + data=None, + errors=[error], + extensions=await extensions_runner.get_extensions_results(), + ) + return # pragma: no cover + + async with extensions_runner.validation(): + _run_validation(execution_context) + if execution_context.errors: + process_errors(execution_context.errors, execution_context) + yield False, ExecutionResult(data=None, errors=execution_context.errors) + return # pragma: no cover + + async def process_result(result: GraphQLExecutionResult): + execution_context.result = result + # Also set errors on the execution_context so that it's easier + # to access in extensions + if result.errors: + execution_context.errors = result.errors + + # Run the `Schema.process_errors` function here before + # extensions have a chance to modify them (see the MaskErrors + # extension). That way we can log the original errors but + # only return a sanitised version to the client. + process_errors(result.errors, execution_context) + + return ExecutionResult( + data=execution_context.result.data, + errors=execution_context.result.errors, + extensions=await extensions_runner.get_extensions_results(), + ) + + async with extensions_runner.executing(): + # currently original_subscribe is an async function. A future release + # of graphql-core will make it optionally awaitable + result: Union[AsyncIterable[GraphQLExecutionResult], GraphQLExecutionResult] + result_or_awaitable = original_subscribe( + schema, + execution_context.graphql_document, + root_value=execution_context.root_value, + context_value=execution_context.context, + variable_values=execution_context.variables, + operation_name=execution_context.operation_name, + ) + if isawaitable(result_or_awaitable): + result = await cast( + Awaitable[ + Union[ + AsyncIterable["GraphQLExecutionResult"], + "GraphQLExecutionResult", + ] + ], + result_or_awaitable, + ) + else: # pragma: no cover + result = cast( + Union[ + AsyncIterable["GraphQLExecutionResult"], + "GraphQLExecutionResult", + ], + result_or_awaitable, + ) + + if isinstance(result, GraphQLExecutionResult): + yield False, await process_result(result) + return + + aiterator = result.__aiter__() + try: + async for result in aiterator: + yield True, await process_result(result) + finally: + if hasattr(aiterator, "aclose"): + await aiterator.aclose() diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index b43963d9b5..b587ef36d3 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -5,7 +5,7 @@ from typing import ( TYPE_CHECKING, Any, - AsyncIterator, + AsyncGenerator, Dict, Iterable, List, @@ -20,10 +20,8 @@ GraphQLNonNull, GraphQLSchema, get_introspection_query, - parse, validate_schema, ) -from graphql.execution import subscribe from graphql.type.directives import specified_directives from strawberry import relay @@ -43,11 +41,10 @@ from . import compat from .base import BaseSchema from .config import StrawberryConfig -from .execute import execute, execute_sync +from .execute import execute, execute_sync, subscribe if TYPE_CHECKING: from graphql import ExecutionContext as GraphQLExecutionContext - from graphql import ExecutionResult as GraphQLExecutionResult from strawberry.custom_scalar import ScalarDefinition, ScalarWrapper from strawberry.directive import StrawberryDirective @@ -305,22 +302,30 @@ def execute_sync( async def subscribe( self, - # TODO: make this optional when we support extensions - query: str, + query: Optional[str], variable_values: Optional[Dict[str, Any]] = None, context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, - ) -> Union[AsyncIterator[GraphQLExecutionResult], GraphQLExecutionResult]: - return await subscribe( - self._schema, - parse(query), + ) -> AsyncGenerator[tuple[bool, ExecutionResult], None]: + execution_context = ExecutionContext( + query=query, + schema=self, + context=context_value, root_value=root_value, - context_value=context_value, - variable_values=variable_values, - operation_name=operation_name, + variables=variable_values, + provided_operation_name=operation_name, ) + async for result in subscribe( + self._schema, + extensions=self.get_extensions(), + execution_context_class=self.execution_context_class, + execution_context=execution_context, + process_errors=self.process_errors, + ): + yield result + def _resolve_node_ids(self): for concrete_type in self.schema_converter.type_map.values(): type_def = concrete_type.definition diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index f20c674124..f26458c09d 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -15,7 +15,6 @@ Optional, ) -from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLError, GraphQLSyntaxError, parse from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( @@ -246,7 +245,7 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: # Get an AsyncGenerator yielding the results if operation_type == OperationType.SUBSCRIPTION: - result_source = await self.schema.subscribe( + result_source = self.schema.subscribe( query=message.payload.query, variable_values=message.payload.variables, operation_name=message.payload.operationName, @@ -256,7 +255,7 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: else: # create AsyncGenerator returning a single result async def get_result_source() -> AsyncIterator[ExecutionResult]: - yield await self.schema.execute( + yield False, await self.schema.execute( query=message.payload.query, variable_values=message.payload.variables, context_value=context, @@ -268,15 +267,6 @@ async def get_result_source() -> AsyncIterator[ExecutionResult]: operation = Operation(self, message.id, operation_type) - # Handle initial validation errors - if isinstance(result_source, GraphQLExecutionResult): - assert operation_type == OperationType.SUBSCRIPTION - assert result_source.errors - payload = [err.formatted for err in result_source.errors] - await self.send_message(ErrorMessage(id=message.id, payload=payload)) - self.schema.process_errors(result_source.errors) - return - # Create task to handle this subscription, reserve the operation ID operation.task = asyncio.create_task( self.operation_task(result_source, operation) @@ -316,21 +306,15 @@ async def handle_async_results( operation: Operation, ) -> None: try: - async for result in result_source: - if ( - result.errors - and operation.operation_type != OperationType.SUBSCRIPTION - ): + async for multiple, result in result_source: + if result.errors and not multiple: error_payload = [err.formatted for err in result.errors] error_message = ErrorMessage(id=operation.id, payload=error_payload) await operation.send_message(error_message) - # don't need to call schema.process_errors() here because - # it was already done by schema.execute() return else: next_payload = {"data": result.data} if result.errors: - self.schema.process_errors(result.errors) next_payload["errors"] = [ err.formatted for err in result.errors ] @@ -345,6 +329,8 @@ async def handle_async_results( await operation.send_message(error_message) self.schema.process_errors([error]) return + finally: + await result_source.aclose() def forget_id(self, id: str) -> None: # de-register the operation id making it immediately available diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index c4863ff49e..b10ff327f2 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -190,13 +190,12 @@ async def handle_async_results( await self.send_message(GQL_COMPLETE, operation_id, None) async def cleanup_operation(self, operation_id: str) -> None: - await self.subscriptions[operation_id].aclose() - del self.subscriptions[operation_id] - - self.tasks[operation_id].cancel() + iterator = self.subscriptions.pop(operation_id) + task = self.tasks.pop(operation_id) + task.cancel() with suppress(BaseException): - await self.tasks[operation_id] - del self.tasks[operation_id] + await task + await iterator.aclose() async def send_message( self, From 9d5d16f5d2232909926c477f8a2656e02e0cf5dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 27 May 2023 10:29:16 +0000 Subject: [PATCH 02/22] Update graphql_ws handler --- .../protocols/graphql_ws/handlers.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index b10ff327f2..4cbb4d8962 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -5,7 +5,6 @@ from contextlib import suppress from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, cast -from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLError from strawberry.subscriptions.protocols.graphql_ws import ( @@ -124,7 +123,7 @@ async def handle_start(self, message: OperationMessage) -> None: pretty_print_graphql_operation(operation_name, query, variables) try: - result_source = await self.schema.subscribe( + result_source = self.schema.subscribe( query=query, variable_values=variables, operation_name=operation_name, @@ -137,13 +136,6 @@ async def handle_start(self, message: OperationMessage) -> None: self.schema.process_errors([error]) return - if isinstance(result_source, GraphQLExecutionResult): - assert result_source.errors - error_payload = result_source.errors[0].formatted - await self.send_message(GQL_ERROR, operation_id, error_payload) - self.schema.process_errors(result_source.errors) - return - self.subscriptions[operation_id] = result_source result_handler = self.handle_async_results(result_source, operation_id) self.tasks[operation_id] = asyncio.create_task(result_handler) @@ -164,7 +156,13 @@ async def handle_async_results( operation_id: str, ) -> None: try: - async for result in result_source: + async for success, result in result_source: + if not success: + assert result.errors + error_payload = result.errors[0].formatted + await self.send_message(GQL_ERROR, operation_id, error_payload) + self.schema.process_errors(result.errors) + return payload = {"data": result.data} if result.errors: payload["errors"] = [err.formatted for err in result.errors] @@ -186,6 +184,8 @@ async def handle_async_results( {"data": None, "errors": [error.formatted]}, ) self.schema.process_errors([error]) + finally: + await result_source.aclose() await self.send_message(GQL_COMPLETE, operation_id, None) From 46c9aa198f6f58fe24deaf9d2e131c3db09d997f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 27 May 2023 22:15:50 +0000 Subject: [PATCH 03/22] Remove unused code --- strawberry/schema/execute.py | 4 +--- strawberry/schema/schema.py | 1 - .../protocols/graphql_ws/handlers.py | 20 +++++++------------ 3 files changed, 8 insertions(+), 17 deletions(-) diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index f8cd0b2974..e04bdf02f5 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -282,7 +282,6 @@ async def subscribe( *, extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]], execution_context: ExecutionContext, - execution_context_class: Optional[Type[GraphQLExecutionContext]] = None, process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], ) -> AsyncGenerator[Tuple[bool, ExecutionResult], None]: """ @@ -308,8 +307,7 @@ async def subscribe( async with extensions_runner.operation(): # Note: In graphql-core the schema would be validated here but in # Strawberry we are validating it at initialisation time instead - if not execution_context.query: - raise MissingQueryError() + assert execution_context.query is not None async with extensions_runner.parsing(): try: diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index b587ef36d3..45edd2f63b 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -320,7 +320,6 @@ async def subscribe( async for result in subscribe( self._schema, extensions=self.get_extensions(), - execution_context_class=self.execution_context_class, execution_context=execution_context, process_errors=self.process_errors, ): diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 4cbb4d8962..1828a80914 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -122,19 +122,13 @@ async def handle_start(self, message: OperationMessage) -> None: if self.debug: pretty_print_graphql_operation(operation_name, query, variables) - try: - result_source = self.schema.subscribe( - query=query, - variable_values=variables, - operation_name=operation_name, - context_value=context, - root_value=root_value, - ) - except GraphQLError as error: - error_payload = error.formatted - await self.send_message(GQL_ERROR, operation_id, error_payload) - self.schema.process_errors([error]) - return + result_source = self.schema.subscribe( + query=query, + variable_values=variables, + operation_name=operation_name, + context_value=context, + root_value=root_value, + ) self.subscriptions[operation_id] = result_source result_handler = self.handle_async_results(result_source, operation_id) From e0cd1fcddd7059d7178d60f1879dc823eb6ed05b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 27 May 2023 10:25:52 +0000 Subject: [PATCH 04/22] fix tests now that field verification is done by Strawberry --- tests/websockets/test_graphql_transport_ws.py | 2 +- tests/websockets/test_graphql_ws.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 1ab738fb73..1ccda7b5a7 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -403,7 +403,7 @@ async def test_subscription_field_errors(ws: WebSocketClient): assert response["payload"][0]["locations"] == [{"line": 1, "column": 16}] assert ( response["payload"][0]["message"] - == "The subscription field 'notASubscriptionField' is not defined." + == "Cannot query field 'notASubscriptionField' on type 'Subscription'." ) process_errors.assert_called_once() diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 83b7e66782..880a504417 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -254,7 +254,9 @@ async def test_subscription_field_error(ws: WebSocketClient): assert response["id"] == "invalid-field" assert response["payload"] == { "locations": [{"line": 1, "column": 16}], - "message": ("The subscription field 'notASubscriptionField' is not defined."), + "message": ( + "Cannot query field 'notASubscriptionField' on type 'Subscription'." + ), } From 22855f4f46264fb7ab9f14d8c93aea66424fda99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 27 May 2023 12:22:19 +0000 Subject: [PATCH 05/22] Return "extensions" as part of ExecutionResult if present. --- .../graphql_transport_ws/handlers.py | 2 + .../protocols/graphql_ws/handlers.py | 2 + tests/channels/test_layers.py | 6 + tests/websockets/test_graphql_transport_ws.py | 122 +++++++----------- tests/websockets/test_graphql_ws.py | 23 ++++ 5 files changed, 79 insertions(+), 76 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index f26458c09d..fbd71e60d8 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -318,6 +318,8 @@ async def handle_async_results( next_payload["errors"] = [ err.formatted for err in result.errors ] + if result.extensions: + next_payload["extensions"] = result.extensions next_message = NextMessage(id=operation.id, payload=next_payload) await operation.send_message(next_message) except Exception as error: diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 1828a80914..5b94562157 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -160,6 +160,8 @@ async def handle_async_results( payload = {"data": result.data} if result.errors: payload["errors"] = [err.formatted for err in result.errors] + if result.extensions: + payload["extensions"] = result.extensions await self.send_message(GQL_DATA, operation_id, payload) # log errors after send_message to prevent potential # slowdown of sending result diff --git a/tests/channels/test_layers.py b/tests/channels/test_layers.py index b868a0826b..d5e09e1bcf 100644 --- a/tests/channels/test_layers.py +++ b/tests/channels/test_layers.py @@ -91,6 +91,7 @@ async def test_channel_listen(ws: WebsocketCommunicator): ) response = await ws.receive_json_from() + response["payload"].pop("extensions", None) assert ( response == NextMessage( @@ -137,6 +138,7 @@ async def test_channel_listen_with_confirmation(ws: WebsocketCommunicator): ) response = await ws.receive_json_from() + response["payload"].pop("extensions", None) assert ( response == NextMessage( @@ -315,6 +317,7 @@ async def test_channel_listen_group(ws: WebsocketCommunicator): }, ) response = await ws.receive_json_from() + response["payload"].pop("extensions", None) assert ( response == NextMessage( @@ -331,6 +334,7 @@ async def test_channel_listen_group(ws: WebsocketCommunicator): ) response = await ws.receive_json_from() + response["payload"].pop("extensions", None) assert ( response == NextMessage( @@ -377,6 +381,7 @@ async def test_channel_listen_group_cm(ws: WebsocketCommunicator): }, ) response = await ws.receive_json_from() + response["payload"].pop("extensions", None) assert ( response == NextMessage( @@ -393,6 +398,7 @@ async def test_channel_listen_group_cm(ws: WebsocketCommunicator): ) response = await ws.receive_json_from() + response["payload"].pop("extensions", None) assert ( response == NextMessage( diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 1ccda7b5a7..92cdde90bd 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -54,6 +54,20 @@ async def ws(ws_raw: WebSocketClient) -> WebSocketClient: return ws_raw +def assert_next(response, id, data, extensions=None): + """ + Assert that the NextMessage payload contains the provided data. + If extensions is provided, it will also assert that the + extensions are present + """ + assert response["type"] == "next" + assert response["id"] == id + assert set(response["payload"].keys()) <= {"data", "errors", "extensions"} + assert response["payload"]["data"] == data + if extensions is not None: + assert response["payload"]["extensions"] == extensions + + async def test_unknown_message_type(ws_raw: WebSocketClient): ws = ws_raw @@ -158,13 +172,7 @@ async def test_connection_init_timeout_cancellation( ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", - payload={"data": {"debug": {"isConnectionInitTimeoutTaskDone": True}}}, - ).as_dict() - ) + assert_next(response, "sub1", {"debug": {"isConnectionInitTimeoutTaskDone": True}}) @pytest.mark.xfail(reason="This test is flaky") @@ -260,10 +268,7 @@ async def test_server_sent_ping(ws: WebSocketClient): await ws.send_json(PongMessage().as_dict()) response = await ws.receive_json() - assert ( - response - == NextMessage(id="sub1", payload={"data": {"requestPing": True}}).as_dict() - ) + assert_next(response, "sub1", {"requestPing": True}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -327,9 +332,7 @@ async def test_reused_operation_ids(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response == NextMessage(id="sub1", payload={"data": {"echo": "Hi"}}).as_dict() - ) + assert_next(response, "sub1", {"echo": "Hi"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -346,9 +349,7 @@ async def test_reused_operation_ids(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response == NextMessage(id="sub1", payload={"data": {"echo": "Hi"}}).as_dict() - ) + assert_next(response, "sub1", {"echo": "Hi"}) async def test_simple_subscription(ws: WebSocketClient): @@ -362,9 +363,7 @@ async def test_simple_subscription(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response == NextMessage(id="sub1", payload={"data": {"echo": "Hi"}}).as_dict() - ) + assert_next(response, "sub1", {"echo": "Hi"}) await ws.send_json(CompleteMessage(id="sub1").as_dict()) @@ -428,12 +427,7 @@ async def test_subscription_cancellation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub2", payload={"data": {"debug": {"numActiveResultHandlers": 2}}} - ).as_dict() - ) + assert_next(response, "sub2", {"debug": {"numActiveResultHandlers": 2}}) response = await ws.receive_json() assert response == CompleteMessage(id="sub2").as_dict() @@ -450,12 +444,7 @@ async def test_subscription_cancellation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub3", payload={"data": {"debug": {"numActiveResultHandlers": 1}}} - ).as_dict() - ) + assert_next(response, "sub3", {"debug": {"numActiveResultHandlers": 1}}) response = await ws.receive_json() assert response == CompleteMessage(id="sub3").as_dict() @@ -543,10 +532,7 @@ async def test_single_result_query_operation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage(id="sub1", payload={"data": {"hello": "Hello world"}}).as_dict() - ) + assert_next(response, "sub1", {"hello": "Hello world"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -568,12 +554,7 @@ async def test_single_result_query_operation_async(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"asyncHello": "Hello Dolly"}} - ).as_dict() - ) + assert_next(response, "sub1", {"asyncHello": "Hello Dolly"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -607,12 +588,7 @@ async def test_single_result_query_operation_overlapped(ws: WebSocketClient): # we expect the response to the second query to arrive first response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub2", payload={"data": {"asyncHello": "Hello Dolly"}} - ).as_dict() - ) + assert_next(response, "sub2", {"asyncHello": "Hello Dolly"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub2").as_dict() @@ -626,10 +602,7 @@ async def test_single_result_mutation_operation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage(id="sub1", payload={"data": {"hello": "strawberry"}}).as_dict() - ) + assert_next(response, "sub1", {"hello": "strawberry"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -653,12 +626,7 @@ async def test_single_result_operation_selection(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"hello": "Hello Strawberry"}} - ).as_dict() - ) + assert_next(response, "sub1", {"hello": "Hello Strawberry"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -808,12 +776,7 @@ async def test_injects_connection_params(ws_raw: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"connectionParams": "rocks"}} - ).as_dict() - ) + assert_next(response, "sub1", {"connectionParams": "rocks"}) await ws.send_json(CompleteMessage(id="sub1").as_dict()) @@ -862,12 +825,7 @@ async def test_subsciption_cancel_finalization_delay(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"longFinalizer": "hello"}} - ).as_dict() - ) + assert_next(response, "sub1", {"longFinalizer": "hello"}) # now cancel the stubscription and send a new query. We expect the response # to the new query to arrive immediately, without waiting for the finalizer @@ -952,9 +910,7 @@ async def test_subscription_errors_continue(ws: WebSocketClient): ) response = await ws.receive_json() - assert response["type"] == NextMessage.type - assert response["id"] == "sub1" - assert response["payload"]["data"] == {"flavorsInvalid": "VANILLA"} + assert_next(response, "sub1", {"flavorsInvalid": "VANILLA"}) response = await ws.receive_json() assert response["type"] == NextMessage.type @@ -965,10 +921,24 @@ async def test_subscription_errors_continue(ws: WebSocketClient): process_errors.assert_called_once() response = await ws.receive_json() - assert response["type"] == NextMessage.type - assert response["id"] == "sub1" - assert response["payload"]["data"] == {"flavorsInvalid": "CHOCOLATE"} + assert_next(response, "sub1", {"flavorsInvalid": "CHOCOLATE"}) response = await ws.receive_json() assert response["type"] == CompleteMessage.type assert response["id"] == "sub1" + + +async def test_extensions(ws: WebSocketClient): + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query='subscription { echo(message: "Hi") }' + ), + ).as_dict() + ) + + response = await ws.receive_json() + assert_next(response, "sub1", {"echo": "Hi"}, extensions={"example": "example"}) + + await ws.send_json(CompleteMessage(id="sub1").as_dict()) diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 880a504417..a8c8118c40 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -558,3 +558,26 @@ async def test_rejects_connection_params(aiohttp_app_client: HttpClient): # make sure the WebSocket is disconnected now await ws.receive(timeout=2) # receive close assert ws.closed + + +async def test_extensions(ws: WebSocketClient): + await ws.send_json( + { + "type": GQL_START, + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } + ) + + response = await ws.receive_json() + assert response["type"] == GQL_DATA + assert response["id"] == "demo" + assert response["payload"]["data"] == {"echo": "Hi"} + assert response["payload"]["extensions"] == {"example": "example"} + + await ws.send_json({"type": GQL_STOP, "id": "demo"}) + response = await ws.receive_json() + assert response["type"] == GQL_COMPLETE + assert response["id"] == "demo" From 7213377d705dcced72e4629c6ec9945389e4e980 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 8 Mar 2023 22:39:22 +0000 Subject: [PATCH 06/22] Add tests for blocking operation validation --- tests/views/schema.py | 25 +++++++++ tests/websockets/test_graphql_transport_ws.py | 56 +++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/tests/views/schema.py b/tests/views/schema.py index 14a25b0c0d..c2207d3fb3 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -20,6 +20,19 @@ def has_permission(self, source: Any, info: strawberry.Info, **kwargs: Any) -> b return False +class ConditionalFailPermission(BasePermission): + @property + def message(self): + return f"failed after sleep {self.sleep}" + + async def has_permission(self, source, info, **kwargs: Any) -> bool: + self.sleep = kwargs.get("sleep", None) + self.fail = kwargs.get("fail", True) + if self.sleep is not None: + await asyncio.sleep(kwargs["sleep"]) + return not self.fail + + class MyExtension(SchemaExtension): def get_results(self) -> Dict[str, str]: return {"example": "example"} @@ -86,6 +99,12 @@ async def async_hello(self, name: Optional[str] = None, delay: float = 0) -> str def always_fail(self) -> Optional[str]: return "Hey" + @strawberry.field(permission_classes=[ConditionalFailPermission]) + def conditional_fail( + self, sleep: Optional[float] = None, fail: bool = False + ) -> str: + return "Hey" + @strawberry.field async def error(self, message: str) -> AsyncGenerator[str, None]: yield GraphQLError(message) # type: ignore @@ -268,6 +287,12 @@ async def long_finalizer( finally: await asyncio.sleep(delay) + @strawberry.subscription(permission_classes=[ConditionalFailPermission]) + async def conditional_fail( + self, sleep: Optional[float] = None, fail: bool = False + ) -> AsyncGenerator[str, None]: + yield "Hey" + class Schema(strawberry.Schema): def process_errors( diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 92cdde90bd..38b94c7357 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -942,3 +942,59 @@ async def test_extensions(ws: WebSocketClient): assert_next(response, "sub1", {"echo": "Hi"}, extensions={"example": "example"}) await ws.send_json(CompleteMessage(id="sub1").as_dict()) + + +async def test_long_validation_concurrent_query(ws: WebSocketClient): + """ + Test that the websocket is not blocked while validating a + single-result-operation + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { conditionalFail(sleep:0.1) }" + ), + ).as_dict() + ) + await ws.send_json( + SubscribeMessage( + id="sub2", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:false) }" + ), + ).as_dict() + ) + + # we expect the second query to arrive first, because the + # first query is stuck in validation + response = await ws.receive_json() + assert_next(response, "sub2", {"conditionalFail": "Hey"}) + + +async def test_long_validation_concurrent_subscription(ws: WebSocketClient): + """ + Test that the websocket is not blocked while validating a + subscription + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(sleep:0.1) }" + ), + ).as_dict() + ) + await ws.send_json( + SubscribeMessage( + id="sub2", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:false) }" + ), + ).as_dict() + ) + + # we expect the second query to arrive first, because the + # first operation is stuck in validation + response = await ws.receive_json() + assert_next(response, "sub2", {"conditionalFail": "Hey"}) From c8cec42d648f7d8cbd4abb63c3902386ace3f24d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 9 Apr 2023 11:22:47 +0000 Subject: [PATCH 07/22] Add some error/validation test cases --- tests/websockets/test_graphql_transport_ws.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 38b94c7357..91a9ad9624 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -407,6 +407,28 @@ async def test_subscription_field_errors(ws: WebSocketClient): process_errors.assert_called_once() +async def test_query_field_errors(ws: WebSocketClient): + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { notASubscriptionField }", + ), + ).as_dict() + ) + + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") is None + assert response["payload"][0]["locations"] == [{"line": 1, "column": 9}] + assert ( + response["payload"][0]["message"] + == "Cannot query field 'notASubscriptionField' on type 'Query'." + ) + + async def test_subscription_cancellation(ws: WebSocketClient): await ws.send_json( SubscribeMessage( @@ -944,6 +966,50 @@ async def test_extensions(ws: WebSocketClient): await ws.send_json(CompleteMessage(id="sub1").as_dict()) +async def test_validation_query(ws: WebSocketClient): + """ + Test validation for query + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:true) }" + ), + ).as_dict() + ) + + # We expect an error message directly + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") == ["conditionalFail"] + assert response["payload"][0]["message"] == "failed after sleep None" + + +async def test_validation_subscription(ws: WebSocketClient): + """ + Test validation for subscription + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(fail:true) }" + ), + ).as_dict() + ) + + # We expect an error message directly + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") == ["conditionalFail"] + assert response["payload"][0]["message"] == "failed after sleep None" + + async def test_long_validation_concurrent_query(ws: WebSocketClient): """ Test that the websocket is not blocked while validating a From 6dedf22aee7de465137079c6cb3430b2ae9284dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 23 Apr 2023 20:41:02 +0000 Subject: [PATCH 08/22] Add release.md --- RELEASE.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..6ef7811880 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: minor + +Subscriptions now support Schema Extensions. From 4998305065af07bcdaa803695091de00e8d0a7fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 27 May 2023 15:58:48 +0000 Subject: [PATCH 09/22] Fix schema tests for new `subscribe()` signature --- docs/operations/testing.md | 5 ++- tests/schema/test_permission.py | 6 ++-- tests/schema/test_subscription.py | 54 ++++++++++++++----------------- 3 files changed, 29 insertions(+), 36 deletions(-) diff --git a/docs/operations/testing.md b/docs/operations/testing.md index a3aca51741..4d90e9f71e 100644 --- a/docs/operations/testing.md +++ b/docs/operations/testing.md @@ -136,10 +136,9 @@ async def test_subscription(): } """ - sub = await schema.subscribe(query) - index = 0 - async for result in sub: + async for ok, result in schema.subscribe(query): + assert ok assert not result.errors assert result.data == {"count": index} diff --git a/tests/schema/test_permission.py b/tests/schema/test_permission.py index 0cec4cde60..9213ce9be9 100644 --- a/tests/schema/test_permission.py +++ b/tests/schema/test_permission.py @@ -80,9 +80,9 @@ async def user( query = "subscription { user }" - result = await schema.subscribe(query) - - assert result.errors[0].message == "You are not authorized" + async for ok, result in schema.subscribe(query): + assert not ok + assert result.errors[0].message == "You are not authorized" @pytest.mark.asyncio diff --git a/tests/schema/test_subscription.py b/tests/schema/test_subscription.py index b63af92dcc..e1adc9dc8c 100644 --- a/tests/schema/test_subscription.py +++ b/tests/schema/test_subscription.py @@ -33,11 +33,10 @@ async def example(self) -> AsyncGenerator[str, None]: query = "subscription { example }" - sub = await schema.subscribe(query) - result = await sub.__anext__() - - assert not result.errors - assert result.data["example"] == "Hi" + async for ok, result in schema.subscribe(query): + assert ok + assert not result.errors + assert result.data["example"] == "Hi" @pytest.mark.asyncio @@ -89,11 +88,10 @@ async def example(self, name: str) -> AsyncGenerator[str, None]: query = 'subscription { example(name: "Nina") }' - sub = await schema.subscribe(query) - result = await sub.__anext__() - - assert not result.errors - assert result.data["example"] == "Hi Nina" + async for ok, result in schema.subscribe(query): + assert ok + assert not result.errors + assert result.data["example"] == "Hi Nina" requires_builtin_generics = pytest.mark.skipif( @@ -132,11 +130,10 @@ class Subscription: query = "subscription { example }" - sub = await schema.subscribe(query) - result = await sub.__anext__() - - assert not result.errors - assert result.data["example"] == "Hi" + async for ok, result in schema.subscribe(query): + assert ok + assert not result.errors + assert result.data["example"] == "Hi" @pytest.mark.asyncio @@ -165,11 +162,10 @@ async def example_with_union(self) -> AsyncGenerator[Union[A, B], None]: query = "subscription { exampleWithUnion { ... on A { a } } }" - sub = await schema.subscribe(query) - result = await sub.__anext__() - - assert not result.errors - assert result.data["exampleWithUnion"]["a"] == "Hi" + async for ok, result in schema.subscribe(query): + assert ok + assert not result.errors + assert result.data["exampleWithUnion"]["a"] == "Hi" del A, B @@ -204,11 +200,10 @@ async def example_with_annotated_union( query = "subscription { exampleWithAnnotatedUnion { ... on C { c } } }" - sub = await schema.subscribe(query) - result = await sub.__anext__() - - assert not result.errors - assert result.data["exampleWithAnnotatedUnion"]["c"] == "Hi" + async for ok, result in schema.subscribe(query): + assert ok + assert not result.errors + assert result.data["exampleWithAnnotatedUnion"]["c"] == "Hi" del C, D @@ -231,8 +226,7 @@ async def example( query = "subscription { example }" - sub = await schema.subscribe(query) - result = await sub.__anext__() - - assert not result.errors - assert result.data["example"] == "Hi" + async for ok, result in schema.subscribe(query): + assert ok + assert not result.errors + assert result.data["example"] == "Hi" From 41d053e1106231cb5625082624a417d770afbe36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 31 May 2023 00:00:48 +0000 Subject: [PATCH 10/22] Add extensive tests for extension hook execution while running subscriptions --- tests/http/clients/aiohttp.py | 3 +- tests/http/clients/asgi.py | 3 +- tests/http/clients/channels.py | 2 +- tests/http/clients/fastapi.py | 3 +- tests/http/clients/starlite.py | 3 +- tests/views/schema.py | 76 +++++++++++++++++++ tests/websockets/test_graphql_transport_ws.py | 49 +++++++++--- 7 files changed, 122 insertions(+), 17 deletions(-) diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index cd552e877c..14c3dcaeb6 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -15,7 +15,8 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .base import ( diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 72d9e95aa6..843ae9e082 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -16,7 +16,8 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .base import ( diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index ee31c8e88b..9a76867ce6 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -18,7 +18,7 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.http.typevars import Context, RootValue -from tests.views.schema import Query, schema +from tests.views.schema import Query, async_schema as schema from ..context import get_context from .base import ( diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index f271509f40..ca7fe820df 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -15,7 +15,8 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .asgi import AsgiWebSocketClient diff --git a/tests/http/clients/starlite.py b/tests/http/clients/starlite.py index 9af3a8206e..2db8593d14 100644 --- a/tests/http/clients/starlite.py +++ b/tests/http/clients/starlite.py @@ -15,7 +15,8 @@ from strawberry.starlite import make_graphql_controller from strawberry.starlite.controller import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .base import ( diff --git a/tests/views/schema.py b/tests/views/schema.py index c2207d3fb3..1e42c88f80 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import inspect from enum import Enum from typing import Any, AsyncGenerator, Dict, List, Optional, Union @@ -37,6 +38,74 @@ class MyExtension(SchemaExtension): def get_results(self) -> Dict[str, str]: return {"example": "example"} + def resolve(self, _next, root, info: Info, *args: Any, **kwargs: Any): + self.resolve_called() + return _next(root, info, *args, **kwargs) + + def resolve_called(self): + pass + + def lifecycle_called(self, event, phase): + pass + + def on_operation(self): + self.lifecycle_called("operation", "before") + yield + self.lifecycle_called("operation", "after") + + def on_validate(self): + self.lifecycle_called("validate", "before") + yield + self.lifecycle_called("validate", "after") + + def on_parse(self): + self.lifecycle_called("parse", "before") + yield + self.lifecycle_called("parse", "after") + + def on_execute(self): + self.lifecycle_called("execute", "before") + yield + self.lifecycle_called("execute", "after") + + +class MyAsyncExtension(SchemaExtension): + def get_results(self) -> Dict[str, str]: + return {"example": "example"} + + async def resolve(self, _next, root, info: Info, *args: Any, **kwargs: Any): + self.resolve_called() + result = _next(root, info, *args, **kwargs) + if inspect.isawaitable(result): + return await result + return result + + def resolve_called(self): + pass + + def lifecycle_called(self, event, phase): + pass + + async def on_operation(self): + self.lifecycle_called("operation", "before") + yield + self.lifecycle_called("operation", "after") + + async def on_validate(self): + self.lifecycle_called("validate", "before") + yield + self.lifecycle_called("validate", "after") + + async def on_parse(self): + self.lifecycle_called("parse", "before") + yield + self.lifecycle_called("parse", "after") + + async def on_execute(self): + self.lifecycle_called("execute", "before") + yield + self.lifecycle_called("execute", "after") + def _read_file(text_file: Upload) -> str: with contextlib.suppress(ModuleNotFoundError): @@ -310,3 +379,10 @@ def process_errors( subscription=Subscription, extensions=[MyExtension], ) + +async_schema = Schema( + query=Query, + mutation=Mutation, + subscription=Subscription, + extensions=[MyAsyncExtension], +) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 91a9ad9624..7ffa30d294 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -7,6 +7,7 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any, AsyncGenerator, Type from unittest.mock import Mock, patch +from unittest.mock import call as mock_call try: from unittest.mock import AsyncMock @@ -32,6 +33,8 @@ from tests.http.clients.base import DebuggableGraphQLTransportWSMixin from tests.views.schema import Schema +from ..views.schema import MyAsyncExtension + if TYPE_CHECKING: from ..http.clients.base import HttpClient, WebSocketClient @@ -951,19 +954,41 @@ async def test_subscription_errors_continue(ws: WebSocketClient): async def test_extensions(ws: WebSocketClient): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { echo(message: "Hi") }' - ), - ).as_dict() - ) - - response = await ws.receive_json() - assert_next(response, "sub1", {"echo": "Hi"}, extensions={"example": "example"}) + resolve_called = Mock() + lifecycle_called = Mock() + + with patch.object(MyAsyncExtension, "resolve_called", resolve_called): + with patch.object(MyAsyncExtension, "lifecycle_called", lifecycle_called): + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query='subscription { echo(message: "Hi") }' + ), + ).as_dict() + ) - await ws.send_json(CompleteMessage(id="sub1").as_dict()) + response = await ws.receive_json() + assert_next( + response, "sub1", {"echo": "Hi"}, extensions={"example": "example"} + ) + response = await ws.receive_json() + assert response == CompleteMessage(id="sub1").as_dict() + + # no resolvers called + assert resolve_called.call_count == 0 + + lifecycle_calls = lifecycle_called.call_args_list + assert lifecycle_calls == [ + mock_call("operation", "before"), + mock_call("parse", "before"), + mock_call("parse", "after"), + mock_call("validate", "before"), + mock_call("validate", "after"), + mock_call("execute", "before"), + mock_call("execute", "after"), + mock_call("operation", "after"), + ] async def test_validation_query(ws: WebSocketClient): From 8a039afcf8bb9dacc10c0c50e9fdba9285dfb626 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 31 May 2023 08:32:53 +0000 Subject: [PATCH 11/22] make extension test more robust, wait for old connections to drain. --- tests/views/schema.py | 42 ++++++++++++++++--- tests/websockets/test_graphql_transport_ws.py | 5 +++ 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/tests/views/schema.py b/tests/views/schema.py index 1e42c88f80..ed73ca40b8 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -35,12 +35,19 @@ async def has_permission(self, source, info, **kwargs: Any) -> bool: class MyExtension(SchemaExtension): + # a counter to keep track of how many operations are active + active_counter = 0 + def get_results(self) -> Dict[str, str]: return {"example": "example"} def resolve(self, _next, root, info: Info, *args: Any, **kwargs: Any): - self.resolve_called() - return _next(root, info, *args, **kwargs) + self.active_counter += 1 + try: + self.resolve_called() + return _next(root, info, *args, **kwargs) + finally: + self.active_counter -= 1 def resolve_called(self): pass @@ -50,35 +57,50 @@ def lifecycle_called(self, event, phase): def on_operation(self): self.lifecycle_called("operation", "before") + self.active_counter += 1 yield self.lifecycle_called("operation", "after") + self.active_counter -= 1 def on_validate(self): self.lifecycle_called("validate", "before") + self.active_counter += 1 yield self.lifecycle_called("validate", "after") + self.active_counter -= 1 def on_parse(self): self.lifecycle_called("parse", "before") + self.active_counter += 1 yield self.lifecycle_called("parse", "after") + self.active_counter -= 1 def on_execute(self): self.lifecycle_called("execute", "before") + self.active_counter += 1 yield self.lifecycle_called("execute", "after") + self.active_counter -= 1 class MyAsyncExtension(SchemaExtension): + # a counter to keep track of how many operations are active + active_counter = 0 + def get_results(self) -> Dict[str, str]: return {"example": "example"} async def resolve(self, _next, root, info: Info, *args: Any, **kwargs: Any): self.resolve_called() - result = _next(root, info, *args, **kwargs) - if inspect.isawaitable(result): - return await result - return result + self.active_counter += 1 + try: + result = _next(root, info, *args, **kwargs) + if inspect.isawaitable(result): + return await result + return result + finally: + self.active_counter -= 1 def resolve_called(self): pass @@ -88,23 +110,31 @@ def lifecycle_called(self, event, phase): async def on_operation(self): self.lifecycle_called("operation", "before") + self.active_counter += 1 yield self.lifecycle_called("operation", "after") + self.active_counter -= 1 async def on_validate(self): self.lifecycle_called("validate", "before") + self.active_counter += 1 yield self.lifecycle_called("validate", "after") + self.active_counter -= 1 async def on_parse(self): self.lifecycle_called("parse", "before") + self.active_counter += 1 yield self.lifecycle_called("parse", "after") + self.active_counter -= 1 async def on_execute(self): self.lifecycle_called("execute", "before") + self.active_counter += 1 yield self.lifecycle_called("execute", "after") + self.active_counter -= 1 def _read_file(text_file: Upload) -> str: diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 7ffa30d294..44daa56ff4 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -957,6 +957,11 @@ async def test_extensions(ws: WebSocketClient): resolve_called = Mock() lifecycle_called = Mock() + # we must make sure that earlier requests and drained before we start + # so that their execution events don't interfere with our events + while MyAsyncExtension.active_counter > 0: + await asyncio.sleep(0.01) + with patch.object(MyAsyncExtension, "resolve_called", resolve_called): with patch.object(MyAsyncExtension, "lifecycle_called", lifecycle_called): await ws.send_json( From b622dac0257b24fa967f265faa0c32f7564f40f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 31 May 2023 00:43:40 +0000 Subject: [PATCH 12/22] update docs for `resolve` extension hook --- docs/guides/custom-extensions.md | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/docs/guides/custom-extensions.md b/docs/guides/custom-extensions.md index 22cc8c1f97..acd5746cba 100644 --- a/docs/guides/custom-extensions.md +++ b/docs/guides/custom-extensions.md @@ -36,8 +36,6 @@ resolvers. If you need to wrap only certain field resolvers with additional logic, please check out [field extensions](field-extensions.md). -Note that `resolve` can also be implemented asynchronously. - ```python from strawberry.extensions import SchemaExtension @@ -47,6 +45,21 @@ class MyExtension(SchemaExtension): return _next(root, info, *args, **kwargs) ``` +Note that `resolve` can also be implemented asynchronously, in which +case the result from `_next` must be optionally awaited: + +```python +from inspect import isawaitable +from strawberry.types import Info +from strawberry.extensions import SchemaExtension + + +class MyExtension(SchemaExtension): + async def resolve(self, _next, root, info: Info, *args, **kwargs): + result = _next(root, info, *args, **kwargs) + return await result if isawaitable(result) else result +``` + ### Get results `get_results` allows to return a dictionary of data or alternatively an From 9a8a6339701b102d2a6b3e1532b499facb2d7827 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Wed, 31 May 2023 22:13:32 +0000 Subject: [PATCH 13/22] fix channels tests to work with sync channels --- tests/http/clients/channels.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 9a76867ce6..22148384c6 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -18,7 +18,8 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.http.typevars import Context, RootValue -from tests.views.schema import Query, async_schema as schema +from tests.views.schema import Query, async_schema +from tests.views.schema import schema as sync_schema from ..context import get_context from .base import ( @@ -143,12 +144,12 @@ def __init__( result_override: ResultOverrideFunction = None, ): self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi( - schema=schema, + schema=async_schema, keep_alive=False, ) self.http_app = DebuggableGraphQLHTTPConsumer.as_asgi( - schema=schema, + schema=async_schema, graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, @@ -157,7 +158,7 @@ def __init__( def create_app(self, **kwargs: Any) -> None: self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi( - schema=schema, **kwargs + schema=async_schema, **kwargs ) async def _graphql_request( @@ -264,7 +265,7 @@ def __init__( result_override: ResultOverrideFunction = None, ): self.http_app = DebuggableSyncGraphQLHTTPConsumer.as_asgi( - schema=schema, + schema=sync_schema, graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, From 7aa200916259b33ee2c8f1a1927db55382d267c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 5 Jun 2023 12:22:48 +0000 Subject: [PATCH 14/22] schema.subscribe() now raises a SubscribeSingleResult when not returning a subscription --- strawberry/schema/__init__.py | 4 +- strawberry/schema/base.py | 11 +++- strawberry/schema/execute.py | 39 ++++++------ strawberry/schema/schema.py | 2 +- .../graphql_transport_ws/handlers.py | 61 +++++++++++-------- .../protocols/graphql_ws/handlers.py | 36 ++++++----- tests/schema/test_permission.py | 9 ++- tests/schema/test_subscription.py | 9 +-- 8 files changed, 97 insertions(+), 74 deletions(-) diff --git a/strawberry/schema/__init__.py b/strawberry/schema/__init__.py index 5cf633ac21..cbc3bf134a 100644 --- a/strawberry/schema/__init__.py +++ b/strawberry/schema/__init__.py @@ -1,4 +1,4 @@ -from .base import BaseSchema +from .base import BaseSchema, SubscribeSingleResult from .schema import Schema -__all__ = ["BaseSchema", "Schema"] +__all__ = ["BaseSchema", "Schema", "SubscribeSingleResult"] diff --git a/strawberry/schema/base.py b/strawberry/schema/base.py index cb6aafefbc..25a039a609 100644 --- a/strawberry/schema/base.py +++ b/strawberry/schema/base.py @@ -32,6 +32,15 @@ from .config import StrawberryConfig +class SubscribeSingleResult(RuntimeError): + """Raised when Schema.subscribe() returns a single execution result, instead of a + subscription generator, typically as a result of validation errors. + """ + + def __init__(self, value: ExecutionResult) -> None: + self.value = value + + class BaseSchema(Protocol): config: StrawberryConfig schema_converter: GraphQLCoreConverter @@ -72,7 +81,7 @@ def subscribe( context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, - ) -> AsyncGenerator[Any, None]: + ) -> AsyncGenerator[ExecutionResult, None]: raise NotImplementedError @abstractmethod diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index e04bdf02f5..c1811b642c 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -29,6 +29,7 @@ from strawberry.extensions.runner import SchemaExtensionsRunner from strawberry.types import ExecutionResult +from .base import SubscribeSingleResult from .exceptions import InvalidOperationTypeError if TYPE_CHECKING: @@ -283,7 +284,7 @@ async def subscribe( extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]], execution_context: ExecutionContext, process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], -) -> AsyncGenerator[Tuple[bool, ExecutionResult], None]: +) -> AsyncGenerator[ExecutionResult, None]: """ The graphql-core subscribe function returns either an ExecutionResult or an AsyncGenerator[ExecutionResult, None]. The former is returned in case of an error @@ -319,32 +320,34 @@ async def subscribe( except GraphQLError as error: execution_context.errors = [error] process_errors([error], execution_context) - yield False, ExecutionResult( - data=None, - errors=[error], - extensions=await extensions_runner.get_extensions_results(), + raise SubscribeSingleResult( + ExecutionResult( + data=None, + errors=[error], + extensions=await extensions_runner.get_extensions_results(), + ) ) - # the generator is usually closed here, so the following is not - # reached - return # pragma: no cover except Exception as error: # pragma: no cover error = GraphQLError(str(error), original_error=error) execution_context.errors = [error] process_errors([error], execution_context) - yield False, ExecutionResult( - data=None, - errors=[error], - extensions=await extensions_runner.get_extensions_results(), + + raise SubscribeSingleResult( + ExecutionResult( + data=None, + errors=[error], + extensions=await extensions_runner.get_extensions_results(), + ) ) - return # pragma: no cover async with extensions_runner.validation(): _run_validation(execution_context) if execution_context.errors: process_errors(execution_context.errors, execution_context) - yield False, ExecutionResult(data=None, errors=execution_context.errors) - return # pragma: no cover + raise SubscribeSingleResult( + ExecutionResult(data=None, errors=execution_context.errors) + ) async def process_result(result: GraphQLExecutionResult): execution_context.result = result @@ -397,13 +400,13 @@ async def process_result(result: GraphQLExecutionResult): ) if isinstance(result, GraphQLExecutionResult): - yield False, await process_result(result) - return + raise SubscribeSingleResult(await process_result(result)) aiterator = result.__aiter__() try: async for result in aiterator: - yield True, await process_result(result) + yield await process_result(result) finally: + # grapql-core's iterator may or may not have an aclose() method if hasattr(aiterator, "aclose"): await aiterator.aclose() diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index 45edd2f63b..1d5009f829 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -307,7 +307,7 @@ async def subscribe( context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, - ) -> AsyncGenerator[tuple[bool, ExecutionResult], None]: + ) -> AsyncGenerator[ExecutionResult, None]: execution_context = ExecutionContext( query=query, schema=self, diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index fbd71e60d8..6b1f9fe20b 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -17,6 +17,7 @@ from graphql import GraphQLError, GraphQLSyntaxError, parse +from strawberry.schema import SubscribeSingleResult from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( CompleteMessage, ConnectionAckMessage, @@ -255,13 +256,17 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: else: # create AsyncGenerator returning a single result async def get_result_source() -> AsyncIterator[ExecutionResult]: - yield False, await self.schema.execute( - query=message.payload.query, - variable_values=message.payload.variables, - context_value=context, - root_value=root_value, - operation_name=message.payload.operationName, + raise SubscribeSingleResult( + await self.schema.execute( + query=message.payload.query, + variable_values=message.payload.variables, + context_value=context, + root_value=root_value, + operation_name=message.payload.operationName, + ) ) + # need a yield here to turn this into an async generator + yield None # pragma: no cover result_source = get_result_source() @@ -306,22 +311,16 @@ async def handle_async_results( operation: Operation, ) -> None: try: - async for multiple, result in result_source: - if result.errors and not multiple: - error_payload = [err.formatted for err in result.errors] - error_message = ErrorMessage(id=operation.id, payload=error_payload) - await operation.send_message(error_message) - return - else: - next_payload = {"data": result.data} - if result.errors: - next_payload["errors"] = [ - err.formatted for err in result.errors - ] - if result.extensions: - next_payload["extensions"] = result.extensions - next_message = NextMessage(id=operation.id, payload=next_payload) - await operation.send_message(next_message) + try: + async for result in result_source: + await self.send_result(operation, result, False) + except SubscribeSingleResult as single_result: + await self.send_result(operation, single_result.value, True) + finally: + await result_source.aclose() + except asyncio.CancelledError: + # CancelledErrors are expected during task cleanup. + raise except Exception as error: # GraphQLErrors are handled by graphql-core and included in the # ExecutionResult @@ -331,8 +330,22 @@ async def handle_async_results( await operation.send_message(error_message) self.schema.process_errors([error]) return - finally: - await result_source.aclose() + + async def send_result( + self, operation: Operation, result: ExecutionResult, single: bool + ) -> None: + if result.errors and single: + error_payload = [err.formatted for err in result.errors] + error_message = ErrorMessage(id=operation.id, payload=error_payload) + await operation.send_message(error_message) + else: + next_payload: Dict[str, Any] = {"data": result.data} + if result.errors: + next_payload["errors"] = [err.formatted for err in result.errors] + if result.extensions: + next_payload["extensions"] = result.extensions + next_message = NextMessage(id=operation.id, payload=next_payload) + await operation.send_message(next_message) def forget_id(self, id: str) -> None: # de-register the operation id making it immediately available diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 5b94562157..c36a8ba911 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -7,6 +7,7 @@ from graphql import GraphQLError +from strawberry.schema import SubscribeSingleResult from strawberry.subscriptions.protocols.graphql_ws import ( GQL_COMPLETE, GQL_CONNECTION_ACK, @@ -150,23 +151,22 @@ async def handle_async_results( operation_id: str, ) -> None: try: - async for success, result in result_source: - if not success: - assert result.errors - error_payload = result.errors[0].formatted - await self.send_message(GQL_ERROR, operation_id, error_payload) - self.schema.process_errors(result.errors) - return - payload = {"data": result.data} - if result.errors: - payload["errors"] = [err.formatted for err in result.errors] - if result.extensions: - payload["extensions"] = result.extensions - await self.send_message(GQL_DATA, operation_id, payload) - # log errors after send_message to prevent potential - # slowdown of sending result - if result.errors: - self.schema.process_errors(result.errors) + try: + async for result in result_source: + payload = {"data": result.data} + if result.errors: + payload["errors"] = [err.formatted for err in result.errors] + if result.extensions: + payload["extensions"] = result.extensions + await self.send_message(GQL_DATA, operation_id, payload) + except SubscribeSingleResult as single_result: + result = single_result.value + assert result.errors + error_payload = result.errors[0].formatted + await self.send_message(GQL_ERROR, operation_id, error_payload) + return + finally: + await result_source.aclose() except asyncio.CancelledError: # CancelledErrors are expected during task cleanup. pass @@ -180,8 +180,6 @@ async def handle_async_results( {"data": None, "errors": [error.formatted]}, ) self.schema.process_errors([error]) - finally: - await result_source.aclose() await self.send_message(GQL_COMPLETE, operation_id, None) diff --git a/tests/schema/test_permission.py b/tests/schema/test_permission.py index 9213ce9be9..fe2fdcde8a 100644 --- a/tests/schema/test_permission.py +++ b/tests/schema/test_permission.py @@ -12,6 +12,7 @@ ) from strawberry.permission import BasePermission, PermissionExtension from strawberry.printer import print_schema +from strawberry.schema import SubscribeSingleResult def test_raises_graphql_error_when_permission_method_is_missing(): @@ -80,9 +81,11 @@ async def user( query = "subscription { user }" - async for ok, result in schema.subscribe(query): - assert not ok - assert result.errors[0].message == "You are not authorized" + with pytest.raises(SubscribeSingleResult) as err: + async for result in schema.subscribe(query): + pass + result = err.value.value + assert result.errors[0].message == "You are not authorized" @pytest.mark.asyncio diff --git a/tests/schema/test_subscription.py b/tests/schema/test_subscription.py index e1adc9dc8c..1d816d4fdd 100644 --- a/tests/schema/test_subscription.py +++ b/tests/schema/test_subscription.py @@ -33,8 +33,7 @@ async def example(self) -> AsyncGenerator[str, None]: query = "subscription { example }" - async for ok, result in schema.subscribe(query): - assert ok + async for result in schema.subscribe(query): assert not result.errors assert result.data["example"] == "Hi" @@ -88,8 +87,7 @@ async def example(self, name: str) -> AsyncGenerator[str, None]: query = 'subscription { example(name: "Nina") }' - async for ok, result in schema.subscribe(query): - assert ok + async for result in schema.subscribe(query): assert not result.errors assert result.data["example"] == "Hi Nina" @@ -130,8 +128,7 @@ class Subscription: query = "subscription { example }" - async for ok, result in schema.subscribe(query): - assert ok + async for result in schema.subscribe(query): assert not result.errors assert result.data["example"] == "Hi" From c5eec36ced1785be6daeabfe0f3292fa29d383fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 5 Jun 2023 13:11:15 +0000 Subject: [PATCH 15/22] Move closure function out into a separate function --- strawberry/schema/execute.py | 54 ++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index c1811b642c..7f2da62cd0 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -349,25 +349,6 @@ async def subscribe( ExecutionResult(data=None, errors=execution_context.errors) ) - async def process_result(result: GraphQLExecutionResult): - execution_context.result = result - # Also set errors on the execution_context so that it's easier - # to access in extensions - if result.errors: - execution_context.errors = result.errors - - # Run the `Schema.process_errors` function here before - # extensions have a chance to modify them (see the MaskErrors - # extension). That way we can log the original errors but - # only return a sanitised version to the client. - process_errors(result.errors, execution_context) - - return ExecutionResult( - data=execution_context.result.data, - errors=execution_context.result.errors, - extensions=await extensions_runner.get_extensions_results(), - ) - async with extensions_runner.executing(): # currently original_subscribe is an async function. A future release # of graphql-core will make it optionally awaitable @@ -400,13 +381,44 @@ async def process_result(result: GraphQLExecutionResult): ) if isinstance(result, GraphQLExecutionResult): - raise SubscribeSingleResult(await process_result(result)) + raise SubscribeSingleResult( + await process_subscribe_result( + execution_context, process_errors, extensions_runner, result + ) + ) aiterator = result.__aiter__() try: async for result in aiterator: - yield await process_result(result) + yield await process_subscribe_result( + execution_context, process_errors, extensions_runner, result + ) finally: # grapql-core's iterator may or may not have an aclose() method if hasattr(aiterator, "aclose"): await aiterator.aclose() + + +async def process_subscribe_result( + execution_context: ExecutionContext, + process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], + extensions_runner: SchemaExtensionsRunner, + result: GraphQLExecutionResult, +) -> ExecutionResult: + execution_context.result = result + # Also set errors on the execution_context so that it's easier + # to access in extensions + if result.errors: + execution_context.errors = result.errors + + # Run the `Schema.process_errors` function here before + # extensions have a chance to modify them (see the MaskErrors + # extension). That way we can log the original errors but + # only return a sanitised version to the client. + process_errors(result.errors, execution_context) + + return ExecutionResult( + data=execution_context.result.data, + errors=execution_context.result.errors, + extensions=await extensions_runner.get_extensions_results(), + ) From edd6d2e4fbd1ca9a00a34ea091c06687fcb725db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 5 Jun 2023 16:13:35 +0000 Subject: [PATCH 16/22] ruff --- tests/schema/test_permission.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/schema/test_permission.py b/tests/schema/test_permission.py index fe2fdcde8a..67f5bac64c 100644 --- a/tests/schema/test_permission.py +++ b/tests/schema/test_permission.py @@ -13,6 +13,7 @@ from strawberry.permission import BasePermission, PermissionExtension from strawberry.printer import print_schema from strawberry.schema import SubscribeSingleResult +from strawberry.types import Info def test_raises_graphql_error_when_permission_method_is_missing(): From 64c33ca83428fbee576ef723ed8d579e28639589 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 9 Jul 2023 10:49:08 +0000 Subject: [PATCH 17/22] tests for process_errors() (graphql_ws) --- tests/schema/test_permission.py | 2 +- tests/websockets/test_graphql_ws.py | 74 ++++++++++++++++------------- 2 files changed, 42 insertions(+), 34 deletions(-) diff --git a/tests/schema/test_permission.py b/tests/schema/test_permission.py index 67f5bac64c..ed3dcec9fa 100644 --- a/tests/schema/test_permission.py +++ b/tests/schema/test_permission.py @@ -83,7 +83,7 @@ async def user( query = "subscription { user }" with pytest.raises(SubscribeSingleResult) as err: - async for result in schema.subscribe(query): + async for _ in schema.subscribe(query): pass result = err.value.value assert result.errors[0].message == "You are not authorized" diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index a8c8118c40..befe7ae797 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -2,6 +2,7 @@ import asyncio from typing import TYPE_CHECKING, AsyncGenerator +from unittest.mock import Mock, patch import pytest import pytest_asyncio @@ -19,6 +20,7 @@ GQL_START, GQL_STOP, ) +from tests.views.schema import Schema if TYPE_CHECKING: from ..http.clients.aiohttp import HttpClient, WebSocketClient @@ -198,25 +200,28 @@ async def test_subscription_cancellation(ws: WebSocketClient): async def test_subscription_errors(ws: WebSocketClient): - await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": {"query": 'subscription { error(message: "TEST ERR") }'}, - } - ) + process_errors = Mock() + with patch.object(Schema, "process_errors", process_errors): + await ws.send_json( + { + "type": GQL_START, + "id": "demo", + "payload": {"query": 'subscription { error(message: "TEST ERR") }'}, + } + ) - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] is None - assert len(response["payload"]["errors"]) == 1 - assert response["payload"]["errors"][0]["path"] == ["error"] - assert response["payload"]["errors"][0]["message"] == "TEST ERR" + response = await ws.receive_json() + assert response["type"] == GQL_DATA + assert response["id"] == "demo" + assert response["payload"]["data"] is None + assert len(response["payload"]["errors"]) == 1 + assert response["payload"]["errors"][0]["path"] == ["error"] + assert response["payload"]["errors"][0]["message"] == "TEST ERR" + process_errors.assert_called_once() - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + response = await ws.receive_json() + assert response["type"] == GQL_COMPLETE + assert response["id"] == "demo" async def test_subscription_exceptions(ws: WebSocketClient): @@ -241,23 +246,26 @@ async def test_subscription_exceptions(ws: WebSocketClient): async def test_subscription_field_error(ws: WebSocketClient): - await ws.send_json( - { - "type": GQL_START, - "id": "invalid-field", - "payload": {"query": "subscription { notASubscriptionField }"}, - } - ) + process_errors = Mock() + with patch.object(Schema, "process_errors", process_errors): + await ws.send_json( + { + "type": GQL_START, + "id": "invalid-field", + "payload": {"query": "subscription { notASubscriptionField }"}, + } + ) - response = await ws.receive_json() - assert response["type"] == GQL_ERROR - assert response["id"] == "invalid-field" - assert response["payload"] == { - "locations": [{"line": 1, "column": 16}], - "message": ( - "Cannot query field 'notASubscriptionField' on type 'Subscription'." - ), - } + response = await ws.receive_json() + assert response["type"] == GQL_ERROR + assert response["id"] == "invalid-field" + assert response["payload"] == { + "locations": [{"line": 1, "column": 16}], + "message": ( + "Cannot query field 'notASubscriptionField' on type 'Subscription'." + ), + } + process_errors.assert_called_once() async def test_subscription_syntax_error(ws: WebSocketClient): From a0336243d6ba158b798c7601005f0157a4d86af3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 25 Aug 2023 15:04:55 +0000 Subject: [PATCH 18/22] update benchmark test --- tests/benchmarks/test_subscriptions.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/benchmarks/test_subscriptions.py b/tests/benchmarks/test_subscriptions.py index 075c993904..c85a398f54 100644 --- a/tests/benchmarks/test_subscriptions.py +++ b/tests/benchmarks/test_subscriptions.py @@ -16,11 +16,8 @@ def test_subscription(benchmark: BenchmarkFixture): async def _run(): for _ in range(100): - iterator = await schema.subscribe(s) - - value = await iterator.__anext__() # type: ignore[union-attr] - - assert value.data is not None - assert value.data["something"] == "Hello World!" + async for value in schema.subscribe(s): + assert value.data is not None + assert value.data["something"] == "Hello World!" benchmark(lambda: asyncio.run(_run())) From 699a910f350b2b01df5338043fda45b65005f0bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 25 Aug 2023 16:08:04 +0000 Subject: [PATCH 19/22] Update newly added tests --- docs/operations/testing.md | 3 +-- tests/schema/test_subscription.py | 9 +++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/docs/operations/testing.md b/docs/operations/testing.md index 4d90e9f71e..056cd8cf0f 100644 --- a/docs/operations/testing.md +++ b/docs/operations/testing.md @@ -137,8 +137,7 @@ async def test_subscription(): """ index = 0 - async for ok, result in schema.subscribe(query): - assert ok + async for result in schema.subscribe(query): assert not result.errors assert result.data == {"count": index} diff --git a/tests/schema/test_subscription.py b/tests/schema/test_subscription.py index 1d816d4fdd..bb82236633 100644 --- a/tests/schema/test_subscription.py +++ b/tests/schema/test_subscription.py @@ -159,8 +159,7 @@ async def example_with_union(self) -> AsyncGenerator[Union[A, B], None]: query = "subscription { exampleWithUnion { ... on A { a } } }" - async for ok, result in schema.subscribe(query): - assert ok + async for result in schema.subscribe(query): assert not result.errors assert result.data["exampleWithUnion"]["a"] == "Hi" @@ -197,8 +196,7 @@ async def example_with_annotated_union( query = "subscription { exampleWithAnnotatedUnion { ... on C { c } } }" - async for ok, result in schema.subscribe(query): - assert ok + async for result in schema.subscribe(query): assert not result.errors assert result.data["exampleWithAnnotatedUnion"]["c"] == "Hi" @@ -223,7 +221,6 @@ async def example( query = "subscription { example }" - async for ok, result in schema.subscribe(query): - assert ok + async for result in schema.subscribe(query): assert not result.errors assert result.data["example"] == "Hi" From 4fdbac40ca7bc085b10edd99b156784ea1cbac6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 3 May 2024 13:53:05 +0000 Subject: [PATCH 20/22] Fix typing --- .../subscriptions/protocols/graphql_transport_ws/handlers.py | 3 +-- tests/views/schema.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 6b1f9fe20b..c02b4d0690 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -8,7 +8,6 @@ TYPE_CHECKING, Any, AsyncGenerator, - AsyncIterator, Callable, Dict, List, @@ -255,7 +254,7 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: ) else: # create AsyncGenerator returning a single result - async def get_result_source() -> AsyncIterator[ExecutionResult]: + async def get_result_source() -> AsyncGenerator[ExecutionResult, None]: raise SubscribeSingleResult( await self.schema.execute( query=message.payload.query, diff --git a/tests/views/schema.py b/tests/views/schema.py index ed73ca40b8..559f343d5e 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -41,7 +41,7 @@ class MyExtension(SchemaExtension): def get_results(self) -> Dict[str, str]: return {"example": "example"} - def resolve(self, _next, root, info: Info, *args: Any, **kwargs: Any): + def resolve(self, _next, root, info: strawberry.Info, *args: Any, **kwargs: Any): self.active_counter += 1 try: self.resolve_called() @@ -91,7 +91,7 @@ class MyAsyncExtension(SchemaExtension): def get_results(self) -> Dict[str, str]: return {"example": "example"} - async def resolve(self, _next, root, info: Info, *args: Any, **kwargs: Any): + async def resolve(self, _next, root, info: strawberry.Info, *args: Any, **kwargs: Any): self.resolve_called() self.active_counter += 1 try: From 5130fba081dc3c487396dbdc7a7aa96ced075064 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 3 May 2024 15:22:21 +0000 Subject: [PATCH 21/22] fix new tests --- tests/schema/test_subscription.py | 8 +++----- tests/websockets/test_graphql_transport_ws.py | 7 ++----- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/schema/test_subscription.py b/tests/schema/test_subscription.py index bb82236633..f32ac433ed 100644 --- a/tests/schema/test_subscription.py +++ b/tests/schema/test_subscription.py @@ -64,11 +64,9 @@ async def example(self) -> AsyncGenerator[str, None]: query = "subscription { example }" - sub = await schema.subscribe(query) - result = await sub.__anext__() - - assert not result.errors - assert result.data["example"] == "Hi" + async for result in schema.subscribe(query): + assert not result.errors + assert result.data["example"] == "Hi" @pytest.mark.asyncio diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 44daa56ff4..6daf7eb4db 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -250,11 +250,8 @@ async def test_can_send_payload_with_additional_things(ws_raw: WebSocketClient): data = await ws.receive(timeout=2) - assert json.loads(data.data) == { - "type": "next", - "id": "1", - "payload": {"data": {"echo": "Hi"}}, - } + result = json.loads(data.data) + assert_next(result, "1", {"echo": "Hi"}) async def test_server_sent_ping(ws: WebSocketClient): From ef6688d8f601e69ad7af051393c3187b95e24573 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 5 May 2024 17:22:08 +0000 Subject: [PATCH 22/22] formatting --- docs/guides/custom-extensions.md | 4 ++-- tests/schema/test_permission.py | 1 - tests/views/schema.py | 4 +++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/guides/custom-extensions.md b/docs/guides/custom-extensions.md index acd5746cba..1dd000409d 100644 --- a/docs/guides/custom-extensions.md +++ b/docs/guides/custom-extensions.md @@ -45,8 +45,8 @@ class MyExtension(SchemaExtension): return _next(root, info, *args, **kwargs) ``` -Note that `resolve` can also be implemented asynchronously, in which -case the result from `_next` must be optionally awaited: +Note that `resolve` can also be implemented asynchronously, in which case the +result from `_next` must be optionally awaited: ```python from inspect import isawaitable diff --git a/tests/schema/test_permission.py b/tests/schema/test_permission.py index ed3dcec9fa..52154b2331 100644 --- a/tests/schema/test_permission.py +++ b/tests/schema/test_permission.py @@ -13,7 +13,6 @@ from strawberry.permission import BasePermission, PermissionExtension from strawberry.printer import print_schema from strawberry.schema import SubscribeSingleResult -from strawberry.types import Info def test_raises_graphql_error_when_permission_method_is_missing(): diff --git a/tests/views/schema.py b/tests/views/schema.py index 559f343d5e..546834d052 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -91,7 +91,9 @@ class MyAsyncExtension(SchemaExtension): def get_results(self) -> Dict[str, str]: return {"example": "example"} - async def resolve(self, _next, root, info: strawberry.Info, *args: Any, **kwargs: Any): + async def resolve( + self, _next, root, info: strawberry.Info, *args: Any, **kwargs: Any + ): self.resolve_called() self.active_counter += 1 try: