diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..6ef7811880 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: minor + +Subscriptions now support Schema Extensions. diff --git a/docs/guides/custom-extensions.md b/docs/guides/custom-extensions.md index 22cc8c1f97..1dd000409d 100644 --- a/docs/guides/custom-extensions.md +++ b/docs/guides/custom-extensions.md @@ -36,8 +36,6 @@ resolvers. If you need to wrap only certain field resolvers with additional logic, please check out [field extensions](field-extensions.md). -Note that `resolve` can also be implemented asynchronously. - ```python from strawberry.extensions import SchemaExtension @@ -47,6 +45,21 @@ class MyExtension(SchemaExtension): return _next(root, info, *args, **kwargs) ``` +Note that `resolve` can also be implemented asynchronously, in which case the +result from `_next` must be optionally awaited: + +```python +from inspect import isawaitable +from strawberry.types import Info +from strawberry.extensions import SchemaExtension + + +class MyExtension(SchemaExtension): + async def resolve(self, _next, root, info: Info, *args, **kwargs): + result = _next(root, info, *args, **kwargs) + return await result if isawaitable(result) else result +``` + ### Get results `get_results` allows to return a dictionary of data or alternatively an diff --git a/docs/operations/testing.md b/docs/operations/testing.md index a3aca51741..056cd8cf0f 100644 --- a/docs/operations/testing.md +++ b/docs/operations/testing.md @@ -136,10 +136,8 @@ async def test_subscription(): } """ - sub = await schema.subscribe(query) - index = 0 - async for result in sub: + async for result in schema.subscribe(query): assert not result.errors assert result.data == {"count": index} diff --git a/strawberry/schema/__init__.py b/strawberry/schema/__init__.py index 5cf633ac21..cbc3bf134a 100644 --- a/strawberry/schema/__init__.py +++ b/strawberry/schema/__init__.py @@ -1,4 +1,4 @@ -from .base import BaseSchema +from .base import BaseSchema, SubscribeSingleResult from .schema import Schema -__all__ = ["BaseSchema", "Schema"] +__all__ = ["BaseSchema", "Schema", "SubscribeSingleResult"] diff --git a/strawberry/schema/base.py b/strawberry/schema/base.py index a1c286c6d0..25a039a609 100644 --- a/strawberry/schema/base.py +++ b/strawberry/schema/base.py @@ -2,7 +2,17 @@ from abc import abstractmethod from functools import lru_cache -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Dict, + Iterable, + List, + Optional, + Type, + Union, +) from typing_extensions import Protocol from strawberry.utils.logging import StrawberryLogger @@ -22,6 +32,15 @@ from .config import StrawberryConfig +class SubscribeSingleResult(RuntimeError): + """Raised when Schema.subscribe() returns a single execution result, instead of a + subscription generator, typically as a result of validation errors. + """ + + def __init__(self, value: ExecutionResult) -> None: + self.value = value + + class BaseSchema(Protocol): config: StrawberryConfig schema_converter: GraphQLCoreConverter @@ -55,14 +74,14 @@ def execute_sync( raise NotImplementedError @abstractmethod - async def subscribe( + def subscribe( self, query: str, variable_values: Optional[Dict[str, Any]] = None, context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, - ) -> Any: + ) -> AsyncGenerator[ExecutionResult, None]: raise NotImplementedError @abstractmethod diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index af0bd07a7f..7f2da62cd0 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -4,6 +4,8 @@ from inspect import isawaitable from typing import ( TYPE_CHECKING, + AsyncGenerator, + AsyncIterable, Awaitable, Callable, Iterable, @@ -17,21 +19,23 @@ cast, ) +from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLError, parse from graphql import execute as original_execute +from graphql import subscribe as original_subscribe from graphql.validation import validate from strawberry.exceptions import MissingQueryError from strawberry.extensions.runner import SchemaExtensionsRunner from strawberry.types import ExecutionResult +from .base import SubscribeSingleResult from .exceptions import InvalidOperationTypeError if TYPE_CHECKING: from typing_extensions import NotRequired, Unpack from graphql import ExecutionContext as GraphQLExecutionContext - from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLSchema from graphql.language import DocumentNode from graphql.validation import ASTValidationRule @@ -272,3 +276,149 @@ def execute_sync( errors=execution_context.result.errors, extensions=extensions_runner.get_extensions_results_sync(), ) + + +async def subscribe( + schema: GraphQLSchema, + *, + extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]], + execution_context: ExecutionContext, + process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], +) -> AsyncGenerator[ExecutionResult, None]: + """ + The graphql-core subscribe function returns either an ExecutionResult or an + AsyncGenerator[ExecutionResult, None]. The former is returned in case of an error + during parsing or validation. + Because we need to maintain execution context, we cannot return an + async generator, we must _be_ an async generator. So we yield a + (bool, ExecutionResult) tuple, where the bool indicates whether the result is an + potentially multiple execution result or a single result. + A False value indicates an single result, most likely an intial + failure (and no more values will be yielded) whereas a True value indicates a + successful subscription, and more values may be yielded. + """ + + extensions_runner = SchemaExtensionsRunner( + execution_context=execution_context, + extensions=list(extensions), + ) + + # unlike execute(), the entire operation, including the results hooks, + # is run within the operation() hook. + async with extensions_runner.operation(): + # Note: In graphql-core the schema would be validated here but in + # Strawberry we are validating it at initialisation time instead + assert execution_context.query is not None + + async with extensions_runner.parsing(): + try: + if not execution_context.graphql_document: + execution_context.graphql_document = parse_document( + execution_context.query, **execution_context.parse_options + ) + + except GraphQLError as error: + execution_context.errors = [error] + process_errors([error], execution_context) + raise SubscribeSingleResult( + ExecutionResult( + data=None, + errors=[error], + extensions=await extensions_runner.get_extensions_results(), + ) + ) + + except Exception as error: # pragma: no cover + error = GraphQLError(str(error), original_error=error) + execution_context.errors = [error] + process_errors([error], execution_context) + + raise SubscribeSingleResult( + ExecutionResult( + data=None, + errors=[error], + extensions=await extensions_runner.get_extensions_results(), + ) + ) + + async with extensions_runner.validation(): + _run_validation(execution_context) + if execution_context.errors: + process_errors(execution_context.errors, execution_context) + raise SubscribeSingleResult( + ExecutionResult(data=None, errors=execution_context.errors) + ) + + async with extensions_runner.executing(): + # currently original_subscribe is an async function. A future release + # of graphql-core will make it optionally awaitable + result: Union[AsyncIterable[GraphQLExecutionResult], GraphQLExecutionResult] + result_or_awaitable = original_subscribe( + schema, + execution_context.graphql_document, + root_value=execution_context.root_value, + context_value=execution_context.context, + variable_values=execution_context.variables, + operation_name=execution_context.operation_name, + ) + if isawaitable(result_or_awaitable): + result = await cast( + Awaitable[ + Union[ + AsyncIterable["GraphQLExecutionResult"], + "GraphQLExecutionResult", + ] + ], + result_or_awaitable, + ) + else: # pragma: no cover + result = cast( + Union[ + AsyncIterable["GraphQLExecutionResult"], + "GraphQLExecutionResult", + ], + result_or_awaitable, + ) + + if isinstance(result, GraphQLExecutionResult): + raise SubscribeSingleResult( + await process_subscribe_result( + execution_context, process_errors, extensions_runner, result + ) + ) + + aiterator = result.__aiter__() + try: + async for result in aiterator: + yield await process_subscribe_result( + execution_context, process_errors, extensions_runner, result + ) + finally: + # grapql-core's iterator may or may not have an aclose() method + if hasattr(aiterator, "aclose"): + await aiterator.aclose() + + +async def process_subscribe_result( + execution_context: ExecutionContext, + process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], + extensions_runner: SchemaExtensionsRunner, + result: GraphQLExecutionResult, +) -> ExecutionResult: + execution_context.result = result + # Also set errors on the execution_context so that it's easier + # to access in extensions + if result.errors: + execution_context.errors = result.errors + + # Run the `Schema.process_errors` function here before + # extensions have a chance to modify them (see the MaskErrors + # extension). That way we can log the original errors but + # only return a sanitised version to the client. + process_errors(result.errors, execution_context) + + return ExecutionResult( + data=execution_context.result.data, + errors=execution_context.result.errors, + extensions=await extensions_runner.get_extensions_results(), + ) diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index b43963d9b5..1d5009f829 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -5,7 +5,7 @@ from typing import ( TYPE_CHECKING, Any, - AsyncIterator, + AsyncGenerator, Dict, Iterable, List, @@ -20,10 +20,8 @@ GraphQLNonNull, GraphQLSchema, get_introspection_query, - parse, validate_schema, ) -from graphql.execution import subscribe from graphql.type.directives import specified_directives from strawberry import relay @@ -43,11 +41,10 @@ from . import compat from .base import BaseSchema from .config import StrawberryConfig -from .execute import execute, execute_sync +from .execute import execute, execute_sync, subscribe if TYPE_CHECKING: from graphql import ExecutionContext as GraphQLExecutionContext - from graphql import ExecutionResult as GraphQLExecutionResult from strawberry.custom_scalar import ScalarDefinition, ScalarWrapper from strawberry.directive import StrawberryDirective @@ -305,22 +302,29 @@ def execute_sync( async def subscribe( self, - # TODO: make this optional when we support extensions - query: str, + query: Optional[str], variable_values: Optional[Dict[str, Any]] = None, context_value: Optional[Any] = None, root_value: Optional[Any] = None, operation_name: Optional[str] = None, - ) -> Union[AsyncIterator[GraphQLExecutionResult], GraphQLExecutionResult]: - return await subscribe( - self._schema, - parse(query), + ) -> AsyncGenerator[ExecutionResult, None]: + execution_context = ExecutionContext( + query=query, + schema=self, + context=context_value, root_value=root_value, - context_value=context_value, - variable_values=variable_values, - operation_name=operation_name, + variables=variable_values, + provided_operation_name=operation_name, ) + async for result in subscribe( + self._schema, + extensions=self.get_extensions(), + execution_context=execution_context, + process_errors=self.process_errors, + ): + yield result + def _resolve_node_ids(self): for concrete_type in self.schema_converter.type_map.values(): type_def = concrete_type.definition diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index f20c674124..c02b4d0690 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -8,16 +8,15 @@ TYPE_CHECKING, Any, AsyncGenerator, - AsyncIterator, Callable, Dict, List, Optional, ) -from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLError, GraphQLSyntaxError, parse +from strawberry.schema import SubscribeSingleResult from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( CompleteMessage, ConnectionAckMessage, @@ -246,7 +245,7 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: # Get an AsyncGenerator yielding the results if operation_type == OperationType.SUBSCRIPTION: - result_source = await self.schema.subscribe( + result_source = self.schema.subscribe( query=message.payload.query, variable_values=message.payload.variables, operation_name=message.payload.operationName, @@ -255,28 +254,23 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: ) else: # create AsyncGenerator returning a single result - async def get_result_source() -> AsyncIterator[ExecutionResult]: - yield await self.schema.execute( - query=message.payload.query, - variable_values=message.payload.variables, - context_value=context, - root_value=root_value, - operation_name=message.payload.operationName, + async def get_result_source() -> AsyncGenerator[ExecutionResult, None]: + raise SubscribeSingleResult( + await self.schema.execute( + query=message.payload.query, + variable_values=message.payload.variables, + context_value=context, + root_value=root_value, + operation_name=message.payload.operationName, + ) ) + # need a yield here to turn this into an async generator + yield None # pragma: no cover result_source = get_result_source() operation = Operation(self, message.id, operation_type) - # Handle initial validation errors - if isinstance(result_source, GraphQLExecutionResult): - assert operation_type == OperationType.SUBSCRIPTION - assert result_source.errors - payload = [err.formatted for err in result_source.errors] - await self.send_message(ErrorMessage(id=message.id, payload=payload)) - self.schema.process_errors(result_source.errors) - return - # Create task to handle this subscription, reserve the operation ID operation.task = asyncio.create_task( self.operation_task(result_source, operation) @@ -316,26 +310,16 @@ async def handle_async_results( operation: Operation, ) -> None: try: - async for result in result_source: - if ( - result.errors - and operation.operation_type != OperationType.SUBSCRIPTION - ): - error_payload = [err.formatted for err in result.errors] - error_message = ErrorMessage(id=operation.id, payload=error_payload) - await operation.send_message(error_message) - # don't need to call schema.process_errors() here because - # it was already done by schema.execute() - return - else: - next_payload = {"data": result.data} - if result.errors: - self.schema.process_errors(result.errors) - next_payload["errors"] = [ - err.formatted for err in result.errors - ] - next_message = NextMessage(id=operation.id, payload=next_payload) - await operation.send_message(next_message) + try: + async for result in result_source: + await self.send_result(operation, result, False) + except SubscribeSingleResult as single_result: + await self.send_result(operation, single_result.value, True) + finally: + await result_source.aclose() + except asyncio.CancelledError: + # CancelledErrors are expected during task cleanup. + raise except Exception as error: # GraphQLErrors are handled by graphql-core and included in the # ExecutionResult @@ -346,6 +330,22 @@ async def handle_async_results( self.schema.process_errors([error]) return + async def send_result( + self, operation: Operation, result: ExecutionResult, single: bool + ) -> None: + if result.errors and single: + error_payload = [err.formatted for err in result.errors] + error_message = ErrorMessage(id=operation.id, payload=error_payload) + await operation.send_message(error_message) + else: + next_payload: Dict[str, Any] = {"data": result.data} + if result.errors: + next_payload["errors"] = [err.formatted for err in result.errors] + if result.extensions: + next_payload["extensions"] = result.extensions + next_message = NextMessage(id=operation.id, payload=next_payload) + await operation.send_message(next_message) + def forget_id(self, id: str) -> None: # de-register the operation id making it immediately available # for re-use diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index c4863ff49e..c36a8ba911 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -5,9 +5,9 @@ from contextlib import suppress from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, cast -from graphql import ExecutionResult as GraphQLExecutionResult from graphql import GraphQLError +from strawberry.schema import SubscribeSingleResult from strawberry.subscriptions.protocols.graphql_ws import ( GQL_COMPLETE, GQL_CONNECTION_ACK, @@ -123,26 +123,13 @@ async def handle_start(self, message: OperationMessage) -> None: if self.debug: pretty_print_graphql_operation(operation_name, query, variables) - try: - result_source = await self.schema.subscribe( - query=query, - variable_values=variables, - operation_name=operation_name, - context_value=context, - root_value=root_value, - ) - except GraphQLError as error: - error_payload = error.formatted - await self.send_message(GQL_ERROR, operation_id, error_payload) - self.schema.process_errors([error]) - return - - if isinstance(result_source, GraphQLExecutionResult): - assert result_source.errors - error_payload = result_source.errors[0].formatted - await self.send_message(GQL_ERROR, operation_id, error_payload) - self.schema.process_errors(result_source.errors) - return + result_source = self.schema.subscribe( + query=query, + variable_values=variables, + operation_name=operation_name, + context_value=context, + root_value=root_value, + ) self.subscriptions[operation_id] = result_source result_handler = self.handle_async_results(result_source, operation_id) @@ -164,15 +151,22 @@ async def handle_async_results( operation_id: str, ) -> None: try: - async for result in result_source: - payload = {"data": result.data} - if result.errors: - payload["errors"] = [err.formatted for err in result.errors] - await self.send_message(GQL_DATA, operation_id, payload) - # log errors after send_message to prevent potential - # slowdown of sending result - if result.errors: - self.schema.process_errors(result.errors) + try: + async for result in result_source: + payload = {"data": result.data} + if result.errors: + payload["errors"] = [err.formatted for err in result.errors] + if result.extensions: + payload["extensions"] = result.extensions + await self.send_message(GQL_DATA, operation_id, payload) + except SubscribeSingleResult as single_result: + result = single_result.value + assert result.errors + error_payload = result.errors[0].formatted + await self.send_message(GQL_ERROR, operation_id, error_payload) + return + finally: + await result_source.aclose() except asyncio.CancelledError: # CancelledErrors are expected during task cleanup. pass @@ -190,13 +184,12 @@ async def handle_async_results( await self.send_message(GQL_COMPLETE, operation_id, None) async def cleanup_operation(self, operation_id: str) -> None: - await self.subscriptions[operation_id].aclose() - del self.subscriptions[operation_id] - - self.tasks[operation_id].cancel() + iterator = self.subscriptions.pop(operation_id) + task = self.tasks.pop(operation_id) + task.cancel() with suppress(BaseException): - await self.tasks[operation_id] - del self.tasks[operation_id] + await task + await iterator.aclose() async def send_message( self, diff --git a/tests/benchmarks/test_subscriptions.py b/tests/benchmarks/test_subscriptions.py index 075c993904..c85a398f54 100644 --- a/tests/benchmarks/test_subscriptions.py +++ b/tests/benchmarks/test_subscriptions.py @@ -16,11 +16,8 @@ def test_subscription(benchmark: BenchmarkFixture): async def _run(): for _ in range(100): - iterator = await schema.subscribe(s) - - value = await iterator.__anext__() # type: ignore[union-attr] - - assert value.data is not None - assert value.data["something"] == "Hello World!" + async for value in schema.subscribe(s): + assert value.data is not None + assert value.data["something"] == "Hello World!" benchmark(lambda: asyncio.run(_run())) diff --git a/tests/channels/test_layers.py b/tests/channels/test_layers.py index b868a0826b..d5e09e1bcf 100644 --- a/tests/channels/test_layers.py +++ b/tests/channels/test_layers.py @@ -91,6 +91,7 @@ async def test_channel_listen(ws: WebsocketCommunicator): ) response = await ws.receive_json_from() + response["payload"].pop("extensions", None) assert ( response == NextMessage( @@ -137,6 +138,7 @@ async def test_channel_listen_with_confirmation(ws: WebsocketCommunicator): ) response = await ws.receive_json_from() + response["payload"].pop("extensions", None) assert ( response == NextMessage( @@ -315,6 +317,7 @@ async def test_channel_listen_group(ws: WebsocketCommunicator): }, ) response = await ws.receive_json_from() + response["payload"].pop("extensions", None) assert ( response == NextMessage( @@ -331,6 +334,7 @@ async def test_channel_listen_group(ws: WebsocketCommunicator): ) response = await ws.receive_json_from() + response["payload"].pop("extensions", None) assert ( response == NextMessage( @@ -377,6 +381,7 @@ async def test_channel_listen_group_cm(ws: WebsocketCommunicator): }, ) response = await ws.receive_json_from() + response["payload"].pop("extensions", None) assert ( response == NextMessage( @@ -393,6 +398,7 @@ async def test_channel_listen_group_cm(ws: WebsocketCommunicator): ) response = await ws.receive_json_from() + response["payload"].pop("extensions", None) assert ( response == NextMessage( diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index cd552e877c..14c3dcaeb6 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -15,7 +15,8 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .base import ( diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 72d9e95aa6..843ae9e082 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -16,7 +16,8 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .base import ( diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index ee31c8e88b..22148384c6 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -18,7 +18,8 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.http.typevars import Context, RootValue -from tests.views.schema import Query, schema +from tests.views.schema import Query, async_schema +from tests.views.schema import schema as sync_schema from ..context import get_context from .base import ( @@ -143,12 +144,12 @@ def __init__( result_override: ResultOverrideFunction = None, ): self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi( - schema=schema, + schema=async_schema, keep_alive=False, ) self.http_app = DebuggableGraphQLHTTPConsumer.as_asgi( - schema=schema, + schema=async_schema, graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, @@ -157,7 +158,7 @@ def __init__( def create_app(self, **kwargs: Any) -> None: self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi( - schema=schema, **kwargs + schema=async_schema, **kwargs ) async def _graphql_request( @@ -264,7 +265,7 @@ def __init__( result_override: ResultOverrideFunction = None, ): self.http_app = DebuggableSyncGraphQLHTTPConsumer.as_asgi( - schema=schema, + schema=sync_schema, graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index f271509f40..ca7fe820df 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -15,7 +15,8 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .asgi import AsgiWebSocketClient diff --git a/tests/http/clients/starlite.py b/tests/http/clients/starlite.py index 9af3a8206e..2db8593d14 100644 --- a/tests/http/clients/starlite.py +++ b/tests/http/clients/starlite.py @@ -15,7 +15,8 @@ from strawberry.starlite import make_graphql_controller from strawberry.starlite.controller import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.types import ExecutionResult -from tests.views.schema import Query, schema +from tests.views.schema import Query +from tests.views.schema import async_schema as schema from ..context import get_context from .base import ( diff --git a/tests/schema/test_permission.py b/tests/schema/test_permission.py index 0cec4cde60..52154b2331 100644 --- a/tests/schema/test_permission.py +++ b/tests/schema/test_permission.py @@ -12,6 +12,7 @@ ) from strawberry.permission import BasePermission, PermissionExtension from strawberry.printer import print_schema +from strawberry.schema import SubscribeSingleResult def test_raises_graphql_error_when_permission_method_is_missing(): @@ -80,8 +81,10 @@ async def user( query = "subscription { user }" - result = await schema.subscribe(query) - + with pytest.raises(SubscribeSingleResult) as err: + async for _ in schema.subscribe(query): + pass + result = err.value.value assert result.errors[0].message == "You are not authorized" diff --git a/tests/schema/test_subscription.py b/tests/schema/test_subscription.py index b63af92dcc..f32ac433ed 100644 --- a/tests/schema/test_subscription.py +++ b/tests/schema/test_subscription.py @@ -33,11 +33,9 @@ async def example(self) -> AsyncGenerator[str, None]: query = "subscription { example }" - sub = await schema.subscribe(query) - result = await sub.__anext__() - - assert not result.errors - assert result.data["example"] == "Hi" + async for result in schema.subscribe(query): + assert not result.errors + assert result.data["example"] == "Hi" @pytest.mark.asyncio @@ -66,11 +64,9 @@ async def example(self) -> AsyncGenerator[str, None]: query = "subscription { example }" - sub = await schema.subscribe(query) - result = await sub.__anext__() - - assert not result.errors - assert result.data["example"] == "Hi" + async for result in schema.subscribe(query): + assert not result.errors + assert result.data["example"] == "Hi" @pytest.mark.asyncio @@ -89,11 +85,9 @@ async def example(self, name: str) -> AsyncGenerator[str, None]: query = 'subscription { example(name: "Nina") }' - sub = await schema.subscribe(query) - result = await sub.__anext__() - - assert not result.errors - assert result.data["example"] == "Hi Nina" + async for result in schema.subscribe(query): + assert not result.errors + assert result.data["example"] == "Hi Nina" requires_builtin_generics = pytest.mark.skipif( @@ -132,11 +126,9 @@ class Subscription: query = "subscription { example }" - sub = await schema.subscribe(query) - result = await sub.__anext__() - - assert not result.errors - assert result.data["example"] == "Hi" + async for result in schema.subscribe(query): + assert not result.errors + assert result.data["example"] == "Hi" @pytest.mark.asyncio @@ -165,11 +157,9 @@ async def example_with_union(self) -> AsyncGenerator[Union[A, B], None]: query = "subscription { exampleWithUnion { ... on A { a } } }" - sub = await schema.subscribe(query) - result = await sub.__anext__() - - assert not result.errors - assert result.data["exampleWithUnion"]["a"] == "Hi" + async for result in schema.subscribe(query): + assert not result.errors + assert result.data["exampleWithUnion"]["a"] == "Hi" del A, B @@ -204,11 +194,9 @@ async def example_with_annotated_union( query = "subscription { exampleWithAnnotatedUnion { ... on C { c } } }" - sub = await schema.subscribe(query) - result = await sub.__anext__() - - assert not result.errors - assert result.data["exampleWithAnnotatedUnion"]["c"] == "Hi" + async for result in schema.subscribe(query): + assert not result.errors + assert result.data["exampleWithAnnotatedUnion"]["c"] == "Hi" del C, D @@ -231,8 +219,6 @@ async def example( query = "subscription { example }" - sub = await schema.subscribe(query) - result = await sub.__anext__() - - assert not result.errors - assert result.data["example"] == "Hi" + async for result in schema.subscribe(query): + assert not result.errors + assert result.data["example"] == "Hi" diff --git a/tests/views/schema.py b/tests/views/schema.py index 14a25b0c0d..546834d052 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -1,5 +1,6 @@ import asyncio import contextlib +import inspect from enum import Enum from typing import Any, AsyncGenerator, Dict, List, Optional, Union @@ -20,10 +21,123 @@ def has_permission(self, source: Any, info: strawberry.Info, **kwargs: Any) -> b return False +class ConditionalFailPermission(BasePermission): + @property + def message(self): + return f"failed after sleep {self.sleep}" + + async def has_permission(self, source, info, **kwargs: Any) -> bool: + self.sleep = kwargs.get("sleep", None) + self.fail = kwargs.get("fail", True) + if self.sleep is not None: + await asyncio.sleep(kwargs["sleep"]) + return not self.fail + + class MyExtension(SchemaExtension): + # a counter to keep track of how many operations are active + active_counter = 0 + def get_results(self) -> Dict[str, str]: return {"example": "example"} + def resolve(self, _next, root, info: strawberry.Info, *args: Any, **kwargs: Any): + self.active_counter += 1 + try: + self.resolve_called() + return _next(root, info, *args, **kwargs) + finally: + self.active_counter -= 1 + + def resolve_called(self): + pass + + def lifecycle_called(self, event, phase): + pass + + def on_operation(self): + self.lifecycle_called("operation", "before") + self.active_counter += 1 + yield + self.lifecycle_called("operation", "after") + self.active_counter -= 1 + + def on_validate(self): + self.lifecycle_called("validate", "before") + self.active_counter += 1 + yield + self.lifecycle_called("validate", "after") + self.active_counter -= 1 + + def on_parse(self): + self.lifecycle_called("parse", "before") + self.active_counter += 1 + yield + self.lifecycle_called("parse", "after") + self.active_counter -= 1 + + def on_execute(self): + self.lifecycle_called("execute", "before") + self.active_counter += 1 + yield + self.lifecycle_called("execute", "after") + self.active_counter -= 1 + + +class MyAsyncExtension(SchemaExtension): + # a counter to keep track of how many operations are active + active_counter = 0 + + def get_results(self) -> Dict[str, str]: + return {"example": "example"} + + async def resolve( + self, _next, root, info: strawberry.Info, *args: Any, **kwargs: Any + ): + self.resolve_called() + self.active_counter += 1 + try: + result = _next(root, info, *args, **kwargs) + if inspect.isawaitable(result): + return await result + return result + finally: + self.active_counter -= 1 + + def resolve_called(self): + pass + + def lifecycle_called(self, event, phase): + pass + + async def on_operation(self): + self.lifecycle_called("operation", "before") + self.active_counter += 1 + yield + self.lifecycle_called("operation", "after") + self.active_counter -= 1 + + async def on_validate(self): + self.lifecycle_called("validate", "before") + self.active_counter += 1 + yield + self.lifecycle_called("validate", "after") + self.active_counter -= 1 + + async def on_parse(self): + self.lifecycle_called("parse", "before") + self.active_counter += 1 + yield + self.lifecycle_called("parse", "after") + self.active_counter -= 1 + + async def on_execute(self): + self.lifecycle_called("execute", "before") + self.active_counter += 1 + yield + self.lifecycle_called("execute", "after") + self.active_counter -= 1 + def _read_file(text_file: Upload) -> str: with contextlib.suppress(ModuleNotFoundError): @@ -86,6 +200,12 @@ async def async_hello(self, name: Optional[str] = None, delay: float = 0) -> str def always_fail(self) -> Optional[str]: return "Hey" + @strawberry.field(permission_classes=[ConditionalFailPermission]) + def conditional_fail( + self, sleep: Optional[float] = None, fail: bool = False + ) -> str: + return "Hey" + @strawberry.field async def error(self, message: str) -> AsyncGenerator[str, None]: yield GraphQLError(message) # type: ignore @@ -268,6 +388,12 @@ async def long_finalizer( finally: await asyncio.sleep(delay) + @strawberry.subscription(permission_classes=[ConditionalFailPermission]) + async def conditional_fail( + self, sleep: Optional[float] = None, fail: bool = False + ) -> AsyncGenerator[str, None]: + yield "Hey" + class Schema(strawberry.Schema): def process_errors( @@ -285,3 +411,10 @@ def process_errors( subscription=Subscription, extensions=[MyExtension], ) + +async_schema = Schema( + query=Query, + mutation=Mutation, + subscription=Subscription, + extensions=[MyAsyncExtension], +) diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 1ab738fb73..6daf7eb4db 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -7,6 +7,7 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any, AsyncGenerator, Type from unittest.mock import Mock, patch +from unittest.mock import call as mock_call try: from unittest.mock import AsyncMock @@ -32,6 +33,8 @@ from tests.http.clients.base import DebuggableGraphQLTransportWSMixin from tests.views.schema import Schema +from ..views.schema import MyAsyncExtension + if TYPE_CHECKING: from ..http.clients.base import HttpClient, WebSocketClient @@ -54,6 +57,20 @@ async def ws(ws_raw: WebSocketClient) -> WebSocketClient: return ws_raw +def assert_next(response, id, data, extensions=None): + """ + Assert that the NextMessage payload contains the provided data. + If extensions is provided, it will also assert that the + extensions are present + """ + assert response["type"] == "next" + assert response["id"] == id + assert set(response["payload"].keys()) <= {"data", "errors", "extensions"} + assert response["payload"]["data"] == data + if extensions is not None: + assert response["payload"]["extensions"] == extensions + + async def test_unknown_message_type(ws_raw: WebSocketClient): ws = ws_raw @@ -158,13 +175,7 @@ async def test_connection_init_timeout_cancellation( ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", - payload={"data": {"debug": {"isConnectionInitTimeoutTaskDone": True}}}, - ).as_dict() - ) + assert_next(response, "sub1", {"debug": {"isConnectionInitTimeoutTaskDone": True}}) @pytest.mark.xfail(reason="This test is flaky") @@ -239,11 +250,8 @@ async def test_can_send_payload_with_additional_things(ws_raw: WebSocketClient): data = await ws.receive(timeout=2) - assert json.loads(data.data) == { - "type": "next", - "id": "1", - "payload": {"data": {"echo": "Hi"}}, - } + result = json.loads(data.data) + assert_next(result, "1", {"echo": "Hi"}) async def test_server_sent_ping(ws: WebSocketClient): @@ -260,10 +268,7 @@ async def test_server_sent_ping(ws: WebSocketClient): await ws.send_json(PongMessage().as_dict()) response = await ws.receive_json() - assert ( - response - == NextMessage(id="sub1", payload={"data": {"requestPing": True}}).as_dict() - ) + assert_next(response, "sub1", {"requestPing": True}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -327,9 +332,7 @@ async def test_reused_operation_ids(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response == NextMessage(id="sub1", payload={"data": {"echo": "Hi"}}).as_dict() - ) + assert_next(response, "sub1", {"echo": "Hi"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -346,9 +349,7 @@ async def test_reused_operation_ids(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response == NextMessage(id="sub1", payload={"data": {"echo": "Hi"}}).as_dict() - ) + assert_next(response, "sub1", {"echo": "Hi"}) async def test_simple_subscription(ws: WebSocketClient): @@ -362,9 +363,7 @@ async def test_simple_subscription(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response == NextMessage(id="sub1", payload={"data": {"echo": "Hi"}}).as_dict() - ) + assert_next(response, "sub1", {"echo": "Hi"}) await ws.send_json(CompleteMessage(id="sub1").as_dict()) @@ -403,11 +402,33 @@ async def test_subscription_field_errors(ws: WebSocketClient): assert response["payload"][0]["locations"] == [{"line": 1, "column": 16}] assert ( response["payload"][0]["message"] - == "The subscription field 'notASubscriptionField' is not defined." + == "Cannot query field 'notASubscriptionField' on type 'Subscription'." ) process_errors.assert_called_once() +async def test_query_field_errors(ws: WebSocketClient): + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { notASubscriptionField }", + ), + ).as_dict() + ) + + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") is None + assert response["payload"][0]["locations"] == [{"line": 1, "column": 9}] + assert ( + response["payload"][0]["message"] + == "Cannot query field 'notASubscriptionField' on type 'Query'." + ) + + async def test_subscription_cancellation(ws: WebSocketClient): await ws.send_json( SubscribeMessage( @@ -428,12 +449,7 @@ async def test_subscription_cancellation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub2", payload={"data": {"debug": {"numActiveResultHandlers": 2}}} - ).as_dict() - ) + assert_next(response, "sub2", {"debug": {"numActiveResultHandlers": 2}}) response = await ws.receive_json() assert response == CompleteMessage(id="sub2").as_dict() @@ -450,12 +466,7 @@ async def test_subscription_cancellation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub3", payload={"data": {"debug": {"numActiveResultHandlers": 1}}} - ).as_dict() - ) + assert_next(response, "sub3", {"debug": {"numActiveResultHandlers": 1}}) response = await ws.receive_json() assert response == CompleteMessage(id="sub3").as_dict() @@ -543,10 +554,7 @@ async def test_single_result_query_operation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage(id="sub1", payload={"data": {"hello": "Hello world"}}).as_dict() - ) + assert_next(response, "sub1", {"hello": "Hello world"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -568,12 +576,7 @@ async def test_single_result_query_operation_async(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"asyncHello": "Hello Dolly"}} - ).as_dict() - ) + assert_next(response, "sub1", {"asyncHello": "Hello Dolly"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -607,12 +610,7 @@ async def test_single_result_query_operation_overlapped(ws: WebSocketClient): # we expect the response to the second query to arrive first response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub2", payload={"data": {"asyncHello": "Hello Dolly"}} - ).as_dict() - ) + assert_next(response, "sub2", {"asyncHello": "Hello Dolly"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub2").as_dict() @@ -626,10 +624,7 @@ async def test_single_result_mutation_operation(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage(id="sub1", payload={"data": {"hello": "strawberry"}}).as_dict() - ) + assert_next(response, "sub1", {"hello": "strawberry"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -653,12 +648,7 @@ async def test_single_result_operation_selection(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"hello": "Hello Strawberry"}} - ).as_dict() - ) + assert_next(response, "sub1", {"hello": "Hello Strawberry"}) response = await ws.receive_json() assert response == CompleteMessage(id="sub1").as_dict() @@ -808,12 +798,7 @@ async def test_injects_connection_params(ws_raw: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"connectionParams": "rocks"}} - ).as_dict() - ) + assert_next(response, "sub1", {"connectionParams": "rocks"}) await ws.send_json(CompleteMessage(id="sub1").as_dict()) @@ -862,12 +847,7 @@ async def test_subsciption_cancel_finalization_delay(ws: WebSocketClient): ) response = await ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"longFinalizer": "hello"}} - ).as_dict() - ) + assert_next(response, "sub1", {"longFinalizer": "hello"}) # now cancel the stubscription and send a new query. We expect the response # to the new query to arrive immediately, without waiting for the finalizer @@ -952,9 +932,7 @@ async def test_subscription_errors_continue(ws: WebSocketClient): ) response = await ws.receive_json() - assert response["type"] == NextMessage.type - assert response["id"] == "sub1" - assert response["payload"]["data"] == {"flavorsInvalid": "VANILLA"} + assert_next(response, "sub1", {"flavorsInvalid": "VANILLA"}) response = await ws.receive_json() assert response["type"] == NextMessage.type @@ -965,10 +943,151 @@ async def test_subscription_errors_continue(ws: WebSocketClient): process_errors.assert_called_once() response = await ws.receive_json() - assert response["type"] == NextMessage.type - assert response["id"] == "sub1" - assert response["payload"]["data"] == {"flavorsInvalid": "CHOCOLATE"} + assert_next(response, "sub1", {"flavorsInvalid": "CHOCOLATE"}) response = await ws.receive_json() assert response["type"] == CompleteMessage.type assert response["id"] == "sub1" + + +async def test_extensions(ws: WebSocketClient): + resolve_called = Mock() + lifecycle_called = Mock() + + # we must make sure that earlier requests and drained before we start + # so that their execution events don't interfere with our events + while MyAsyncExtension.active_counter > 0: + await asyncio.sleep(0.01) + + with patch.object(MyAsyncExtension, "resolve_called", resolve_called): + with patch.object(MyAsyncExtension, "lifecycle_called", lifecycle_called): + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query='subscription { echo(message: "Hi") }' + ), + ).as_dict() + ) + + response = await ws.receive_json() + assert_next( + response, "sub1", {"echo": "Hi"}, extensions={"example": "example"} + ) + response = await ws.receive_json() + assert response == CompleteMessage(id="sub1").as_dict() + + # no resolvers called + assert resolve_called.call_count == 0 + + lifecycle_calls = lifecycle_called.call_args_list + assert lifecycle_calls == [ + mock_call("operation", "before"), + mock_call("parse", "before"), + mock_call("parse", "after"), + mock_call("validate", "before"), + mock_call("validate", "after"), + mock_call("execute", "before"), + mock_call("execute", "after"), + mock_call("operation", "after"), + ] + + +async def test_validation_query(ws: WebSocketClient): + """ + Test validation for query + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:true) }" + ), + ).as_dict() + ) + + # We expect an error message directly + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") == ["conditionalFail"] + assert response["payload"][0]["message"] == "failed after sleep None" + + +async def test_validation_subscription(ws: WebSocketClient): + """ + Test validation for subscription + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(fail:true) }" + ), + ).as_dict() + ) + + # We expect an error message directly + response = await ws.receive_json() + assert response["type"] == ErrorMessage.type + assert response["id"] == "sub1" + assert len(response["payload"]) == 1 + assert response["payload"][0].get("path") == ["conditionalFail"] + assert response["payload"][0]["message"] == "failed after sleep None" + + +async def test_long_validation_concurrent_query(ws: WebSocketClient): + """ + Test that the websocket is not blocked while validating a + single-result-operation + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="query { conditionalFail(sleep:0.1) }" + ), + ).as_dict() + ) + await ws.send_json( + SubscribeMessage( + id="sub2", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:false) }" + ), + ).as_dict() + ) + + # we expect the second query to arrive first, because the + # first query is stuck in validation + response = await ws.receive_json() + assert_next(response, "sub2", {"conditionalFail": "Hey"}) + + +async def test_long_validation_concurrent_subscription(ws: WebSocketClient): + """ + Test that the websocket is not blocked while validating a + subscription + """ + await ws.send_json( + SubscribeMessage( + id="sub1", + payload=SubscribeMessagePayload( + query="subscription { conditionalFail(sleep:0.1) }" + ), + ).as_dict() + ) + await ws.send_json( + SubscribeMessage( + id="sub2", + payload=SubscribeMessagePayload( + query="query { conditionalFail(fail:false) }" + ), + ).as_dict() + ) + + # we expect the second query to arrive first, because the + # first operation is stuck in validation + response = await ws.receive_json() + assert_next(response, "sub2", {"conditionalFail": "Hey"}) diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 83b7e66782..befe7ae797 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -2,6 +2,7 @@ import asyncio from typing import TYPE_CHECKING, AsyncGenerator +from unittest.mock import Mock, patch import pytest import pytest_asyncio @@ -19,6 +20,7 @@ GQL_START, GQL_STOP, ) +from tests.views.schema import Schema if TYPE_CHECKING: from ..http.clients.aiohttp import HttpClient, WebSocketClient @@ -198,25 +200,28 @@ async def test_subscription_cancellation(ws: WebSocketClient): async def test_subscription_errors(ws: WebSocketClient): - await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": {"query": 'subscription { error(message: "TEST ERR") }'}, - } - ) + process_errors = Mock() + with patch.object(Schema, "process_errors", process_errors): + await ws.send_json( + { + "type": GQL_START, + "id": "demo", + "payload": {"query": 'subscription { error(message: "TEST ERR") }'}, + } + ) - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] is None - assert len(response["payload"]["errors"]) == 1 - assert response["payload"]["errors"][0]["path"] == ["error"] - assert response["payload"]["errors"][0]["message"] == "TEST ERR" + response = await ws.receive_json() + assert response["type"] == GQL_DATA + assert response["id"] == "demo" + assert response["payload"]["data"] is None + assert len(response["payload"]["errors"]) == 1 + assert response["payload"]["errors"][0]["path"] == ["error"] + assert response["payload"]["errors"][0]["message"] == "TEST ERR" + process_errors.assert_called_once() - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + response = await ws.receive_json() + assert response["type"] == GQL_COMPLETE + assert response["id"] == "demo" async def test_subscription_exceptions(ws: WebSocketClient): @@ -241,21 +246,26 @@ async def test_subscription_exceptions(ws: WebSocketClient): async def test_subscription_field_error(ws: WebSocketClient): - await ws.send_json( - { - "type": GQL_START, - "id": "invalid-field", - "payload": {"query": "subscription { notASubscriptionField }"}, - } - ) + process_errors = Mock() + with patch.object(Schema, "process_errors", process_errors): + await ws.send_json( + { + "type": GQL_START, + "id": "invalid-field", + "payload": {"query": "subscription { notASubscriptionField }"}, + } + ) - response = await ws.receive_json() - assert response["type"] == GQL_ERROR - assert response["id"] == "invalid-field" - assert response["payload"] == { - "locations": [{"line": 1, "column": 16}], - "message": ("The subscription field 'notASubscriptionField' is not defined."), - } + response = await ws.receive_json() + assert response["type"] == GQL_ERROR + assert response["id"] == "invalid-field" + assert response["payload"] == { + "locations": [{"line": 1, "column": 16}], + "message": ( + "Cannot query field 'notASubscriptionField' on type 'Subscription'." + ), + } + process_errors.assert_called_once() async def test_subscription_syntax_error(ws: WebSocketClient): @@ -556,3 +566,26 @@ async def test_rejects_connection_params(aiohttp_app_client: HttpClient): # make sure the WebSocket is disconnected now await ws.receive(timeout=2) # receive close assert ws.closed + + +async def test_extensions(ws: WebSocketClient): + await ws.send_json( + { + "type": GQL_START, + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } + ) + + response = await ws.receive_json() + assert response["type"] == GQL_DATA + assert response["id"] == "demo" + assert response["payload"]["data"] == {"echo": "Hi"} + assert response["payload"]["extensions"] == {"example": "example"} + + await ws.send_json({"type": GQL_STOP, "id": "demo"}) + response = await ws.receive_json() + assert response["type"] == GQL_COMPLETE + assert response["id"] == "demo"