From 37265b230e511480a9ceace492f9f6a484be1387 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Wed, 25 Sep 2024 18:08:23 +0200 Subject: [PATCH] Disable multipart uploads by default (#3645) * Disable multipart uploads by default * Document the new option * Stop disabling Django's CSRF protection by default * Document breaking changes * Add release file * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Bump date * Test * Add tweet file * Shorter tweet --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Patrick Arminio --- .github/workflows/test.yml | 14 ----- RELEASE.md | 7 +++ TWEET.md | 8 +++ docs/breaking-changes.md | 1 + docs/breaking-changes/0.243.0.md | 53 ++++++++++++++++ docs/integrations/aiohttp.md | 6 +- docs/integrations/asgi.md | 6 +- docs/integrations/channels.md | 4 ++ docs/integrations/django.md | 7 ++- docs/integrations/fastapi.md | 4 ++ docs/integrations/flask.md | 6 +- docs/integrations/litestar.md | 4 ++ docs/integrations/quart.md | 6 +- docs/integrations/sanic.md | 7 ++- strawberry/aiohttp/views.py | 2 + strawberry/asgi/__init__.py | 2 + strawberry/channels/handlers/http_handler.py | 2 + strawberry/django/views.py | 7 +-- strawberry/fastapi/router.py | 2 + strawberry/flask/views.py | 2 + strawberry/http/async_base_view.py | 2 +- strawberry/http/base.py | 1 + strawberry/http/sync_base_view.py | 2 +- strawberry/litestar/controller.py | 2 + strawberry/quart/views.py | 2 + strawberry/sanic/views.py | 2 + tests/http/clients/aiohttp.py | 2 + tests/http/clients/asgi.py | 2 + tests/http/clients/async_django.py | 1 + tests/http/clients/async_flask.py | 2 + tests/http/clients/base.py | 1 + tests/http/clients/chalice.py | 1 + tests/http/clients/channels.py | 4 ++ tests/http/clients/django.py | 3 + tests/http/clients/fastapi.py | 2 + tests/http/clients/flask.py | 2 + tests/http/clients/litestar.py | 2 + tests/http/clients/quart.py | 2 + tests/http/clients/sanic.py | 2 + tests/http/test_upload.py | 64 ++++++++++++++------ 40 files changed, 207 insertions(+), 44 deletions(-) create mode 100644 RELEASE.md create mode 100644 TWEET.md create mode 100644 docs/breaking-changes/0.243.0.md diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0a3c02a29c..0374bbe3d4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -59,20 +59,6 @@ jobs: 3.12 3.13-dev - - name: Pip and nox cache - id: cache - uses: actions/cache@v4 - with: - path: | - ~/.cache - ~/.nox - .nox - key: - ${{ runner.os }}-nox-${{ matrix.session.session }}-${{ env.pythonLocation }}-${{ - hashFiles('**/poetry.lock') }}-${{ hashFiles('**/noxfile.py') }} - restore-keys: | - ${{ runner.os }}-nox-${{ matrix.session.session }}-${{ env.pythonLocation }} - - run: pip install poetry nox nox-poetry uv - run: nox -r -t tests -s "${{ matrix.session.session }}" - uses: actions/upload-artifact@v4 diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..0d039486fc --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,7 @@ +Release type: minor + +Starting with this release, multipart uploads are disabled by default and Strawberry Django view is no longer implicitly exempted from Django's CSRF protection. +Both changes relieve users from implicit security implications inherited from the GraphQL multipart request specification which was enabled in Strawberry by default. + +These are breaking changes if you are using multipart uploads OR the Strawberry Django view. +Migrations guides including further information are available on the Strawberry website. diff --git a/TWEET.md b/TWEET.md new file mode 100644 index 0000000000..ec081269f6 --- /dev/null +++ b/TWEET.md @@ -0,0 +1,8 @@ +🆕 Release $version is out! Thanks to $contributor 👏 + +We've made some important security changes regarding file uploads and CSRF in +Django. + +Check out our migration guides if you're using multipart or Django view. + +👇 $release_url diff --git a/docs/breaking-changes.md b/docs/breaking-changes.md index 9f0ed20a80..3df11e2b23 100644 --- a/docs/breaking-changes.md +++ b/docs/breaking-changes.md @@ -4,6 +4,7 @@ title: List of breaking changes and deprecations # List of breaking changes and deprecations +- [Version 0.243.0 - 25 September 2024](./breaking-changes/0.243.0.md) - [Version 0.240.0 - 10 September 2024](./breaking-changes/0.240.0.md) - [Version 0.236.0 - 17 July 2024](./breaking-changes/0.236.0.md) - [Version 0.233.0 - 29 May 2024](./breaking-changes/0.233.0.md) diff --git a/docs/breaking-changes/0.243.0.md b/docs/breaking-changes/0.243.0.md new file mode 100644 index 0000000000..0eec0b9633 --- /dev/null +++ b/docs/breaking-changes/0.243.0.md @@ -0,0 +1,53 @@ +--- +title: 0.243.0 Breaking Changes +slug: breaking-changes/0.243.0 +--- + +# v0.240.0 Breaking Changes + +Release v0.240.0 comes with two breaking changes regarding multipart file +uploads and Django CSRF protection. + +## Multipart uploads disabled by default + +Previously, support for uploads via the +[GraphQL multipart request specification](https://github.com/jaydenseric/graphql-multipart-request-spec) +was enabled by default. This implicitly required Strawberry users to consider +the +[security implications outlined in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security). +Given that most Strawberry users were likely not aware of this, we're making +multipart file upload support stictly opt-in via a new +`multipart_uploads_enabled` view settings. + +To enable multipart upload support for your Strawberry view integration, please +follow the updated integration guides and enable appropriate security +measurements for your server. + +## Django CSRF protection enabled + +Previously, the Strawberry Django view integration was internally exempted from +Django's built-in CSRF protection (i.e, the `CsrfViewMiddleware` middleware). +While this is how many GraphQL APIs operate, implicitly addded exemptions can +lead to security vulnerabilities. Instead, we delegate the decision of adding an +CSRF exemption to users now. + +Note that having the CSRF protection enabled on your Strawberry Django view +potentially requires all your clients to send an CSRF token with every request. +You can learn more about this in the official Django +[Cross Site Request Forgery protection documentation](https://docs.djangoproject.com/en/dev/ref/csrf/). + +To restore the behaviour of the integration before this release, you can add the +`csrf_exempt` decorator provided by Django yourself: + +```python +from django.urls import path +from django.views.decorators.csrf import csrf_exempt + +from strawberry.django.views import GraphQLView + +from api.schema import schema + +urlpatterns = [ + path("graphql/", csrf_exempt(GraphQLView.as_view(schema=schema))), +] +``` diff --git a/docs/integrations/aiohttp.md b/docs/integrations/aiohttp.md index cb622234fd..e968e43bf2 100644 --- a/docs/integrations/aiohttp.md +++ b/docs/integrations/aiohttp.md @@ -29,7 +29,7 @@ app.router.add_route("*", "/graphql", GraphQLView(schema=schema)) ## Options -The `GraphQLView` accepts two options at the moment: +The `GraphQLView` accepts the following options at the moment: - `schema`: mandatory, the schema created by `strawberry.Schema`. - `graphql_ide`: optional, defaults to `"graphiql"`, allows to choose the @@ -37,6 +37,10 @@ The `GraphQLView` accepts two options at the moment: to disable it by passing `None`. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## Extending the view diff --git a/docs/integrations/asgi.md b/docs/integrations/asgi.md index eaad8a52ff..dcb51643b3 100644 --- a/docs/integrations/asgi.md +++ b/docs/integrations/asgi.md @@ -29,7 +29,7 @@ app with `uvicorn server:app` ## Options -The `GraphQL` app accepts two options at the moment: +The `GraphQL` app accepts the following options at the moment: - `schema`: mandatory, the schema created by `strawberry.Schema`. - `graphql_ide`: optional, defaults to `"graphiql"`, allows to choose the @@ -37,6 +37,10 @@ The `GraphQL` app accepts two options at the moment: to disable it by passing `None`. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## Extending the view diff --git a/docs/integrations/channels.md b/docs/integrations/channels.md index 4fa4ae716f..ce4fc70fb6 100644 --- a/docs/integrations/channels.md +++ b/docs/integrations/channels.md @@ -524,6 +524,10 @@ GraphQLWebsocketCommunicator( queries via `GET` requests - `subscriptions_enabled`: optional boolean paramenter enabling subscriptions in the GraphiQL interface, defaults to `True` +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ### Extending the consumer diff --git a/docs/integrations/django.md b/docs/integrations/django.md index 2a4e75266a..2909aba6f9 100644 --- a/docs/integrations/django.md +++ b/docs/integrations/django.md @@ -10,13 +10,14 @@ It provides a view that you can use to serve your GraphQL schema: ```python from django.urls import path +from django.views.decorators.csrf import csrf_exempt from strawberry.django.views import GraphQLView from api.schema import schema urlpatterns = [ - path("graphql/", GraphQLView.as_view(schema=schema)), + path("graphql/", csrf_exempt(GraphQLView.as_view(schema=schema))), ] ``` @@ -40,6 +41,10 @@ The `GraphQLView` accepts the following arguments: queries via `GET` requests - `subscriptions_enabled`: optional boolean paramenter enabling subscriptions in the GraphiQL interface, defaults to `False`. +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## Deprecated options diff --git a/docs/integrations/fastapi.md b/docs/integrations/fastapi.md index 1fa144b974..18b7dae5b6 100644 --- a/docs/integrations/fastapi.md +++ b/docs/integrations/fastapi.md @@ -54,6 +54,10 @@ The `GraphQLRouter` accepts the following options: value. - `root_value_getter`: optional FastAPI dependency for providing custom root value. +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## context_getter diff --git a/docs/integrations/flask.md b/docs/integrations/flask.md index 9fe4e17ac3..b44e166efa 100644 --- a/docs/integrations/flask.md +++ b/docs/integrations/flask.md @@ -34,13 +34,17 @@ from strawberry.flask.views import AsyncGraphQLView ## Options -The `GraphQLView` accepts two options at the moment: +The `GraphQLView` accepts the following options at the moment: - `schema`: mandatory, the schema created by `strawberry.Schema`. - `graphiql:` optional, defaults to `True`, whether to enable the GraphiQL interface. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## Extending the view diff --git a/docs/integrations/litestar.md b/docs/integrations/litestar.md index 0ba7c1ded8..b505096626 100644 --- a/docs/integrations/litestar.md +++ b/docs/integrations/litestar.md @@ -61,6 +61,10 @@ The `make_graphql_controller` function accepts the following options: the maximum time to wait for the connection initialization message when using `graphql-transport-ws` [protocol](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#connectioninit) +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## context_getter diff --git a/docs/integrations/quart.md b/docs/integrations/quart.md index 05eb0cd994..fa9b071990 100644 --- a/docs/integrations/quart.md +++ b/docs/integrations/quart.md @@ -26,13 +26,17 @@ if __name__ == "__main__": ## Options -The `GraphQLView` accepts two options at the moment: +The `GraphQLView` accepts the following options at the moment: - `schema`: mandatory, the schema created by `strawberry.Schema`. - `graphiql:` optional, defaults to `True`, whether to enable the GraphiQL interface. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## Extending the view diff --git a/docs/integrations/sanic.md b/docs/integrations/sanic.md index 1701a74a83..0a7d885c7d 100644 --- a/docs/integrations/sanic.md +++ b/docs/integrations/sanic.md @@ -22,7 +22,7 @@ app.add_route( ## Options -The `GraphQLView` accepts two options at the moment: +The `GraphQLView` accepts the following options at the moment: - `schema`: mandatory, the schema created by `strawberry.Schema`. - `graphql_ide`: optional, defaults to `"graphiql"`, allows to choose the @@ -30,7 +30,10 @@ The `GraphQLView` accepts two options at the moment: to disable it by passing `None`. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests -- `def encode_json(self, data: GraphQLHTTPResponse) -> str` +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## Extending the view diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index 0a8143657f..f2154309f6 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -111,6 +111,7 @@ def __init__( GRAPHQL_WS_PROTOCOL, ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), + multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get @@ -119,6 +120,7 @@ def __init__( self.debug = debug self.subscription_protocols = subscription_protocols self.connection_init_wait_timeout = connection_init_wait_timeout + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index d5aae404f6..d2647c6ee6 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -106,6 +106,7 @@ def __init__( GRAPHQL_WS_PROTOCOL, ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), + multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get @@ -114,6 +115,7 @@ def __init__( self.debug = debug self.protocols = subscription_protocols self.connection_init_wait_timeout = connection_init_wait_timeout + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index 9169265cbb..c7264a45b3 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -168,12 +168,14 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, subscriptions_enabled: bool = True, + multipart_uploads_enabled: bool = False, **kwargs: Any, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get self.subscriptions_enabled = subscriptions_enabled self._ide_subscriptions_enabled = subscriptions_enabled + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 0ce5bf920a..132c822f78 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -28,8 +28,7 @@ from django.template.exceptions import TemplateDoesNotExist from django.template.loader import render_to_string from django.template.response import TemplateResponse -from django.utils.decorators import classonlymethod, method_decorator -from django.views.decorators.csrf import csrf_exempt +from django.utils.decorators import classonlymethod from django.views.generic import View from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter @@ -147,11 +146,13 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, subscriptions_enabled: bool = False, + multipart_uploads_enabled: bool = False, **kwargs: Any, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get self.subscriptions_enabled = subscriptions_enabled + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( @@ -229,7 +230,6 @@ def get_context(self, request: HttpRequest, response: HttpResponse) -> Any: def get_sub_response(self, request: HttpRequest) -> TemporalHttpResponse: return TemporalHttpResponse() - @method_decorator(csrf_exempt) def dispatch( self, request: HttpRequest, *args: Any, **kwargs: Any ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponseBase]: @@ -288,7 +288,6 @@ async def get_context(self, request: HttpRequest, response: HttpResponse) -> Any async def get_sub_response(self, request: HttpRequest) -> TemporalHttpResponse: return TemporalHttpResponse() - @method_decorator(csrf_exempt) async def dispatch( # pyright: ignore self, request: HttpRequest, *args: Any, **kwargs: Any ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponseBase]: diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index 833b656383..badcfa33e0 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -156,6 +156,7 @@ def __init__( generate_unique_id_function: Callable[[APIRoute], str] = Default( generate_unique_id ), + multipart_uploads_enabled: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -190,6 +191,7 @@ def __init__( ) self.protocols = subscription_protocols self.connection_init_wait_timeout = connection_init_wait_timeout + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index b855c602ea..d952eb6aa9 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -71,10 +71,12 @@ def __init__( graphiql: Optional[bool] = None, graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, + multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema self.graphiql = graphiql self.allow_queries_via_get = allow_queries_via_get + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 7eef89aa40..c9f1e6ae49 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -333,7 +333,7 @@ async def parse_http_body( data = self.parse_query_params(request.query_params) elif "application/json" in content_type: data = self.parse_json(await request.get_body()) - elif content_type == "multipart/form-data": + elif self.multipart_uploads_enabled and content_type == "multipart/form-data": data = await self.parse_multipart(request) else: raise HTTPException(400, "Unsupported content type") diff --git a/strawberry/http/base.py b/strawberry/http/base.py index 7f8e1802bc..5ab57ef65d 100644 --- a/strawberry/http/base.py +++ b/strawberry/http/base.py @@ -23,6 +23,7 @@ def headers(self) -> Mapping[str, str]: ... class BaseView(Generic[Request]): graphql_ide: Optional[GraphQL_IDE] + multipart_uploads_enabled: bool = False # TODO: we might remove this in future :) _ide_replace_variables: bool = True diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index f1ce7ca19a..df770e0541 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -143,7 +143,7 @@ def parse_http_body(self, request: SyncHTTPRequestAdapter) -> GraphQLRequestData elif "application/json" in content_type: data = self.parse_json(request.body) # TODO: multipart via get? - elif content_type == "multipart/form-data": + elif self.multipart_uploads_enabled and content_type == "multipart/form-data": data = self.parse_multipart(request) elif self._is_multipart_subscriptions(content_type, params): raise HTTPException( diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index 7ff68c69ad..e5c27ffe87 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -410,6 +410,7 @@ def make_graphql_controller( GRAPHQL_WS_PROTOCOL, ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), + multipart_uploads_enabled: bool = False, ) -> Type[GraphQLController]: # sourcery skip: move-assign if context_getter is None: custom_context_getter_ = _none_custom_context_getter @@ -456,6 +457,7 @@ class _GraphQLController(GraphQLController): _GraphQLController.schema = schema_ _GraphQLController.allow_queries_via_get = allow_queries_via_get_ _GraphQLController.graphql_ide = graphql_ide_ + _GraphQLController.multipart_uploads_enabled = multipart_uploads_enabled return _GraphQLController diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index f9db21a01d..5aafcec514 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -61,9 +61,11 @@ def __init__( graphiql: Optional[bool] = None, graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, + multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index edb30075f6..b62a63ba65 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -102,11 +102,13 @@ def __init__( allow_queries_via_get: bool = True, json_encoder: Optional[Type[json.JSONEncoder]] = None, json_dumps_params: Optional[Dict[str, Any]] = None, + multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get self.json_encoder = json_encoder self.json_dumps_params = json_dumps_params + self.multipart_uploads_enabled = multipart_uploads_enabled if self.json_encoder is not None: # pragma: no cover warnings.warn( diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index cd552e877c..4979d7d480 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -72,6 +72,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): view = GraphQLView( schema=schema, @@ -79,6 +80,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, keep_alive=False, + multipart_uploads_enabled=multipart_uploads_enabled, ) view.result_override = result_override diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 72d9e95aa6..5734c8df16 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -74,6 +74,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): view = GraphQLView( schema, @@ -81,6 +82,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, keep_alive=False, + multipart_uploads_enabled=multipart_uploads_enabled, ) view.result_override = result_override diff --git a/tests/http/clients/async_django.py b/tests/http/clients/async_django.py index 0e8bfca0ed..5f5caf4dd7 100644 --- a/tests/http/clients/async_django.py +++ b/tests/http/clients/async_django.py @@ -43,6 +43,7 @@ async def _do_request(self, request: RequestFactory) -> Response: graphql_ide=self.graphql_ide, allow_queries_via_get=self.allow_queries_via_get, result_override=self.result_override, + multipart_uploads_enabled=self.multipart_uploads_enabled, ) try: diff --git a/tests/http/clients/async_flask.py b/tests/http/clients/async_flask.py index 86e8e5b7f9..4db37c8d1b 100644 --- a/tests/http/clients/async_flask.py +++ b/tests/http/clients/async_flask.py @@ -52,6 +52,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.app = Flask(__name__) self.app.debug = True @@ -63,6 +64,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + multipart_uploads_enabled=multipart_uploads_enabled, ) self.app.add_url_rule( diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index daf1106359..ff31e4111e 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -103,6 +103,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): ... @abc.abstractmethod diff --git a/tests/http/clients/chalice.py b/tests/http/clients/chalice.py index eddb7d8ada..2e01c1f400 100644 --- a/tests/http/clients/chalice.py +++ b/tests/http/clients/chalice.py @@ -50,6 +50,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.app = Chalice(app_name="TheStackBadger") diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index da981403bb..14abd5e4af 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -139,6 +139,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi( schema=schema, @@ -151,6 +152,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + multipart_uploads_enabled=multipart_uploads_enabled, ) def create_app(self, **kwargs: Any) -> None: @@ -260,6 +262,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.http_app = DebuggableSyncGraphQLHTTPConsumer.as_asgi( schema=schema, @@ -267,6 +270,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + multipart_uploads_enabled=multipart_uploads_enabled, ) diff --git a/tests/http/clients/django.py b/tests/http/clients/django.py index 75efa2825c..a871d7b63b 100644 --- a/tests/http/clients/django.py +++ b/tests/http/clients/django.py @@ -48,11 +48,13 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.graphiql = graphiql self.graphql_ide = graphql_ide self.allow_queries_via_get = allow_queries_via_get self.result_override = result_override + self.multipart_uploads_enabled = multipart_uploads_enabled def _get_header_name(self, key: str) -> str: return f"HTTP_{key.upper().replace('-', '_')}" @@ -75,6 +77,7 @@ async def _do_request(self, request: RequestFactory) -> Response: graphql_ide=self.graphql_ide, allow_queries_via_get=self.allow_queries_via_get, result_override=self.result_override, + multipart_uploads_enabled=self.multipart_uploads_enabled, ) try: diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index 1a8148c136..cddc43032f 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -86,6 +86,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.app = FastAPI() @@ -97,6 +98,7 @@ def __init__( root_value_getter=get_root_value, allow_queries_via_get=allow_queries_via_get, keep_alive=False, + multipart_uploads_enabled=multipart_uploads_enabled, ) graphql_app.result_override = result_override self.app.include_router(graphql_app, prefix="/graphql") diff --git a/tests/http/clients/flask.py b/tests/http/clients/flask.py index abc0e3cec4..7da42bbaff 100644 --- a/tests/http/clients/flask.py +++ b/tests/http/clients/flask.py @@ -61,6 +61,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.app = Flask(__name__) self.app.debug = True @@ -72,6 +73,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + multipart_uploads_enabled=multipart_uploads_enabled, ) self.app.add_url_rule( diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index ccf9999f7f..0b99b43729 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -59,12 +59,14 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.create_app( graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + multipart_uploads_enabled=multipart_uploads_enabled, ) def create_app(self, result_override: ResultOverrideFunction = None, **kwargs: Any): diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index 60bc14b8c2..d9a184dfd4 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -54,6 +54,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.app = Quart(__name__) self.app.debug = True @@ -65,6 +66,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + multipart_uploads_enabled=multipart_uploads_enabled, ) self.app.add_url_rule( diff --git a/tests/http/clients/sanic.py b/tests/http/clients/sanic.py index 449aa316e8..8edc351db9 100644 --- a/tests/http/clients/sanic.py +++ b/tests/http/clients/sanic.py @@ -53,6 +53,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.app = Sanic( f"test_{int(randint(0, 1000))}", # noqa: S311 @@ -63,6 +64,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + multipart_uploads_enabled=multipart_uploads_enabled, ) self.app.add_route( view, diff --git a/tests/http/test_upload.py b/tests/http/test_upload.py index 86871c9f2a..7a991db846 100644 --- a/tests/http/test_upload.py +++ b/tests/http/test_upload.py @@ -20,7 +20,18 @@ def http_client(http_client_class: Type[HttpClient]) -> HttpClient: return http_client_class() -async def test_upload(http_client: HttpClient): +@pytest.fixture() +def enabled_http_client(http_client_class: Type[HttpClient]) -> HttpClient: + with contextlib.suppress(ImportError): + from .clients.chalice import ChaliceHttpClient + + if http_client_class is ChaliceHttpClient: + pytest.xfail(reason="Chalice does not support uploads") + + return http_client_class(multipart_uploads_enabled=True) + + +async def test_multipart_uploads_are_disabled_by_default(http_client: HttpClient): f = BytesIO(b"strawberry") query = """ @@ -35,16 +46,35 @@ async def test_upload(http_client: HttpClient): files={"textFile": f}, ) + assert response.status_code == 400 + assert response.data == b"Unsupported content type" + + +async def test_upload(enabled_http_client: HttpClient): + f = BytesIO(b"strawberry") + + query = """ + mutation($textFile: Upload!) { + readText(textFile: $textFile) + } + """ + + response = await enabled_http_client.query( + query, + variables={"textFile": None}, + files={"textFile": f}, + ) + assert response.json.get("errors") is None assert response.json["data"] == {"readText": "strawberry"} -async def test_file_list_upload(http_client: HttpClient): +async def test_file_list_upload(enabled_http_client: HttpClient): query = "mutation($files: [Upload!]!) { readFiles(files: $files) }" file1 = BytesIO(b"strawberry1") file2 = BytesIO(b"strawberry2") - response = await http_client.query( + response = await enabled_http_client.query( query=query, variables={"files": [None, None]}, files={"file1": file1, "file2": file2}, @@ -57,12 +87,12 @@ async def test_file_list_upload(http_client: HttpClient): assert data["readFiles"][1] == "strawberry2" -async def test_nested_file_list(http_client: HttpClient): +async def test_nested_file_list(enabled_http_client: HttpClient): query = "mutation($folder: FolderInput!) { readFolder(folder: $folder) }" file1 = BytesIO(b"strawberry1") file2 = BytesIO(b"strawberry2") - response = await http_client.query( + response = await enabled_http_client.query( query=query, variables={"folder": {"files": [None, None]}}, files={"file1": file1, "file2": file2}, @@ -74,7 +104,7 @@ async def test_nested_file_list(http_client: HttpClient): assert data["readFolder"][1] == "strawberry2" -async def test_upload_single_and_list_file_together(http_client: HttpClient): +async def test_upload_single_and_list_file_together(enabled_http_client: HttpClient): query = """ mutation($files: [Upload!]!, $textFile: Upload!) { readFiles(files: $files) @@ -85,7 +115,7 @@ async def test_upload_single_and_list_file_together(http_client: HttpClient): file2 = BytesIO(b"strawberry2") file3 = BytesIO(b"strawberry3") - response = await http_client.query( + response = await enabled_http_client.query( query=query, variables={"files": [None, None], "textFile": None}, files={"file1": file1, "file2": file2, "textFile": file3}, @@ -98,7 +128,7 @@ async def test_upload_single_and_list_file_together(http_client: HttpClient): assert data["readText"] == "strawberry3" -async def test_upload_invalid_query(http_client: HttpClient): +async def test_upload_invalid_query(enabled_http_client: HttpClient): f = BytesIO(b"strawberry") query = """ @@ -106,7 +136,7 @@ async def test_upload_invalid_query(http_client: HttpClient): readT """ - response = await http_client.query( + response = await enabled_http_client.query( query, variables={"textFile": None}, files={"textFile": f}, @@ -122,7 +152,7 @@ async def test_upload_invalid_query(http_client: HttpClient): ] -async def test_upload_missing_file(http_client: HttpClient): +async def test_upload_missing_file(enabled_http_client: HttpClient): f = BytesIO(b"strawberry") query = """ @@ -131,7 +161,7 @@ async def test_upload_missing_file(http_client: HttpClient): } """ - response = await http_client.query( + response = await enabled_http_client.query( query, variables={"textFile": None}, # using the wrong name to simulate a missing file @@ -155,7 +185,7 @@ def value(self) -> bytes: return self.buffer.getvalue() -async def test_extra_form_data_fields_are_ignored(http_client: HttpClient): +async def test_extra_form_data_fields_are_ignored(enabled_http_client: HttpClient): query = """mutation($textFile: Upload!) { readText(textFile: $textFile) }""" @@ -175,7 +205,7 @@ async def test_extra_form_data_fields_are_ignored(http_client: HttpClient): data, header = encode_multipart_formdata(fields) - response = await http_client.post( + response = await enabled_http_client.post( url="/graphql", data=data, headers={ @@ -188,9 +218,9 @@ async def test_extra_form_data_fields_are_ignored(http_client: HttpClient): assert response.json["data"] == {"readText": "strawberry"} -async def test_sending_invalid_form_data(http_client: HttpClient): +async def test_sending_invalid_form_data(enabled_http_client: HttpClient): headers = {"content-type": "multipart/form-data; boundary=----fake"} - response = await http_client.post("/graphql", headers=headers) + response = await enabled_http_client.post("/graphql", headers=headers) assert response.status_code == 400 # TODO: consolidate this, it seems only AIOHTTP returns the second error @@ -202,7 +232,7 @@ async def test_sending_invalid_form_data(http_client: HttpClient): @pytest.mark.aiohttp -async def test_sending_invalid_json_body(http_client: HttpClient): +async def test_sending_invalid_json_body(enabled_http_client: HttpClient): f = BytesIO(b"strawberry") operations = "}" file_map = json.dumps({"textFile": ["variables.textFile"]}) @@ -215,7 +245,7 @@ async def test_sending_invalid_json_body(http_client: HttpClient): data, header = encode_multipart_formdata(fields) - response = await http_client.post( + response = await enabled_http_client.post( "/graphql", data=data, headers={