diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..8be54812a9 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,7 @@ +Release type: minor + +The view classes of all integrations now have a `decode_json` method that allows +you to customize the decoding of HTTP JSON requests. + +This is useful if you want to use a different JSON decoder, for example, to +optimize performance. diff --git a/docs/integrations/aiohttp.md b/docs/integrations/aiohttp.md index 6da5644892..ceed995299 100644 --- a/docs/integrations/aiohttp.md +++ b/docs/integrations/aiohttp.md @@ -142,6 +142,27 @@ class MyGraphQLView(GraphQLView): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. +### decode_json + +`decode_json` allows to customize the decoding of HTTP and WebSocket JSON +requests. By default we use `json.loads` but you can override this method to use +a different decoder. + +```python +from strawberry.aiohttp.views import GraphQLView +from typing import Union +import orjson + + +class MyGraphQLView(GraphQLView): + def decode_json(self, data: Union[str, bytes]) -> object: + return orjson.loads(data) +``` + +Make sure your code raises `json.JSONDecodeError` or a subclass of it if the +JSON cannot be decoded. The library shown in the example above, `orjson`, does +this by default. + ### encode_json `encode_json` allows to customize the encoding of HTTP and WebSocket JSON diff --git a/docs/integrations/asgi.md b/docs/integrations/asgi.md index 7078dd8ac1..4687ef05a6 100644 --- a/docs/integrations/asgi.md +++ b/docs/integrations/asgi.md @@ -49,7 +49,8 @@ We allow to extend the base `GraphQL` app, by overriding the following methods: - `async get_context(self, request: Union[Request, WebSocket], response: Optional[Response] = None) -> Any` - `async get_root_value(self, request: Request) -> Any` - `async process_result(self, request: Request, result: ExecutionResult) -> GraphQLHTTPResponse` -- `def encode_json(self, response_data: object) -> str` +- `def decode_json(self, data: Union[str, bytes]) -> object` +- `def encode_json(self, data: object) -> str` - `async def render_graphql_ide(self, request: Request) -> Response` ### get_context @@ -167,6 +168,26 @@ class MyGraphQL(GraphQL): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. +### decode_json + +`decode_json` allows to customize the decoding of HTTP JSON requests. By default +we use `json.loads` but you can override this method to use a different decoder. + +```python +from strawberry.asgi import GraphQL +from typing import Union +import orjson + + +class MyGraphQLView(GraphQL): + def decode_json(self, data: Union[str, bytes]) -> object: + return orjson.loads(data) +``` + +Make sure your code raises `json.JSONDecodeError` or a subclass of it if the +JSON cannot be decoded. The library shown in the example above, `orjson`, does +this by default. + ### encode_json `encode_json` allows to customize the encoding of HTTP and WebSocket JSON diff --git a/docs/integrations/chalice.md b/docs/integrations/chalice.md index 14011de5d1..f2e9ef8edf 100644 --- a/docs/integrations/chalice.md +++ b/docs/integrations/chalice.md @@ -152,6 +152,26 @@ class MyGraphQLView(GraphQLView): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. +### decode_json + +`decode_json` allows to customize the decoding of HTTP JSON requests. By default +we use `json.loads` but you can override this method to use a different decoder. + +```python +from strawberry.chalice.views import GraphQLView +from typing import Union +import orjson + + +class MyGraphQLView(GraphQLView): + def decode_json(self, data: Union[str, bytes]) -> object: + return orjson.loads(data) +``` + +Make sure your code raises `json.JSONDecodeError` or a subclass of it if the +JSON cannot be decoded. The library shown in the example above, `orjson`, does +this by default. + ### encode_json `encode_json` allows to customize the encoding of HTTP and WebSocket JSON diff --git a/docs/integrations/django.md b/docs/integrations/django.md index bdc6483bdf..408772d86b 100644 --- a/docs/integrations/django.md +++ b/docs/integrations/django.md @@ -286,6 +286,27 @@ class MyGraphQLView(AsyncGraphQLView): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. +### decode_json + +`decode_json` allows to customize the decoding of HTTP and WebSocket JSON +requests. By default we use `json.loads` but you can override this method to use +a different decoder. + +```python +from strawberry.django.views import AsyncGraphQLView +from typing import Union +import orjson + + +class MyGraphQLView(AsyncGraphQLView): + def decode_json(self, data: Union[str, bytes]) -> object: + return orjson.loads(data) +``` + +Make sure your code raises `json.JSONDecodeError` or a subclass of it if the +JSON cannot be decoded. The library shown in the example above, `orjson`, does +this by default. + ### encode_json `encode_json` allows to customize the encoding of HTTP and WebSocket JSON diff --git a/docs/integrations/fastapi.md b/docs/integrations/fastapi.md index 1f6733f22f..dd95440683 100644 --- a/docs/integrations/fastapi.md +++ b/docs/integrations/fastapi.md @@ -290,6 +290,27 @@ class MyGraphQLRouter(GraphQLRouter): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. +### decode_json + +`decode_json` allows to customize the decoding of HTTP and WebSocket JSON +requests. By default we use `json.loads` but you can override this method to use +a different decoder. + +```python +from strawberry.fastapi import GraphQLRouter +from typing import Union +import orjson + + +class MyGraphQLRouter(GraphQLRouter): + def decode_json(self, data: Union[str, bytes]) -> object: + return orjson.loads(data) +``` + +Make sure your code raises `json.JSONDecodeError` or a subclass of it if the +JSON cannot be decoded. The library shown in the example above, `orjson`, does +this by default. + ### encode_json `encode_json` allows to customize the encoding of HTTP and WebSocket JSON diff --git a/docs/integrations/flask.md b/docs/integrations/flask.md index 3ed41c65b2..bacd6b4011 100644 --- a/docs/integrations/flask.md +++ b/docs/integrations/flask.md @@ -139,6 +139,26 @@ class MyGraphQLView(GraphQLView): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. +### decode_json + +`decode_json` allows to customize the decoding of HTTP JSON requests. By default +we use `json.loads` but you can override this method to use a different decoder. + +```python +from strawberry.flask.views import GraphQLView +from typing import Union +import orjson + + +class MyGraphQLView(GraphQLView): + def decode_json(self, data: Union[str, bytes]) -> object: + return orjson.loads(data) +``` + +Make sure your code raises `json.JSONDecodeError` or a subclass of it if the +JSON cannot be decoded. The library shown in the example above, `orjson`, does +this by default. + ### encode_json `encode_json` allows to customize the encoding of HTTP and WebSocket JSON diff --git a/docs/integrations/quart.md b/docs/integrations/quart.md index cf91ffd034..feecf7b4f5 100644 --- a/docs/integrations/quart.md +++ b/docs/integrations/quart.md @@ -123,6 +123,26 @@ class MyGraphQLView(GraphQLView): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. +### decode_json + +`decode_json` allows to customize the decoding of HTTP JSON requests. By default +we use `json.loads` but you can override this method to use a different decoder. + +```python +from strawberry.quart.views import GraphQLView +from typing import Union +import orjson + + +class MyGraphQLView(GraphQLView): + def decode_json(self, data: Union[str, bytes]) -> object: + return orjson.loads(data) +``` + +Make sure your code raises `json.JSONDecodeError` or a subclass of it if the +JSON cannot be decoded. The library shown in the example above, `orjson`, does +this by default. + ### encode_json `encode_json` allows to customize the encoding of HTTP and WebSocket JSON diff --git a/docs/integrations/sanic.md b/docs/integrations/sanic.md index c771d86cff..f10ef60c02 100644 --- a/docs/integrations/sanic.md +++ b/docs/integrations/sanic.md @@ -121,6 +121,26 @@ class MyGraphQLView(GraphQLView): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. +### decode_json + +`decode_json` allows to customize the decoding of HTTP JSON requests. By default +we use `json.loads` but you can override this method to use a different decoder. + +```python +from strawberry.sanic.views import GraphQLView +from typing import Union +import orjson + + +class MyGraphQLView(GraphQLView): + def decode_json(self, data: Union[str, bytes]) -> object: + return orjson.loads(data) +``` + +Make sure your code raises `json.JSONDecodeError` or a subclass of it if the +JSON cannot be decoded. The library shown in the example above, `orjson`, does +this by default. + ### encode_json `encode_json` allows to customize the encoding of HTTP and WebSocket JSON diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index d3f2d50031..aa07e34d8b 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -96,11 +96,11 @@ def __init__( async def iter_json( self, *, ignore_parsing_errors: bool = False - ) -> AsyncGenerator[Dict[str, object], None]: + ) -> AsyncGenerator[object, None]: async for ws_message in self.ws: if ws_message.type == http.WSMsgType.TEXT: try: - yield ws_message.json() + yield self.view.decode_json(ws_message.data) except JSONDecodeError: if not ignore_parsing_errors: raise NonJsonMessageReceived() diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index 0cc701fb2c..90eacd8bd2 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -95,11 +95,12 @@ def __init__( async def iter_json( self, *, ignore_parsing_errors: bool = False - ) -> AsyncGenerator[Dict[str, object], None]: + ) -> AsyncGenerator[object, None]: try: while self.ws.application_state != WebSocketState.DISCONNECTED: try: - yield await self.ws.receive_json() + text = await self.ws.receive_text() + yield self.view.decode_json(text) except JSONDecodeError: # noqa: PERF203 if not ignore_parsing_errors: raise NonJsonMessageReceived() diff --git a/strawberry/channels/handlers/ws_handler.py b/strawberry/channels/handlers/ws_handler.py index 4d53122216..34ba50f8bb 100644 --- a/strawberry/channels/handlers/ws_handler.py +++ b/strawberry/channels/handlers/ws_handler.py @@ -6,7 +6,6 @@ from typing import ( TYPE_CHECKING, AsyncGenerator, - Dict, Mapping, Optional, Tuple, @@ -39,7 +38,7 @@ def __init__( async def iter_json( self, *, ignore_parsing_errors: bool = False - ) -> AsyncGenerator[Dict[str, object], None]: + ) -> AsyncGenerator[object, None]: while True: message = await self.ws_consumer.message_queue.get() @@ -50,7 +49,7 @@ async def iter_json( raise NonTextMessageReceived() try: - yield json.loads(message["message"]) + yield self.view.decode_json(message["message"]) except json.JSONDecodeError: if not ignore_parsing_errors: raise NonJsonMessageReceived() diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 9ed3419e7a..57a307b2e8 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -86,7 +86,7 @@ def __init__(self, view: "AsyncBaseHTTPView") -> None: @abc.abstractmethod def iter_json( self, *, ignore_parsing_errors: bool = False - ) -> AsyncGenerator[Dict[str, object], None]: ... + ) -> AsyncGenerator[object, None]: ... @abc.abstractmethod async def send_json(self, message: Mapping[str, object]) -> None: ... diff --git a/strawberry/http/base.py b/strawberry/http/base.py index 5f9d353f8c..ffb41bf751 100644 --- a/strawberry/http/base.py +++ b/strawberry/http/base.py @@ -39,10 +39,13 @@ def is_request_allowed(self, request: BaseRequestProtocol) -> bool: def parse_json(self, data: Union[str, bytes]) -> Any: try: - return json.loads(data) + return self.decode_json(data) except json.JSONDecodeError as e: raise HTTPException(400, "Unable to parse request body as JSON") from e + def decode_json(self, data: Union[str, bytes]) -> object: + return json.loads(data) + def encode_json(self, data: object) -> str: return json.dumps(data) diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index ad80e4af8b..2ffc456df8 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -202,7 +202,7 @@ def __init__( async def iter_json( self, *, ignore_parsing_errors: bool = False - ) -> AsyncGenerator[Dict[str, object], None]: + ) -> AsyncGenerator[object, None]: try: while self.ws.connection_state != "disconnect": text = await self.ws.receive_text() @@ -212,7 +212,7 @@ async def iter_json( raise NonTextMessageReceived() try: - yield json.loads(text) + yield self.view.decode_json(text) except json.JSONDecodeError: if not ignore_parsing_errors: raise NonJsonMessageReceived() diff --git a/tests/http/test_http.py b/tests/http/test_http.py index 6c4d781aa7..7c318ce7eb 100644 --- a/tests/http/test_http.py +++ b/tests/http/test_http.py @@ -1,5 +1,7 @@ import pytest +from strawberry.http.base import BaseView + from .clients.base import HttpClient @@ -11,3 +13,18 @@ async def test_does_only_allow_get_and_post( response = await http_client.request(url="/graphql", method=method) # type: ignore assert response.status_code == 405 + + +async def test_the_http_handler_uses_the_views_decode_json_method( + http_client: HttpClient, mocker +): + spy = mocker.spy(BaseView, "decode_json") + + response = await http_client.query(query="{ hello }") + assert response.status_code == 200 + + data = response.json["data"] + assert isinstance(data, dict) + assert data["hello"] == "Hello world" + + assert spy.call_count == 1 diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index 98379949ee..b680e9267a 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -4,6 +4,8 @@ import pytest +from strawberry.http.base import BaseView + from .clients.base import HttpClient @@ -81,3 +83,32 @@ async def test_multipart_subscription( ] assert response.status_code == 200 + + +async def test_multipart_subscription_use_the_views_decode_json_method( + http_client: HttpClient, mocker +): + spy = mocker.spy(BaseView, "decode_json") + + response = await http_client.query( + query='subscription { echo(message: "Hello world", delay: 0.2) }', + headers={ + "accept": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + "content-type": "application/json", + }, + ) + + data = [d async for d in response.streaming_json()] + + assert data == [ + { + "payload": { + "data": {"echo": "Hello world"}, + "extensions": {"example": "example"}, + } + } + ] + + assert response.status_code == 200 + + assert spy.call_count == 1 diff --git a/tests/websockets/test_websockets.py b/tests/websockets/test_websockets.py index af8420b3cf..d85eda42d5 100644 --- a/tests/websockets/test_websockets.py +++ b/tests/websockets/test_websockets.py @@ -2,6 +2,9 @@ from strawberry.http.async_base_view import AsyncBaseHTTPView from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL +from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( + ConnectionAckMessage, +) from tests.http.clients.base import HttpClient @@ -84,19 +87,37 @@ async def test_clients_can_prefer_subprotocols(http_client_class: Type[HttpClien async def test_handlers_use_the_views_encode_json_method( - http_client: HttpClient, monkeypatch + http_client: HttpClient, mocker ): - def mock_encode_json(self, data): - return '{"custom": "json"}' - - monkeypatch.setattr(AsyncBaseHTTPView, "encode_json", mock_encode_json) + spy = mocker.spy(AsyncBaseHTTPView, "encode_json") async with http_client.ws_connect( "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] ) as ws: await ws.send_json({"type": "connection_init"}) - message = await ws.receive_json() - assert message == {"custom": "json"} + connection_ack_message: ConnectionAckMessage = await ws.receive_json() + assert connection_ack_message == {"type": "connection_ack"} + + await ws.close() + assert ws.closed + + assert spy.call_count == 1 + + +async def test_handlers_use_the_views_decode_json_method( + http_client: HttpClient, mocker +): + spy = mocker.spy(AsyncBaseHTTPView, "decode_json") + + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) as ws: + await ws.send_message({"type": "connection_init"}) + + connection_ack_message: ConnectionAckMessage = await ws.receive_json() + assert connection_ack_message == {"type": "connection_ack"} await ws.close() assert ws.closed + + assert spy.call_count == 1