diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..0ad55f9b02 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,5 @@ +Release type: minor + +In this release, all types of the legacy graphql-ws protocol were refactored. +The types are now much stricter and precisely model the difference between null and undefined fields. +As a result, our protocol implementation and related tests are now more robust and easier to maintain. diff --git a/strawberry/channels/testing.py b/strawberry/channels/testing.py index 890c7147d2..ec129263b5 100644 --- a/strawberry/channels/testing.py +++ b/strawberry/channels/testing.py @@ -1,6 +1,5 @@ from __future__ import annotations -import dataclasses import uuid from typing import ( TYPE_CHECKING, @@ -26,10 +25,14 @@ SubscribeMessage, SubscribeMessagePayload, ) -from strawberry.subscriptions.protocols.graphql_ws import ( - GQL_CONNECTION_ACK, - GQL_CONNECTION_INIT, - GQL_START, +from strawberry.subscriptions.protocols.graphql_ws.types import ( + ConnectionAckMessage as GraphQLWSConnectionAckMessage, +) +from strawberry.subscriptions.protocols.graphql_ws.types import ( + ConnectionInitMessage as GraphQLWSConnectionInitMessage, +) +from strawberry.subscriptions.protocols.graphql_ws.types import ( + StartMessage as GraphQLWSStartMessage, ) from strawberry.types import ExecutionResult @@ -112,9 +115,11 @@ async def gql_init(self) -> None: assert response == ConnectionAckMessage().as_dict() else: assert res == (True, GRAPHQL_WS_PROTOCOL) - await self.send_json_to({"type": GQL_CONNECTION_INIT}) - response = await self.receive_json_from() - assert response["type"] == GQL_CONNECTION_ACK + await self.send_json_to( + GraphQLWSConnectionInitMessage({"type": "connection_init"}) + ) + response: GraphQLWSConnectionAckMessage = await self.receive_json_from() + assert response["type"] == "connection_ack" # Actual `ExecutionResult`` objects are not available client-side, since they # get transformed into `FormattedExecutionResult` on the wire, but we attempt @@ -123,22 +128,28 @@ async def subscribe( self, query: str, variables: Optional[Dict] = None ) -> Union[ExecutionResult, AsyncIterator[ExecutionResult]]: id_ = uuid.uuid4().hex - sub_payload = SubscribeMessagePayload(query=query, variables=variables) + if self.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: await self.send_json_to( SubscribeMessage( id=id_, - payload=sub_payload, + payload=SubscribeMessagePayload(query=query, variables=variables), ).as_dict() ) else: - await self.send_json_to( - { - "type": GQL_START, - "id": id_, - "payload": dataclasses.asdict(sub_payload), - } - ) + start_message: GraphQLWSStartMessage = { + "type": "start", + "id": id_, + "payload": { + "query": query, + }, + } + + if variables is not None: + start_message["payload"]["variables"] = variables + + await self.send_json_to(start_message) + while True: response = await self.receive_json_from(timeout=5) message_type = response["type"] diff --git a/strawberry/subscriptions/protocols/graphql_ws/__init__.py b/strawberry/subscriptions/protocols/graphql_ws/__init__.py index 2b3696e822..e69de29bb2 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/__init__.py +++ b/strawberry/subscriptions/protocols/graphql_ws/__init__.py @@ -1,24 +0,0 @@ -GQL_CONNECTION_INIT = "connection_init" -GQL_CONNECTION_ACK = "connection_ack" -GQL_CONNECTION_ERROR = "connection_error" -GQL_CONNECTION_TERMINATE = "connection_terminate" -GQL_CONNECTION_KEEP_ALIVE = "ka" -GQL_START = "start" -GQL_DATA = "data" -GQL_ERROR = "error" -GQL_COMPLETE = "complete" -GQL_STOP = "stop" - - -__all__ = [ - "GQL_CONNECTION_INIT", - "GQL_CONNECTION_ACK", - "GQL_CONNECTION_ERROR", - "GQL_CONNECTION_TERMINATE", - "GQL_CONNECTION_KEEP_ALIVE", - "GQL_START", - "GQL_DATA", - "GQL_ERROR", - "GQL_COMPLETE", - "GQL_STOP", -] diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 3237bade18..9b2eddbf85 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -11,24 +11,18 @@ ) from strawberry.http.exceptions import NonTextMessageReceived, WebSocketDisconnected -from strawberry.subscriptions.protocols.graphql_ws import ( - GQL_COMPLETE, - GQL_CONNECTION_ACK, - GQL_CONNECTION_ERROR, - GQL_CONNECTION_INIT, - GQL_CONNECTION_KEEP_ALIVE, - GQL_CONNECTION_TERMINATE, - GQL_DATA, - GQL_ERROR, - GQL_START, - GQL_STOP, -) from strawberry.subscriptions.protocols.graphql_ws.types import ( - ConnectionInitPayload, - DataPayload, + CompleteMessage, + ConnectionAckMessage, + ConnectionErrorMessage, + ConnectionInitMessage, + ConnectionKeepAliveMessage, + ConnectionTerminateMessage, + DataMessage, + ErrorMessage, OperationMessage, - OperationMessagePayload, - StartPayload, + StartMessage, + StopMessage, ) from strawberry.types.execution import ExecutionResult, PreExecutionError from strawberry.utils.debug import pretty_print_graphql_operation @@ -59,7 +53,7 @@ def __init__( self.keep_alive_task: Optional[asyncio.Task] = None self.subscriptions: Dict[str, AsyncGenerator] = {} self.tasks: Dict[str, asyncio.Task] = {} - self.connection_params: Optional[ConnectionInitPayload] = None + self.connection_params: Optional[Dict[str, object]] = None async def handle(self) -> None: try: @@ -87,41 +81,40 @@ async def handle_message( self, message: OperationMessage, ) -> None: - message_type = message["type"] - - if message_type == GQL_CONNECTION_INIT: + if message["type"] == "connection_init": await self.handle_connection_init(message) - elif message_type == GQL_CONNECTION_TERMINATE: + elif message["type"] == "connection_terminate": await self.handle_connection_terminate(message) - elif message_type == GQL_START: + elif message["type"] == "start": await self.handle_start(message) - elif message_type == GQL_STOP: + elif message["type"] == "stop": await self.handle_stop(message) - async def handle_connection_init(self, message: OperationMessage) -> None: + async def handle_connection_init(self, message: ConnectionInitMessage) -> None: payload = message.get("payload") if payload is not None and not isinstance(payload, dict): - error_message: OperationMessage = {"type": GQL_CONNECTION_ERROR} + error_message: ConnectionErrorMessage = {"type": "connection_error"} await self.websocket.send_json(error_message) await self.websocket.close(code=1000, reason="") return - payload = cast(Optional["ConnectionInitPayload"], payload) self.connection_params = payload - acknowledge_message: OperationMessage = {"type": GQL_CONNECTION_ACK} - await self.websocket.send_json(acknowledge_message) + connection_ack_message: ConnectionAckMessage = {"type": "connection_ack"} + await self.websocket.send_json(connection_ack_message) if self.keep_alive: keep_alive_handler = self.handle_keep_alive() self.keep_alive_task = asyncio.create_task(keep_alive_handler) - async def handle_connection_terminate(self, message: OperationMessage) -> None: + async def handle_connection_terminate( + self, message: ConnectionTerminateMessage + ) -> None: await self.websocket.close(code=1000, reason="") - async def handle_start(self, message: OperationMessage) -> None: + async def handle_start(self, message: StartMessage) -> None: operation_id = message["id"] - payload = cast("StartPayload", message["payload"]) + payload = message["payload"] query = payload["query"] operation_name = payload.get("operationName") variables = payload.get("variables") @@ -139,14 +132,14 @@ async def handle_start(self, message: OperationMessage) -> None: ) self.tasks[operation_id] = asyncio.create_task(result_handler) - async def handle_stop(self, message: OperationMessage) -> None: + async def handle_stop(self, message: StopMessage) -> None: operation_id = message["id"] await self.cleanup_operation(operation_id) async def handle_keep_alive(self) -> None: assert self.keep_alive_interval while True: - data: OperationMessage = {"type": GQL_CONNECTION_KEEP_ALIVE} + data: ConnectionKeepAliveMessage = {"type": "ka"} await self.websocket.send_json(data) await asyncio.sleep(self.keep_alive_interval) @@ -168,14 +161,24 @@ async def handle_async_results( if isinstance(agen_or_err, PreExecutionError): assert agen_or_err.errors error_payload = agen_or_err.errors[0].formatted - await self.send_message(GQL_ERROR, operation_id, error_payload) + error_message: ErrorMessage = { + "type": "error", + "id": operation_id, + "payload": error_payload, + } + await self.websocket.send_json(error_message) else: self.subscriptions[operation_id] = agen_or_err async for result in agen_or_err: await self.send_data(result, operation_id) - await self.send_message(GQL_COMPLETE, operation_id, None) + complete_message: CompleteMessage = { + "type": "complete", + "id": operation_id, + } + await self.websocket.send_json(complete_message) except asyncio.CancelledError: - await self.send_message(GQL_COMPLETE, operation_id, None) + complete_message: CompleteMessage = {"type": "complete", "id": operation_id} + await self.websocket.send_json(complete_message) async def cleanup_operation(self, operation_id: str) -> None: if operation_id in self.subscriptions: @@ -188,26 +191,24 @@ async def cleanup_operation(self, operation_id: str) -> None: await self.tasks[operation_id] del self.tasks[operation_id] - async def send_message( - self, - type_: str, - operation_id: str, - payload: Optional[OperationMessagePayload] = None, - ) -> None: - data: OperationMessage = {"type": type_, "id": operation_id} - if payload is not None: - data["payload"] = payload - await self.websocket.send_json(data) - async def send_data( self, execution_result: ExecutionResult, operation_id: str ) -> None: - payload: DataPayload = {"data": execution_result.data} + data_message: DataMessage = { + "type": "data", + "id": operation_id, + "payload": {"data": execution_result.data}, + } + if execution_result.errors: - payload["errors"] = [err.formatted for err in execution_result.errors] + data_message["payload"]["errors"] = [ + err.formatted for err in execution_result.errors + ] + if execution_result.extensions: - payload["extensions"] = execution_result.extensions - await self.send_message(GQL_DATA, operation_id, payload) + data_message["payload"]["extensions"] = execution_result.extensions + + await self.websocket.send_json(data_message) __all__ = ["BaseGraphQLWSHandler"] diff --git a/strawberry/subscriptions/protocols/graphql_ws/types.py b/strawberry/subscriptions/protocols/graphql_ws/types.py index 5ada0b100a..ee0dd06921 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/types.py +++ b/strawberry/subscriptions/protocols/graphql_ws/types.py @@ -1,52 +1,96 @@ -from typing import Any, Dict, List, Optional, Union -from typing_extensions import TypedDict +from typing import Dict, List, TypedDict, Union +from typing_extensions import Literal, NotRequired from graphql import GraphQLFormattedError -ConnectionInitPayload = Dict[str, Any] +class ConnectionInitMessage(TypedDict): + type: Literal["connection_init"] + payload: NotRequired[Dict[str, object]] -ConnectionErrorPayload = Dict[str, Any] - -class StartPayload(TypedDict, total=False): +class StartMessagePayload(TypedDict): query: str - variables: Optional[Dict[str, Any]] - operationName: Optional[str] + variables: NotRequired[Dict[str, object]] + operationName: NotRequired[str] + + +class StartMessage(TypedDict): + type: Literal["start"] + id: str + payload: StartMessagePayload -class DataPayload(TypedDict, total=False): - data: Any +class StopMessage(TypedDict): + type: Literal["stop"] + id: str - # Optional list of formatted graphql.GraphQLError objects - errors: Optional[List[GraphQLFormattedError]] - extensions: Optional[Dict[str, Any]] +class ConnectionTerminateMessage(TypedDict): + type: Literal["connection_terminate"] -ErrorPayload = GraphQLFormattedError +class ConnectionErrorMessage(TypedDict): + type: Literal["connection_error"] + payload: NotRequired[Dict[str, object]] -OperationMessagePayload = Union[ - ConnectionInitPayload, - ConnectionErrorPayload, - StartPayload, - DataPayload, - ErrorPayload, -] +class ConnectionAckMessage(TypedDict): + type: Literal["connection_ack"] -class OperationMessage(TypedDict, total=False): - type: str + +class DataMessagePayload(TypedDict): + data: object + errors: NotRequired[List[GraphQLFormattedError]] + + # Non-standard field: + extensions: NotRequired[Dict[str, object]] + + +class DataMessage(TypedDict): + type: Literal["data"] id: str - payload: OperationMessagePayload + payload: DataMessagePayload + + +class ErrorMessage(TypedDict): + type: Literal["error"] + id: str + payload: GraphQLFormattedError + + +class CompleteMessage(TypedDict): + type: Literal["complete"] + id: str + + +class ConnectionKeepAliveMessage(TypedDict): + type: Literal["ka"] + + +OperationMessage = Union[ + ConnectionInitMessage, + StartMessage, + StopMessage, + ConnectionTerminateMessage, + ConnectionErrorMessage, + ConnectionAckMessage, + DataMessage, + ErrorMessage, + CompleteMessage, +] __all__ = [ - "ConnectionInitPayload", - "ConnectionErrorPayload", - "StartPayload", - "DataPayload", - "ErrorPayload", - "OperationMessagePayload", + "ConnectionInitMessage", + "StartMessage", + "StopMessage", + "ConnectionTerminateMessage", + "ConnectionErrorMessage", + "ConnectionAckMessage", + "DataMessage", + "ErrorMessage", + "CompleteMessage", + "ConnectionKeepAliveMessage", "OperationMessage", ] diff --git a/tests/fastapi/test_context.py b/tests/fastapi/test_context.py index a5b6ea9ec8..22e8a56765 100644 --- a/tests/fastapi/test_context.py +++ b/tests/fastapi/test_context.py @@ -14,14 +14,26 @@ SubscribeMessage, SubscribeMessagePayload, ) -from strawberry.subscriptions.protocols.graphql_ws import ( - GQL_COMPLETE, - GQL_CONNECTION_ACK, - GQL_CONNECTION_INIT, - GQL_CONNECTION_TERMINATE, - GQL_DATA, - GQL_START, - GQL_STOP, +from strawberry.subscriptions.protocols.graphql_ws.types import ( + CompleteMessage as GraphQLWSCompleteMessage, +) +from strawberry.subscriptions.protocols.graphql_ws.types import ( + ConnectionAckMessage as GraphQLWSConnectionAckMessage, +) +from strawberry.subscriptions.protocols.graphql_ws.types import ( + ConnectionInitMessage as GraphQLWSConnectionInitMessage, +) +from strawberry.subscriptions.protocols.graphql_ws.types import ( + ConnectionTerminateMessage as GraphQLWSConnectionTerminateMessage, +) +from strawberry.subscriptions.protocols.graphql_ws.types import ( + DataMessage as GraphQLWSDataMessage, +) +from strawberry.subscriptions.protocols.graphql_ws.types import ( + StartMessage as GraphQLWSStartMessage, +) +from strawberry.subscriptions.protocols.graphql_ws.types import ( + StopMessage as GraphQLWSStopMessage, ) @@ -298,36 +310,42 @@ def get_context(context: Context = Depends()) -> Context: with test_client.websocket_connect("/graphql", [GRAPHQL_WS_PROTOCOL]) as ws: ws.send_json( - { - "type": GQL_CONNECTION_INIT, - "id": "demo", - "payload": {"strawberry": "rocks"}, - } + GraphQLWSConnectionInitMessage( + { + "type": "connection_init", + "payload": {"strawberry": "rocks"}, + } + ) ) ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": { - "query": "subscription { connectionParams }", - }, - } + GraphQLWSStartMessage( + { + "type": "start", + "id": "demo", + "payload": { + "query": "subscription { connectionParams }", + }, + } + ) ) - response = ws.receive_json() - assert response["type"] == GQL_CONNECTION_ACK + connection_ack_message: GraphQLWSConnectionAckMessage = ws.receive_json() + assert connection_ack_message["type"] == "connection_ack" - response = ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] == {"connectionParams": "rocks"} + data_message: GraphQLWSDataMessage = ws.receive_json() + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert data_message["payload"]["data"] == {"connectionParams": "rocks"} - ws.send_json({"type": GQL_STOP, "id": "demo"}) - response = ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + ws.send_json(GraphQLWSStopMessage({"type": "stop", "id": "demo"})) - ws.send_json({"type": GQL_CONNECTION_TERMINATE}) + complete_message: GraphQLWSCompleteMessage = ws.receive_json() + assert complete_message["type"] == "complete" + assert complete_message["id"] == "demo" + + ws.send_json( + GraphQLWSConnectionTerminateMessage({"type": "connection_terminate"}) + ) # make sure the websocket is disconnected now with pytest.raises(WebSocketDisconnect): diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index c85b2efe00..95923b9c9b 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -268,7 +268,7 @@ def name(self) -> str: async def send_text(self, payload: str) -> None: ... @abc.abstractmethod - async def send_json(self, payload: Dict[str, Any]) -> None: ... + async def send_json(self, payload: Mapping[str, object]) -> None: ... @abc.abstractmethod async def send_bytes(self, payload: bytes) -> None: ... diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 73c302e161..3c3765c5ff 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -2,24 +2,24 @@ import asyncio import json -from typing import TYPE_CHECKING, AsyncGenerator +from typing import TYPE_CHECKING, AsyncGenerator, Union from unittest import mock import pytest import pytest_asyncio from strawberry.subscriptions import GRAPHQL_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_ws import ( - GQL_COMPLETE, - GQL_CONNECTION_ACK, - GQL_CONNECTION_ERROR, - GQL_CONNECTION_INIT, - GQL_CONNECTION_KEEP_ALIVE, - GQL_CONNECTION_TERMINATE, - GQL_DATA, - GQL_ERROR, - GQL_START, - GQL_STOP, +from strawberry.subscriptions.protocols.graphql_ws.types import ( + CompleteMessage, + ConnectionAckMessage, + ConnectionErrorMessage, + ConnectionInitMessage, + ConnectionKeepAliveMessage, + ConnectionTerminateMessage, + DataMessage, + ErrorMessage, + StartMessage, + StopMessage, ) from tests.views.schema import MyExtension, Schema @@ -40,13 +40,14 @@ async def ws_raw(http_client: HttpClient) -> AsyncGenerator[WebSocketClient, Non @pytest_asyncio.fixture async def ws(ws_raw: WebSocketClient) -> AsyncGenerator[WebSocketClient, None]: ws = ws_raw - await ws.send_json({"type": GQL_CONNECTION_INIT}) - response = await ws.receive_json() - assert response["type"] == GQL_CONNECTION_ACK + + await ws.send_json(ConnectionInitMessage({"type": "connection_init"})) + response: ConnectionAckMessage = await ws.receive_json() + assert response["type"] == "connection_ack" yield ws - await ws.send_json({"type": GQL_CONNECTION_TERMINATE}) + await ws.send_json(ConnectionTerminateMessage({"type": "connection_terminate"})) # make sure the WebSocket is disconnected now await ws.receive(timeout=2) # receive close assert ws.closed @@ -60,50 +61,56 @@ def aiohttp_app_client(http_client: HttpClient) -> HttpClient: async def test_simple_subscription(ws: WebSocketClient): await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": { - "query": 'subscription { echo(message: "Hi") }', - }, - } + StartMessage( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] == {"echo": "Hi"} + data_message: DataMessage = await ws.receive_json() + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert data_message["payload"]["data"] == {"echo": "Hi"} + + await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) - await ws.send_json({"type": GQL_STOP, "id": "demo"}) - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message["type"] == "complete" + assert complete_message["id"] == "demo" async def test_operation_selection(ws: WebSocketClient): await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": { - "query": """ + StartMessage( + { + "type": "start", + "id": "demo", + "payload": { + "query": """ subscription Subscription1 { echo(message: "Hi1") } subscription Subscription2 { echo(message: "Hi2") } """, - "operationName": "Subscription2", - }, - } + "operationName": "Subscription2", + }, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] == {"echo": "Hi2"} + data_message: DataMessage = await ws.receive_json() + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert data_message["payload"]["data"] == {"echo": "Hi2"} + + await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) - await ws.send_json({"type": GQL_STOP, "id": "demo"}) - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message["type"] == "complete" + assert complete_message["id"] == "demo" async def test_sends_keep_alive(aiohttp_app_client: HttpClient): @@ -111,151 +118,174 @@ async def test_sends_keep_alive(aiohttp_app_client: HttpClient): async with aiohttp_app_client.ws_connect( "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] ) as ws: - await ws.send_json({"type": GQL_CONNECTION_INIT}) + await ws.send_json(ConnectionInitMessage({"type": "connection_init"})) await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": { - "query": 'subscription { echo(message: "Hi", delay: 0.15) }', - }, - } + StartMessage( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi", delay: 0.15) }', + }, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_CONNECTION_ACK + ack_message: ConnectionAckMessage = await ws.receive_json() + assert ack_message["type"] == "connection_ack" # we can't be sure how many keep-alives exactly we # get but they should be more than one. keepalive_count = 0 while True: - response = await ws.receive_json() - if response["type"] == GQL_CONNECTION_KEEP_ALIVE: + ka_or_data_message: Union[ + ConnectionKeepAliveMessage, DataMessage + ] = await ws.receive_json() + if ka_or_data_message["type"] == "ka": keepalive_count += 1 else: break assert keepalive_count >= 1 - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] == {"echo": "Hi"} + assert ka_or_data_message["type"] == "data" + assert ka_or_data_message["id"] == "demo" + assert ka_or_data_message["payload"]["data"] == {"echo": "Hi"} - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message["type"] == "complete" + assert complete_message["id"] == "demo" - await ws.send_json({"type": GQL_CONNECTION_TERMINATE}) + await ws.send_json(ConnectionTerminateMessage({"type": "connection_terminate"})) async def test_subscription_cancellation(ws: WebSocketClient): await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": {"query": 'subscription { echo(message: "Hi", delay: 99) }'}, - } + StartMessage( + { + "type": "start", + "id": "demo", + "payload": {"query": 'subscription { echo(message: "Hi", delay: 99) }'}, + } + ) ) await ws.send_json( - { - "type": GQL_START, - "id": "debug1", - "payload": { - "query": "subscription { debug { numActiveResultHandlers } }", - }, - } + StartMessage( + { + "type": "start", + "id": "debug1", + "payload": { + "query": "subscription { debug { numActiveResultHandlers } }", + }, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "debug1" - assert response["payload"]["data"] == {"debug": {"numActiveResultHandlers": 2}} + data_message: DataMessage = await ws.receive_json() + assert data_message["type"] == "data" + assert data_message["id"] == "debug1" + assert data_message["payload"]["data"] == {"debug": {"numActiveResultHandlers": 2}} + + complete_message1 = await ws.receive_json() + assert complete_message1["type"] == "complete" + assert complete_message1["id"] == "debug1" - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "debug1" + await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) - await ws.send_json({"type": GQL_STOP, "id": "demo"}) - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + complete_message2 = await ws.receive_json() + assert complete_message2["type"] == "complete" + assert complete_message2["id"] == "demo" await ws.send_json( - { - "type": GQL_START, - "id": "debug2", - "payload": { - "query": "subscription { debug { numActiveResultHandlers} }", - }, - } + StartMessage( + { + "type": "start", + "id": "debug2", + "payload": { + "query": "subscription { debug { numActiveResultHandlers} }", + }, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "debug2" - assert response["payload"]["data"] == {"debug": {"numActiveResultHandlers": 1}} + data_message2 = await ws.receive_json() + assert data_message2["type"] == "data" + assert data_message2["id"] == "debug2" + assert data_message2["payload"]["data"] == {"debug": {"numActiveResultHandlers": 1}} - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "debug2" + complete_message3: CompleteMessage = await ws.receive_json() + assert complete_message3["type"] == "complete" + assert complete_message3["id"] == "debug2" async def test_subscription_errors(ws: WebSocketClient): await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": {"query": 'subscription { error(message: "TEST ERR") }'}, - } + StartMessage( + { + "type": "start", + "id": "demo", + "payload": {"query": 'subscription { error(message: "TEST ERR") }'}, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] is None - assert len(response["payload"]["errors"]) == 1 - assert response["payload"]["errors"][0]["path"] == ["error"] - assert response["payload"]["errors"][0]["message"] == "TEST ERR" + data_message: DataMessage = await ws.receive_json() + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert data_message["payload"]["data"] is None + + data_payload_errors = data_message["payload"].get("errors") + assert data_payload_errors is not None + assert len(data_payload_errors) == 1 + assert data_payload_errors[0].get("path") == ["error"] + assert data_payload_errors[0].get("message") == "TEST ERR" - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message["type"] == "complete" + assert complete_message["id"] == "demo" async def test_subscription_exceptions(ws: WebSocketClient): await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": {"query": 'subscription { exception(message: "TEST EXC") }'}, - } + StartMessage( + { + "type": "start", + "id": "demo", + "payload": {"query": 'subscription { exception(message: "TEST EXC") }'}, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] is None - assert response["payload"]["errors"] == [{"message": "TEST EXC"}] + data_message: DataMessage = await ws.receive_json() + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert data_message["payload"]["data"] is None + + payload_errors = data_message["payload"].get("errors") + assert payload_errors is not None + assert payload_errors == [{"message": "TEST EXC"}] - await ws.send_json({"type": GQL_STOP, "id": "demo"}) - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + complete_message = await ws.receive_json() + assert complete_message["type"] == "complete" + assert complete_message["id"] == "demo" async def test_subscription_field_error(ws: WebSocketClient): await ws.send_json( - { - "type": GQL_START, - "id": "invalid-field", - "payload": {"query": "subscription { notASubscriptionField }"}, - } + StartMessage( + { + "type": "start", + "id": "invalid-field", + "payload": {"query": "subscription { notASubscriptionField }"}, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_ERROR - assert response["id"] == "invalid-field" - assert response["payload"] == { + error_message: ErrorMessage = await ws.receive_json() + assert error_message["type"] == "error" + assert error_message["id"] == "invalid-field" + assert error_message["payload"] == { "locations": [{"line": 1, "column": 16}], "message": ( "Cannot query field 'notASubscriptionField' on type 'Subscription'." @@ -265,17 +295,19 @@ async def test_subscription_field_error(ws: WebSocketClient): async def test_subscription_syntax_error(ws: WebSocketClient): await ws.send_json( - { - "type": GQL_START, - "id": "syntax-error", - "payload": {"query": "subscription { example "}, - } + StartMessage( + { + "type": "start", + "id": "syntax-error", + "payload": {"query": "subscription { example "}, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_ERROR - assert response["id"] == "syntax-error" - assert response["payload"] == { + error_message: ErrorMessage = await ws.receive_json() + assert error_message["type"] == "error" + assert error_message["id"] == "syntax-error" + assert error_message["payload"] == { "locations": [{"line": 1, "column": 24}], "message": "Syntax Error: Expected Name, found .", } @@ -284,7 +316,9 @@ async def test_subscription_syntax_error(ws: WebSocketClient): async def test_non_text_ws_messages_result_in_socket_closure(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_bytes(json.dumps({"type": GQL_CONNECTION_INIT}).encode()) + await ws.send_bytes( + json.dumps(ConnectionInitMessage({"type": "connection_init"})).encode() + ) await ws.receive(timeout=2) assert ws.closed @@ -296,36 +330,38 @@ async def test_non_json_ws_messages_are_ignored(ws_raw: WebSocketClient): ws = ws_raw await ws.send_text("NOT VALID JSON") - await ws.send_json({"type": GQL_CONNECTION_INIT}) + await ws.send_json(ConnectionInitMessage({"type": "connection_init"})) - response = await ws.receive_json() - assert response["type"] == GQL_CONNECTION_ACK + connection_ack_message: ConnectionAckMessage = await ws.receive_json() + assert connection_ack_message["type"] == "connection_ack" await ws.send_text("NOT VALID JSON") await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": { - "query": 'subscription { echo(message: "Hi") }', - }, - } + StartMessage( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] == {"echo": "Hi"} + data_message = await ws.receive_json() + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert data_message["payload"]["data"] == {"echo": "Hi"} await ws.send_text("NOT VALID JSON") - await ws.send_json({"type": GQL_STOP, "id": "demo"}) + await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message["type"] == "complete" + assert complete_message["id"] == "demo" await ws.send_text("NOT VALID JSON") - await ws.send_json({"type": GQL_CONNECTION_TERMINATE}) + await ws.send_json(ConnectionTerminateMessage({"type": "connection_terminate"})) await ws.receive(timeout=2) # receive close assert ws.closed @@ -333,20 +369,22 @@ async def test_non_json_ws_messages_are_ignored(ws_raw: WebSocketClient): async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_json({"type": GQL_CONNECTION_INIT}) + await ws.send_json(ConnectionInitMessage({"type": "connection_init"})) - response = await ws.receive_json() - assert response["type"] == GQL_CONNECTION_ACK + connection_ack_message: ConnectionAckMessage = await ws.receive_json() + assert connection_ack_message["type"] == "connection_ack" await ws.send_bytes( json.dumps( - { - "type": GQL_START, - "id": "demo", - "payload": { - "query": 'subscription { echo(message: "Hi") }', - }, - } + StartMessage( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } + ) ).encode() ) @@ -359,35 +397,38 @@ async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): async def test_unknown_protocol_messages_are_ignored(ws_raw: WebSocketClient): ws = ws_raw await ws.send_json({"type": "NotAProtocolMessage"}) - await ws.send_json({"type": GQL_CONNECTION_INIT}) + await ws.send_json(ConnectionInitMessage({"type": "connection_init"})) await ws.send_json({"type": "NotAProtocolMessage"}) await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": { - "query": 'subscription { echo(message: "Hi") }', - }, - } + StartMessage( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_CONNECTION_ACK + connection_ack_message: ConnectionAckMessage = await ws.receive_json() + assert connection_ack_message["type"] == "connection_ack" - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] == {"echo": "Hi"} + data_message = await ws.receive_json() + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert data_message["payload"]["data"] == {"echo": "Hi"} await ws.send_json({"type": "NotAProtocolMessage"}) - await ws.send_json({"type": GQL_STOP, "id": "demo"}) - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message["type"] == "complete" + assert complete_message["id"] == "demo" await ws.send_json({"type": "NotAProtocolMessage"}) - await ws.send_json({"type": GQL_CONNECTION_TERMINATE}) + await ws.send_json(ConnectionTerminateMessage({"type": "connection_terminate"})) # make sure the WebSocket is disconnected now await ws.receive(timeout=2) # receive close @@ -396,56 +437,62 @@ async def test_unknown_protocol_messages_are_ignored(ws_raw: WebSocketClient): async def test_custom_context(ws: WebSocketClient): await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": { - "query": "subscription { context }", - }, - } + StartMessage( + { + "type": "start", + "id": "demo", + "payload": { + "query": "subscription { context }", + }, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] == {"context": "a value from context"} + data_message: DataMessage = await ws.receive_json() + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert data_message["payload"]["data"] == {"context": "a value from context"} - await ws.send_json({"type": GQL_STOP, "id": "demo"}) - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message["type"] == "complete" + assert complete_message["id"] == "demo" async def test_resolving_enums(ws: WebSocketClient): await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": { - "query": "subscription { flavors }", - }, - } + StartMessage( + { + "type": "start", + "id": "demo", + "payload": { + "query": "subscription { flavors }", + }, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] == {"flavors": "VANILLA"} + data_message1: DataMessage = await ws.receive_json() + assert data_message1["type"] == "data" + assert data_message1["id"] == "demo" + assert data_message1["payload"]["data"] == {"flavors": "VANILLA"} - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] == {"flavors": "STRAWBERRY"} + data_message2: DataMessage = await ws.receive_json() + assert data_message2["type"] == "data" + assert data_message2["id"] == "demo" + assert data_message2["payload"]["data"] == {"flavors": "STRAWBERRY"} - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] == {"flavors": "CHOCOLATE"} + data_message3: DataMessage = await ws.receive_json() + assert data_message3["type"] == "data" + assert data_message3["id"] == "demo" + assert data_message3["payload"]["data"] == {"flavors": "CHOCOLATE"} - await ws.send_json({"type": GQL_STOP, "id": "demo"}) - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message["type"] == "complete" + assert complete_message["id"] == "demo" @pytest.mark.xfail(reason="flaky test") @@ -476,8 +523,8 @@ def get_result_handler_tasks(): ) async with connection1 as ws1, connection2 as ws2: - start_payload = { - "type": GQL_START, + start_message: StartMessage = { + "type": "start", "id": "demo", "payload": {"query": 'subscription { infinity(message: "Hi") }'}, } @@ -486,8 +533,8 @@ def get_result_handler_tasks(): if aio: assert len(get_result_handler_tasks()) == 0 - await ws1.send_json({"type": GQL_CONNECTION_INIT}) - await ws1.send_json(start_payload) + await ws1.send_json(ConnectionInitMessage({"type": "connection_init"})) + await ws1.send_json(start_message) await ws1.receive_json() # ack await ws1.receive_json() # data @@ -495,8 +542,8 @@ def get_result_handler_tasks(): if aio: assert len(get_result_handler_tasks()) == 1 - await ws2.send_json({"type": GQL_CONNECTION_INIT}) - await ws2.send_json(start_payload) + await ws2.send_json(ConnectionInitMessage({"type": "connection_init"})) + await ws2.send_json(start_message) await ws2.receive_json() await ws2.receive_json() @@ -504,14 +551,14 @@ def get_result_handler_tasks(): if aio: assert len(get_result_handler_tasks()) == 2 - await ws1.send_json({"type": GQL_STOP, "id": "demo"}) + await ws1.send_json(StopMessage({"type": "stop", "id": "demo"})) await ws1.receive_json() # complete # 1 active result handler tasks if aio: assert len(get_result_handler_tasks()) == 1 - await ws2.send_json({"type": GQL_STOP, "id": "demo"}) + await ws2.send_json(StopMessage({"type": "stop", "id": "demo"})) await ws2.receive_json() # complete # 0 active result handler tasks @@ -519,25 +566,29 @@ def get_result_handler_tasks(): assert len(get_result_handler_tasks()) == 0 await ws1.send_json( - { - "type": GQL_START, - "id": "debug1", - "payload": { - "query": "subscription { debug { numActiveResultHandlers } }", - }, - } + StartMessage( + { + "type": "start", + "id": "debug1", + "payload": { + "query": "subscription { debug { numActiveResultHandlers } }", + }, + } + ) ) - response = await ws1.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "debug1" + data_message: DataMessage = await ws1.receive_json() + assert data_message["type"] == "data" + assert data_message["id"] == "debug1" # The one active result handler is the one for this debug subscription - assert response["payload"]["data"] == {"debug": {"numActiveResultHandlers": 1}} + assert data_message["payload"]["data"] == { + "debug": {"numActiveResultHandlers": 1} + } - response = await ws1.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "debug1" + complete_message: CompleteMessage = await ws1.receive_json() + assert complete_message["type"] == "complete" + assert complete_message["id"] == "debug1" async def test_injects_connection_params(aiohttp_app_client: HttpClient): @@ -545,36 +596,40 @@ async def test_injects_connection_params(aiohttp_app_client: HttpClient): "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] ) as ws: await ws.send_json( - { - "type": GQL_CONNECTION_INIT, - "id": "demo", - "payload": {"strawberry": "rocks"}, - } + ConnectionInitMessage( + { + "type": "connection_init", + "payload": {"strawberry": "rocks"}, + } + ) ) await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": { - "query": "subscription { connectionParams }", - }, - } + StartMessage( + { + "type": "start", + "id": "demo", + "payload": { + "query": "subscription { connectionParams }", + }, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_CONNECTION_ACK + connection_ack_message: ConnectionAckMessage = await ws.receive_json() + assert connection_ack_message["type"] == "connection_ack" - response = await ws.receive_json() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert response["payload"]["data"] == {"connectionParams": "rocks"} + data_message: DataMessage = await ws.receive_json() + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert data_message["payload"]["data"] == {"connectionParams": "rocks"} - await ws.send_json({"type": GQL_STOP, "id": "demo"}) - response = await ws.receive_json() - assert response["type"] == GQL_COMPLETE - assert response["id"] == "demo" + await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) - await ws.send_json({"type": GQL_CONNECTION_TERMINATE}) + complete_message: CompleteMessage = await ws.receive_json() + assert complete_message["type"] == "complete" + assert complete_message["id"] == "demo" + + await ws.send_json(ConnectionTerminateMessage({"type": "connection_terminate"})) # make sure the WebSocket is disconnected now await ws.receive(timeout=2) # receive close @@ -587,14 +642,14 @@ async def test_rejects_connection_params(aiohttp_app_client: HttpClient): ) as ws: await ws.send_json( { - "type": GQL_CONNECTION_INIT, + "type": "connection_init", "id": "demo", "payload": "gonna fail", } ) - response = await ws.receive_json() - assert response["type"] == GQL_CONNECTION_ERROR + connection_error_message: ConnectionErrorMessage = await ws.receive_json() + assert connection_error_message["type"] == "connection_error" # make sure the WebSocket is disconnected now await ws.receive(timeout=2) # receive close @@ -608,28 +663,30 @@ async def test_no_extensions_results_wont_send_extensions_in_payload( async with aiohttp_app_client.ws_connect( "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] ) as ws: - await ws.send_json({"type": GQL_CONNECTION_INIT}) + await ws.send_json(ConnectionInitMessage({"type": "connection_init"})) await ws.send_json( - { - "type": GQL_START, - "id": "demo", - "payload": { - "query": 'subscription { echo(message: "Hi") }', - }, - } + StartMessage( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } + ) ) - response = await ws.receive_json() - assert response["type"] == GQL_CONNECTION_ACK + connection_ack_message = await ws.receive_json() + assert connection_ack_message["type"] == "connection_ack" - response = await ws.receive_json() + data_message: DataMessage = await ws.receive_json() mock.assert_called_once() - assert response["type"] == GQL_DATA - assert response["id"] == "demo" - assert "extensions" not in response["payload"] + assert data_message["type"] == "data" + assert data_message["id"] == "demo" + assert "extensions" not in data_message["payload"] - await ws.send_json({"type": GQL_STOP, "id": "demo"}) - response = await ws.receive_json() + await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + await ws.receive_json() async def test_unexpected_client_disconnects_are_gracefully_handled( @@ -639,18 +696,21 @@ async def test_unexpected_client_disconnects_are_gracefully_handled( process_errors = mock.Mock() with mock.patch.object(Schema, "process_errors", process_errors): - await ws.send_json({"type": GQL_CONNECTION_INIT}) - response = await ws.receive_json() - assert response["type"] == GQL_CONNECTION_ACK + await ws.send_json(ConnectionInitMessage({"type": "connection_init"})) + + connection_ack_message: ConnectionAckMessage = await ws.receive_json() + assert connection_ack_message["type"] == "connection_ack" await ws.send_json( - { - "type": GQL_START, - "id": "sub1", - "payload": { - "query": 'subscription { echo(message: "Hi", delay: 0.5) }', - }, - } + StartMessage( + { + "type": "start", + "id": "sub1", + "payload": { + "query": 'subscription { echo(message: "Hi", delay: 0.5) }', + }, + } + ) ) await ws.close()