Skip to content

Commit

Permalink
Fix graphql ws did not ignore parsing errors (#3670)
Browse files Browse the repository at this point in the history
* Fix graphql-ws did not ignore parsing errors

* Add release file

* Test on every stage of the protocol

* Make new arg keyword-only and add defaults
  • Loading branch information
DoctorJohn authored Oct 21, 2024
1 parent 2a6d788 commit b701eb0
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 37 deletions.
6 changes: 6 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Release type: minor

This release fixes a regression in the legacy GraphQL over WebSocket protocol.
Legacy protocol implementations should ignore client message parsing errors.
During a recent refactor, Strawberry changed this behavior to match the new protocol, where parsing errors must close the WebSocket connection.
The expected behavior is restored and adequately tested in this release.
15 changes: 11 additions & 4 deletions strawberry/aiohttp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
AsyncHTTPRequestAdapter,
AsyncWebSocketAdapter,
)
from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
from strawberry.http.exceptions import (
HTTPException,
NonJsonMessageReceived,
NonTextMessageReceived,
)
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import (
Context,
Expand Down Expand Up @@ -86,16 +90,19 @@ def __init__(self, request: web.Request, ws: web.WebSocketResponse) -> None:
self.request = request
self.ws = ws

async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
async def iter_json(
self, *, ignore_parsing_errors: bool = False
) -> AsyncGenerator[Dict[str, object], None]:
async for ws_message in self.ws:
if ws_message.type == http.WSMsgType.TEXT:
try:
yield ws_message.json()
except JSONDecodeError:
raise NonJsonMessageReceived()
if not ignore_parsing_errors:
raise NonJsonMessageReceived()

elif ws_message.type == http.WSMsgType.BINARY:
raise NonJsonMessageReceived()
raise NonTextMessageReceived()

async def send_json(self, message: Mapping[str, object]) -> None:
await self.ws.send_json(message)
Expand Down
21 changes: 15 additions & 6 deletions strawberry/asgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
AsyncHTTPRequestAdapter,
AsyncWebSocketAdapter,
)
from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
from strawberry.http.exceptions import (
HTTPException,
NonJsonMessageReceived,
NonTextMessageReceived,
)
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import (
Context,
Expand Down Expand Up @@ -85,13 +89,18 @@ class ASGIWebSocketAdapter(AsyncWebSocketAdapter):
def __init__(self, request: WebSocket, response: WebSocket) -> None:
self.ws = response

async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
async def iter_json(
self, *, ignore_parsing_errors: bool = False
) -> AsyncGenerator[Dict[str, object], None]:
try:
try:
while self.ws.application_state != WebSocketState.DISCONNECTED:
while self.ws.application_state != WebSocketState.DISCONNECTED:
try:
yield await self.ws.receive_json()
except (KeyError, JSONDecodeError):
raise NonJsonMessageReceived()
except JSONDecodeError: # noqa: PERF203
if not ignore_parsing_errors:
raise NonJsonMessageReceived()
except KeyError:
raise NonTextMessageReceived()
except WebSocketDisconnect: # pragma: no cover
pass

Expand Down
11 changes: 7 additions & 4 deletions strawberry/channels/handlers/ws_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing_extensions import TypeGuard

from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter
from strawberry.http.exceptions import NonJsonMessageReceived
from strawberry.http.exceptions import NonJsonMessageReceived, NonTextMessageReceived
from strawberry.http.typevars import Context, RootValue
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL

Expand All @@ -31,20 +31,23 @@ class ChannelsWebSocketAdapter(AsyncWebSocketAdapter):
def __init__(self, request: GraphQLWSConsumer, response: GraphQLWSConsumer) -> None:
self.ws_consumer = response

async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
async def iter_json(
self, *, ignore_parsing_errors: bool = False
) -> AsyncGenerator[Dict[str, object], None]:
while True:
message = await self.ws_consumer.message_queue.get()

if message["disconnected"]:
break

if message["message"] is None:
raise NonJsonMessageReceived()
raise NonTextMessageReceived()

try:
yield json.loads(message["message"])
except json.JSONDecodeError:
raise NonJsonMessageReceived()
if not ignore_parsing_errors:
raise NonJsonMessageReceived()

async def send_json(self, message: Mapping[str, object]) -> None:
serialized_message = json.dumps(message)
Expand Down
4 changes: 3 additions & 1 deletion strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ async def get_form_data(self) -> FormData: ...

class AsyncWebSocketAdapter(abc.ABC):
@abc.abstractmethod
def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: ...
def iter_json(
self, *, ignore_parsing_errors: bool = False
) -> AsyncGenerator[Dict[str, object], None]: ...

@abc.abstractmethod
async def send_json(self, message: Mapping[str, object]) -> None: ...
Expand Down
4 changes: 4 additions & 0 deletions strawberry/http/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ def __init__(self, status_code: int, reason: str) -> None:
self.reason = reason


class NonTextMessageReceived(Exception):
pass


class NonJsonMessageReceived(Exception):
pass

Expand Down
29 changes: 21 additions & 8 deletions strawberry/litestar/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import json
import warnings
from datetime import timedelta
from typing import (
Expand Down Expand Up @@ -37,7 +38,6 @@
from litestar.di import Provide
from litestar.exceptions import (
NotFoundException,
SerializationException,
ValidationException,
WebSocketDisconnect,
)
Expand All @@ -49,7 +49,11 @@
AsyncHTTPRequestAdapter,
AsyncWebSocketAdapter,
)
from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
from strawberry.http.exceptions import (
HTTPException,
NonJsonMessageReceived,
NonTextMessageReceived,
)
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import Context, RootValue
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
Expand Down Expand Up @@ -192,13 +196,22 @@ class LitestarWebSocketAdapter(AsyncWebSocketAdapter):
def __init__(self, request: WebSocket, response: WebSocket) -> None:
self.ws = response

async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
async def iter_json(
self, *, ignore_parsing_errors: bool = False
) -> AsyncGenerator[Dict[str, object], None]:
try:
try:
while self.ws.connection_state != "disconnect":
yield await self.ws.receive_json()
except (SerializationException, ValueError):
raise NonJsonMessageReceived()
while self.ws.connection_state != "disconnect":
text = await self.ws.receive_text()

# Litestar internally defaults to an empty string for non-text messages
if text == "":
raise NonTextMessageReceived()

try:
yield json.loads(text)
except json.JSONDecodeError:
if not ignore_parsing_errors:
raise NonJsonMessageReceived()
except WebSocketDisconnect:
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from graphql import GraphQLError, GraphQLSyntaxError, parse

from strawberry.http.exceptions import NonJsonMessageReceived
from strawberry.http.exceptions import NonJsonMessageReceived, NonTextMessageReceived
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
CompleteMessage,
ConnectionAckMessage,
Expand Down Expand Up @@ -78,8 +78,10 @@ async def handle(self) -> Any:
try:
async for message in self.websocket.iter_json():
await self.handle_message(message)
except NonJsonMessageReceived:
except NonTextMessageReceived:
await self.handle_invalid_message("WebSocket message type must be text")
except NonJsonMessageReceived:
await self.handle_invalid_message("WebSocket message must be valid JSON")
finally:
await self.shutdown()

Expand Down
6 changes: 3 additions & 3 deletions strawberry/subscriptions/protocols/graphql_ws/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
cast,
)

