Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strict graphql ws typed dicts #3689

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Release type: minor

In this release, all types of the legacy graphql-ws protocol were refactored.
The types are now much stricter and precisely model the difference between null and undefined fields.
As a result, our protocol implementation and related tests are now more robust and easier to maintain.
45 changes: 28 additions & 17 deletions strawberry/channels/testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import dataclasses
import uuid
from typing import (
TYPE_CHECKING,
Expand All @@ -26,10 +25,14 @@
SubscribeMessage,
SubscribeMessagePayload,
)
from strawberry.subscriptions.protocols.graphql_ws import (
GQL_CONNECTION_ACK,
GQL_CONNECTION_INIT,
GQL_START,
from strawberry.subscriptions.protocols.graphql_ws.types import (
ConnectionAckMessage as GraphQLWSConnectionAckMessage,
)
from strawberry.subscriptions.protocols.graphql_ws.types import (
ConnectionInitMessage as GraphQLWSConnectionInitMessage,
)
from strawberry.subscriptions.protocols.graphql_ws.types import (
StartMessage as GraphQLWSStartMessage,
)
from strawberry.types import ExecutionResult

Expand Down Expand Up @@ -112,9 +115,11 @@
assert response == ConnectionAckMessage().as_dict()
else:
assert res == (True, GRAPHQL_WS_PROTOCOL)
await self.send_json_to({"type": GQL_CONNECTION_INIT})
response = await self.receive_json_from()
assert response["type"] == GQL_CONNECTION_ACK
await self.send_json_to(
GraphQLWSConnectionInitMessage({"type": "connection_init"})
)
response: GraphQLWSConnectionAckMessage = await self.receive_json_from()
assert response["type"] == "connection_ack"

# Actual `ExecutionResult`` objects are not available client-side, since they
# get transformed into `FormattedExecutionResult` on the wire, but we attempt
Expand All @@ -123,22 +128,28 @@
self, query: str, variables: Optional[Dict] = None
) -> Union[ExecutionResult, AsyncIterator[ExecutionResult]]:
id_ = uuid.uuid4().hex
sub_payload = SubscribeMessagePayload(query=query, variables=variables)

if self.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
await self.send_json_to(
SubscribeMessage(
id=id_,
payload=sub_payload,
payload=SubscribeMessagePayload(query=query, variables=variables),
).as_dict()
)
else:
await self.send_json_to(
{
"type": GQL_START,
"id": id_,
"payload": dataclasses.asdict(sub_payload),
}
)
start_message: GraphQLWSStartMessage = {
"type": "start",
"id": id_,
"payload": {
"query": query,
},
}

if variables:
start_message["payload"]["variables"] = variables

Check warning on line 149 in strawberry/channels/testing.py

View check run for this annotation

Codecov / codecov/patch

strawberry/channels/testing.py#L149

Added line #L149 was not covered by tests

await self.send_json_to(start_message)

while True:
response = await self.receive_json_from(timeout=5)
message_type = response["type"]
Expand Down
24 changes: 0 additions & 24 deletions strawberry/subscriptions/protocols/graphql_ws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +0,0 @@
GQL_CONNECTION_INIT = "connection_init"
GQL_CONNECTION_ACK = "connection_ack"
GQL_CONNECTION_ERROR = "connection_error"
GQL_CONNECTION_TERMINATE = "connection_terminate"
GQL_CONNECTION_KEEP_ALIVE = "ka"
GQL_START = "start"
GQL_DATA = "data"
GQL_ERROR = "error"
GQL_COMPLETE = "complete"
GQL_STOP = "stop"


__all__ = [
"GQL_CONNECTION_INIT",
"GQL_CONNECTION_ACK",
"GQL_CONNECTION_ERROR",
"GQL_CONNECTION_TERMINATE",
"GQL_CONNECTION_KEEP_ALIVE",
"GQL_START",
"GQL_DATA",
"GQL_ERROR",
"GQL_COMPLETE",
"GQL_STOP",
]
103 changes: 52 additions & 51 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,18 @@
)

