From 60b026e6ce597f1d75e523bc17648a809477e42a Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Tue, 19 Nov 2024 11:43:03 +0100 Subject: [PATCH] Let graphql-ws use type inference whenever possible (#3704) * Let the legacy protocol make use of type inference * Make test assertions more precise * Remove unneeded type ignores * Add mini release file --- RELEASE.md | 3 + .../protocols/graphql_ws/handlers.py | 45 +- .../protocols/graphql_ws/types.py | 1 + tests/http/clients/base.py | 4 + tests/websockets/test_graphql_transport_ws.py | 12 +- tests/websockets/test_graphql_ws.py | 396 ++++++++---------- 6 files changed, 212 insertions(+), 249 deletions(-) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..810cc64333 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: patch + +This release refactors part of the legacy `graphql-ws` protocol implementation, making it easier to read, maintain, and extend. diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 0eda72cb56..03cf11b71f 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -12,14 +12,9 @@ from strawberry.http.exceptions import NonTextMessageReceived, WebSocketDisconnected from strawberry.subscriptions.protocols.graphql_ws.types import ( - CompleteMessage, - ConnectionAckMessage, - ConnectionErrorMessage, ConnectionInitMessage, - ConnectionKeepAliveMessage, ConnectionTerminateMessage, DataMessage, - ErrorMessage, OperationMessage, StartMessage, StopMessage, @@ -93,15 +88,13 @@ async def handle_message( async def handle_connection_init(self, message: ConnectionInitMessage) -> None: payload = message.get("payload") if payload is not None and not isinstance(payload, dict): - error_message: ConnectionErrorMessage = {"type": "connection_error"} - await self.websocket.send_json(error_message) + await self.send_message({"type": "connection_error"}) await self.websocket.close(code=1000, reason="") return self.connection_params = payload - connection_ack_message: ConnectionAckMessage = {"type": "connection_ack"} - await self.websocket.send_json(connection_ack_message) + await self.send_message({"type": "connection_ack"}) if self.keep_alive: keep_alive_handler = self.handle_keep_alive() @@ -139,8 +132,7 @@ async def handle_stop(self, message: StopMessage) -> None: async def handle_keep_alive(self) -> None: assert self.keep_alive_interval while True: - data: ConnectionKeepAliveMessage = {"type": "ka"} - await self.websocket.send_json(data) + await self.send_message({"type": "ka"}) await asyncio.sleep(self.keep_alive_interval) async def handle_async_results( @@ -160,26 +152,22 @@ async def handle_async_results( ) if isinstance(agen_or_err, PreExecutionError): assert agen_or_err.errors - error_payload = agen_or_err.errors[0].formatted - error_message: ErrorMessage = { - "type": "error", - "id": operation_id, - "payload": error_payload, - } - await self.websocket.send_json(error_message) + await self.send_message( + { + "type": "error", + "id": operation_id, + "payload": agen_or_err.errors[0].formatted, + } + ) else: self.subscriptions[operation_id] = agen_or_err async for result in agen_or_err: - await self.send_data(result, operation_id) + await self.send_data_message(result, operation_id) - await self.websocket.send_json( - CompleteMessage({"type": "complete", "id": operation_id}) - ) + await self.send_message({"type": "complete", "id": operation_id}) except asyncio.CancelledError: - await self.websocket.send_json( - CompleteMessage({"type": "complete", "id": operation_id}) - ) + await self.send_message({"type": "complete", "id": operation_id}) async def cleanup_operation(self, operation_id: str) -> None: if operation_id in self.subscriptions: @@ -192,7 +180,7 @@ async def cleanup_operation(self, operation_id: str) -> None: await self.tasks[operation_id] del self.tasks[operation_id] - async def send_data( + async def send_data_message( self, execution_result: ExecutionResult, operation_id: str ) -> None: data_message: DataMessage = { @@ -209,7 +197,10 @@ async def send_data( if execution_result.extensions: data_message["payload"]["extensions"] = execution_result.extensions - await self.websocket.send_json(data_message) + await self.send_message(data_message) + + async def send_message(self, message: OperationMessage) -> None: + await self.websocket.send_json(message) __all__ = ["BaseGraphQLWSHandler"] diff --git a/strawberry/subscriptions/protocols/graphql_ws/types.py b/strawberry/subscriptions/protocols/graphql_ws/types.py index ee0dd06921..56aa81ab1b 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/types.py +++ b/strawberry/subscriptions/protocols/graphql_ws/types.py @@ -78,6 +78,7 @@ class ConnectionKeepAliveMessage(TypedDict): DataMessage, ErrorMessage, CompleteMessage, + ConnectionKeepAliveMessage, ] diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index 91bf0ae027..c7e7f12b44 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -28,6 +28,7 @@ Message as GraphQLTransportWSMessage, ) from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler +from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage from strawberry.types import ExecutionResult logger = logging.getLogger("strawberry.test.http_client") @@ -307,6 +308,9 @@ async def __aiter__(self) -> AsyncGenerator[Message, None]: async def send_message(self, message: GraphQLTransportWSMessage) -> None: await self.send_json(message) + async def send_legacy_message(self, message: OperationMessage) -> None: + await self.send_json(message) + class DebuggableGraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): def on_init(self) -> None: diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 6cb301b012..09e271f681 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -72,7 +72,7 @@ def assert_next( async def test_unknown_message_type(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_message({"type": "NOT_A_MESSAGE_TYPE"}) # type: ignore + await ws.send_json({"type": "NOT_A_MESSAGE_TYPE"}) await ws.receive(timeout=2) assert ws.closed @@ -83,7 +83,7 @@ async def test_unknown_message_type(ws_raw: WebSocketClient): async def test_missing_message_type(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_message({"notType": None}) # type: ignore + await ws.send_json({"notType": None}) await ws.receive(timeout=2) assert ws.closed @@ -92,7 +92,7 @@ async def test_missing_message_type(ws_raw: WebSocketClient): async def test_parsing_an_invalid_message(ws: WebSocketClient): - await ws.send_message({"type": "subscribe", "notPayload": None}) # type: ignore + await ws.send_json({"type": "subscribe", "notPayload": None}) await ws.receive(timeout=2) assert ws.closed @@ -218,7 +218,7 @@ async def test_close_twice( # We set payload is set to "invalid value" to force a invalid payload error # which will close the connection - await ws.send_message({"type": "connection_init", "payload": "invalid value"}) # type: ignore + await ws.send_json({"type": "connection_init", "payload": "invalid value"}) # Yield control so that ._close can be called await asyncio.sleep(0) @@ -830,7 +830,7 @@ async def test_injects_connection_params(ws_raw: WebSocketClient): async def test_rejects_connection_params_not_dict(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_message({"type": "connection_init", "payload": "gonna fail"}) # type: ignore + await ws.send_json({"type": "connection_init", "payload": "gonna fail"}) await ws.receive(timeout=2) assert ws.closed @@ -846,7 +846,7 @@ async def test_rejects_connection_params_with_wrong_type( payload: object, ws_raw: WebSocketClient ): ws = ws_raw - await ws.send_message({"type": "connection_init", "payload": payload}) # type: ignore + await ws.send_json({"type": "connection_init", "payload": payload}) await ws.receive(timeout=2) assert ws.closed diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 3c3765c5ff..5564238ca3 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -15,11 +15,9 @@ ConnectionErrorMessage, ConnectionInitMessage, ConnectionKeepAliveMessage, - ConnectionTerminateMessage, DataMessage, ErrorMessage, StartMessage, - StopMessage, ) from tests.views.schema import MyExtension, Schema @@ -41,13 +39,13 @@ async def ws_raw(http_client: HttpClient) -> AsyncGenerator[WebSocketClient, Non async def ws(ws_raw: WebSocketClient) -> AsyncGenerator[WebSocketClient, None]: ws = ws_raw - await ws.send_json(ConnectionInitMessage({"type": "connection_init"})) + await ws.send_legacy_message({"type": "connection_init"}) response: ConnectionAckMessage = await ws.receive_json() assert response["type"] == "connection_ack" yield ws - await ws.send_json(ConnectionTerminateMessage({"type": "connection_terminate"})) + await ws.send_legacy_message({"type": "connection_terminate"}) # make sure the WebSocket is disconnected now await ws.receive(timeout=2) # receive close assert ws.closed @@ -60,16 +58,14 @@ def aiohttp_app_client(http_client: HttpClient) -> HttpClient: async def test_simple_subscription(ws: WebSocketClient): - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "demo", - "payload": { - "query": 'subscription { echo(message: "Hi") }', - }, - } - ) + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } ) data_message: DataMessage = await ws.receive_json() @@ -77,7 +73,7 @@ async def test_simple_subscription(ws: WebSocketClient): assert data_message["id"] == "demo" assert data_message["payload"]["data"] == {"echo": "Hi"} - await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + await ws.send_legacy_message({"type": "stop", "id": "demo"}) complete_message: CompleteMessage = await ws.receive_json() assert complete_message["type"] == "complete" @@ -85,20 +81,18 @@ async def test_simple_subscription(ws: WebSocketClient): async def test_operation_selection(ws: WebSocketClient): - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "demo", - "payload": { - "query": """ + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": { + "query": """ subscription Subscription1 { echo(message: "Hi1") } subscription Subscription2 { echo(message: "Hi2") } """, - "operationName": "Subscription2", - }, - } - ) + "operationName": "Subscription2", + }, + } ) data_message: DataMessage = await ws.receive_json() @@ -106,7 +100,7 @@ async def test_operation_selection(ws: WebSocketClient): assert data_message["id"] == "demo" assert data_message["payload"]["data"] == {"echo": "Hi2"} - await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + await ws.send_legacy_message({"type": "stop", "id": "demo"}) complete_message: CompleteMessage = await ws.receive_json() assert complete_message["type"] == "complete" @@ -118,17 +112,15 @@ async def test_sends_keep_alive(aiohttp_app_client: HttpClient): async with aiohttp_app_client.ws_connect( "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] ) as ws: - await ws.send_json(ConnectionInitMessage({"type": "connection_init"})) - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "demo", - "payload": { - "query": 'subscription { echo(message: "Hi", delay: 0.15) }', - }, - } - ) + await ws.send_legacy_message({"type": "connection_init"}) + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi", delay: 0.15) }', + }, + } ) ack_message: ConnectionAckMessage = await ws.receive_json() @@ -155,30 +147,26 @@ async def test_sends_keep_alive(aiohttp_app_client: HttpClient): assert complete_message["type"] == "complete" assert complete_message["id"] == "demo" - await ws.send_json(ConnectionTerminateMessage({"type": "connection_terminate"})) + await ws.send_legacy_message({"type": "connection_terminate"}) async def test_subscription_cancellation(ws: WebSocketClient): - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "demo", - "payload": {"query": 'subscription { echo(message: "Hi", delay: 99) }'}, - } - ) + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": {"query": 'subscription { echo(message: "Hi", delay: 99) }'}, + } ) - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "debug1", - "payload": { - "query": "subscription { debug { numActiveResultHandlers } }", - }, - } - ) + await ws.send_legacy_message( + { + "type": "start", + "id": "debug1", + "payload": { + "query": "subscription { debug { numActiveResultHandlers } }", + }, + } ) data_message: DataMessage = await ws.receive_json() @@ -190,22 +178,20 @@ async def test_subscription_cancellation(ws: WebSocketClient): assert complete_message1["type"] == "complete" assert complete_message1["id"] == "debug1" - await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + await ws.send_legacy_message({"type": "stop", "id": "demo"}) complete_message2 = await ws.receive_json() assert complete_message2["type"] == "complete" assert complete_message2["id"] == "demo" - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "debug2", - "payload": { - "query": "subscription { debug { numActiveResultHandlers} }", - }, - } - ) + await ws.send_legacy_message( + { + "type": "start", + "id": "debug2", + "payload": { + "query": "subscription { debug { numActiveResultHandlers} }", + }, + } ) data_message2 = await ws.receive_json() @@ -219,14 +205,12 @@ async def test_subscription_cancellation(ws: WebSocketClient): async def test_subscription_errors(ws: WebSocketClient): - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "demo", - "payload": {"query": 'subscription { error(message: "TEST ERR") }'}, - } - ) + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": {"query": 'subscription { error(message: "TEST ERR") }'}, + } ) data_message: DataMessage = await ws.receive_json() @@ -234,11 +218,15 @@ async def test_subscription_errors(ws: WebSocketClient): assert data_message["id"] == "demo" assert data_message["payload"]["data"] is None - data_payload_errors = data_message["payload"].get("errors") - assert data_payload_errors is not None - assert len(data_payload_errors) == 1 - assert data_payload_errors[0].get("path") == ["error"] - assert data_payload_errors[0].get("message") == "TEST ERR" + assert "errors" in data_message["payload"] + assert data_message["payload"]["errors"] is not None + assert len(data_message["payload"]["errors"]) == 1 + + assert "path" in data_message["payload"]["errors"][0] + assert data_message["payload"]["errors"][0]["path"] == ["error"] + + assert "message" in data_message["payload"]["errors"][0] + assert data_message["payload"]["errors"][0]["message"] == "TEST ERR" complete_message: CompleteMessage = await ws.receive_json() assert complete_message["type"] == "complete" @@ -246,14 +234,12 @@ async def test_subscription_errors(ws: WebSocketClient): async def test_subscription_exceptions(ws: WebSocketClient): - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "demo", - "payload": {"query": 'subscription { exception(message: "TEST EXC") }'}, - } - ) + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": {"query": 'subscription { exception(message: "TEST EXC") }'}, + } ) data_message: DataMessage = await ws.receive_json() @@ -261,25 +247,23 @@ async def test_subscription_exceptions(ws: WebSocketClient): assert data_message["id"] == "demo" assert data_message["payload"]["data"] is None - payload_errors = data_message["payload"].get("errors") - assert payload_errors is not None - assert payload_errors == [{"message": "TEST EXC"}] + assert "errors" in data_message["payload"] + assert data_message["payload"]["errors"] is not None + assert data_message["payload"]["errors"] == [{"message": "TEST EXC"}] - await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + await ws.send_legacy_message({"type": "stop", "id": "demo"}) complete_message = await ws.receive_json() assert complete_message["type"] == "complete" assert complete_message["id"] == "demo" async def test_subscription_field_error(ws: WebSocketClient): - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "invalid-field", - "payload": {"query": "subscription { notASubscriptionField }"}, - } - ) + await ws.send_legacy_message( + { + "type": "start", + "id": "invalid-field", + "payload": {"query": "subscription { notASubscriptionField }"}, + } ) error_message: ErrorMessage = await ws.receive_json() @@ -294,14 +278,12 @@ async def test_subscription_field_error(ws: WebSocketClient): async def test_subscription_syntax_error(ws: WebSocketClient): - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "syntax-error", - "payload": {"query": "subscription { example "}, - } - ) + await ws.send_legacy_message( + { + "type": "start", + "id": "syntax-error", + "payload": {"query": "subscription { example "}, + } ) error_message: ErrorMessage = await ws.receive_json() @@ -330,22 +312,20 @@ async def test_non_json_ws_messages_are_ignored(ws_raw: WebSocketClient): ws = ws_raw await ws.send_text("NOT VALID JSON") - await ws.send_json(ConnectionInitMessage({"type": "connection_init"})) + await ws.send_legacy_message({"type": "connection_init"}) connection_ack_message: ConnectionAckMessage = await ws.receive_json() assert connection_ack_message["type"] == "connection_ack" await ws.send_text("NOT VALID JSON") - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "demo", - "payload": { - "query": 'subscription { echo(message: "Hi") }', - }, - } - ) + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } ) data_message = await ws.receive_json() @@ -354,14 +334,14 @@ async def test_non_json_ws_messages_are_ignored(ws_raw: WebSocketClient): assert data_message["payload"]["data"] == {"echo": "Hi"} await ws.send_text("NOT VALID JSON") - await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + await ws.send_legacy_message({"type": "stop", "id": "demo"}) complete_message: CompleteMessage = await ws.receive_json() assert complete_message["type"] == "complete" assert complete_message["id"] == "demo" await ws.send_text("NOT VALID JSON") - await ws.send_json(ConnectionTerminateMessage({"type": "connection_terminate"})) + await ws.send_legacy_message({"type": "connection_terminate"}) await ws.receive(timeout=2) # receive close assert ws.closed @@ -369,7 +349,7 @@ async def test_non_json_ws_messages_are_ignored(ws_raw: WebSocketClient): async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): ws = ws_raw - await ws.send_json(ConnectionInitMessage({"type": "connection_init"})) + await ws.send_legacy_message({"type": "connection_init"}) connection_ack_message: ConnectionAckMessage = await ws.receive_json() assert connection_ack_message["type"] == "connection_ack" @@ -397,19 +377,17 @@ async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): async def test_unknown_protocol_messages_are_ignored(ws_raw: WebSocketClient): ws = ws_raw await ws.send_json({"type": "NotAProtocolMessage"}) - await ws.send_json(ConnectionInitMessage({"type": "connection_init"})) + await ws.send_legacy_message({"type": "connection_init"}) await ws.send_json({"type": "NotAProtocolMessage"}) - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "demo", - "payload": { - "query": 'subscription { echo(message: "Hi") }', - }, - } - ) + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } ) connection_ack_message: ConnectionAckMessage = await ws.receive_json() @@ -421,14 +399,14 @@ async def test_unknown_protocol_messages_are_ignored(ws_raw: WebSocketClient): assert data_message["payload"]["data"] == {"echo": "Hi"} await ws.send_json({"type": "NotAProtocolMessage"}) - await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + await ws.send_legacy_message({"type": "stop", "id": "demo"}) complete_message: CompleteMessage = await ws.receive_json() assert complete_message["type"] == "complete" assert complete_message["id"] == "demo" await ws.send_json({"type": "NotAProtocolMessage"}) - await ws.send_json(ConnectionTerminateMessage({"type": "connection_terminate"})) + await ws.send_legacy_message({"type": "connection_terminate"}) # make sure the WebSocket is disconnected now await ws.receive(timeout=2) # receive close @@ -436,16 +414,14 @@ async def test_unknown_protocol_messages_are_ignored(ws_raw: WebSocketClient): async def test_custom_context(ws: WebSocketClient): - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "demo", - "payload": { - "query": "subscription { context }", - }, - } - ) + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": { + "query": "subscription { context }", + }, + } ) data_message: DataMessage = await ws.receive_json() @@ -453,7 +429,7 @@ async def test_custom_context(ws: WebSocketClient): assert data_message["id"] == "demo" assert data_message["payload"]["data"] == {"context": "a value from context"} - await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + await ws.send_legacy_message({"type": "stop", "id": "demo"}) complete_message: CompleteMessage = await ws.receive_json() assert complete_message["type"] == "complete" @@ -461,16 +437,14 @@ async def test_custom_context(ws: WebSocketClient): async def test_resolving_enums(ws: WebSocketClient): - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "demo", - "payload": { - "query": "subscription { flavors }", - }, - } - ) + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": { + "query": "subscription { flavors }", + }, + } ) data_message1: DataMessage = await ws.receive_json() @@ -488,7 +462,7 @@ async def test_resolving_enums(ws: WebSocketClient): assert data_message3["id"] == "demo" assert data_message3["payload"]["data"] == {"flavors": "CHOCOLATE"} - await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + await ws.send_legacy_message({"type": "stop", "id": "demo"}) complete_message: CompleteMessage = await ws.receive_json() assert complete_message["type"] == "complete" @@ -533,8 +507,8 @@ def get_result_handler_tasks(): if aio: assert len(get_result_handler_tasks()) == 0 - await ws1.send_json(ConnectionInitMessage({"type": "connection_init"})) - await ws1.send_json(start_message) + await ws1.send_legacy_message({"type": "connection_init"}) + await ws1.send_legacy_message(start_message) await ws1.receive_json() # ack await ws1.receive_json() # data @@ -542,8 +516,8 @@ def get_result_handler_tasks(): if aio: assert len(get_result_handler_tasks()) == 1 - await ws2.send_json(ConnectionInitMessage({"type": "connection_init"})) - await ws2.send_json(start_message) + await ws2.send_legacy_message({"type": "connection_init"}) + await ws2.send_legacy_message(start_message) await ws2.receive_json() await ws2.receive_json() @@ -551,30 +525,28 @@ def get_result_handler_tasks(): if aio: assert len(get_result_handler_tasks()) == 2 - await ws1.send_json(StopMessage({"type": "stop", "id": "demo"})) + await ws1.send_legacy_message({"type": "stop", "id": "demo"}) await ws1.receive_json() # complete # 1 active result handler tasks if aio: assert len(get_result_handler_tasks()) == 1 - await ws2.send_json(StopMessage({"type": "stop", "id": "demo"})) + await ws2.send_legacy_message({"type": "stop", "id": "demo"}) await ws2.receive_json() # complete # 0 active result handler tasks if aio: assert len(get_result_handler_tasks()) == 0 - await ws1.send_json( - StartMessage( - { - "type": "start", - "id": "debug1", - "payload": { - "query": "subscription { debug { numActiveResultHandlers } }", - }, - } - ) + await ws1.send_legacy_message( + { + "type": "start", + "id": "debug1", + "payload": { + "query": "subscription { debug { numActiveResultHandlers } }", + }, + } ) data_message: DataMessage = await ws1.receive_json() @@ -595,24 +567,20 @@ async def test_injects_connection_params(aiohttp_app_client: HttpClient): async with aiohttp_app_client.ws_connect( "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] ) as ws: - await ws.send_json( - ConnectionInitMessage( - { - "type": "connection_init", - "payload": {"strawberry": "rocks"}, - } - ) + await ws.send_legacy_message( + { + "type": "connection_init", + "payload": {"strawberry": "rocks"}, + } ) - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "demo", - "payload": { - "query": "subscription { connectionParams }", - }, - } - ) + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": { + "query": "subscription { connectionParams }", + }, + } ) connection_ack_message: ConnectionAckMessage = await ws.receive_json() @@ -623,13 +591,13 @@ async def test_injects_connection_params(aiohttp_app_client: HttpClient): assert data_message["id"] == "demo" assert data_message["payload"]["data"] == {"connectionParams": "rocks"} - await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + await ws.send_legacy_message({"type": "stop", "id": "demo"}) complete_message: CompleteMessage = await ws.receive_json() assert complete_message["type"] == "complete" assert complete_message["id"] == "demo" - await ws.send_json(ConnectionTerminateMessage({"type": "connection_terminate"})) + await ws.send_legacy_message({"type": "connection_terminate"}) # make sure the WebSocket is disconnected now await ws.receive(timeout=2) # receive close @@ -663,17 +631,15 @@ async def test_no_extensions_results_wont_send_extensions_in_payload( async with aiohttp_app_client.ws_connect( "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] ) as ws: - await ws.send_json(ConnectionInitMessage({"type": "connection_init"})) - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "demo", - "payload": { - "query": 'subscription { echo(message: "Hi") }', - }, - } - ) + await ws.send_legacy_message({"type": "connection_init"}) + await ws.send_legacy_message( + { + "type": "start", + "id": "demo", + "payload": { + "query": 'subscription { echo(message: "Hi") }', + }, + } ) connection_ack_message = await ws.receive_json() @@ -685,7 +651,7 @@ async def test_no_extensions_results_wont_send_extensions_in_payload( assert data_message["id"] == "demo" assert "extensions" not in data_message["payload"] - await ws.send_json(StopMessage({"type": "stop", "id": "demo"})) + await ws.send_legacy_message({"type": "stop", "id": "demo"}) await ws.receive_json() @@ -696,21 +662,19 @@ async def test_unexpected_client_disconnects_are_gracefully_handled( process_errors = mock.Mock() with mock.patch.object(Schema, "process_errors", process_errors): - await ws.send_json(ConnectionInitMessage({"type": "connection_init"})) + await ws.send_legacy_message({"type": "connection_init"}) connection_ack_message: ConnectionAckMessage = await ws.receive_json() assert connection_ack_message["type"] == "connection_ack" - await ws.send_json( - StartMessage( - { - "type": "start", - "id": "sub1", - "payload": { - "query": 'subscription { echo(message: "Hi", delay: 0.5) }', - }, - } - ) + await ws.send_legacy_message( + { + "type": "start", + "id": "sub1", + "payload": { + "query": 'subscription { echo(message: "Hi", delay: 0.5) }', + }, + } ) await ws.close()