From 7bbee00579e9260f67b268843feecbcb8a9e608b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 23 Apr 2023 20:17:12 +0000 Subject: [PATCH] schema.subscribe returns a Union[ExecutionResult, AsyncGen] Update graphql_ws handler fix tests now that field verification is done by Strawberry Return "extensions" as part of ExecutionResult if present. Add some error/validation test cases Add release.md Fix tests, remove extra "extensions" payload we don't want to inspect Add extensive tests for extension hook execution while running subscriptions update docs for `resolve` extension hook --- RELEASE.md | 3 + docs/guides/custom-extensions.md | 17 +- strawberry/schema/base.py | 14 +- strawberry/schema/execute.py | 326 ++++++++++++++---- strawberry/schema/schema.py | 29 +- .../graphql_transport_ws/handlers.py | 6 +- .../protocols/graphql_ws/handlers.py | 18 +- tests/channels/test_layers.py | 3 + tests/http/clients/aiohttp.py | 3 +- tests/http/clients/asgi.py | 3 +- tests/http/clients/channels.py | 11 +- tests/http/clients/fastapi.py | 3 +- tests/http/clients/starlite.py | 3 +- tests/views/schema.py | 131 +++++++ tests/websockets/test_graphql_transport_ws.py | 230 +++++++----- tests/websockets/test_graphql_ws.py | 27 +- 16 files changed, 634 insertions(+), 193 deletions(-) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..6ef7811880 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: minor + +Subscriptions now support Schema Extensions. diff --git a/docs/guides/custom-extensions.md b/docs/guides/custom-extensions.md index 453ca449eb..ff614045d1 100644 --- a/docs/guides/custom-extensions.md +++ b/docs/guides/custom-extensions.md @@ -36,8 +36,6 @@ resolvers. If you need to wrap only certain field resolvers with additional logic, please check out [field extensions](field-extensions.md). -Note that `resolve` can also be implemented asynchronously. - ```python from strawberry.types import Info from strawberry.extensions import SchemaExtension @@ -48,6 +46,21 @@ class MyExtension(SchemaExtension): return _next(root, info, *args, **kwargs) ``` +Note that `resolve` can also be implemented asynchronously, in which +case the result from `_next` must be optionally awaited: + +```python +from inspect import isawaitable +from strawberry.types import Info +from strawberry.extensions import SchemaExtension + + +class MyExtension(SchemaExtension): + async def resolve(self, _next, root, info: Info, *args, **kwargs): + result = _next(root, info, *args, **kwargs) + return await result if isawaitable(result) else result +``` + ### Get results `get_results` allows to return a dictionary of data or alternatively an diff --git a/strawberry/schema/base.py b/strawberry/schema/base.py index 74c0de4356..3553cf9da5 100644 --- a/strawberry/schema/base.py +++ b/strawberry/schema/base.py @@ -2,7 +2,17 @@ from abc import abstractmethod from functools import lru_cache -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Dict, + Iterable, + List, + Optional, + Type, + Union, +) from typing_extensions import Protocol from strawberry.utils.logging import StrawberryLogger @@ -62,7 +72,7 @@ async def subscribe( context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, - ) -> Any: + ) -> Union[ExecutionResult, AsyncGenerator[ExecutionResult, None]]: raise NotImplementedError @abstractmethod diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index 7a35560410..acae82887a 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -4,6 +4,8 @@ from inspect import isawaitable from typing import ( TYPE_CHECKING, + AsyncGenerator, + AsyncIterable, Awaitable, Callable, Iterable, @@ -16,8 +18,10 @@ cast, ) +from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLError, parse from graphql import execute as original_execute +from graphql import subscribe as original_subscribe from graphql.validation import validate from strawberry.exceptions import MissingQueryError @@ -30,7 +34,6 @@ from typing_extensions import Unpack from graphql import ExecutionContext as GraphQLExecutionContext - from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLSchema from graphql.language import DocumentNode from graphql.validation import ASTValidationRule @@ -57,6 +60,100 @@ def validate_document( ) +async def _parse_and_validate_async( + execution_context: ExecutionContext, + process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], + extensions_runner: SchemaExtensionsRunner, + allowed_operation_types: Optional[Iterable[OperationType]] = None, +): + assert execution_context.query + async with extensions_runner.parsing(): + try: + if not execution_context.graphql_document: + execution_context.graphql_document = parse_document( + execution_context.query, **execution_context.parse_options + ) + + except GraphQLError as error: + execution_context.errors = [error] + process_errors([error], execution_context) + return ExecutionResult( + data=None, + errors=[error], + extensions=await extensions_runner.get_extensions_results(), + ) + + except Exception as error: # pragma: no cover + error = GraphQLError(str(error), original_error=error) + + execution_context.errors = [error] + process_errors([error], execution_context) + + return ExecutionResult( + data=None, + errors=[error], + extensions=await extensions_runner.get_extensions_results(), + ) + + if ( + allowed_operation_types is not None + and execution_context.operation_type not in allowed_operation_types + ): + raise InvalidOperationTypeError(execution_context.operation_type) + + async with extensions_runner.validation(): + _run_validation(execution_context) + if execution_context.errors: + process_errors(execution_context.errors, execution_context) + return ExecutionResult(data=None, errors=execution_context.errors) + + +def _parse_and_validate_sync( + execution_context: ExecutionContext, + process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], + extensions_runner: SchemaExtensionsRunner, + allowed_operation_types: Iterable[OperationType], +) -> Optional[ExecutionResult]: + assert execution_context.query + with extensions_runner.parsing(): + try: + if not execution_context.graphql_document: + execution_context.graphql_document = parse_document( + execution_context.query, **execution_context.parse_options + ) + + except GraphQLError as error: + execution_context.errors = [error] + process_errors([error], execution_context) + return ExecutionResult( + data=None, + errors=[error], + extensions=extensions_runner.get_extensions_results_sync(), + ) + + except Exception as error: # pragma: no cover + error = GraphQLError(str(error), original_error=error) + + execution_context.errors = [error] + process_errors([error], execution_context) + + return ExecutionResult( + data=None, + errors=[error], + extensions=extensions_runner.get_extensions_results_sync(), + ) + + if execution_context.operation_type not in allowed_operation_types: + raise InvalidOperationTypeError(execution_context.operation_type) + + with extensions_runner.validation(): + _run_validation(execution_context) + if execution_context.errors: + process_errors(execution_context.errors, execution_context) + return ExecutionResult(data=None, errors=execution_context.errors) + return None + + def _run_validation(execution_context: ExecutionContext) -> None: # Check if there are any validation rules or if validation has # already been run by an extension @@ -89,45 +186,18 @@ async def execute( if not execution_context.query: raise MissingQueryError() - async with extensions_runner.parsing(): - try: - if not execution_context.graphql_document: - execution_context.graphql_document = parse_document( - execution_context.query, **execution_context.parse_options - ) - - except GraphQLError as error: - execution_context.errors = [error] - process_errors([error], execution_context) - return ExecutionResult( - data=None, - errors=[error], - extensions=await extensions_runner.get_extensions_results(), - ) - - except Exception as error: # pragma: no cover - error = GraphQLError(str(error), original_error=error) - - execution_context.errors = [error] - process_errors([error], execution_context) - - return ExecutionResult( - data=None, - errors=[error], - extensions=await extensions_runner.get_extensions_results(), - ) - - if execution_context.operation_type not in allowed_operation_types: - raise InvalidOperationTypeError(execution_context.operation_type) - - async with extensions_runner.validation(): - _run_validation(execution_context) - if execution_context.errors: - process_errors(execution_context.errors, execution_context) - return ExecutionResult(data=None, errors=execution_context.errors) + error_result = await _parse_and_validate_async( + execution_context, + process_errors, + extensions_runner, + allowed_operation_types, + ) + if error_result is not None: + return error_result async with extensions_runner.executing(): if not execution_context.result: + assert execution_context.graphql_document result = original_execute( schema, execution_context.graphql_document, @@ -182,44 +252,18 @@ def execute_sync( if not execution_context.query: raise MissingQueryError() - with extensions_runner.parsing(): - try: - if not execution_context.graphql_document: - execution_context.graphql_document = parse_document( - execution_context.query, **execution_context.parse_options - ) - - except GraphQLError as error: - execution_context.errors = [error] - process_errors([error], execution_context) - return ExecutionResult( - data=None, - errors=[error], - extensions=extensions_runner.get_extensions_results_sync(), - ) - - except Exception as error: # pragma: no cover - error = GraphQLError(str(error), original_error=error) - - execution_context.errors = [error] - process_errors([error], execution_context) - return ExecutionResult( - data=None, - errors=[error], - extensions=extensions_runner.get_extensions_results_sync(), - ) - - if execution_context.operation_type not in allowed_operation_types: - raise InvalidOperationTypeError(execution_context.operation_type) - - with extensions_runner.validation(): - _run_validation(execution_context) - if execution_context.errors: - process_errors(execution_context.errors, execution_context) - return ExecutionResult(data=None, errors=execution_context.errors) + error_result = _parse_and_validate_sync( + execution_context, + process_errors, + extensions_runner, + allowed_operation_types, + ) + if error_result is not None: + return error_result with extensions_runner.executing(): if not execution_context.result: + assert execution_context.graphql_document result = original_execute( schema, execution_context.graphql_document, @@ -256,3 +300,141 @@ def execute_sync( errors=execution_context.result.errors, extensions=extensions_runner.get_extensions_results_sync(), ) + + +async def subscribe( + schema: GraphQLSchema, + *, + extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]], + execution_context: ExecutionContext, + process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], +) -> Union[ExecutionResult, AsyncGenerator[ExecutionResult, None]]: + # The graphql-core subscribe function returns either an ExecutionResult or an + # AsyncGenerator[ExecutionResult, None]. The former is returned in case of an error + # during parsing or validation. + # We repeat that pattern here, but to maintain the context of the extensions + # context manager, we must delegate to an inner async generator. The inner + # generator yields an initial result, either a None, or an ExecutionResult, + # to indicate the two different cases. + + asyncgen = _subscribe( + schema, + extensions=extensions, + execution_context=execution_context, + process_errors=process_errors, + ) + # start the generator + first = await asyncgen.__anext__() + if first is not None: + # Single result. Close the generator to exit any context managers + await asyncgen.aclose() + return first + else: + # return the started generator. Cast away the Optional[] type + return cast(AsyncGenerator[ExecutionResult, None], asyncgen) + + +async def _subscribe( + schema: GraphQLSchema, + *, + extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]], + execution_context: ExecutionContext, + process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], +) -> AsyncGenerator[Optional[ExecutionResult], None]: + # This Async generator first yields either a single ExecutionResult or None. + # If None is yielded, then the subscription has failed and the generator should + # be closed. + # Otherwise, if None is yielded, the subscription can continue. + + extensions_runner = SchemaExtensionsRunner( + execution_context=execution_context, + extensions=list(extensions), + ) + + # unlike execute(), the entire operation, including the results hooks, + # is run within the operation() hook. + async with extensions_runner.operation(): + # Note: In graphql-core the schema would be validated here but in + # Strawberry we are validating it at initialisation time instead + + error_result = await _parse_and_validate_async( + execution_context, process_errors, extensions_runner + ) + if error_result is not None: + yield error_result + return # pragma: no cover + + async with extensions_runner.executing(): + # currently original_subscribe is an async function. A future release + # of graphql-core will make it optionally awaitable + assert execution_context.graphql_document + result: Union[AsyncIterable[GraphQLExecutionResult], GraphQLExecutionResult] + result_or_awaitable = original_subscribe( + schema, + execution_context.graphql_document, + root_value=execution_context.root_value, + context_value=execution_context.context, + variable_values=execution_context.variables, + operation_name=execution_context.operation_name, + ) + if isawaitable(result_or_awaitable): + result = await cast( + Awaitable[ + Union[ + AsyncIterable["GraphQLExecutionResult"], + "GraphQLExecutionResult", + ] + ], + result_or_awaitable, + ) + else: # pragma: no cover + result = cast( + Union[ + AsyncIterable["GraphQLExecutionResult"], + "GraphQLExecutionResult", + ], + result_or_awaitable, + ) + + if isinstance(result, GraphQLExecutionResult): + yield await process_subscribe_result( + execution_context, process_errors, extensions_runner, result + ) + return # pragma: no cover + + yield None # signal that we are returning an async generator + aiterator = result.__aiter__() + try: + async for result in aiterator: + yield await process_subscribe_result( + execution_context, process_errors, extensions_runner, result + ) + finally: + # grapql-core's iterator may or may not have an aclose() method + if hasattr(aiterator, "aclose"): + await aiterator.aclose() + + +async def process_subscribe_result( + execution_context: ExecutionContext, + process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], + extensions_runner: SchemaExtensionsRunner, + result: GraphQLExecutionResult, +) -> ExecutionResult: + execution_context.result = result + # Also set errors on the execution_context so that it's easier + # to access in extensions + if result.errors: + execution_context.errors = result.errors + + # Run the `Schema.process_errors` function here before + # extensions have a chance to modify them (see the MaskErrors + # extension). That way we can log the original errors but + # only return a sanitised version to the client. + process_errors(result.errors, execution_context) + + return ExecutionResult( + data=execution_context.result.data, + errors=execution_context.result.errors, + extensions=await extensions_runner.get_extensions_results(), + ) diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index 1c6a9254a4..2ef096202f 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -5,7 +5,7 @@ from typing import ( TYPE_CHECKING, Any, - AsyncIterator, + AsyncGenerator, Dict, Iterable, List, @@ -20,10 +20,8 @@ GraphQLNonNull, GraphQLSchema, get_introspection_query, - parse, validate_schema, ) -from graphql.execution import subscribe from graphql.type.directives import specified_directives from strawberry import relay @@ -43,11 +41,10 @@ from . import compat from .base import BaseSchema from .config import StrawberryConfig -from .execute import execute, execute_sync +from .execute import execute, execute_sync, subscribe if TYPE_CHECKING: from graphql import ExecutionContext as GraphQLExecutionContext - from graphql import ExecutionResult as GraphQLExecutionResult from strawberry.custom_scalar import ScalarDefinition, ScalarWrapper from strawberry.directive import StrawberryDirective @@ -298,20 +295,26 @@ def execute_sync( async def subscribe( self, - # TODO: make this optional when we support extensions - query: str, + query: Optional[str], variable_values: Optional[Dict[str, Any]] = None, context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, - ) -> Union[AsyncIterator[GraphQLExecutionResult], GraphQLExecutionResult]: + ) -> Union[ExecutionResult, AsyncGenerator[ExecutionResult, None]]: + execution_context = ExecutionContext( + query=query, + schema=self, + context=context_value, + root_value=root_value, + variables=variable_values, + provided_operation_name=operation_name, + ) + return await subscribe( self._schema, - parse(query), - root_value=root_value, - context_value=context_value, - variable_values=variable_values, - operation_name=operation_name, + extensions=self.get_extensions(), + execution_context=execution_context, + process_errors=self.process_errors, ) def _resolve_node_ids(self): diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index fb29b85b86..9e450e494f 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -6,7 +6,6 @@ from contextlib import suppress from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional -from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLError, GraphQLSyntaxError, parse from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( @@ -20,6 +19,7 @@ SubscribeMessage, SubscribeMessagePayload, ) +from strawberry.types import ExecutionResult from strawberry.types.graphql import OperationType from strawberry.unset import UNSET from strawberry.utils.debug import pretty_print_graphql_operation @@ -247,7 +247,7 @@ async def get_result_source(): operation = Operation(self, message.id, operation_type) # Handle initial validation errors - if isinstance(result_source, GraphQLExecutionResult): + if isinstance(result_source, ExecutionResult): assert result_source.errors payload = [err.formatted for err in result_source.errors] await self.send_message(ErrorMessage(id=message.id, payload=payload)) @@ -309,6 +309,8 @@ async def handle_async_results( next_payload["errors"] = [ err.formatted for err in result.errors ] + if result.extensions: + next_payload["extensions"] = result.extensions next_message = NextMessage(id=operation.id, payload=next_payload) await operation.send_message(next_message) except asyncio.CancelledError: diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 609f9fdab8..10a1f9a275 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -5,7 +5,6 @@ from contextlib import suppress from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, cast -from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLError from strawberry.subscriptions.protocols.graphql_ws import ( @@ -20,6 +19,7 @@ GQL_START, GQL_STOP, ) +from strawberry.types import ExecutionResult from strawberry.utils.debug import pretty_print_graphql_operation if TYPE_CHECKING: @@ -137,7 +137,7 @@ async def handle_start(self, message: OperationMessage) -> None: self.schema.process_errors([error]) return - if isinstance(result_source, GraphQLExecutionResult): + if isinstance(result_source, ExecutionResult): assert result_source.errors error_payload = result_source.errors[0].formatted await self.send_message(GQL_ERROR, operation_id, error_payload) @@ -168,6 +168,8 @@ async def handle_async_results( payload = {"data": result.data} if result.errors: payload["errors"] = [err.formatted for err in result.errors] + if result.extensions: + payload["extensions"] = result.extensions await self.send_message(GQL_DATA, operation_id, payload) # log errors after send_message to prevent potential # slowdown of sending result @@ -186,17 +188,17 @@ async def handle_async_results( {"data": None, "errors": [error.formatted]}, ) self.schema.process_errors([error]) + finally: + await result_source.aclose() await self.send_message(GQL_COMPLETE, operation_id, None) async def cleanup_operation(self, operation_id: str) -> None: - await self.subscriptions[operation_id].aclose() - del self.subscriptions[operation_id] - - self.tasks[operation_id].cancel() + self.subscriptions.pop(operation_id) + task = self.tasks.pop(operation_id) + task.cancel() with suppress(BaseException): - await self.tasks[operation_id] - del self.tasks[operation_id] + await task async def send_message( self, diff --git a/tests/channels/test_layers.py b/tests/channels/test_layers.py index 2a80a6f285..6a21190e8e 100644 --- a/tests/channels/test_layers.py +++ b/tests/channels/test_layers.py @@ -79,6 +79,7 @@ async def test_channel_listen(ws: WebsocketCommunicator): ) response = await ws.receive_json_from() + del response["payload"]["extensions"] assert ( response == NextMessage( @@ -179,6 +180,7 @@ async def test_channel_listen_group(ws: WebsocketCommunicator): }, ) response = await ws.receive_json_from() + del response["payload"]["extensions"] assert ( response == NextMessage( @@ -195,6 +197,7 @@ async def test_channel_listen_group(ws: WebsocketCommunicator): ) response = await ws.receive_json_from() + del response["payload"]["extensions"] assert ( response == NextMessage( diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index 454b7b26b5..f7d42a2f56 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -14,7 +14,8 @@ from strawberry.aiohttp.views import GraphQLView as BaseGraphQLView from strawberry.http import GraphQLHTTPResponse from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .base import ( diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 8207cd4a05..0930303d1a 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -15,7 +15,8 @@ from strawberry.asgi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.http import GraphQLHTTPResponse from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .base import ( diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 15f34277ff..f268fc8213 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -17,7 +17,8 @@ from strawberry.channels.handlers.base import ChannelsConsumer from strawberry.http import GraphQLHTTPResponse from strawberry.http.typevars import Context, RootValue -from tests.views.schema import Query, schema +from tests.views.schema import Query, async_schema +from tests.views.schema import schema as sync_schema from ..context import get_context from .base import ( @@ -141,12 +142,12 @@ def __init__( result_override: ResultOverrideFunction = None, ): self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi( - schema=schema, + schema=async_schema, keep_alive=False, ) self.http_app = DebuggableGraphQLHTTPConsumer.as_asgi( - schema=schema, + schema=async_schema, graphiql=graphiql, allow_queries_via_get=allow_queries_via_get, result_override=result_override, @@ -154,7 +155,7 @@ def __init__( def create_app(self, **kwargs: Any) -> None: self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi( - schema=schema, **kwargs + schema=async_schema, **kwargs ) async def _graphql_request( @@ -260,7 +261,7 @@ def __init__( result_override: ResultOverrideFunction = None, ): self.http_app = DebuggableSyncGraphQLHTTPConsumer.as_asgi( - schema=schema, + schema=sync_schema, graphiql=graphiql, allow_queries_via_get=allow_queries_via_get, result_override=result_override, diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index bf5d33bf7b..c488263563 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -14,7 +14,8 @@ from strawberry.fastapi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.http import GraphQLHTTPResponse from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .asgi import AsgiWebSocketClient diff --git a/tests/http/clients/starlite.py b/tests/http/clients/starlite.py index 0ba88dc6ff..d56d4e8a17 100644 --- a/tests/http/clients/starlite.py +++ b/tests/http/clients/starlite.py @@ -14,7 +14,8 @@ from strawberry.starlite import make_graphql_controller from strawberry.starlite.controller import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .base import ( diff --git a/tests/views/schema.py b/tests/views/schema.py index 5e71649c9c..71a49b1a42 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import inspect from enum import Enum from typing import Any, AsyncGenerator, Dict, List, Optional @@ -20,10 +21,121 @@ def has_permission(self, source: Any, info: Info[Any, Any], **kwargs: Any) -> bo return False +class ConditionalFailPermission(BasePermission): + @property + def message(self): + return f"failed after sleep {self.sleep}" + + async def has_permission(self, source, info, **kwargs: Any) -> bool: + self.sleep = kwargs.get("sleep", None) + self.fail = kwargs.get("fail", True) + if self.sleep is not None: + await asyncio.sleep(kwargs["sleep"]) + return not self.fail + + class MyExtension(SchemaExtension): + # a counter to keep track of how many operations are active + active_counter = 0 + def get_results(self) -> Dict[str, str]: return {"example": "example"} + def resolve(self, _next, root, info: Info, *args: Any, **kwargs: Any): + self.active_counter += 1 + try: + self.resolve_called() + return _next(root, info, *args, **kwargs) + finally: + self.active_counter -= 1 + + def resolve_called(self): + pass + + def lifecycle_called(self, event, phase): + pass + + def on_operation(self): + self.lifecycle_called("operation", "before") + self.active_counter += 1 + yield + self.lifecycle_called("operation", "after") + self.active_counter -= 1 + + def on_validate(self): + self.lifecycle_called("validate", "before") + self.active_counter += 1 + yield + self.lifecycle_called("validate", "after") + self.active_counter -= 1 + + def on_parse(self): + self.lifecycle_called("parse", "before") + self.active_counter += 1 + yield + self.lifecycle_called("parse", "after") + self.active_counter -= 1 + + def on_execute(self): + self.lifecycle_called("execute", "before") + self.active_counter += 1 + yield + self.lifecycle_called("execute", "after") + self.active_counter -= 1 + + +class MyAsyncExtension(SchemaExtension): + # a counter to keep track of how many operations are active + active_counter = 0 + + def get_results(self) -> Dict[str, str]: + return {"example": "example"} + + async def resolve(self, _next, root, info: Info, *args: Any, **kwargs: Any): + self.resolve_called() + self.active_counter += 1 + try: + result = _next(root, info, *args, **kwargs) + if inspect.isawaitable(result): + return await result + return result + finally: + self.active_counter -= 1 + + def resolve_called(self): + pass + + def lifecycle_called(self, event, phase): + pass + + async def on_operation(self): + self.lifecycle_called("operation", "before") + self.active_counter += 1 + yield + self.lifecycle_called("operation", "after") + self.active_counter -= 1 + + async def on_validate(self): + self.lifecycle_called("validate", "before") + self.active_counter += 1 + yield + self.lifecycle_called("validate", "after") + self.active_counter -= 1 + + async def on_parse(self): + self.lifecycle_called("parse", "before") + self.active_counter += 1 + yield + self.lifecycle_called("parse", "after") + self.active_counter -= 1 + + async def on_execute(self): + self.lifecycle_called("execute", "before") + self.active_counter += 1 + yield + self.lifecycle_called("execute", "after") + self.active_counter -= 1 + def _read_file(text_file: Upload) -> str: from starlette.datastructures import UploadFile @@ -79,6 +191,12 @@ async def async_hello(self, name: Optional[str] = None, delay: float = 0) -> str def always_fail(self) -> Optional[str]: return "Hey" + @strawberry.field(permission_classes=[ConditionalFailPermission]) + def conditional_fail( + self, sleep: Optional[float] = None, fail: bool = False + ) -> str: + return "Hey" + @strawberry.field async def error(self, message: str) -> AsyncGenerator[str, None]: yield GraphQLError(message) # type: ignore @@ -244,6 +362,12 @@ async def long_finalizer( finally: await asyncio.sleep(delay) + @strawberry.subscription(permission_classes=[ConditionalFailPermission]) + async def conditional_fail( + self, sleep: Optional[float] = None, fail: bool = False + ) -> AsyncGenerator[str, None]: + yield "Hey" + schema = strawberry.Schema( query=Query, @@ -251,3 +375,10 @@ async def long_finalizer( subscription=Subscription, extensions=[MyExtension], ) + +async_schema = strawberry.Schema( + query=Query, + mutation=Mutation, + subscription=Subscription, + extensions=[MyAsyncExtension], +) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 22829a1421..067c7f6551 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -4,13 +4,15 @@ import time from datetime import timedelta from typing import AsyncGenerator, Type -from unittest.mock import patch +from unittest.mock import Mock, patch +from unittest.mock import call as mock_call try: from unittest.mock import AsyncMock except ImportError: AsyncMock = None + import pytest import pytest_asyncio from pytest_mock import MockerFixture @@ -21,7 +23,6 @@ ConnectionAckMessage, ConnectionInitMessage, ErrorMessage, - NextMessage, PingMessage, PongMessage, SubscribeMessage, @@ -31,6 +32,7 @@ from tests.http.clients.base import DebuggableGraphQLTransportWSMixin from ..http.clients import HttpClient, WebSocketClient +from ..views.schema import MyAsyncExtension @pytest_asyncio.fixture @@ -51,6 +53,20 @@ async def ws(ws_raw: WebSocketClient) -> WebSocketClient: return ws_raw +def assert_next(response, id, data, extensions=None): + """ + Assert that the NextMessage payload contains the provided data. + If extensions is provided, it will also assert that the + extensions are present + """ + assert response["type"] == "next" + assert response["id"] == id + assert set(response["payload"].keys()) <= {"data", "errors", "extensions"} + assert response["payload"]["data"] == data + if extensions is not None: + assert response["payload"]["extensions"] == extensions + + async def test_unknown_message_type(ws_raw: WebSocketClient): ws = ws_raw @@ -149,13 +165,7 @@ async def test_connection_init_timeout_cancellation( ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", - payload={"data": {"debug": {"isConnectionInitTimeoutTaskDone": True}}}, - ).as_dict() - ) + assert_next(response, "sub1", {"debug": {"isConnectionInitTimeoutTaskDone": True}}) @pytest.mark.skipif( @@ -221,10 +231,7 @@ async def test_server_sent_ping(ws: WebSocketClient): await ws.send_json(PongMessage().as_dict()) response = await ws.receive_json() - assert ( - response - == NextMessage(id="sub1", payload={"data": {"requestPing": True}}).as_dict() - ) + assert_next(response, "sub1", {"requestPing": True}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -288,9 +295,7 @@ async def test_reused_operation_ids(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response == NextMessage(id="sub1", payload={"data": {"echo": "Hi"}}).as_dict() - ) + assert_next(response, "sub1", {"echo": "Hi"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -307,9 +312,7 @@ async def test_reused_operation_ids(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response == NextMessage(id="sub1", payload={"data": {"echo": "Hi"}}).as_dict() - ) + assert_next(response, "sub1", {"echo": "Hi"}) async def test_simple_subscription(ws: WebSocketClient): @@ -323,9 +326,7 @@ async def test_simple_subscription(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response == NextMessage(id="sub1", payload={"data": {"echo": "Hi"}}).as_dict() - ) + assert_next(response, "sub1", {"echo": "Hi"}) await ws.send_json(CompleteMessage(id="sub1").as_dict()) @@ -362,7 +363,29 @@ async def test_subscription_field_errors(ws: WebSocketClient): assert response["payload"][0]["locations"] == [{"line": 1, "column": 16}] assert ( response["payload"][0]["message"] - == "The subscription field 'notASubscriptionField' is not defined." + == "Cannot query field 'notASubscriptionField' on type 'Subscription'." + ) + + +async def test_query_field_errors(ws: WebSocketClient): + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { notASubscriptionField }", + ), + ).as_dict() + ) + + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") is None + assert response["payload"][0]["locations"] == [{"line": 1, "column": 9}] + assert ( + response["payload"][0]["message"] + == "Cannot query field 'notASubscriptionField' on type 'Query'." ) @@ -386,12 +409,7 @@ async def test_subscription_cancellation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub2", payload={"data": {"debug": {"numActiveResultHandlers": 2}}} - ).as_dict() - ) + assert_next(response, "sub2", {"debug": {"numActiveResultHandlers": 2}}) response = await ws.receive_json() assert response == CompleteMessage(id="sub2").as_dict() @@ -408,12 +426,7 @@ async def test_subscription_cancellation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub3", payload={"data": {"debug": {"numActiveResultHandlers": 1}}} - ).as_dict() - ) + assert_next(response, "sub3", {"debug": {"numActiveResultHandlers": 1}}) response = await ws.receive_json() assert response == CompleteMessage(id="sub3").as_dict() @@ -430,8 +443,7 @@ async def test_subscription_errors(ws: WebSocketClient): ) response = await ws.receive_json() - assert response["type"] == NextMessage.type - assert response["id"] == "sub1" + assert_next(response, "sub1", None) assert len(response["payload"]["errors"]) == 1 assert response["payload"]["errors"][0]["path"] == ["error"] assert response["payload"]["errors"][0]["message"] == "TEST ERR" @@ -498,10 +510,7 @@ async def test_single_result_query_operation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage(id="sub1", payload={"data": {"hello": "Hello world"}}).as_dict() - ) + assert_next(response, "sub1", {"hello": "Hello world"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -523,12 +532,7 @@ async def test_single_result_query_operation_async(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"asyncHello": "Hello Dolly"}} - ).as_dict() - ) + assert_next(response, "sub1", {"asyncHello": "Hello Dolly"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -562,12 +566,7 @@ async def test_single_result_query_operation_overlapped(ws: WebSocketClient): # we expect the response to the second query to arrive first response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub2", payload={"data": {"asyncHello": "Hello Dolly"}} - ).as_dict() - ) + assert_next(response, "sub2", {"asyncHello": "Hello Dolly"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub2").as_dict() @@ -581,10 +580,7 @@ async def test_single_result_mutation_operation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage(id="sub1", payload={"data": {"hello": "strawberry"}}).as_dict() - ) + assert_next(response, "sub1", {"hello": "strawberry"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -608,12 +604,7 @@ async def test_single_result_operation_selection(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"hello": "Hello Strawberry"}} - ).as_dict() - ) + assert_next(response, "sub1", {"hello": "Hello Strawberry"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -757,12 +748,7 @@ async def test_injects_connection_params(ws_raw: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"connectionParams": "rocks"}} - ).as_dict() - ) + assert_next(response, "sub1", {"connectionParams": "rocks"}) await ws.send_json(CompleteMessage(id="sub1").as_dict()) @@ -805,12 +791,7 @@ async def test_subsciption_cancel_finalization_delay(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"longFinalizer": "hello"}} - ).as_dict() - ) + assert_next(response, "sub1", {"longFinalizer": "hello"}) # now cancel the stubscription and send a new query. We expect the response # to the new query to arrive immediately, without waiting for the finalizer @@ -889,22 +870,103 @@ async def test_subscription_errors_continue(ws: WebSocketClient): ) response = await ws.receive_json() - assert response["type"] == NextMessage.type - assert response["id"] == "sub1" - assert response["payload"]["data"] == {"flavorsInvalid": "VANILLA"} + assert_next(response, "sub1", {"flavorsInvalid": "VANILLA"}) response = await ws.receive_json() - assert response["type"] == NextMessage.type - assert response["id"] == "sub1" - assert response["payload"]["data"] is None + assert_next(response, "sub1", None) errors = response["payload"]["errors"] assert "cannot represent value" in str(errors) response = await ws.receive_json() - assert response["type"] == NextMessage.type - assert response["id"] == "sub1" - assert response["payload"]["data"] == {"flavorsInvalid": "CHOCOLATE"} + assert_next(response, "sub1", {"flavorsInvalid": "CHOCOLATE"}) response = await ws.receive_json() assert response["type"] == CompleteMessage.type assert response["id"] == "sub1" + + +async def test_extensions(ws: WebSocketClient): + resolve_called = Mock() + lifecycle_called = Mock() + + # we must make sure that earlier requests and drained before we start + # so that their execution events don't interfere with our events + while MyAsyncExtension.active_counter > 0: + await asyncio.sleep(0.01) + + with patch.object(MyAsyncExtension, "resolve_called", resolve_called): + with patch.object(MyAsyncExtension, "lifecycle_called", lifecycle_called): + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query='subscription { echo(message: "Hi") }' + ), + ).as_dict() + ) + + response = await ws.receive_json() + assert_next( + response, "sub1", {"echo": "Hi"}, extensions={"example": "example"} + ) + response = await ws.receive_json() + assert response == CompleteMessage(id="sub1").as_dict() + + # no resolvers called + assert resolve_called.call_count == 0 + + lifecycle_calls = lifecycle_called.call_args_list + assert lifecycle_calls == [ + mock_call("operation", "before"), + mock_call("parse", "before"), + mock_call("parse", "after"), + mock_call("validate", "before"), + mock_call("validate", "after"), + mock_call("execute", "before"), + mock_call("execute", "after"), + mock_call("operation", "after"), + ] + + +async def test_validation_query(ws: WebSocketClient): + """ + Test validation for query + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:true) }" + ), + ).as_dict() + ) + + # We expect an error message directly + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") == ["conditionalFail"] + assert response["payload"][0]["message"] == "failed after sleep None" + + +async def test_validation_subscription(ws: WebSocketClient): + """ + Test validation for subscription + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(fail:true) }" + ), + ).as_dict() + ) + + # We expect an error message directly + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") == ["conditionalFail"] + assert response["payload"][0]["message"] == "failed after sleep None" diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 6a07fb6ac4..3da07ea73f 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -251,7 +251,9 @@ async def test_subscription_field_error(ws: WebSocketClient): assert response["id"] == "invalid-field" assert response["payload"] == { "locations": [{"line": 1, "column": 16}], - "message": ("The subscription field 'notASubscriptionField' is not defined."), + "message": ( + "Cannot query field 'notASubscriptionField' on type 'Subscription'." + ), } @@ -548,3 +550,26 @@ async def test_rejects_connection_params(aiohttp_app_client: HttpClient): # make sure the WebSocket is disconnected now await ws.receive(timeout=2) # receive close assert ws.closed + + +async def test_extensions(ws: WebSocketClient): + await ws.send_json( + { + "type": GQL_START, + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } + ) + + response = await ws.receive_json() + assert response["type"] == GQL_DATA + assert response["id"] == "demo" + assert response["payload"]["data"] == {"echo": "Hi"} + assert response["payload"]["extensions"] == {"example": "example"} + + await ws.send_json({"type": GQL_STOP, "id": "demo"}) + response = await ws.receive_json() + assert response["type"] == GQL_COMPLETE + assert response["id"] == "demo"