from strawberry.http.exceptions import NonTextMessageReceived, WebSocketDisconnected
from strawberry.subscriptions.protocols.graphql_ws import (
GQL_COMPLETE,
GQL_CONNECTION_ACK,
GQL_CONNECTION_ERROR,
GQL_CONNECTION_INIT,
GQL_CONNECTION_KEEP_ALIVE,
GQL_CONNECTION_TERMINATE,
GQL_DATA,
GQL_ERROR,
GQL_START,
GQL_STOP,
)
from strawberry.subscriptions.protocols.graphql_ws.types import (
ConnectionInitPayload,
DataPayload,
CompleteMessage,
ConnectionAckMessage,
ConnectionErrorMessage,
ConnectionInitMessage,
ConnectionKeepAliveMessage,
ConnectionTerminateMessage,
DataMessage,
ErrorMessage,
OperationMessage,
OperationMessagePayload,
StartPayload,
StartMessage,
StopMessage,
)
from strawberry.types.execution import ExecutionResult, PreExecutionError
from strawberry.utils.debug import pretty_print_graphql_operation
Expand Down Expand Up @@ -59,7 +53,7 @@ def __init__(
self.keep_alive_task: Optional[asyncio.Task] = None
self.subscriptions: Dict[str, AsyncGenerator] = {}
self.tasks: Dict[str, asyncio.Task] = {}
self.connection_params: Optional[ConnectionInitPayload] = None
self.connection_params: Optional[Dict[str, object]] = None

async def handle(self) -> None:
try:
Expand Down Expand Up @@ -87,41 +81,40 @@ async def handle_message(
self,
message: OperationMessage,
) -> None:
message_type = message["type"]

if message_type == GQL_CONNECTION_INIT:
if message["type"] == "connection_init":
await self.handle_connection_init(message)
elif message_type == GQL_CONNECTION_TERMINATE:
elif message["type"] == "connection_terminate":
await self.handle_connection_terminate(message)
elif message_type == GQL_START:
elif message["type"] == "start":
await self.handle_start(message)
elif message_type == GQL_STOP:
elif message["type"] == "stop":
await self.handle_stop(message)

async def handle_connection_init(self, message: OperationMessage) -> None:
async def handle_connection_init(self, message: ConnectionInitMessage) -> None:
payload = message.get("payload")
if payload is not None and not isinstance(payload, dict):
error_message: OperationMessage = {"type": GQL_CONNECTION_ERROR}
error_message: ConnectionErrorMessage = {"type": "connection_error"}
await self.websocket.send_json(error_message)
await self.websocket.close(code=1000, reason="")
return

payload = cast(Optional["ConnectionInitPayload"], payload)
self.connection_params = payload

acknowledge_message: OperationMessage = {"type": GQL_CONNECTION_ACK}
await self.websocket.send_json(acknowledge_message)
connection_ack_message: ConnectionAckMessage = {"type": "connection_ack"}
await self.websocket.send_json(connection_ack_message)

if self.keep_alive:
keep_alive_handler = self.handle_keep_alive()
self.keep_alive_task = asyncio.create_task(keep_alive_handler)

async def handle_connection_terminate(self, message: OperationMessage) -> None:
async def handle_connection_terminate(
self, message: ConnectionTerminateMessage
) -> None:
await self.websocket.close(code=1000, reason="")

async def handle_start(self, message: OperationMessage) -> None:
async def handle_start(self, message: StartMessage) -> None:
operation_id = message["id"]
payload = cast("StartPayload", message["payload"])
payload = message["payload"]
query = payload["query"]
DoctorJohn marked this conversation as resolved.
Show resolved Hide resolved
operation_name = payload.get("operationName")
variables = payload.get("variables")
Expand All @@ -139,14 +132,14 @@ async def handle_start(self, message: OperationMessage) -> None:
)
self.tasks[operation_id] = asyncio.create_task(result_handler)

async def handle_stop(self, message: OperationMessage) -> None:
async def handle_stop(self, message: StopMessage) -> None:
operation_id = message["id"]
await self.cleanup_operation(operation_id)

async def handle_keep_alive(self) -> None:
assert self.keep_alive_interval
while True:
data: OperationMessage = {"type": GQL_CONNECTION_KEEP_ALIVE}
data: ConnectionKeepAliveMessage = {"type": "ka"}
await self.websocket.send_json(data)
await asyncio.sleep(self.keep_alive_interval)

Expand All @@ -168,14 +161,24 @@ async def handle_async_results(
if isinstance(agen_or_err, PreExecutionError):
assert agen_or_err.errors
error_payload = agen_or_err.errors[0].formatted
await self.send_message(GQL_ERROR, operation_id, error_payload)
error_message: ErrorMessage = {
DoctorJohn marked this conversation as resolved.
Show resolved Hide resolved
"type": "error",
"id": operation_id,
"payload": error_payload,
}
await self.websocket.send_json(error_message)
else:
self.subscriptions[operation_id] = agen_or_err
async for result in agen_or_err:
await self.send_data(result, operation_id)
await self.send_message(GQL_COMPLETE, operation_id, None)
complete_message: CompleteMessage = {
"type": "complete",
"id": operation_id,
}
await self.websocket.send_json(complete_message)
except asyncio.CancelledError:
await self.send_message(GQL_COMPLETE, operation_id, None)
complete_message: CompleteMessage = {"type": "complete", "id": operation_id}
await self.websocket.send_json(complete_message)

async def cleanup_operation(self, operation_id: str) -> None:
if operation_id in self.subscriptions:
Expand All @@ -188,26 +191,24 @@ async def cleanup_operation(self, operation_id: str) -> None:
await self.tasks[operation_id]
del self.tasks[operation_id]

async def send_message(
self,
type_: str,
operation_id: str,
payload: Optional[OperationMessagePayload] = None,
) -> None:
data: OperationMessage = {"type": type_, "id": operation_id}
if payload is not None:
data["payload"] = payload
await self.websocket.send_json(data)

async def send_data(
self, execution_result: ExecutionResult, operation_id: str
) -> None:
payload: DataPayload = {"data": execution_result.data}
data_message: DataMessage = {
"type": "data",
"id": operation_id,
"payload": {"data": execution_result.data},
}

if execution_result.errors:
payload["errors"] = [err.formatted for err in execution_result.errors]
data_message["payload"]["errors"] = [
err.formatted for err in execution_result.errors
]

if execution_result.extensions:
payload["extensions"] = execution_result.extensions
await self.send_message(GQL_DATA, operation_id, payload)
data_message["payload"]["extensions"] = execution_result.extensions

await self.websocket.send_json(data_message)


__all__ = ["BaseGraphQLWSHandler"]
Loading
Loading