Skip to content

Commit

Permalink
Let graphql-ws use type inference whenever possible (#3704)
Browse files Browse the repository at this point in the history
* Let the legacy protocol make use of type inference

* Make test assertions more precise

* Remove unneeded type ignores

* Add mini release file
  • Loading branch information
DoctorJohn authored Nov 19, 2024
1 parent 4551c04 commit 60b026e
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 249 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

This release refactors part of the legacy `graphql-ws` protocol implementation, making it easier to read, maintain, and extend.
45 changes: 18 additions & 27 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,9 @@

from strawberry.http.exceptions import NonTextMessageReceived, WebSocketDisconnected
from strawberry.subscriptions.protocols.graphql_ws.types import (
CompleteMessage,
ConnectionAckMessage,
ConnectionErrorMessage,
ConnectionInitMessage,
ConnectionKeepAliveMessage,
ConnectionTerminateMessage,
DataMessage,
ErrorMessage,
OperationMessage,
StartMessage,
StopMessage,
Expand Down Expand Up @@ -93,15 +88,13 @@ async def handle_message(
async def handle_connection_init(self, message: ConnectionInitMessage) -> None:
payload = message.get("payload")
if payload is not None and not isinstance(payload, dict):
error_message: ConnectionErrorMessage = {"type": "connection_error"}
await self.websocket.send_json(error_message)
await self.send_message({"type": "connection_error"})
await self.websocket.close(code=1000, reason="")
return

self.connection_params = payload

connection_ack_message: ConnectionAckMessage = {"type": "connection_ack"}
await self.websocket.send_json(connection_ack_message)
await self.send_message({"type": "connection_ack"})

if self.keep_alive:
keep_alive_handler = self.handle_keep_alive()
Expand Down Expand Up @@ -139,8 +132,7 @@ async def handle_stop(self, message: StopMessage) -> None:
async def handle_keep_alive(self) -> None:
assert self.keep_alive_interval
while True:
data: ConnectionKeepAliveMessage = {"type": "ka"}
await self.websocket.send_json(data)
await self.send_message({"type": "ka"})
await asyncio.sleep(self.keep_alive_interval)

async def handle_async_results(
Expand All @@ -160,26 +152,22 @@ async def handle_async_results(
)
if isinstance(agen_or_err, PreExecutionError):
assert agen_or_err.errors
error_payload = agen_or_err.errors[0].formatted
error_message: ErrorMessage = {
"type": "error",
"id": operation_id,
"payload": error_payload,
}
await self.websocket.send_json(error_message)
await self.send_message(
{
"type": "error",
"id": operation_id,
"payload": agen_or_err.errors[0].formatted,
}
)
else:
self.subscriptions[operation_id] = agen_or_err

async for result in agen_or_err:
await self.send_data(result, operation_id)
await self.send_data_message(result, operation_id)

await self.websocket.send_json(
CompleteMessage({"type": "complete", "id": operation_id})
)
await self.send_message({"type": "complete", "id": operation_id})
except asyncio.CancelledError:
await self.websocket.send_json(
CompleteMessage({"type": "complete", "id": operation_id})
)
await self.send_message({"type": "complete", "id": operation_id})

async def cleanup_operation(self, operation_id: str) -> None:
if operation_id in self.subscriptions:
Expand All @@ -192,7 +180,7 @@ async def cleanup_operation(self, operation_id: str) -> None:
await self.tasks[operation_id]
del self.tasks[operation_id]

async def send_data(
async def send_data_message(
self, execution_result: ExecutionResult, operation_id: str
) -> None:
data_message: DataMessage = {
Expand All @@ -209,7 +197,10 @@ async def send_data(
if execution_result.extensions:
data_message["payload"]["extensions"] = execution_result.extensions

await self.websocket.send_json(data_message)
await self.send_message(data_message)

async def send_message(self, message: OperationMessage) -> None:
await self.websocket.send_json(message)


__all__ = ["BaseGraphQLWSHandler"]
1 change: 1 addition & 0 deletions strawberry/subscriptions/protocols/graphql_ws/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class ConnectionKeepAliveMessage(TypedDict):
DataMessage,
ErrorMessage,
CompleteMessage,
ConnectionKeepAliveMessage,
]


Expand Down
4 changes: 4 additions & 0 deletions tests/http/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Message as GraphQLTransportWSMessage,
)
from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler
from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage
from strawberry.types import ExecutionResult

logger = logging.getLogger("strawberry.test.http_client")
Expand Down Expand Up @@ -307,6 +308,9 @@ async def __aiter__(self) -> AsyncGenerator[Message, None]:
async def send_message(self, message: GraphQLTransportWSMessage) -> None:
await self.send_json(message)

async def send_legacy_message(self, message: OperationMessage) -> None:
await self.send_json(message)


class DebuggableGraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
def on_init(self) -> None:
Expand Down
12 changes: 6 additions & 6 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def assert_next(
async def test_unknown_message_type(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_message({"type": "NOT_A_MESSAGE_TYPE"}) # type: ignore
await ws.send_json({"type": "NOT_A_MESSAGE_TYPE"})

await ws.receive(timeout=2)
assert ws.closed
Expand All @@ -83,7 +83,7 @@ async def test_unknown_message_type(ws_raw: WebSocketClient):
async def test_missing_message_type(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_message({"notType": None}) # type: ignore
await ws.send_json({"notType": None})

await ws.receive(timeout=2)
assert ws.closed
Expand All @@ -92,7 +92,7 @@ async def test_missing_message_type(ws_raw: WebSocketClient):


async def test_parsing_an_invalid_message(ws: WebSocketClient):
await ws.send_message({"type": "subscribe", "notPayload": None}) # type: ignore
await ws.send_json({"type": "subscribe", "notPayload": None})

await ws.receive(timeout=2)
assert ws.closed
Expand Down Expand Up @@ -218,7 +218,7 @@ async def test_close_twice(

# We set payload is set to "invalid value" to force a invalid payload error
# which will close the connection
await ws.send_message({"type": "connection_init", "payload": "invalid value"}) # type: ignore
await ws.send_json({"type": "connection_init", "payload": "invalid value"})

# Yield control so that ._close can be called
await asyncio.sleep(0)
Expand Down Expand Up @@ -830,7 +830,7 @@ async def test_injects_connection_params(ws_raw: WebSocketClient):

async def test_rejects_connection_params_not_dict(ws_raw: WebSocketClient):
ws = ws_raw
await ws.send_message({"type": "connection_init", "payload": "gonna fail"}) # type: ignore
await ws.send_json({"type": "connection_init", "payload": "gonna fail"})

await ws.receive(timeout=2)
assert ws.closed
Expand All @@ -846,7 +846,7 @@ async def test_rejects_connection_params_with_wrong_type(
payload: object, ws_raw: WebSocketClient
):
ws = ws_raw
await ws.send_message({"type": "connection_init", "payload": payload}) # type: ignore
await ws.send_json({"type": "connection_init", "payload": payload})

await ws.receive(timeout=2)
assert ws.closed
Expand Down
Loading

0 comments on commit 60b026e

Please sign in to comment.