from strawberry.http.exceptions import NonJsonMessageReceived
from strawberry.http.exceptions import NonTextMessageReceived
from strawberry.subscriptions.protocols.graphql_ws import (
GQL_COMPLETE,
GQL_CONNECTION_ACK,
Expand Down Expand Up @@ -65,9 +65,9 @@ def __init__(

async def handle(self) -> None:
try:
async for message in self.websocket.iter_json():
async for message in self.websocket.iter_json(ignore_parsing_errors=True):
await self.handle_message(cast(OperationMessage, message))
except NonJsonMessageReceived:
except NonTextMessageReceived:
await self.websocket.close(
code=1002, reason="WebSocket message type must be text"
)
Expand Down
6 changes: 3 additions & 3 deletions tests/websockets/test_graphql_transport_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def test_parsing_an_invalid_payload(ws_raw: WebSocketClient):
assert ws.close_reason == "Failed to parse message"


async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):
async def test_non_text_ws_messages_result_in_socket_closure(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_bytes(json.dumps(ConnectionInitMessage().as_dict()).encode())
Expand All @@ -123,15 +123,15 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):
assert ws.close_reason == "WebSocket message type must be text"


async def test_ws_messages_must_be_json(ws_raw: WebSocketClient):
async def test_non_json_ws_messages_result_in_socket_closure(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_text("not valid json")

await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
assert ws.close_reason == "WebSocket message type must be text"
assert ws.close_reason == "WebSocket message must be valid JSON"


async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient):
Expand Down
39 changes: 33 additions & 6 deletions tests/websockets/test_graphql_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ async def test_subscription_syntax_error(ws: WebSocketClient):
}


async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):
async def test_non_text_ws_messages_result_in_socket_closure(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_bytes(json.dumps({"type": GQL_CONNECTION_INIT}).encode())
Expand All @@ -292,15 +292,42 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):
assert ws.close_reason == "WebSocket message type must be text"


async def test_ws_messages_must_be_json(ws_raw: WebSocketClient):
async def test_non_json_ws_messages_are_ignored(ws_raw: WebSocketClient):
ws = ws_raw

await ws.send_text("not valid json")
await ws.send_text("NOT VALID JSON")
await ws.send_json({"type": GQL_CONNECTION_INIT})

await ws.receive(timeout=2)
response = await ws.receive_json()
assert response["type"] == GQL_CONNECTION_ACK

await ws.send_text("NOT VALID JSON")
await ws.send_json(
{
"type": GQL_START,
"id": "demo",
"payload": {
"query": 'subscription { echo(message: "Hi") }',
},
}
)

response = await ws.receive_json()
assert response["type"] == GQL_DATA
assert response["id"] == "demo"
assert response["payload"]["data"] == {"echo": "Hi"}

await ws.send_text("NOT VALID JSON")
await ws.send_json({"type": GQL_STOP, "id": "demo"})

response = await ws.receive_json()
assert response["type"] == GQL_COMPLETE
assert response["id"] == "demo"

await ws.send_text("NOT VALID JSON")
await ws.send_json({"type": GQL_CONNECTION_TERMINATE})
await ws.receive(timeout=2) # receive close
assert ws.closed
assert ws.close_code == 1002
assert ws.close_reason == "WebSocket message type must be text"


async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient):
Expand Down

0 comments on commit b701eb0

Please sign in to comment.