Skip to content

Commit

Permalink
Strict graphql ws typed dicts (#3689)
Browse files Browse the repository at this point in the history
* Refactor graphql-ws types to be much stricter

* Add release file

* Fix imprecise check
  • Loading branch information
DoctorJohn authored Nov 7, 2024
1 parent 74ad18c commit ac52f2f
Show file tree
Hide file tree
Showing 8 changed files with 580 additions and 465 deletions.
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Release type: minor

In this release, all types of the legacy graphql-ws protocol were refactored.
The types are now much stricter and precisely model the difference between null and undefined fields.
As a result, our protocol implementation and related tests are now more robust and easier to maintain.
45 changes: 28 additions & 17 deletions strawberry/channels/testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import dataclasses
import uuid
from typing import (
TYPE_CHECKING,
Expand All @@ -26,10 +25,14 @@
SubscribeMessage,
SubscribeMessagePayload,
)
from strawberry.subscriptions.protocols.graphql_ws import (
GQL_CONNECTION_ACK,
GQL_CONNECTION_INIT,
GQL_START,
from strawberry.subscriptions.protocols.graphql_ws.types import (
ConnectionAckMessage as GraphQLWSConnectionAckMessage,
)
from strawberry.subscriptions.protocols.graphql_ws.types import (
ConnectionInitMessage as GraphQLWSConnectionInitMessage,
)
from strawberry.subscriptions.protocols.graphql_ws.types import (
StartMessage as GraphQLWSStartMessage,
)
from strawberry.types import ExecutionResult

Expand Down Expand Up @@ -112,9 +115,11 @@ async def gql_init(self) -> None:
assert response == ConnectionAckMessage().as_dict()
else:
assert res == (True, GRAPHQL_WS_PROTOCOL)
await self.send_json_to({"type": GQL_CONNECTION_INIT})
response = await self.receive_json_from()
assert response["type"] == GQL_CONNECTION_ACK
await self.send_json_to(
GraphQLWSConnectionInitMessage({"type": "connection_init"})
)
response: GraphQLWSConnectionAckMessage = await self.receive_json_from()
assert response["type"] == "connection_ack"

# Actual `ExecutionResult`` objects are not available client-side, since they
# get transformed into `FormattedExecutionResult` on the wire, but we attempt
Expand All @@ -123,22 +128,28 @@ async def subscribe(
self, query: str, variables: Optional[Dict] = None
) -> Union[ExecutionResult, AsyncIterator[ExecutionResult]]:
id_ = uuid.uuid4().hex
sub_payload = SubscribeMessagePayload(query=query, variables=variables)

if self.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
await self.send_json_to(
SubscribeMessage(
id=id_,
payload=sub_payload,
payload=SubscribeMessagePayload(query=query, variables=variables),
).as_dict()
)
else:
await self.send_json_to(
{
"type": GQL_START,
"id": id_,
"payload": dataclasses.asdict(sub_payload),
}
)
start_message: GraphQLWSStartMessage = {
"type": "start",
"id": id_,
"payload": {
"query": query,
},
}

if variables is not None:
start_message["payload"]["variables"] = variables

await self.send_json_to(start_message)

while True:
response = await self.receive_json_from(timeout=5)
message_type = response["type"]
Expand Down
24 changes: 0 additions & 24 deletions strawberry/subscriptions/protocols/graphql_ws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +0,0 @@
GQL_CONNECTION_INIT = "connection_init"
GQL_CONNECTION_ACK = "connection_ack"
GQL_CONNECTION_ERROR = "connection_error"
GQL_CONNECTION_TERMINATE = "connection_terminate"
GQL_CONNECTION_KEEP_ALIVE = "ka"
GQL_START = "start"
GQL_DATA = "data"
GQL_ERROR = "error"
GQL_COMPLETE = "complete"
GQL_STOP = "stop"


__all__ = [
"GQL_CONNECTION_INIT",
"GQL_CONNECTION_ACK",
"GQL_CONNECTION_ERROR",
"GQL_CONNECTION_TERMINATE",
"GQL_CONNECTION_KEEP_ALIVE",
"GQL_START",
"GQL_DATA",
"GQL_ERROR",
"GQL_COMPLETE",
"GQL_STOP",
]
103 changes: 52 additions & 51 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,18 @@
)

from strawberry.http.exceptions import NonTextMessageReceived, WebSocketDisconnected
from strawberry.subscriptions.protocols.graphql_ws import (
GQL_COMPLETE,
GQL_CONNECTION_ACK,
GQL_CONNECTION_ERROR,
GQL_CONNECTION_INIT,
GQL_CONNECTION_KEEP_ALIVE,
GQL_CONNECTION_TERMINATE,
GQL_DATA,
GQL_ERROR,
GQL_START,
GQL_STOP,
)
from strawberry.subscriptions.protocols.graphql_ws.types import (
ConnectionInitPayload,
DataPayload,
CompleteMessage,
ConnectionAckMessage,
ConnectionErrorMessage,
ConnectionInitMessage,
ConnectionKeepAliveMessage,
ConnectionTerminateMessage,
DataMessage,
ErrorMessage,
OperationMessage,
OperationMessagePayload,
StartPayload,
StartMessage,
StopMessage,
)
from strawberry.types.execution import ExecutionResult, PreExecutionError
from strawberry.utils.debug import pretty_print_graphql_operation
Expand Down Expand Up @@ -59,7 +53,7 @@ def __init__(
self.keep_alive_task: Optional[asyncio.Task] = None
self.subscriptions: Dict[str, AsyncGenerator] = {}
self.tasks: Dict[str, asyncio.Task] = {}
self.connection_params: Optional[ConnectionInitPayload] = None
self.connection_params: Optional[Dict[str, object]] = None

async def handle(self) -> None:
try:
Expand Down Expand Up @@ -87,41 +81,40 @@ async def handle_message(
self,
message: OperationMessage,
) -> None:
message_type = message["type"]

if message_type == GQL_CONNECTION_INIT:
if message["type"] == "connection_init":
await self.handle_connection_init(message)
elif message_type == GQL_CONNECTION_TERMINATE:
elif message["type"] == "connection_terminate":
await self.handle_connection_terminate(message)
elif message_type == GQL_START:
elif message["type"] == "start":
await self.handle_start(message)
elif message_type == GQL_STOP:
elif message["type"] == "stop":
await self.handle_stop(message)

async def handle_connection_init(self, message: OperationMessage) -> None:
async def handle_connection_init(self, message: ConnectionInitMessage) -> None:
payload = message.get("payload")
if payload is not None and not isinstance(payload, dict):
error_message: OperationMessage = {"type": GQL_CONNECTION_ERROR}
error_message: ConnectionErrorMessage = {"type": "connection_error"}
await self.websocket.send_json(error_message)
await self.websocket.close(code=1000, reason="")
return

payload = cast(Optional["ConnectionInitPayload"], payload)
self.connection_params = payload

acknowledge_message: OperationMessage = {"type": GQL_CONNECTION_ACK}
await self.websocket.send_json(acknowledge_message)
connection_ack_message: ConnectionAckMessage = {"type": "connection_ack"}
await self.websocket.send_json(connection_ack_message)

if self.keep_alive:
keep_alive_handler = self.handle_keep_alive()
self.keep_alive_task = asyncio.create_task(keep_alive_handler)

async def handle_connection_terminate(self, message: OperationMessage) -> None:
async def handle_connection_terminate(
self, message: ConnectionTerminateMessage
) -> None:
await self.websocket.close(code=1000, reason="")

async def handle_start(self, message: OperationMessage) -> None:
async def handle_start(self, message: StartMessage) -> None:
operation_id = message["id"]
payload = cast("StartPayload", message["payload"])
payload = message["payload"]
query = payload["query"]
operation_name = payload.get("operationName")
variables = payload.get("variables")
Expand All @@ -139,14 +132,14 @@ async def handle_start(self, message: OperationMessage) -> None:
)
self.tasks[operation_id] = asyncio.create_task(result_handler)

async def handle_stop(self, message: OperationMessage) -> None:
async def handle_stop(self, message: StopMessage) -> None:
operation_id = message["id"]
await self.cleanup_operation(operation_id)

async def handle_keep_alive(self) -> None:
assert self.keep_alive_interval
while True:
data: OperationMessage = {"type": GQL_CONNECTION_KEEP_ALIVE}
data: ConnectionKeepAliveMessage = {"type": "ka"}
await self.websocket.send_json(data)
await asyncio.sleep(self.keep_alive_interval)

Expand All @@ -168,14 +161,24 @@ async def handle_async_results(
if isinstance(agen_or_err, PreExecutionError):
assert agen_or_err.errors
error_payload = agen_or_err.errors[0].formatted
await self.send_message(GQL_ERROR, operation_id, error_payload)
error_message: ErrorMessage = {
"type": "error",
"id": operation_id,
"payload": error_payload,
}
await self.websocket.send_json(error_message)
else:
self.subscriptions[operation_id] = agen_or_err
async for result in agen_or_err:
await self.send_data(result, operation_id)
await self.send_message(GQL_COMPLETE, operation_id, None)
complete_message: CompleteMessage = {
"type": "complete",
"id": operation_id,
}
await self.websocket.send_json(complete_message)
except asyncio.CancelledError:
await self.send_message(GQL_COMPLETE, operation_id, None)
complete_message: CompleteMessage = {"type": "complete", "id": operation_id}
await self.websocket.send_json(complete_message)

async def cleanup_operation(self, operation_id: str) -> None:
if operation_id in self.subscriptions:
Expand All @@ -188,26 +191,24 @@ async def cleanup_operation(self, operation_id: str) -> None:
await self.tasks[operation_id]
del self.tasks[operation_id]

async def send_message(
self,
type_: str,
operation_id: str,
payload: Optional[OperationMessagePayload] = None,
) -> None:
data: OperationMessage = {"type": type_, "id": operation_id}
if payload is not None:
data["payload"] = payload
await self.websocket.send_json(data)

async def send_data(
self, execution_result: ExecutionResult, operation_id: str
) -> None:
payload: DataPayload = {"data": execution_result.data}
data_message: DataMessage = {
"type": "data",
"id": operation_id,
"payload": {"data": execution_result.data},
}

if execution_result.errors:
payload["errors"] = [err.formatted for err in execution_result.errors]
data_message["payload"]["errors"] = [
err.formatted for err in execution_result.errors
]

if execution_result.extensions:
payload["extensions"] = execution_result.extensions
await self.send_message(GQL_DATA, operation_id, payload)
data_message["payload"]["extensions"] = execution_result.extensions

await self.websocket.send_json(data_message)


__all__ = ["BaseGraphQLWSHandler"]
Loading

0 comments on commit ac52f2f

Please sign in to comment.