diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..91b0163149 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,5 @@ +Release type: minor + +In this release, we migrated the `graphql-transport-ws` types from data classes to typed dicts. +Using typed dicts enabled us to precisely model `null` versus `undefined` values, which are common in that protocol. +As a result, we could remove custom conversion methods handling these cases and simplify the codebase. diff --git a/strawberry/channels/testing.py b/strawberry/channels/testing.py index 511ed34d1b..892bb4bda0 100644 --- a/strawberry/channels/testing.py +++ b/strawberry/channels/testing.py @@ -17,23 +17,10 @@ from channels.testing.websocket import WebsocketCommunicator from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( - ConnectionAckMessage, - ConnectionInitMessage, - ErrorMessage, - NextMessage, - SubscribeMessage, - SubscribeMessagePayload, -) -from strawberry.subscriptions.protocols.graphql_ws.types import ( - ConnectionAckMessage as GraphQLWSConnectionAckMessage, -) -from strawberry.subscriptions.protocols.graphql_ws.types import ( - ConnectionInitMessage as GraphQLWSConnectionInitMessage, -) -from strawberry.subscriptions.protocols.graphql_ws.types import ( - StartMessage as GraphQLWSStartMessage, +from strawberry.subscriptions.protocols.graphql_transport_ws import ( + types as transport_ws_types, ) +from strawberry.subscriptions.protocols.graphql_ws import types as ws_types from strawberry.types import ExecutionResult if TYPE_CHECKING: @@ -109,19 +96,21 @@ async def gql_init(self) -> None: if self.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: assert res == (True, GRAPHQL_TRANSPORT_WS_PROTOCOL) await self.send_json_to( - ConnectionInitMessage(payload=self.connection_params).as_dict() + transport_ws_types.ConnectionInitMessage( + {"type": "connection_init", "payload": self.connection_params} + ) ) - graphql_transport_ws_response = await self.receive_json_from() - assert graphql_transport_ws_response == ConnectionAckMessage().as_dict() + transport_ws_connection_ack_message: transport_ws_types.ConnectionAckMessage = await self.receive_json_from() + assert transport_ws_connection_ack_message == {"type": "connection_ack"} else: assert res == (True, GRAPHQL_WS_PROTOCOL) await self.send_json_to( - GraphQLWSConnectionInitMessage({"type": "connection_init"}) + ws_types.ConnectionInitMessage({"type": "connection_init"}) ) - graphql_ws_response: GraphQLWSConnectionAckMessage = ( + ws_connection_ack_message: ws_types.ConnectionAckMessage = ( await self.receive_json_from() ) - assert graphql_ws_response["type"] == "connection_ack" + assert ws_connection_ack_message["type"] == "connection_ack" # Actual `ExecutionResult`` objects are not available client-side, since they # get transformed into `FormattedExecutionResult` on the wire, but we attempt @@ -133,13 +122,16 @@ async def subscribe( if self.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: await self.send_json_to( - SubscribeMessage( - id=id_, - payload=SubscribeMessagePayload(query=query, variables=variables), - ).as_dict() + transport_ws_types.SubscribeMessage( + { + "id": id_, + "type": "subscribe", + "payload": {"query": query, "variables": variables}, + } + ) ) else: - start_message: GraphQLWSStartMessage = { + start_message: ws_types.StartMessage = { "type": "start", "id": id_, "payload": { @@ -153,17 +145,18 @@ async def subscribe( await self.send_json_to(start_message) while True: - response = await self.receive_json_from(timeout=5) - message_type = response["type"] - if message_type == NextMessage.type: - payload = NextMessage(**response).payload + message: transport_ws_types.Message = await self.receive_json_from( + timeout=5 + ) + if message["type"] == "next": + payload = message["payload"] ret = ExecutionResult(payload.get("data"), None) if "errors" in payload: ret.errors = self.process_errors(payload.get("errors") or []) ret.extensions = payload.get("extensions", None) yield ret - elif message_type == ErrorMessage.type: - error_payload = ErrorMessage(**response).payload + elif message["type"] == "error": + error_payload = message["payload"] yield ExecutionResult( data=None, errors=self.process_errors(error_payload) ) diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 76c30f4005..cdda39595a 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -5,11 +5,11 @@ from contextlib import suppress from typing import ( TYPE_CHECKING, - Any, Awaitable, Dict, List, Optional, + cast, ) from graphql import GraphQLError, GraphQLSyntaxError, parse @@ -21,20 +21,16 @@ ) from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( CompleteMessage, - ConnectionAckMessage, ConnectionInitMessage, - ErrorMessage, - NextMessage, - NextPayload, + Message, + NextMessagePayload, PingMessage, PongMessage, SubscribeMessage, - SubscribeMessagePayload, ) from strawberry.types import ExecutionResult from strawberry.types.execution import PreExecutionError from strawberry.types.graphql import OperationType -from strawberry.types.unset import UNSET from strawberry.utils.debug import pretty_print_graphql_operation from strawberry.utils.operation import get_operation_type @@ -44,9 +40,6 @@ from strawberry.http.async_base_view import AsyncWebSocketAdapter from strawberry.schema import BaseSchema from strawberry.schema.subscribe import SubscriptionResult - from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( - GraphQLTransportMessage, - ) class BaseGraphQLTransportWSHandler: @@ -73,15 +66,15 @@ def __init__( self.connection_timed_out = False self.operations: Dict[str, Operation] = {} self.completed_tasks: List[asyncio.Task] = [] - self.connection_params: Optional[Dict[str, Any]] = None + self.connection_params: Optional[Dict[str, object]] = None - async def handle(self) -> Any: + async def handle(self) -> None: self.on_request_accepted() try: try: async for message in self.websocket.iter_json(): - await self.handle_message(message) + await self.handle_message(cast(Message, message)) except NonTextMessageReceived: await self.handle_invalid_message("WebSocket message type must be text") except NonJsonMessageReceived: @@ -134,39 +127,28 @@ async def handle_connection_init_timeout(self) -> None: async def handle_task_exception(self, error: Exception) -> None: # pragma: no cover self.task_logger.exception("Exception in worker task", exc_info=error) - async def handle_message(self, message: dict) -> None: + async def handle_message(self, message: Message) -> None: try: - message_type = message.pop("type") + if message["type"] == "connection_init": + await self.handle_connection_init(message) - if message_type == ConnectionInitMessage.type: - await self.handle_connection_init(ConnectionInitMessage(**message)) + elif message["type"] == "ping": + await self.handle_ping(message) - elif message_type == PingMessage.type: - await self.handle_ping(PingMessage(**message)) + elif message["type"] == "pong": + await self.handle_pong(message) - elif message_type == PongMessage.type: - await self.handle_pong(PongMessage(**message)) + elif message["type"] == "subscribe": + await self.handle_subscribe(message) - elif message_type == SubscribeMessage.type: - payload_args = message.pop("payload") - payload = SubscribeMessagePayload( - query=payload_args["query"], - operationName=payload_args.get("operationName"), - variables=payload_args.get("variables"), - extensions=payload_args.get("extensions"), - ) - await self.handle_subscribe( - SubscribeMessage(payload=payload, **message) - ) - - elif message_type == CompleteMessage.type: - await self.handle_complete(CompleteMessage(**message)) + elif message["type"] == "complete": + await self.handle_complete(message) else: - error_message = f"Unknown message type: {message_type}" + error_message = f"Unknown message type: {message['type']}" await self.handle_invalid_message(error_message) - except (KeyError, TypeError): + except KeyError: await self.handle_invalid_message("Failed to parse message") finally: await self.reap_completed_tasks() @@ -175,14 +157,11 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None: if self.connection_timed_out: # No way to reliably excercise this case during testing return # pragma: no cover + if self.connection_init_timeout_task: self.connection_init_timeout_task.cancel() - payload = ( - message.payload - if message.payload is not None and message.payload is not UNSET - else {} - ) + payload = message.get("payload", {}) if not isinstance(payload, dict): await self.websocket.close( @@ -198,11 +177,11 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None: return self.connection_init_received = True - await self.send_message(ConnectionAckMessage()) + await self.send_message({"type": "connection_ack"}) self.connection_acknowledged = True async def handle_ping(self, message: PingMessage) -> None: - await self.send_message(PongMessage()) + await self.send_message({"type": "pong"}) async def handle_pong(self, message: PongMessage) -> None: pass @@ -213,14 +192,14 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: return try: - graphql_document = parse(message.payload.query) + graphql_document = parse(message["payload"]["query"]) except GraphQLSyntaxError as exc: await self.websocket.close(code=4400, reason=exc.message) return try: operation_type = get_operation_type( - graphql_document, message.payload.operationName + graphql_document, message["payload"].get("operationName") ) except RuntimeError: await self.websocket.close( @@ -228,16 +207,16 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: ) return - if message.id in self.operations: - reason = f"Subscriber for {message.id} already exists" + if message["id"] in self.operations: + reason = f"Subscriber for {message['id']} already exists" await self.websocket.close(code=4409, reason=reason) return if self.debug: # pragma: no cover pretty_print_graphql_operation( - message.payload.operationName, - message.payload.query, - message.payload.variables, + message["payload"].get("operationName"), + message["payload"]["query"], + message["payload"].get("variables"), ) if isinstance(self.context, dict): @@ -247,15 +226,15 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: operation = Operation( self, - message.id, + message["id"], operation_type, - message.payload.query, - message.payload.variables, - message.payload.operationName, + message["payload"]["query"], + message["payload"].get("variables"), + message["payload"].get("operationName"), ) operation.task = asyncio.create_task(self.run_operation(operation)) - self.operations[message.id] = operation + self.operations[message["id"]] = operation async def run_operation(self, operation: Operation) -> None: """The operation task's top level method. Cleans-up and de-registers the operation once it is done.""" @@ -291,11 +270,15 @@ async def run_operation(self, operation: Operation) -> None: # that's a mutation / query result elif isinstance(first_res_or_agen, ExecutionResult): await operation.send_next(first_res_or_agen) - await operation.send_message(CompleteMessage(id=operation.id)) + await operation.send_operation_message( + {"id": operation.id, "type": "complete"} + ) else: async for result in first_res_or_agen: await operation.send_next(result) - await operation.send_message(CompleteMessage(id=operation.id)) + await operation.send_operation_message( + {"id": operation.id, "type": "complete"} + ) except BaseException as e: # pragma: no cover self.operations.pop(operation.id, None) @@ -312,14 +295,13 @@ def forget_id(self, id: str) -> None: del self.operations[id] async def handle_complete(self, message: CompleteMessage) -> None: - await self.cleanup_operation(operation_id=message.id) + await self.cleanup_operation(operation_id=message["id"]) async def handle_invalid_message(self, error_message: str) -> None: await self.websocket.close(code=4400, reason=error_message) - async def send_message(self, message: GraphQLTransportMessage) -> None: - data = message.as_dict() - await self.websocket.send_json(data) + async def send_message(self, message: Message) -> None: + await self.websocket.send_json(message) async def cleanup_operation(self, operation_id: str) -> None: if operation_id not in self.operations: @@ -358,7 +340,7 @@ def __init__( id: str, operation_type: OperationType, query: str, - variables: Optional[Dict[str, Any]], + variables: Optional[Dict[str, object]], operation_name: Optional[str], ) -> None: self.handler = handler @@ -370,10 +352,10 @@ def __init__( self.completed = False self.task: Optional[asyncio.Task] = None - async def send_message(self, message: GraphQLTransportMessage) -> None: + async def send_operation_message(self, message: Message) -> None: if self.completed: return - if isinstance(message, (CompleteMessage, ErrorMessage)): + if message["type"] == "complete" or message["type"] == "error": self.completed = True # de-register the operation _before_ sending the final message self.handler.forget_id(self.id) @@ -383,17 +365,26 @@ async def send_initial_errors(self, errors: list[GraphQLError]) -> None: # Initial errors see https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#error # "This can occur before execution starts, # usually due to validation errors, or during the execution of the request" - await self.send_message( - ErrorMessage(id=self.id, payload=[err.formatted for err in errors]) + await self.send_operation_message( + { + "id": self.id, + "type": "error", + "payload": [err.formatted for err in errors], + } ) async def send_next(self, execution_result: ExecutionResult) -> None: - next_payload: NextPayload = {"data": execution_result.data} + next_payload: NextMessagePayload = {"data": execution_result.data} + if execution_result.errors: next_payload["errors"] = [err.formatted for err in execution_result.errors] + if execution_result.extensions: next_payload["extensions"] = execution_result.extensions - await self.send_message(NextMessage(id=self.id, payload=next_payload)) + + await self.send_operation_message( + {"id": self.id, "type": "next", "payload": next_payload} + ) __all__ = ["BaseGraphQLTransportWSHandler", "Operation"] diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/types.py b/strawberry/subscriptions/protocols/graphql_transport_ws/types.py index 300f9204a7..7e5a804f29 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/types.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/types.py @@ -1,108 +1,91 @@ -from __future__ import annotations +from typing import Dict, List, TypedDict, Union +from typing_extensions import Literal, NotRequired -from dataclasses import asdict, dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypedDict +from graphql import GraphQLFormattedError -from strawberry.types.unset import UNSET -if TYPE_CHECKING: - from graphql import GraphQLFormattedError - - -@dataclass -class GraphQLTransportMessage: - def as_dict(self) -> dict: - data = asdict(self) - if getattr(self, "payload", None) is UNSET: - # Unset fields must have a JSON value of "undefined" not "null" - data.pop("payload") - return data - - -@dataclass -class ConnectionInitMessage(GraphQLTransportMessage): +class ConnectionInitMessage(TypedDict): """Direction: Client -> Server.""" - payload: Optional[Dict[str, Any]] = UNSET - type: str = "connection_init" + type: Literal["connection_init"] + payload: NotRequired[Union[Dict[str, object], None]] -@dataclass -class ConnectionAckMessage(GraphQLTransportMessage): +class ConnectionAckMessage(TypedDict): """Direction: Server -> Client.""" - payload: Optional[Dict[str, Any]] = UNSET - type: str = "connection_ack" + type: Literal["connection_ack"] + payload: NotRequired[Union[Dict[str, object], None]] -@dataclass -class PingMessage(GraphQLTransportMessage): +class PingMessage(TypedDict): """Direction: bidirectional.""" - payload: Optional[Dict[str, Any]] = UNSET - type: str = "ping" + type: Literal["ping"] + payload: NotRequired[Union[Dict[str, object], None]] -@dataclass -class PongMessage(GraphQLTransportMessage): +class PongMessage(TypedDict): """Direction: bidirectional.""" - payload: Optional[Dict[str, Any]] = UNSET - type: str = "pong" + type: Literal["pong"] + payload: NotRequired[Union[Dict[str, object], None]] -@dataclass -class SubscribeMessagePayload: +class SubscribeMessagePayload(TypedDict): + operationName: NotRequired[Union[str, None]] query: str - operationName: Optional[str] = None - variables: Optional[Dict[str, Any]] = None - extensions: Optional[Dict[str, Any]] = None + variables: NotRequired[Union[Dict[str, object], None]] + extensions: NotRequired[Union[Dict[str, object], None]] -@dataclass -class SubscribeMessage(GraphQLTransportMessage): +class SubscribeMessage(TypedDict): """Direction: Client -> Server.""" id: str + type: Literal["subscribe"] payload: SubscribeMessagePayload - type: str = "subscribe" -class NextPayload(TypedDict, total=False): - data: Any +class NextMessagePayload(TypedDict): + errors: NotRequired[List[GraphQLFormattedError]] + data: NotRequired[Union[Dict[str, object], None]] + extensions: NotRequired[Dict[str, object]] - # Optional list of formatted graphql.GraphQLError objects - errors: Optional[List[GraphQLFormattedError]] - extensions: Optional[Dict[str, Any]] - -@dataclass -class NextMessage(GraphQLTransportMessage): +class NextMessage(TypedDict): """Direction: Server -> Client.""" id: str - payload: NextPayload - type: str = "next" - - def as_dict(self) -> dict: - return {"id": self.id, "payload": self.payload, "type": self.type} + type: Literal["next"] + payload: NextMessagePayload -@dataclass -class ErrorMessage(GraphQLTransportMessage): +class ErrorMessage(TypedDict): """Direction: Server -> Client.""" id: str + type: Literal["error"] payload: List[GraphQLFormattedError] - type: str = "error" -@dataclass -class CompleteMessage(GraphQLTransportMessage): +class CompleteMessage(TypedDict): """Direction: bidirectional.""" id: str - type: str = "complete" + type: Literal["complete"] + + +Message = Union[ + ConnectionInitMessage, + ConnectionAckMessage, + PingMessage, + PongMessage, + SubscribeMessage, + NextMessage, + ErrorMessage, + CompleteMessage, +] __all__ = [ @@ -114,4 +97,5 @@ class CompleteMessage(GraphQLTransportMessage): "NextMessage", "ErrorMessage", "CompleteMessage", + "Message", ] diff --git a/tests/channels/test_layers.py b/tests/channels/test_layers.py index 8db1205fdf..b0a2960242 100644 --- a/tests/channels/test_layers.py +++ b/tests/channels/test_layers.py @@ -12,7 +12,6 @@ ConnectionInitMessage, NextMessage, SubscribeMessage, - SubscribeMessagePayload, ) from tests.views.schema import schema @@ -62,25 +61,30 @@ async def test_no_layers(): async def test_channel_listen(ws: WebsocketCommunicator): from channels.layers import get_channel_layer - await ws.send_json_to(ConnectionInitMessage().as_dict()) + await ws.send_json_to(ConnectionInitMessage({"type": "connection_init"})) - response = await ws.receive_json_from() - assert response == ConnectionAckMessage().as_dict() + connection_ack_message: ConnectionAckMessage = await ws.receive_json_from() + assert connection_ack_message == {"type": "connection_ack"} await ws.send_json_to( SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="subscription { listener }", - ), - ).as_dict() + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "subscription { listener }", + }, + } + ) ) channel_layer = get_channel_layer() assert channel_layer - response = await ws.receive_json_from() - channel_name = response["payload"]["data"]["listener"] + next_message1: NextMessage = await ws.receive_json_from() + assert "data" in next_message1["payload"] + assert next_message1["payload"]["data"] is not None + channel_name = next_message1["payload"]["data"]["listener"] await channel_layer.send( channel_name, @@ -90,47 +94,52 @@ async def test_channel_listen(ws: WebsocketCommunicator): }, ) - response = await ws.receive_json_from() - assert ( - response - == NextMessage( - id="sub1", - payload={ - "data": {"listener": "Hello there!"}, - "extensions": {"example": "example"}, - }, - ).as_dict() - ) + next_message2: NextMessage = await ws.receive_json_from() + assert next_message2 == { + "id": "sub1", + "type": "next", + "payload": { + "data": {"listener": "Hello there!"}, + "extensions": {"example": "example"}, + }, + } - await ws.send_json_to(CompleteMessage(id="sub1").as_dict()) + await ws.send_json_to(CompleteMessage({"id": "sub1", "type": "complete"})) async def test_channel_listen_with_confirmation(ws: WebsocketCommunicator): from channels.layers import get_channel_layer - await ws.send_json_to(ConnectionInitMessage().as_dict()) + await ws.send_json_to(ConnectionInitMessage({"type": "connection_init"})) - response = await ws.receive_json_from() - assert response == ConnectionAckMessage().as_dict() + connection_ack_message: ConnectionAckMessage = await ws.receive_json_from() + assert connection_ack_message == {"type": "connection_ack"} await ws.send_json_to( SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="subscription { listenerWithConfirmation }", - ), - ).as_dict() + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "subscription { listenerWithConfirmation }", + }, + } + ) ) channel_layer = get_channel_layer() assert channel_layer - response = await ws.receive_json_from() - confirmation = response["payload"]["data"]["listenerWithConfirmation"] + next_message1: NextMessage = await ws.receive_json_from() + assert "data" in next_message1["payload"] + assert next_message1["payload"]["data"] is not None + confirmation = next_message1["payload"]["data"]["listenerWithConfirmation"] assert confirmation is None - response = await ws.receive_json_from() - channel_name = response["payload"]["data"]["listenerWithConfirmation"] + next_message2: NextMessage = await ws.receive_json_from() + assert "data" in next_message2["payload"] + assert next_message2["payload"]["data"] is not None + channel_name = next_message2["payload"]["data"]["listenerWithConfirmation"] await channel_layer.send( channel_name, @@ -140,103 +149,118 @@ async def test_channel_listen_with_confirmation(ws: WebsocketCommunicator): }, ) - response = await ws.receive_json_from() - assert ( - response - == NextMessage( - id="sub1", - payload={ - "data": {"listenerWithConfirmation": "Hello there!"}, - "extensions": {"example": "example"}, - }, - ).as_dict() - ) + next_message3: NextMessage = await ws.receive_json_from() + assert next_message3 == { + "id": "sub1", + "type": "next", + "payload": { + "data": {"listenerWithConfirmation": "Hello there!"}, + "extensions": {"example": "example"}, + }, + } - await ws.send_json_to(CompleteMessage(id="sub1").as_dict()) + await ws.send_json_to(CompleteMessage({"id": "sub1", "type": "complete"})) async def test_channel_listen_timeout(ws: WebsocketCommunicator): from channels.layers import get_channel_layer - await ws.send_json_to(ConnectionInitMessage().as_dict()) + await ws.send_json_to(ConnectionInitMessage({"type": "connection_init"})) - response = await ws.receive_json_from() - assert response == ConnectionAckMessage().as_dict() + connection_ack_message: ConnectionAckMessage = await ws.receive_json_from() + assert connection_ack_message == {"type": "connection_ack"} await ws.send_json_to( SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="subscription { listener(timeout: 0.5) }", - ), - ).as_dict() + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "subscription { listener(timeout: 0.5) }", + }, + } + ) ) channel_layer = get_channel_layer() assert channel_layer - response = await ws.receive_json_from() - channel_name = response["payload"]["data"]["listener"] + next_message: NextMessage = await ws.receive_json_from() + assert "data" in next_message["payload"] + assert next_message["payload"]["data"] is not None + channel_name = next_message["payload"]["data"]["listener"] assert channel_name - response = await ws.receive_json_from() - assert response == CompleteMessage(id="sub1").as_dict() + complete_message = await ws.receive_json_from() + assert complete_message == {"id": "sub1", "type": "complete"} async def test_channel_listen_timeout_cm(ws: WebsocketCommunicator): from channels.layers import get_channel_layer - await ws.send_json_to(ConnectionInitMessage().as_dict()) + await ws.send_json_to(ConnectionInitMessage({"type": "connection_init"})) - response = await ws.receive_json_from() - assert response == ConnectionAckMessage().as_dict() + connection_ack_message: ConnectionAckMessage = await ws.receive_json_from() + assert connection_ack_message == {"type": "connection_ack"} await ws.send_json_to( SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="subscription { listenerWithConfirmation(timeout: 0.5) }", - ), - ).as_dict() + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "subscription { listenerWithConfirmation(timeout: 0.5) }", + }, + } + ) ) channel_layer = get_channel_layer() assert channel_layer - response = await ws.receive_json_from() - confirmation = response["payload"]["data"]["listenerWithConfirmation"] + next_message1: NextMessage = await ws.receive_json_from() + assert "data" in next_message1["payload"] + assert next_message1["payload"]["data"] is not None + confirmation = next_message1["payload"]["data"]["listenerWithConfirmation"] assert confirmation is None - response = await ws.receive_json_from() - channel_name = response["payload"]["data"]["listenerWithConfirmation"] + next_message2 = await ws.receive_json_from() + assert "data" in next_message2["payload"] + assert next_message2["payload"]["data"] is not None + channel_name = next_message2["payload"]["data"]["listenerWithConfirmation"] assert channel_name - response = await ws.receive_json_from() - assert response == CompleteMessage(id="sub1").as_dict() + complete_message: CompleteMessage = await ws.receive_json_from() + assert complete_message == {"id": "sub1", "type": "complete"} async def test_channel_listen_no_message_on_channel(ws: WebsocketCommunicator): from channels.layers import get_channel_layer - await ws.send_json_to(ConnectionInitMessage().as_dict()) + await ws.send_json_to(ConnectionInitMessage({"type": "connection_init"})) - response = await ws.receive_json_from() - assert response == ConnectionAckMessage().as_dict() + connection_ack_message: ConnectionAckMessage = await ws.receive_json_from() + assert connection_ack_message == {"type": "connection_ack"} await ws.send_json_to( SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="subscription { listener(timeout: 0.5) }", - ), - ).as_dict() + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "subscription { listener(timeout: 0.5) }", + }, + } + ) ) channel_layer = get_channel_layer() assert channel_layer - response = await ws.receive_json_from() - channel_name = response["payload"]["data"]["listener"] + next_message: NextMessage = await ws.receive_json_from() + assert "data" in next_message["payload"] + assert next_message["payload"]["data"] is not None + channel_name = next_message["payload"]["data"]["listener"] assert channel_name await channel_layer.send( @@ -247,36 +271,43 @@ async def test_channel_listen_no_message_on_channel(ws: WebsocketCommunicator): }, ) - response = await ws.receive_json_from() - assert response == CompleteMessage(id="sub1").as_dict() + complete_message: CompleteMessage = await ws.receive_json_from() + assert complete_message == {"id": "sub1", "type": "complete"} async def test_channel_listen_no_message_on_channel_cm(ws: WebsocketCommunicator): from channels.layers import get_channel_layer - await ws.send_json_to(ConnectionInitMessage().as_dict()) + await ws.send_json_to(ConnectionInitMessage({"type": "connection_init"})) - response = await ws.receive_json_from() - assert response == ConnectionAckMessage().as_dict() + connection_ack_message: ConnectionAckMessage = await ws.receive_json_from() + assert connection_ack_message == {"type": "connection_ack"} await ws.send_json_to( SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="subscription { listenerWithConfirmation(timeout: 0.5) }", - ), - ).as_dict() + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "subscription { listenerWithConfirmation(timeout: 0.5) }", + }, + } + ) ) channel_layer = get_channel_layer() assert channel_layer - response = await ws.receive_json_from() - confirmation = response["payload"]["data"]["listenerWithConfirmation"] + next_message1: NextMessage = await ws.receive_json_from() + assert "data" in next_message1["payload"] + assert next_message1["payload"]["data"] is not None + confirmation = next_message1["payload"]["data"]["listenerWithConfirmation"] assert confirmation is None - response = await ws.receive_json_from() - channel_name = response["payload"]["data"]["listenerWithConfirmation"] + next_message2 = await ws.receive_json_from() + assert "data" in next_message2["payload"] + assert next_message2["payload"]["data"] is not None + channel_name = next_message2["payload"]["data"]["listenerWithConfirmation"] assert channel_name await channel_layer.send( @@ -287,32 +318,37 @@ async def test_channel_listen_no_message_on_channel_cm(ws: WebsocketCommunicator }, ) - response = await ws.receive_json_from() - assert response == CompleteMessage(id="sub1").as_dict() + complete_message: CompleteMessage = await ws.receive_json_from() + assert complete_message == {"id": "sub1", "type": "complete"} async def test_channel_listen_group(ws: WebsocketCommunicator): from channels.layers import get_channel_layer - await ws.send_json_to(ConnectionInitMessage().as_dict()) + await ws.send_json_to(ConnectionInitMessage({"type": "connection_init"})) - response = await ws.receive_json_from() - assert response == ConnectionAckMessage().as_dict() + connection_ack_message: ConnectionAckMessage = await ws.receive_json_from() + assert connection_ack_message == {"type": "connection_ack"} await ws.send_json_to( SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { listener(group: "foobar") }', - ), - ).as_dict() + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": 'subscription { listener(group: "foobar") }', + }, + } + ) ) channel_layer = get_channel_layer() assert channel_layer - response = await ws.receive_json_from() - channel_name = response["payload"]["data"]["listener"] + next_message1 = await ws.receive_json_from() + assert "data" in next_message1["payload"] + assert next_message1["payload"]["data"] is not None + channel_name = next_message1["payload"]["data"]["listener"] # Sent at least once to the consumer to make sure the groups were registered await channel_layer.send( @@ -322,17 +358,16 @@ async def test_channel_listen_group(ws: WebsocketCommunicator): "text": "Hello there!", }, ) - response = await ws.receive_json_from() - assert ( - response - == NextMessage( - id="sub1", - payload={ - "data": {"listener": "Hello there!"}, - "extensions": {"example": "example"}, - }, - ).as_dict() - ) + + next_message2: NextMessage = await ws.receive_json_from() + assert next_message2 == { + "id": "sub1", + "type": "next", + "payload": { + "data": {"listener": "Hello there!"}, + "extensions": {"example": "example"}, + }, + } await channel_layer.group_send( "foobar", @@ -342,47 +377,52 @@ async def test_channel_listen_group(ws: WebsocketCommunicator): }, ) - response = await ws.receive_json_from() - assert ( - response - == NextMessage( - id="sub1", - payload={ - "data": {"listener": "Hello there!"}, - "extensions": {"example": "example"}, - }, - ).as_dict() - ) + next_message3: NextMessage = await ws.receive_json_from() + assert next_message3 == { + "id": "sub1", + "type": "next", + "payload": { + "data": {"listener": "Hello there!"}, + "extensions": {"example": "example"}, + }, + } - await ws.send_json_to(CompleteMessage(id="sub1").as_dict()) + await ws.send_json_to(CompleteMessage({"id": "sub1", "type": "complete"})) async def test_channel_listen_group_cm(ws: WebsocketCommunicator): from channels.layers import get_channel_layer - await ws.send_json_to(ConnectionInitMessage().as_dict()) + await ws.send_json_to(ConnectionInitMessage({"type": "connection_init"})) - response = await ws.receive_json_from() - assert response == ConnectionAckMessage().as_dict() + connection_ack_message: ConnectionAckMessage = await ws.receive_json_from() + assert connection_ack_message == {"type": "connection_ack"} await ws.send_json_to( SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { listenerWithConfirmation(group: "foobar") }', - ), - ).as_dict() + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": 'subscription { listenerWithConfirmation(group: "foobar") }', + }, + } + ) ) channel_layer = get_channel_layer() assert channel_layer - response = await ws.receive_json_from() - confirmation = response["payload"]["data"]["listenerWithConfirmation"] + next_message1: NextMessage = await ws.receive_json_from() + assert "data" in next_message1["payload"] + assert next_message1["payload"]["data"] is not None + confirmation = next_message1["payload"]["data"]["listenerWithConfirmation"] assert confirmation is None - response = await ws.receive_json_from() - channel_name = response["payload"]["data"]["listenerWithConfirmation"] + next_message2 = await ws.receive_json_from() + assert "data" in next_message2["payload"] + assert next_message2["payload"]["data"] is not None + channel_name = next_message2["payload"]["data"]["listenerWithConfirmation"] # Sent at least once to the consumer to make sure the groups were registered await channel_layer.send( @@ -392,17 +432,16 @@ async def test_channel_listen_group_cm(ws: WebsocketCommunicator): "text": "Hello there!", }, ) - response = await ws.receive_json_from() - assert ( - response - == NextMessage( - id="sub1", - payload={ - "data": {"listenerWithConfirmation": "Hello there!"}, - "extensions": {"example": "example"}, - }, - ).as_dict() - ) + + next_message3: NextMessage = await ws.receive_json_from() + assert next_message3 == { + "id": "sub1", + "type": "next", + "payload": { + "data": {"listenerWithConfirmation": "Hello there!"}, + "extensions": {"example": "example"}, + }, + } await channel_layer.group_send( "foobar", @@ -412,56 +451,62 @@ async def test_channel_listen_group_cm(ws: WebsocketCommunicator): }, ) - response = await ws.receive_json_from() - assert ( - response - == NextMessage( - id="sub1", - payload={ - "data": {"listenerWithConfirmation": "Hello there!"}, - "extensions": {"example": "example"}, - }, - ).as_dict() - ) + next_message4: NextMessage = await ws.receive_json_from() + assert next_message4 == { + "id": "sub1", + "type": "next", + "payload": { + "data": {"listenerWithConfirmation": "Hello there!"}, + "extensions": {"example": "example"}, + }, + } - await ws.send_json_to(CompleteMessage(id="sub1").as_dict()) + await ws.send_json_to(CompleteMessage({"id": "sub1", "type": "complete"})) async def test_channel_listen_group_twice(ws: WebsocketCommunicator): from channels.layers import get_channel_layer - await ws.send_json_to(ConnectionInitMessage().as_dict()) + await ws.send_json_to(ConnectionInitMessage({"type": "connection_init"})) - response = await ws.receive_json_from() - assert response == ConnectionAckMessage().as_dict() + connection_ack_message: ConnectionAckMessage = await ws.receive_json_from() + assert connection_ack_message == {"type": "connection_ack"} await ws.send_json_to( SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { listener(group: "group1") }', - ), - ).as_dict() + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": 'subscription { listener(group: "group1") }', + }, + } + ) ) await ws.send_json_to( SubscribeMessage( - id="sub2", - payload=SubscribeMessagePayload( - query='subscription { listener(group: "group2") }', - ), - ).as_dict() + { + "id": "sub2", + "type": "subscribe", + "payload": { + "query": 'subscription { listener(group: "group2") }', + }, + } + ) ) channel_layer = get_channel_layer() assert channel_layer # Wait for channel subscriptions to start - response1, response2 = await asyncio.gather( - ws.receive_json_from(), ws.receive_json_from() - ) - assert {"sub1", "sub2"} == {response1["id"], response2["id"]} - channel_name = response1["payload"]["data"]["listener"] + next_message1: NextMessage = await ws.receive_json_from() + next_message2: NextMessage = await ws.receive_json_from() + assert {"sub1", "sub2"} == {next_message1["id"], next_message2["id"]} + + assert "data" in next_message1["payload"] + assert next_message1["payload"]["data"] is not None + channel_name = next_message1["payload"]["data"]["listener"] # Sent at least once to the consumer to make sure the groups were registered await channel_layer.send( @@ -471,12 +516,18 @@ async def test_channel_listen_group_twice(ws: WebsocketCommunicator): "text": "Hello there!", }, ) - response1, response2 = await asyncio.gather( - ws.receive_json_from(), ws.receive_json_from() - ) - assert {"sub1", "sub2"} == {response1["id"], response2["id"]} - assert response1["payload"]["data"]["listener"] == "Hello there!" - assert response2["payload"]["data"]["listener"] == "Hello there!" + + next_message3: NextMessage = await ws.receive_json_from() + next_message4: NextMessage = await ws.receive_json_from() + assert {"sub1", "sub2"} == {next_message3["id"], next_message4["id"]} + + assert "data" in next_message3["payload"] + assert next_message3["payload"]["data"] is not None + assert next_message3["payload"]["data"]["listener"] == "Hello there!" + + assert "data" in next_message4["payload"] + assert next_message4["payload"]["data"] is not None + assert next_message4["payload"]["data"]["listener"] == "Hello there!" # We now have two channel_listen AsyncGenerators waiting, one for id="sub1" # and one for id="sub2". This group message will be received by both of them @@ -491,12 +542,17 @@ async def test_channel_listen_group_twice(ws: WebsocketCommunicator): }, ) - response1, response2 = await asyncio.gather( - ws.receive_json_from(), ws.receive_json_from() - ) - assert {"sub1", "sub2"} == {response1["id"], response2["id"]} - assert response1["payload"]["data"]["listener"] == "Hello group 1!" - assert response2["payload"]["data"]["listener"] == "Hello group 1!" + next_message5: NextMessage = await ws.receive_json_from() + next_message6: NextMessage = await ws.receive_json_from() + assert {"sub1", "sub2"} == {next_message5["id"], next_message6["id"]} + + assert "data" in next_message5["payload"] + assert next_message5["payload"]["data"] is not None + assert next_message5["payload"]["data"]["listener"] == "Hello group 1!" + + assert "data" in next_message6["payload"] + assert next_message6["payload"]["data"] is not None + assert next_message6["payload"]["data"]["listener"] == "Hello group 1!" await channel_layer.group_send( "group2", @@ -506,48 +562,59 @@ async def test_channel_listen_group_twice(ws: WebsocketCommunicator): }, ) - response1, response2 = await asyncio.gather( - ws.receive_json_from(), ws.receive_json_from() - ) - assert {"sub1", "sub2"} == {response1["id"], response2["id"]} - assert response1["payload"]["data"]["listener"] == "Hello group 2!" - assert response2["payload"]["data"]["listener"] == "Hello group 2!" + next_message7: NextMessage = await ws.receive_json_from() + next_message8: NextMessage = await ws.receive_json_from() + assert {"sub1", "sub2"} == {next_message7["id"], next_message8["id"]} + + assert "data" in next_message7["payload"] + assert next_message7["payload"]["data"] is not None + assert next_message7["payload"]["data"]["listener"] == "Hello group 2!" - await ws.send_json_to(CompleteMessage(id="sub1").as_dict()) - await ws.send_json_to(CompleteMessage(id="sub2").as_dict()) + assert "data" in next_message8["payload"] + assert next_message8["payload"]["data"] is not None + assert next_message8["payload"]["data"]["listener"] == "Hello group 2!" + + await ws.send_json_to(CompleteMessage({"id": "sub1", "type": "complete"})) + await ws.send_json_to(CompleteMessage({"id": "sub2", "type": "complete"})) async def test_channel_listen_group_twice_cm(ws: WebsocketCommunicator): from channels.layers import get_channel_layer - await ws.send_json_to(ConnectionInitMessage().as_dict()) + await ws.send_json_to(ConnectionInitMessage({"type": "connection_init"})) - response = await ws.receive_json_from() - assert response == ConnectionAckMessage().as_dict() + connection_ack_message: ConnectionAckMessage = await ws.receive_json_from() + assert connection_ack_message == {"type": "connection_ack"} await ws.send_json_to( SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { listenerWithConfirmation(group: "group1") }', - ), - ).as_dict() + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": 'subscription { listenerWithConfirmation(group: "group1") }', + }, + } + ) ) await ws.send_json_to( SubscribeMessage( - id="sub2", - payload=SubscribeMessagePayload( - query='subscription { listenerWithConfirmation(group: "group2") }', - ), - ).as_dict() + { + "id": "sub2", + "type": "subscribe", + "payload": { + "query": 'subscription { listenerWithConfirmation(group: "group2") }', + }, + } + ) ) channel_layer = get_channel_layer() assert channel_layer # Wait for confirmation for channel subscriptions - responses = await asyncio.gather( + messages = await asyncio.gather( ws.receive_json_from(), ws.receive_json_from(), ws.receive_json_from(), @@ -555,27 +622,28 @@ async def test_channel_listen_group_twice_cm(ws: WebsocketCommunicator): ) confirmation1 = next( i - for i in responses + for i in messages if not i["payload"]["data"]["listenerWithConfirmation"] and i["id"] == "sub1" ) confirmation2 = next( i - for i in responses + for i in messages if not i["payload"]["data"]["listenerWithConfirmation"] and i["id"] == "sub2" ) channel_name1 = next( i - for i in responses + for i in messages if i["payload"]["data"]["listenerWithConfirmation"] and i["id"] == "sub1" ) channel_name2 = next( i - for i in responses + for i in messages if i["payload"]["data"]["listenerWithConfirmation"] and i["id"] == "sub2" ) + # Ensure correct ordering of responses - assert responses.index(confirmation1) < responses.index(channel_name1) - assert responses.index(confirmation2) < responses.index(channel_name2) + assert messages.index(confirmation1) < messages.index(channel_name1) + assert messages.index(confirmation2) < messages.index(channel_name2) channel_name = channel_name1["payload"]["data"]["listenerWithConfirmation"] # Sent at least once to the consumer to make sure the groups were registered @@ -586,12 +654,22 @@ async def test_channel_listen_group_twice_cm(ws: WebsocketCommunicator): "text": "Hello there!", }, ) - response1, response2 = await asyncio.gather( - ws.receive_json_from(), ws.receive_json_from() + + next_message1: NextMessage = await ws.receive_json_from() + next_message2: NextMessage = await ws.receive_json_from() + assert {"sub1", "sub2"} == {next_message1["id"], next_message2["id"]} + + assert "data" in next_message1["payload"] + assert next_message1["payload"]["data"] is not None + assert ( + next_message1["payload"]["data"]["listenerWithConfirmation"] == "Hello there!" + ) + + assert "data" in next_message2["payload"] + assert next_message2["payload"]["data"] is not None + assert ( + next_message2["payload"]["data"]["listenerWithConfirmation"] == "Hello there!" ) - assert {"sub1", "sub2"} == {response1["id"], response2["id"]} - assert response1["payload"]["data"]["listenerWithConfirmation"] == "Hello there!" - assert response2["payload"]["data"]["listenerWithConfirmation"] == "Hello there!" # We now have two channel_listen AsyncGenerators waiting, one for id="sub1" # and one for id="sub2". This group message will be received by both of them @@ -606,12 +684,21 @@ async def test_channel_listen_group_twice_cm(ws: WebsocketCommunicator): }, ) - response1, response2 = await asyncio.gather( - ws.receive_json_from(), ws.receive_json_from() + next_message3: NextMessage = await ws.receive_json_from() + next_message4: NextMessage = await ws.receive_json_from() + assert {"sub1", "sub2"} == {next_message3["id"], next_message4["id"]} + + assert "data" in next_message3["payload"] + assert next_message3["payload"]["data"] is not None + assert ( + next_message3["payload"]["data"]["listenerWithConfirmation"] == "Hello group 1!" + ) + + assert "data" in next_message4["payload"] + assert next_message4["payload"]["data"] is not None + assert ( + next_message4["payload"]["data"]["listenerWithConfirmation"] == "Hello group 1!" ) - assert {"sub1", "sub2"} == {response1["id"], response2["id"]} - assert response1["payload"]["data"]["listenerWithConfirmation"] == "Hello group 1!" - assert response2["payload"]["data"]["listenerWithConfirmation"] == "Hello group 1!" await channel_layer.group_send( "group2", @@ -621,12 +708,21 @@ async def test_channel_listen_group_twice_cm(ws: WebsocketCommunicator): }, ) - response1, response2 = await asyncio.gather( - ws.receive_json_from(), ws.receive_json_from() + next_message5: NextMessage = await ws.receive_json_from() + next_message6: NextMessage = await ws.receive_json_from() + assert {"sub1", "sub2"} == {next_message5["id"], next_message6["id"]} + + assert "data" in next_message5["payload"] + assert next_message5["payload"]["data"] is not None + assert ( + next_message5["payload"]["data"]["listenerWithConfirmation"] == "Hello group 2!" + ) + + assert "data" in next_message6["payload"] + assert next_message6["payload"]["data"] is not None + assert ( + next_message6["payload"]["data"]["listenerWithConfirmation"] == "Hello group 2!" ) - assert {"sub1", "sub2"} == {response1["id"], response2["id"]} - assert response1["payload"]["data"]["listenerWithConfirmation"] == "Hello group 2!" - assert response2["payload"]["data"]["listenerWithConfirmation"] == "Hello group 2!" - await ws.send_json_to(CompleteMessage(id="sub1").as_dict()) - await ws.send_json_to(CompleteMessage(id="sub2").as_dict()) + await ws.send_json_to(CompleteMessage({"id": "sub1", "type": "complete"})) + await ws.send_json_to(CompleteMessage({"id": "sub2", "type": "complete"})) diff --git a/tests/fastapi/test_context.py b/tests/fastapi/test_context.py index 22e8a56765..48eebc9550 100644 --- a/tests/fastapi/test_context.py +++ b/tests/fastapi/test_context.py @@ -6,35 +6,10 @@ import strawberry from strawberry.exceptions import InvalidCustomContext from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( - CompleteMessage, - ConnectionAckMessage, - ConnectionInitMessage, - NextMessage, - SubscribeMessage, - SubscribeMessagePayload, -) -from strawberry.subscriptions.protocols.graphql_ws.types import ( - CompleteMessage as GraphQLWSCompleteMessage, -) -from strawberry.subscriptions.protocols.graphql_ws.types import ( - ConnectionAckMessage as GraphQLWSConnectionAckMessage, -) -from strawberry.subscriptions.protocols.graphql_ws.types import ( - ConnectionInitMessage as GraphQLWSConnectionInitMessage, -) -from strawberry.subscriptions.protocols.graphql_ws.types import ( - ConnectionTerminateMessage as GraphQLWSConnectionTerminateMessage, -) -from strawberry.subscriptions.protocols.graphql_ws.types import ( - DataMessage as GraphQLWSDataMessage, -) -from strawberry.subscriptions.protocols.graphql_ws.types import ( - StartMessage as GraphQLWSStartMessage, -) -from strawberry.subscriptions.protocols.graphql_ws.types import ( - StopMessage as GraphQLWSStopMessage, +from strawberry.subscriptions.protocols.graphql_transport_ws import ( + types as transport_ws_types, ) +from strawberry.subscriptions.protocols.graphql_ws import types as ws_types def test_base_context(): @@ -245,29 +220,37 @@ def get_context(context: Context = Depends()) -> Context: with test_client.websocket_connect( "/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL] ) as ws: - ws.send_json(ConnectionInitMessage(payload={"strawberry": "rocks"}).as_dict()) + ws.send_json( + transport_ws_types.ConnectionInitMessage( + {"type": "connection_init", "payload": {"strawberry": "rocks"}} + ) + ) - response = ws.receive_json() - assert response == ConnectionAckMessage().as_dict() + connection_ack_message: transport_ws_types.ConnectionInitMessage = ( + ws.receive_json() + ) + assert connection_ack_message == {"type": "connection_ack"} ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="subscription { connectionParams }" - ), - ).as_dict() + transport_ws_types.SubscribeMessage( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": "subscription { connectionParams }"}, + } + ) ) - response = ws.receive_json() - assert ( - response - == NextMessage( - id="sub1", payload={"data": {"connectionParams": "rocks"}} - ).as_dict() - ) + next_message: transport_ws_types.NextMessage = ws.receive_json() + assert next_message == { + "id": "sub1", + "type": "next", + "payload": {"data": {"connectionParams": "rocks"}}, + } - ws.send_json(CompleteMessage(id="sub1").as_dict()) + ws.send_json( + transport_ws_types.CompleteMessage({"id": "sub1", "type": "complete"}) + ) ws.close() @@ -310,7 +293,7 @@ def get_context(context: Context = Depends()) -> Context: with test_client.websocket_connect("/graphql", [GRAPHQL_WS_PROTOCOL]) as ws: ws.send_json( - GraphQLWSConnectionInitMessage( + ws_types.ConnectionInitMessage( { "type": "connection_init", "payload": {"strawberry": "rocks"}, @@ -318,7 +301,7 @@ def get_context(context: Context = Depends()) -> Context: ) ) ws.send_json( - GraphQLWSStartMessage( + ws_types.StartMessage( { "type": "start", "id": "demo", @@ -329,22 +312,22 @@ def get_context(context: Context = Depends()) -> Context: ) ) - connection_ack_message: GraphQLWSConnectionAckMessage = ws.receive_json() + connection_ack_message: ws_types.ConnectionAckMessage = ws.receive_json() assert connection_ack_message["type"] == "connection_ack" - data_message: GraphQLWSDataMessage = ws.receive_json() + data_message: ws_types.DataMessage = ws.receive_json() assert data_message["type"] == "data" assert data_message["id"] == "demo" assert data_message["payload"]["data"] == {"connectionParams": "rocks"} - ws.send_json(GraphQLWSStopMessage({"type": "stop", "id": "demo"})) + ws.send_json(ws_types.StopMessage({"type": "stop", "id": "demo"})) - complete_message: GraphQLWSCompleteMessage = ws.receive_json() + complete_message: ws_types.CompleteMessage = ws.receive_json() assert complete_message["type"] == "complete" assert complete_message["id"] == "demo" ws.send_json( - GraphQLWSConnectionTerminateMessage({"type": "connection_terminate"}) + ws_types.ConnectionTerminateMessage({"type": "connection_terminate"}) ) # make sure the websocket is disconnected now diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index 95923b9c9b..91bf0ae027 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -24,6 +24,9 @@ from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( BaseGraphQLTransportWSHandler, ) +from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( + Message as GraphQLTransportWSMessage, +) from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler from strawberry.types import ExecutionResult @@ -301,6 +304,9 @@ async def __aiter__(self) -> AsyncGenerator[Message, None]: while not self.closed: yield await self.receive() + async def send_message(self, message: GraphQLTransportWSMessage) -> None: + await self.send_json(message) + class DebuggableGraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): def on_init(self) -> None: diff --git a/tests/views/schema.py b/tests/views/schema.py index b0c14bfd76..ab959fbe01 100644 --- a/tests/views/schema.py +++ b/tests/views/schema.py @@ -158,7 +158,7 @@ async def echo(self, message: str, delay: float = 0) -> AsyncGenerator[str, None @strawberry.subscription async def request_ping(self, info: strawberry.Info) -> AsyncGenerator[bool, None]: ws = info.context["ws"] - await ws.send_json(PingMessage().as_dict()) + await ws.send_json(PingMessage({"type": "ping"})) yield True @strawberry.subscription diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index ffeab947a4..6cb301b012 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -5,7 +5,7 @@ import json import time from datetime import timedelta -from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, Type +from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, Type, Union from unittest.mock import AsyncMock, Mock, patch import pytest @@ -22,7 +22,6 @@ PingMessage, PongMessage, SubscribeMessage, - SubscribeMessagePayload, ) from tests.http.clients.base import DebuggableGraphQLTransportWSHandler from tests.views.schema import MyExtension, Schema @@ -43,37 +42,39 @@ async def ws_raw(http_client: HttpClient) -> AsyncGenerator[WebSocketClient, Non @pytest_asyncio.fixture async def ws(ws_raw: WebSocketClient) -> WebSocketClient: - await ws_raw.send_json(ConnectionInitMessage().as_dict()) - response = await ws_raw.receive_json() - assert response == ConnectionAckMessage().as_dict() + await ws_raw.send_message({"type": "connection_init"}) + connection_ack_message: ConnectionAckMessage = await ws_raw.receive_json() + assert connection_ack_message == {"type": "connection_ack"} return ws_raw def assert_next( - response: dict[str, Any], + next_message: NextMessage, id: str, - data: Dict[str, Any], - extensions: Optional[Dict[str, Any]] = None, + data: Dict[str, object], + extensions: Optional[Dict[str, object]] = None, ): """ Assert that the NextMessage payload contains the provided data. If extensions is provided, it will also assert that the extensions are present """ - assert response["type"] == "next" - assert response["id"] == id - assert set(response["payload"].keys()) <= {"data", "errors", "extensions"} - assert response["payload"]["data"] == data + assert next_message["type"] == "next" + assert next_message["id"] == id + assert set(next_message["payload"].keys()) <= {"data", "errors", "extensions"} + assert "data" in next_message["payload"] + assert next_message["payload"]["data"] == data if extensions is not None: - assert response["payload"]["extensions"] == extensions + assert "extensions" in next_message["payload"] + assert next_message["payload"]["extensions"] == extensions async def test_unknown_message_type(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_json({"type": "NOT_A_MESSAGE_TYPE"}) + await ws.send_message({"type": "NOT_A_MESSAGE_TYPE"}) # type: ignore - data = await ws.receive(timeout=2) + await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 assert ws.close_reason == "Unknown message type: NOT_A_MESSAGE_TYPE" @@ -82,31 +83,18 @@ async def test_unknown_message_type(ws_raw: WebSocketClient): async def test_missing_message_type(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_json({"notType": None}) - - data = await ws.receive(timeout=2) - assert ws.closed - assert ws.close_code == 4400 - assert ws.close_reason == "Failed to parse message" - - -async def test_parsing_an_invalid_message(ws_raw: WebSocketClient): - ws = ws_raw - - await ws.send_json({"type": "subscribe", "notPayload": None}) + await ws.send_message({"notType": None}) # type: ignore - data = await ws.receive(timeout=2) + await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 assert ws.close_reason == "Failed to parse message" -async def test_parsing_an_invalid_payload(ws_raw: WebSocketClient): - ws = ws_raw - - await ws.send_json({"type": "subscribe", "payload": {"unexpectedField": 42}}) +async def test_parsing_an_invalid_message(ws: WebSocketClient): + await ws.send_message({"type": "subscribe", "notPayload": None}) # type: ignore - data = await ws.receive(timeout=2) + await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 assert ws.close_reason == "Failed to parse message" @@ -115,7 +103,9 @@ async def test_parsing_an_invalid_payload(ws_raw: WebSocketClient): async def test_non_text_ws_messages_result_in_socket_closure(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_bytes(json.dumps(ConnectionInitMessage().as_dict()).encode()) + await ws.send_bytes( + json.dumps(ConnectionInitMessage({"type": "connection_init"})).encode() + ) await ws.receive(timeout=2) assert ws.closed @@ -137,19 +127,22 @@ async def test_non_json_ws_messages_result_in_socket_closure(ws_raw: WebSocketCl async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_json(ConnectionInitMessage().as_dict()) + await ws.send_message({"type": "connection_init"}) - response = await ws.receive_json() - assert response == ConnectionAckMessage().as_dict() + ack_message: ConnectionAckMessage = await ws.receive_json() + assert ack_message == {"type": "connection_ack"} await ws.send_bytes( json.dumps( SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="subscription { debug { isConnectionInitTimeoutTaskDone } }" - ), - ).as_dict() + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "subscription { debug { isConnectionInitTimeoutTaskDone } }" + }, + } + ) ).encode() ) @@ -160,7 +153,7 @@ async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): async def test_connection_init_timeout( - request: Any, http_client_class: Type[HttpClient] + request: object, http_client_class: Type[HttpClient] ): with contextlib.suppress(ImportError): from tests.http.clients.aiohttp import AioHttpClient @@ -177,7 +170,7 @@ async def test_connection_init_timeout( async with test_client.ws_connect( "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] ) as ws: - data = await ws.receive(timeout=2) + await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4408 assert ws.close_reason == "Connection initialisation timeout" @@ -190,27 +183,30 @@ async def test_connection_init_timeout_cancellation( # Verify that the timeout task is cancelled after the connection Init # message is received ws = ws_raw - await ws.send_json(ConnectionInitMessage().as_dict()) + await ws.send_message({"type": "connection_init"}) - response = await ws.receive_json() - assert response == ConnectionAckMessage().as_dict() + connection_ack_message: ConnectionAckMessage = await ws.receive_json() + assert connection_ack_message == {"type": "connection_ack"} - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="subscription { debug { isConnectionInitTimeoutTaskDone } }" - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "subscription { debug { isConnectionInitTimeoutTaskDone } }" + }, + } ) - response = await ws.receive_json() - assert_next(response, "sub1", {"debug": {"isConnectionInitTimeoutTaskDone": True}}) + next_message: NextMessage = await ws.receive_json() + assert_next( + next_message, "sub1", {"debug": {"isConnectionInitTimeoutTaskDone": True}} + ) @pytest.mark.xfail(reason="This test is flaky") async def test_close_twice( - mocker: MockerFixture, request: Any, http_client_class: Type[HttpClient] + mocker: MockerFixture, request: object, http_client_class: Type[HttpClient] ): test_client = http_client_class() test_client.create_app(connection_init_wait_timeout=timedelta(seconds=0.25)) @@ -222,9 +218,8 @@ async def test_close_twice( # We set payload is set to "invalid value" to force a invalid payload error # which will close the connection - await ws.send_json( - ConnectionInitMessage(payload="invalid value").as_dict(), # type: ignore - ) + await ws.send_message({"type": "connection_init", "payload": "invalid value"}) # type: ignore + # Yield control so that ._close can be called await asyncio.sleep(0) @@ -245,17 +240,17 @@ async def test_close_twice( async def test_too_many_initialisation_requests(ws: WebSocketClient): - await ws.send_json(ConnectionInitMessage().as_dict()) - data = await ws.receive(timeout=2) + await ws.send_message({"type": "connection_init"}) + await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4429 assert ws.close_reason == "Too many initialisation requests" async def test_ping_pong(ws: WebSocketClient): - await ws.send_json(PingMessage().as_dict()) - response = await ws.receive_json() - assert response == PongMessage().as_dict() + await ws.send_message({"type": "ping"}) + pong_message: PongMessage = await ws.receive_json() + assert pong_message == {"type": "pong"} async def test_can_send_payload_with_additional_things(ws_raw: WebSocketClient): @@ -263,24 +258,26 @@ async def test_can_send_payload_with_additional_things(ws_raw: WebSocketClient): # send init - await ws.send_json(ConnectionInitMessage().as_dict()) + await ws.send_message({"type": "connection_init"}) await ws.receive(timeout=2) - await ws.send_json( + await ws.send_message( { "type": "subscribe", "payload": { "query": 'subscription { echo(message: "Hi") }', - "some": "other thing", + "extensions": { + "some": "other thing", + }, }, "id": "1", } ) - data = await ws.receive(timeout=2) + next_message: NextMessage = await ws.receive_json(timeout=2) - assert json.loads(data.data) == { + assert next_message == { "type": "next", "id": "1", "payload": {"data": {"echo": "Hi"}, "extensions": {"example": "example"}}, @@ -288,61 +285,60 @@ async def test_can_send_payload_with_additional_things(ws_raw: WebSocketClient): async def test_server_sent_ping(ws: WebSocketClient): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload(query="subscription { requestPing }"), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": "subscription { requestPing }"}, + } ) - response = await ws.receive_json() - assert response == PingMessage().as_dict() + ping_message: PingMessage = await ws.receive_json() + assert ping_message == {"type": "ping"} - await ws.send_json(PongMessage().as_dict()) + await ws.send_message({"type": "pong"}) - response = await ws.receive_json() - assert_next(response, "sub1", {"requestPing": True}) - response = await ws.receive_json() - assert response == CompleteMessage(id="sub1").as_dict() + next_message: NextMessage = await ws.receive_json() + assert_next(next_message, "sub1", {"requestPing": True}) + + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message == {"id": "sub1", "type": "complete"} async def test_unauthorized_subscriptions(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { echo(message: "Hi") }' - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": 'subscription { echo(message: "Hi") }'}, + } ) - data = await ws.receive(timeout=2) + await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4401 assert ws.close_reason == "Unauthorized" async def test_duplicated_operation_ids(ws: WebSocketClient): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { echo(message: "Hi", delay: 5) }' - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": 'subscription { echo(message: "Hi", delay: 5) }'}, + } ) - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { echo(message: "Hi", delay: 5) }' - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": 'subscription { echo(message: "Hi", delay: 5) }'}, + } ) - data = await ws.receive(timeout=2) + await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4409 assert ws.close_reason == "Subscriber for sub1 already exists" @@ -353,60 +349,58 @@ async def test_reused_operation_ids(ws: WebSocketClient): previously used for a completed operation. """ # Use sub1 as an id for an operation - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { echo(message: "Hi") }' - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": 'subscription { echo(message: "Hi") }'}, + } ) - response = await ws.receive_json() - assert_next(response, "sub1", {"echo": "Hi"}) + next_message1: NextMessage = await ws.receive_json() + assert_next(next_message1, "sub1", {"echo": "Hi"}) - response = await ws.receive_json() - assert response == CompleteMessage(id="sub1").as_dict() + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message == {"id": "sub1", "type": "complete"} # operation is now complete. Create a new operation using # the same ID - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { echo(message: "Hi") }' - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": 'subscription { echo(message: "Hi") }'}, + } ) - response = await ws.receive_json() - assert_next(response, "sub1", {"echo": "Hi"}) + next_message2: NextMessage = await ws.receive_json() + assert_next(next_message2, "sub1", {"echo": "Hi"}) async def test_simple_subscription(ws: WebSocketClient): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { echo(message: "Hi") }' - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": 'subscription { echo(message: "Hi") }'}, + } ) - response = await ws.receive_json() - assert_next(response, "sub1", {"echo": "Hi"}) - await ws.send_json(CompleteMessage(id="sub1").as_dict()) + next_message: NextMessage = await ws.receive_json() + assert_next(next_message, "sub1", {"echo": "Hi"}) + await ws.send_message({"id": "sub1", "type": "complete"}) async def test_subscription_syntax_error(ws: WebSocketClient): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload(query="subscription { INVALID_SYNTAX "), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": "subscription { INVALID_SYNTAX "}, + } ) - data = await ws.receive(timeout=2) + await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 assert ws.close_reason == "Syntax Error: Expected Name, found ." @@ -415,86 +409,102 @@ async def test_subscription_syntax_error(ws: WebSocketClient): async def test_subscription_field_errors(ws: WebSocketClient): process_errors = Mock() with patch.object(Schema, "process_errors", process_errors): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="subscription { notASubscriptionField }", - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "subscription { notASubscriptionField }", + }, + } ) - response = await ws.receive_json() - assert response["type"] == ErrorMessage.type - assert response["id"] == "sub1" - assert len(response["payload"]) == 1 - assert response["payload"][0].get("path") is None - assert response["payload"][0]["locations"] == [{"line": 1, "column": 16}] + error_message: ErrorMessage = await ws.receive_json() + assert error_message["type"] == "error" + assert error_message["id"] == "sub1" + assert len(error_message["payload"]) == 1 + + assert "locations" in error_message["payload"][0] + assert error_message["payload"][0]["locations"] == [{"line": 1, "column": 16}] + + assert "message" in error_message["payload"][0] assert ( - response["payload"][0]["message"] + error_message["payload"][0]["message"] == "Cannot query field 'notASubscriptionField' on type 'Subscription'." ) + process_errors.assert_called_once() async def test_subscription_cancellation(ws: WebSocketClient): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { echo(message: "Hi", delay: 99) }' - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": 'subscription { echo(message: "Hi", delay: 99) }'}, + } ) - await ws.send_json( - SubscribeMessage( - id="sub2", - payload=SubscribeMessagePayload( - query="subscription { debug { numActiveResultHandlers } }", - ), - ).as_dict() + await ws.send_message( + { + "id": "sub2", + "type": "subscribe", + "payload": { + "query": "subscription { debug { numActiveResultHandlers } }", + }, + } ) - response = await ws.receive_json() - assert_next(response, "sub2", {"debug": {"numActiveResultHandlers": 2}}) - response = await ws.receive_json() - assert response == CompleteMessage(id="sub2").as_dict() + next_message: NextMessage = await ws.receive_json() + assert_next(next_message, "sub2", {"debug": {"numActiveResultHandlers": 2}}) + + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message == {"id": "sub2", "type": "complete"} - await ws.send_json(CompleteMessage(id="sub1").as_dict()) + await ws.send_message({"id": "sub1", "type": "complete"}) - await ws.send_json( - SubscribeMessage( - id="sub3", - payload=SubscribeMessagePayload( - query="subscription { debug { numActiveResultHandlers } }", - ), - ).as_dict() + await ws.send_message( + { + "id": "sub3", + "type": "subscribe", + "payload": { + "query": "subscription { debug { numActiveResultHandlers } }", + }, + } ) - response = await ws.receive_json() - assert_next(response, "sub3", {"debug": {"numActiveResultHandlers": 1}}) + next_message: NextMessage = await ws.receive_json() + assert_next(next_message, "sub3", {"debug": {"numActiveResultHandlers": 1}}) - response = await ws.receive_json() - assert response == CompleteMessage(id="sub3").as_dict() + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message == {"id": "sub3", "type": "complete"} async def test_subscription_errors(ws: WebSocketClient): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { error(message: "TEST ERR") }', - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": 'subscription { error(message: "TEST ERR") }', + }, + } ) - response = await ws.receive_json() - assert response["type"] == NextMessage.type - assert response["id"] == "sub1" - assert len(response["payload"]["errors"]) == 1 - assert response["payload"]["errors"][0]["path"] == ["error"] - assert response["payload"]["errors"][0]["message"] == "TEST ERR" + next_message: NextMessage = await ws.receive_json() + assert next_message["type"] == "next" + assert next_message["id"] == "sub1" + + assert "errors" in next_message["payload"] + payload_errors = next_message["payload"]["errors"] + assert payload_errors is not None + assert len(payload_errors) == 1 + + assert "path" in payload_errors[0] + assert payload_errors[0]["path"] == ["error"] + + assert "message" in payload_errors[0] + assert payload_errors[0]["message"] == "TEST ERR" async def test_operation_error_no_complete(ws: WebSocketClient): @@ -502,57 +512,63 @@ async def test_operation_error_no_complete(ws: WebSocketClient): # Since we don't include the operation variables, # the subscription will fail immediately. # see https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#error - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="subscription Foo($bar: String!){ exception(message: $bar) }", - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "subscription Foo($bar: String!){ exception(message: $bar) }", + }, + } ) - response = await ws.receive_json() - assert response["type"] == ErrorMessage.type - assert response["id"] == "sub1" + error_message: ErrorMessage = await ws.receive_json() + assert error_message["type"] == "error" + assert error_message["id"] == "sub1" # after an "error" message, there should be nothing more # sent regarding "sub1", not even a "complete". - await ws.send_json(PingMessage().as_dict()) - data = await ws.receive_json(timeout=1) - assert data == PongMessage().as_dict() + await ws.send_message({"type": "ping"}) + + pong_message: PongMessage = await ws.receive_json(timeout=1) + assert pong_message == {"type": "pong"} async def test_subscription_exceptions(ws: WebSocketClient): process_errors = Mock() with patch.object(Schema, "process_errors", process_errors): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { exception(message: "TEST EXC") }', - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": 'subscription { exception(message: "TEST EXC") }', + }, + } ) - response = await ws.receive_json() - assert response["type"] == NextMessage.type - assert response["id"] == "sub1" - assert response["payload"]["errors"] == [{"message": "TEST EXC"}] + next_message: NextMessage = await ws.receive_json() + assert next_message["type"] == "next" + assert next_message["id"] == "sub1" + assert "errors" in next_message["payload"] + assert next_message["payload"]["errors"] == [{"message": "TEST EXC"}] process_errors.assert_called_once() async def test_single_result_query_operation(ws: WebSocketClient): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload(query="query { hello }"), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": "query { hello }"}, + } ) - response = await ws.receive_json() - assert_next(response, "sub1", {"hello": "Hello world"}) - response = await ws.receive_json() - assert response == CompleteMessage(id="sub1").as_dict() + next_message: NextMessage = await ws.receive_json() + assert_next(next_message, "sub1", {"hello": "Hello world"}) + + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message == {"id": "sub1", "type": "complete"} async def test_single_result_query_operation_async(ws: WebSocketClient): @@ -560,67 +576,66 @@ async def test_single_result_query_operation_async(ws: WebSocketClient): `async` method in the schema, including an artificial async delay. """ - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='query { asyncHello(name: "Dolly", delay:0.01)}' - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": 'query { asyncHello(name: "Dolly", delay:0.01)}'}, + } ) - response = await ws.receive_json() - assert_next(response, "sub1", {"asyncHello": "Hello Dolly"}) + next_message: NextMessage = await ws.receive_json() + assert_next(next_message, "sub1", {"asyncHello": "Hello Dolly"}) - response = await ws.receive_json() - assert response == CompleteMessage(id="sub1").as_dict() + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message == {"id": "sub1", "type": "complete"} async def test_single_result_query_operation_overlapped(ws: WebSocketClient): """Test that two single result queries can be in flight at the same time, just like regular queries. Start two queries with separate ids. The - first query has a delay, so we expect the response to the second + first query has a delay, so we expect the message to the second query to be delivered first. """ # first query - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='query { asyncHello(name: "Dolly", delay:1)}' - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": 'query { asyncHello(name: "Dolly", delay:1)}'}, + } ) # second query - await ws.send_json( - SubscribeMessage( - id="sub2", - payload=SubscribeMessagePayload( - query='query { asyncHello(name: "Dolly", delay:0)}' - ), - ).as_dict() + await ws.send_message( + { + "id": "sub2", + "type": "subscribe", + "payload": {"query": 'query { asyncHello(name: "Dolly", delay:0)}'}, + } ) - # we expect the response to the second query to arrive first - response = await ws.receive_json() - assert_next(response, "sub2", {"asyncHello": "Hello Dolly"}) - response = await ws.receive_json() - assert response == CompleteMessage(id="sub2").as_dict() + # we expect the message to the second query to arrive first + next_message: NextMessage = await ws.receive_json() + assert_next(next_message, "sub2", {"asyncHello": "Hello Dolly"}) + + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message == {"id": "sub2", "type": "complete"} async def test_single_result_mutation_operation(ws: WebSocketClient): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload(query="mutation { hello }"), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": "mutation { hello }"}, + } ) - response = await ws.receive_json() - assert_next(response, "sub1", {"hello": "strawberry"}) + next_message: NextMessage = await ws.receive_json() + assert_next(next_message, "sub1", {"hello": "strawberry"}) - response = await ws.receive_json() - assert response == CompleteMessage(id="sub1").as_dict() + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message == {"id": "sub1", "type": "complete"} async def test_single_result_operation_selection(ws: WebSocketClient): @@ -633,17 +648,19 @@ async def test_single_result_operation_selection(ws: WebSocketClient): } """ - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload(query=query, operationName="Query2"), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": query, "operationName": "Query2"}, + } ) - response = await ws.receive_json() - assert_next(response, "sub1", {"hello": "Hello Strawberry"}) - response = await ws.receive_json() - assert response == CompleteMessage(id="sub1").as_dict() + next_message: NextMessage = await ws.receive_json() + assert_next(next_message, "sub1", {"hello": "Hello Strawberry"}) + + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message == {"id": "sub1", "type": "complete"} async def test_single_result_invalid_operation_selection(ws: WebSocketClient): @@ -653,14 +670,15 @@ async def test_single_result_invalid_operation_selection(ws: WebSocketClient): } """ - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload(query=query, operationName="Query2"), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": query, "operationName": "Query2"}, + } ) - data = await ws.receive(timeout=2) + await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 assert ws.close_reason == "Can't get GraphQL operation type" @@ -669,22 +687,30 @@ async def test_single_result_invalid_operation_selection(ws: WebSocketClient): async def test_single_result_execution_error(ws: WebSocketClient): process_errors = Mock() with patch.object(Schema, "process_errors", process_errors): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="query { alwaysFail }", - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "query { alwaysFail }", + }, + } ) - response = await ws.receive_json() - assert response["type"] == NextMessage.type - assert response["id"] == "sub1" - errs = response["payload"]["errors"] - assert len(errs) == 1 - assert errs[0]["path"] == ["alwaysFail"] - assert errs[0]["message"] == "You are not authorized" + next_message: NextMessage = await ws.receive_json() + assert next_message["type"] == "next" + assert next_message["id"] == "sub1" + + assert "errors" in next_message["payload"] + payload_errors = next_message["payload"]["errors"] + assert payload_errors is not None + assert len(payload_errors) == 1 + + assert "path" in payload_errors[0] + assert payload_errors[0]["path"] == ["alwaysFail"] + + assert "message" in payload_errors[0] + assert payload_errors[0]["message"] == "You are not authorized" process_errors.assert_called_once() @@ -695,21 +721,23 @@ async def test_single_result_pre_execution_error(ws: WebSocketClient): """ process_errors = Mock() with patch.object(Schema, "process_errors", process_errors): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="query { IDontExist }", - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "query { IDontExist }", + }, + } ) - response = await ws.receive_json() - assert response["type"] == ErrorMessage.type - assert response["id"] == "sub1" - assert len(response["payload"]) == 1 + error_message: ErrorMessage = await ws.receive_json() + assert error_message["type"] == "error" + assert error_message["id"] == "sub1" + assert len(error_message["payload"]) == 1 + assert "message" in error_message["payload"][0] assert ( - response["payload"][0]["message"] + error_message["payload"][0]["message"] == "Cannot query field 'IDontExist' on type 'Query'." ) process_errors.assert_called_once() @@ -722,25 +750,25 @@ async def test_single_result_duplicate_ids_sub(ws: WebSocketClient): error due to already existing ID """ # regular subscription - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { echo(message: "Hi", delay: 5) }' - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": 'subscription { echo(message: "Hi", delay: 5) }'}, + } ) # single result subscription with duplicate id - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="query { hello }", - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "query { hello }", + }, + } ) - data = await ws.receive(timeout=2) + await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4409 assert ws.close_reason == "Subscriber for sub1 already exists" @@ -752,26 +780,26 @@ async def test_single_result_duplicate_ids_query(ws: WebSocketClient): with delay, then another with same id. Expect error. """ # single result subscription 1 - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='query { asyncHello(name: "Hi", delay: 5) }' - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": 'query { asyncHello(name: "Hi", delay: 5) }'}, + } ) # single result subscription with duplicate id - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="query { hello }", - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "query { hello }", + }, + } ) # We expect the remote to close the socket due to duplicate ID in use - data = await ws.receive(timeout=2) + await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4409 assert ws.close_reason == "Subscriber for sub1 already exists" @@ -779,29 +807,32 @@ async def test_single_result_duplicate_ids_query(ws: WebSocketClient): async def test_injects_connection_params(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_json(ConnectionInitMessage(payload={"strawberry": "rocks"}).as_dict()) + await ws.send_message( + {"type": "connection_init", "payload": {"strawberry": "rocks"}} + ) - response = await ws.receive_json() - assert response == ConnectionAckMessage().as_dict() + connection_ack_message: ConnectionAckMessage = await ws.receive_json() + assert connection_ack_message == {"type": "connection_ack"} - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload(query="subscription { connectionParams }"), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": "subscription { connectionParams }"}, + } ) - response = await ws.receive_json() - assert_next(response, "sub1", {"connectionParams": "rocks"}) + next_message: NextMessage = await ws.receive_json() + assert_next(next_message, "sub1", {"connectionParams": "rocks"}) - await ws.send_json(CompleteMessage(id="sub1").as_dict()) + await ws.send_message({"id": "sub1", "type": "complete"}) async def test_rejects_connection_params_not_dict(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_json(ConnectionInitMessage(payload="gonna fail").as_dict()) + await ws.send_message({"type": "connection_init", "payload": "gonna fail"}) # type: ignore - data = await ws.receive(timeout=2) + await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 assert ws.close_reason == "Invalid connection init payload" @@ -812,12 +843,12 @@ async def test_rejects_connection_params_not_dict(ws_raw: WebSocketClient): [[], "invalid value", 1], ) async def test_rejects_connection_params_with_wrong_type( - payload: Any, ws_raw: WebSocketClient + payload: object, ws_raw: WebSocketClient ): ws = ws_raw - await ws.send_json(ConnectionInitMessage(payload=payload).as_dict()) + await ws.send_message({"type": "connection_init", "payload": payload}) # type: ignore - data = await ws.receive(timeout=2) + await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 assert ws.close_reason == "Invalid connection init payload" @@ -831,33 +862,39 @@ async def test_subsciption_cancel_finalization_delay(ws: WebSocketClient): # while some complex finalization takes place. delay = 0.1 - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query=f"subscription {{ longFinalizer(delay: {delay}) }}" - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": f"subscription {{ longFinalizer(delay: {delay}) }}"}, + } ) - response = await ws.receive_json() - assert_next(response, "sub1", {"longFinalizer": "hello"}) + next_message: NextMessage = await ws.receive_json() + assert_next(next_message, "sub1", {"longFinalizer": "hello"}) - # now cancel the stubscription and send a new query. We expect the response + # now cancel the stubscription and send a new query. We expect the message # to the new query to arrive immediately, without waiting for the finalizer start = time.time() - await ws.send_json(CompleteMessage(id="sub1").as_dict()) - await ws.send_json( - SubscribeMessage( - id="sub2", - payload=SubscribeMessagePayload(query="query { hello }"), - ).as_dict() + await ws.send_message({"id": "sub1", "type": "complete"}) + await ws.send_message( + { + "id": "sub2", + "type": "subscribe", + "payload": {"query": "query { hello }"}, + } ) + while True: - response = await ws.receive_json() - assert response["type"] in ("next", "complete") - if response["id"] == "sub2": + next_or_complete_message: Union[ + NextMessage, CompleteMessage + ] = await ws.receive_json() + + assert next_or_complete_message["type"] in ("next", "complete") + + if next_or_complete_message["id"] == "sub2": break + end = time.time() elapsed = end - start assert elapsed < delay @@ -895,9 +932,9 @@ def on_init(_handler): "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] ) as ws: await asyncio.sleep(0.01) # wait for the timeout task to start - await ws.send_json(ConnectionInitMessage().as_dict()) - response = await ws.receive_json() - assert response == ConnectionAckMessage().as_dict() + await ws.send_message({"type": "connection_init"}) + connection_ack_message: ConnectionAckMessage = await ws.receive_json() + assert connection_ack_message == {"type": "connection_ack"} await ws.close() # the error hander should have been called @@ -914,55 +951,58 @@ async def test_subscription_errors_continue(ws: WebSocketClient): """ process_errors = Mock() with patch.object(Schema, "process_errors", process_errors): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query="subscription { flavorsInvalid }", - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": "subscription { flavorsInvalid }", + }, + } ) - response = await ws.receive_json() - assert response["type"] == NextMessage.type - assert response["id"] == "sub1" - assert response["payload"]["data"] == {"flavorsInvalid": "VANILLA"} - - response = await ws.receive_json() - assert response["type"] == NextMessage.type - assert response["id"] == "sub1" - assert response["payload"]["data"] is None - errors = response["payload"]["errors"] - assert "cannot represent value" in str(errors) + next_message1: NextMessage = await ws.receive_json() + assert next_message1["type"] == "next" + assert next_message1["id"] == "sub1" + assert "data" in next_message1["payload"] + assert next_message1["payload"]["data"] == {"flavorsInvalid": "VANILLA"} + + next_message2: NextMessage = await ws.receive_json() + assert next_message2["type"] == "next" + assert next_message2["id"] == "sub1" + assert "data" in next_message2["payload"] + assert next_message2["payload"]["data"] is None + assert "errors" in next_message2["payload"] + assert "cannot represent value" in str(next_message2["payload"]["errors"]) process_errors.assert_called_once() - response = await ws.receive_json() - assert response["type"] == NextMessage.type - assert response["id"] == "sub1" - assert response["payload"]["data"] == {"flavorsInvalid": "CHOCOLATE"} + next_message3: NextMessage = await ws.receive_json() + assert next_message3["type"] == "next" + assert next_message3["id"] == "sub1" + assert "data" in next_message3["payload"] + assert next_message3["payload"]["data"] == {"flavorsInvalid": "CHOCOLATE"} - response = await ws.receive_json() - assert response["type"] == CompleteMessage.type - assert response["id"] == "sub1" + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message["type"] == "complete" + assert complete_message["id"] == "sub1" @patch.object(MyExtension, MyExtension.get_results.__name__, return_value={}) async def test_no_extensions_results_wont_send_extensions_in_payload( mock: Mock, ws: WebSocketClient ): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { echo(message: "Hi") }' - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": {"query": 'subscription { echo(message: "Hi") }'}, + } ) - response = await ws.receive_json() + next_message: NextMessage = await ws.receive_json() mock.assert_called_once() - assert_next(response, "sub1", {"echo": "Hi"}) - assert "extensions" not in response["payload"] + assert_next(next_message, "sub1", {"echo": "Hi"}) + assert "extensions" not in next_message["payload"] async def test_unexpected_client_disconnects_are_gracefully_handled( @@ -971,13 +1011,14 @@ async def test_unexpected_client_disconnects_are_gracefully_handled( process_errors = Mock() with patch.object(Schema, "process_errors", process_errors): - await ws.send_json( - SubscribeMessage( - id="sub1", - payload=SubscribeMessagePayload( - query='subscription { echo(message: "Hi", delay: 0.5) }' - ), - ).as_dict() + await ws.send_message( + { + "id": "sub1", + "type": "subscribe", + "payload": { + "query": 'subscription { echo(message: "Hi", delay: 0.5) }' + }, + } ) await ws.close()