From 18f0f5d61478bf956b3123cba84b70a6e3f88d3e Mon Sep 17 00:00:00 2001 From: Emil 'Skeen' Madsen Date: Mon, 23 Sep 2024 16:53:59 +0200 Subject: [PATCH 01/31] Fix missing return annotation in get_customer docs (#3642) Using the documentation as-is produces this error: ``` MissingReturnAnnotationError: Return annotation missing for field "get_customer", did you forget to add it? ``` --- docs/types/schema.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/types/schema.md b/docs/types/schema.md index f82f4fe7cb..a0d14e8b51 100644 --- a/docs/types/schema.md +++ b/docs/types/schema.md @@ -91,7 +91,7 @@ class Query: @strawberry.field def get_customer( self, id: strawberry.ID - ): # -> Customer note we're returning the interface here + ) -> Customer: # note we're returning the interface here if id == "mark": return Individual(name="Mark", date_of_birth=date(1984, 5, 14)) From 37265b230e511480a9ceace492f9f6a484be1387 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Wed, 25 Sep 2024 18:08:23 +0200 Subject: [PATCH 02/31] Disable multipart uploads by default (#3645) * Disable multipart uploads by default * Document the new option * Stop disabling Django's CSRF protection by default * Document breaking changes * Add release file * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Bump date * Test * Add tweet file * Shorter tweet --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Patrick Arminio --- .github/workflows/test.yml | 14 ----- RELEASE.md | 7 +++ TWEET.md | 8 +++ docs/breaking-changes.md | 1 + docs/breaking-changes/0.243.0.md | 53 ++++++++++++++++ docs/integrations/aiohttp.md | 6 +- docs/integrations/asgi.md | 6 +- docs/integrations/channels.md | 4 ++ docs/integrations/django.md | 7 ++- docs/integrations/fastapi.md | 4 ++ docs/integrations/flask.md | 6 +- docs/integrations/litestar.md | 4 ++ docs/integrations/quart.md | 6 +- docs/integrations/sanic.md | 7 ++- strawberry/aiohttp/views.py | 2 + strawberry/asgi/__init__.py | 2 + strawberry/channels/handlers/http_handler.py | 2 + strawberry/django/views.py | 7 +-- strawberry/fastapi/router.py | 2 + strawberry/flask/views.py | 2 + strawberry/http/async_base_view.py | 2 +- strawberry/http/base.py | 1 + strawberry/http/sync_base_view.py | 2 +- strawberry/litestar/controller.py | 2 + strawberry/quart/views.py | 2 + strawberry/sanic/views.py | 2 + tests/http/clients/aiohttp.py | 2 + tests/http/clients/asgi.py | 2 + tests/http/clients/async_django.py | 1 + tests/http/clients/async_flask.py | 2 + tests/http/clients/base.py | 1 + tests/http/clients/chalice.py | 1 + tests/http/clients/channels.py | 4 ++ tests/http/clients/django.py | 3 + tests/http/clients/fastapi.py | 2 + tests/http/clients/flask.py | 2 + tests/http/clients/litestar.py | 2 + tests/http/clients/quart.py | 2 + tests/http/clients/sanic.py | 2 + tests/http/test_upload.py | 64 ++++++++++++++------ 40 files changed, 207 insertions(+), 44 deletions(-) create mode 100644 RELEASE.md create mode 100644 TWEET.md create mode 100644 docs/breaking-changes/0.243.0.md diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0a3c02a29c..0374bbe3d4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -59,20 +59,6 @@ jobs: 3.12 3.13-dev - - name: Pip and nox cache - id: cache - uses: actions/cache@v4 - with: - path: | - ~/.cache - ~/.nox - .nox - key: - ${{ runner.os }}-nox-${{ matrix.session.session }}-${{ env.pythonLocation }}-${{ - hashFiles('**/poetry.lock') }}-${{ hashFiles('**/noxfile.py') }} - restore-keys: | - ${{ runner.os }}-nox-${{ matrix.session.session }}-${{ env.pythonLocation }} - - run: pip install poetry nox nox-poetry uv - run: nox -r -t tests -s "${{ matrix.session.session }}" - uses: actions/upload-artifact@v4 diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..0d039486fc --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,7 @@ +Release type: minor + +Starting with this release, multipart uploads are disabled by default and Strawberry Django view is no longer implicitly exempted from Django's CSRF protection. +Both changes relieve users from implicit security implications inherited from the GraphQL multipart request specification which was enabled in Strawberry by default. + +These are breaking changes if you are using multipart uploads OR the Strawberry Django view. +Migrations guides including further information are available on the Strawberry website. diff --git a/TWEET.md b/TWEET.md new file mode 100644 index 0000000000..ec081269f6 --- /dev/null +++ b/TWEET.md @@ -0,0 +1,8 @@ +πŸ†• Release $version is out! Thanks to $contributor πŸ‘ + +We've made some important security changes regarding file uploads and CSRF in +Django. + +Check out our migration guides if you're using multipart or Django view. + +πŸ‘‡ $release_url diff --git a/docs/breaking-changes.md b/docs/breaking-changes.md index 9f0ed20a80..3df11e2b23 100644 --- a/docs/breaking-changes.md +++ b/docs/breaking-changes.md @@ -4,6 +4,7 @@ title: List of breaking changes and deprecations # List of breaking changes and deprecations +- [Version 0.243.0 - 25 September 2024](./breaking-changes/0.243.0.md) - [Version 0.240.0 - 10 September 2024](./breaking-changes/0.240.0.md) - [Version 0.236.0 - 17 July 2024](./breaking-changes/0.236.0.md) - [Version 0.233.0 - 29 May 2024](./breaking-changes/0.233.0.md) diff --git a/docs/breaking-changes/0.243.0.md b/docs/breaking-changes/0.243.0.md new file mode 100644 index 0000000000..0eec0b9633 --- /dev/null +++ b/docs/breaking-changes/0.243.0.md @@ -0,0 +1,53 @@ +--- +title: 0.243.0 Breaking Changes +slug: breaking-changes/0.243.0 +--- + +# v0.240.0 Breaking Changes + +Release v0.240.0 comes with two breaking changes regarding multipart file +uploads and Django CSRF protection. + +## Multipart uploads disabled by default + +Previously, support for uploads via the +[GraphQL multipart request specification](https://github.com/jaydenseric/graphql-multipart-request-spec) +was enabled by default. This implicitly required Strawberry users to consider +the +[security implications outlined in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security). +Given that most Strawberry users were likely not aware of this, we're making +multipart file upload support stictly opt-in via a new +`multipart_uploads_enabled` view settings. + +To enable multipart upload support for your Strawberry view integration, please +follow the updated integration guides and enable appropriate security +measurements for your server. + +## Django CSRF protection enabled + +Previously, the Strawberry Django view integration was internally exempted from +Django's built-in CSRF protection (i.e, the `CsrfViewMiddleware` middleware). +While this is how many GraphQL APIs operate, implicitly addded exemptions can +lead to security vulnerabilities. Instead, we delegate the decision of adding an +CSRF exemption to users now. + +Note that having the CSRF protection enabled on your Strawberry Django view +potentially requires all your clients to send an CSRF token with every request. +You can learn more about this in the official Django +[Cross Site Request Forgery protection documentation](https://docs.djangoproject.com/en/dev/ref/csrf/). + +To restore the behaviour of the integration before this release, you can add the +`csrf_exempt` decorator provided by Django yourself: + +```python +from django.urls import path +from django.views.decorators.csrf import csrf_exempt + +from strawberry.django.views import GraphQLView + +from api.schema import schema + +urlpatterns = [ + path("graphql/", csrf_exempt(GraphQLView.as_view(schema=schema))), +] +``` diff --git a/docs/integrations/aiohttp.md b/docs/integrations/aiohttp.md index cb622234fd..e968e43bf2 100644 --- a/docs/integrations/aiohttp.md +++ b/docs/integrations/aiohttp.md @@ -29,7 +29,7 @@ app.router.add_route("*", "/graphql", GraphQLView(schema=schema)) ## Options -The `GraphQLView` accepts two options at the moment: +The `GraphQLView` accepts the following options at the moment: - `schema`: mandatory, the schema created by `strawberry.Schema`. - `graphql_ide`: optional, defaults to `"graphiql"`, allows to choose the @@ -37,6 +37,10 @@ The `GraphQLView` accepts two options at the moment: to disable it by passing `None`. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## Extending the view diff --git a/docs/integrations/asgi.md b/docs/integrations/asgi.md index eaad8a52ff..dcb51643b3 100644 --- a/docs/integrations/asgi.md +++ b/docs/integrations/asgi.md @@ -29,7 +29,7 @@ app with `uvicorn server:app` ## Options -The `GraphQL` app accepts two options at the moment: +The `GraphQL` app accepts the following options at the moment: - `schema`: mandatory, the schema created by `strawberry.Schema`. - `graphql_ide`: optional, defaults to `"graphiql"`, allows to choose the @@ -37,6 +37,10 @@ The `GraphQL` app accepts two options at the moment: to disable it by passing `None`. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## Extending the view diff --git a/docs/integrations/channels.md b/docs/integrations/channels.md index 4fa4ae716f..ce4fc70fb6 100644 --- a/docs/integrations/channels.md +++ b/docs/integrations/channels.md @@ -524,6 +524,10 @@ GraphQLWebsocketCommunicator( queries via `GET` requests - `subscriptions_enabled`: optional boolean paramenter enabling subscriptions in the GraphiQL interface, defaults to `True` +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ### Extending the consumer diff --git a/docs/integrations/django.md b/docs/integrations/django.md index 2a4e75266a..2909aba6f9 100644 --- a/docs/integrations/django.md +++ b/docs/integrations/django.md @@ -10,13 +10,14 @@ It provides a view that you can use to serve your GraphQL schema: ```python from django.urls import path +from django.views.decorators.csrf import csrf_exempt from strawberry.django.views import GraphQLView from api.schema import schema urlpatterns = [ - path("graphql/", GraphQLView.as_view(schema=schema)), + path("graphql/", csrf_exempt(GraphQLView.as_view(schema=schema))), ] ``` @@ -40,6 +41,10 @@ The `GraphQLView` accepts the following arguments: queries via `GET` requests - `subscriptions_enabled`: optional boolean paramenter enabling subscriptions in the GraphiQL interface, defaults to `False`. +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## Deprecated options diff --git a/docs/integrations/fastapi.md b/docs/integrations/fastapi.md index 1fa144b974..18b7dae5b6 100644 --- a/docs/integrations/fastapi.md +++ b/docs/integrations/fastapi.md @@ -54,6 +54,10 @@ The `GraphQLRouter` accepts the following options: value. - `root_value_getter`: optional FastAPI dependency for providing custom root value. +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## context_getter diff --git a/docs/integrations/flask.md b/docs/integrations/flask.md index 9fe4e17ac3..b44e166efa 100644 --- a/docs/integrations/flask.md +++ b/docs/integrations/flask.md @@ -34,13 +34,17 @@ from strawberry.flask.views import AsyncGraphQLView ## Options -The `GraphQLView` accepts two options at the moment: +The `GraphQLView` accepts the following options at the moment: - `schema`: mandatory, the schema created by `strawberry.Schema`. - `graphiql:` optional, defaults to `True`, whether to enable the GraphiQL interface. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## Extending the view diff --git a/docs/integrations/litestar.md b/docs/integrations/litestar.md index 0ba7c1ded8..b505096626 100644 --- a/docs/integrations/litestar.md +++ b/docs/integrations/litestar.md @@ -61,6 +61,10 @@ The `make_graphql_controller` function accepts the following options: the maximum time to wait for the connection initialization message when using `graphql-transport-ws` [protocol](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#connectioninit) +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## context_getter diff --git a/docs/integrations/quart.md b/docs/integrations/quart.md index 05eb0cd994..fa9b071990 100644 --- a/docs/integrations/quart.md +++ b/docs/integrations/quart.md @@ -26,13 +26,17 @@ if __name__ == "__main__": ## Options -The `GraphQLView` accepts two options at the moment: +The `GraphQLView` accepts the following options at the moment: - `schema`: mandatory, the schema created by `strawberry.Schema`. - `graphiql:` optional, defaults to `True`, whether to enable the GraphiQL interface. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## Extending the view diff --git a/docs/integrations/sanic.md b/docs/integrations/sanic.md index 1701a74a83..0a7d885c7d 100644 --- a/docs/integrations/sanic.md +++ b/docs/integrations/sanic.md @@ -22,7 +22,7 @@ app.add_route( ## Options -The `GraphQLView` accepts two options at the moment: +The `GraphQLView` accepts the following options at the moment: - `schema`: mandatory, the schema created by `strawberry.Schema`. - `graphql_ide`: optional, defaults to `"graphiql"`, allows to choose the @@ -30,7 +30,10 @@ The `GraphQLView` accepts two options at the moment: to disable it by passing `None`. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests -- `def encode_json(self, data: GraphQLHTTPResponse) -> str` +- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether + to enable multipart uploads. Please make sure to consider the + [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) + when enabling this feature. ## Extending the view diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index 0a8143657f..f2154309f6 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -111,6 +111,7 @@ def __init__( GRAPHQL_WS_PROTOCOL, ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), + multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get @@ -119,6 +120,7 @@ def __init__( self.debug = debug self.subscription_protocols = subscription_protocols self.connection_init_wait_timeout = connection_init_wait_timeout + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index d5aae404f6..d2647c6ee6 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -106,6 +106,7 @@ def __init__( GRAPHQL_WS_PROTOCOL, ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), + multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get @@ -114,6 +115,7 @@ def __init__( self.debug = debug self.protocols = subscription_protocols self.connection_init_wait_timeout = connection_init_wait_timeout + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index 9169265cbb..c7264a45b3 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -168,12 +168,14 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, subscriptions_enabled: bool = True, + multipart_uploads_enabled: bool = False, **kwargs: Any, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get self.subscriptions_enabled = subscriptions_enabled self._ide_subscriptions_enabled = subscriptions_enabled + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 0ce5bf920a..132c822f78 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -28,8 +28,7 @@ from django.template.exceptions import TemplateDoesNotExist from django.template.loader import render_to_string from django.template.response import TemplateResponse -from django.utils.decorators import classonlymethod, method_decorator -from django.views.decorators.csrf import csrf_exempt +from django.utils.decorators import classonlymethod from django.views.generic import View from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter @@ -147,11 +146,13 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, subscriptions_enabled: bool = False, + multipart_uploads_enabled: bool = False, **kwargs: Any, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get self.subscriptions_enabled = subscriptions_enabled + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( @@ -229,7 +230,6 @@ def get_context(self, request: HttpRequest, response: HttpResponse) -> Any: def get_sub_response(self, request: HttpRequest) -> TemporalHttpResponse: return TemporalHttpResponse() - @method_decorator(csrf_exempt) def dispatch( self, request: HttpRequest, *args: Any, **kwargs: Any ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponseBase]: @@ -288,7 +288,6 @@ async def get_context(self, request: HttpRequest, response: HttpResponse) -> Any async def get_sub_response(self, request: HttpRequest) -> TemporalHttpResponse: return TemporalHttpResponse() - @method_decorator(csrf_exempt) async def dispatch( # pyright: ignore self, request: HttpRequest, *args: Any, **kwargs: Any ) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponseBase]: diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index 833b656383..badcfa33e0 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -156,6 +156,7 @@ def __init__( generate_unique_id_function: Callable[[APIRoute], str] = Default( generate_unique_id ), + multipart_uploads_enabled: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -190,6 +191,7 @@ def __init__( ) self.protocols = subscription_protocols self.connection_init_wait_timeout = connection_init_wait_timeout + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index b855c602ea..d952eb6aa9 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -71,10 +71,12 @@ def __init__( graphiql: Optional[bool] = None, graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, + multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema self.graphiql = graphiql self.allow_queries_via_get = allow_queries_via_get + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index 7eef89aa40..c9f1e6ae49 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -333,7 +333,7 @@ async def parse_http_body( data = self.parse_query_params(request.query_params) elif "application/json" in content_type: data = self.parse_json(await request.get_body()) - elif content_type == "multipart/form-data": + elif self.multipart_uploads_enabled and content_type == "multipart/form-data": data = await self.parse_multipart(request) else: raise HTTPException(400, "Unsupported content type") diff --git a/strawberry/http/base.py b/strawberry/http/base.py index 7f8e1802bc..5ab57ef65d 100644 --- a/strawberry/http/base.py +++ b/strawberry/http/base.py @@ -23,6 +23,7 @@ def headers(self) -> Mapping[str, str]: ... class BaseView(Generic[Request]): graphql_ide: Optional[GraphQL_IDE] + multipart_uploads_enabled: bool = False # TODO: we might remove this in future :) _ide_replace_variables: bool = True diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py index f1ce7ca19a..df770e0541 100644 --- a/strawberry/http/sync_base_view.py +++ b/strawberry/http/sync_base_view.py @@ -143,7 +143,7 @@ def parse_http_body(self, request: SyncHTTPRequestAdapter) -> GraphQLRequestData elif "application/json" in content_type: data = self.parse_json(request.body) # TODO: multipart via get? - elif content_type == "multipart/form-data": + elif self.multipart_uploads_enabled and content_type == "multipart/form-data": data = self.parse_multipart(request) elif self._is_multipart_subscriptions(content_type, params): raise HTTPException( diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index 7ff68c69ad..e5c27ffe87 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -410,6 +410,7 @@ def make_graphql_controller( GRAPHQL_WS_PROTOCOL, ), connection_init_wait_timeout: timedelta = timedelta(minutes=1), + multipart_uploads_enabled: bool = False, ) -> Type[GraphQLController]: # sourcery skip: move-assign if context_getter is None: custom_context_getter_ = _none_custom_context_getter @@ -456,6 +457,7 @@ class _GraphQLController(GraphQLController): _GraphQLController.schema = schema_ _GraphQLController.allow_queries_via_get = allow_queries_via_get_ _GraphQLController.graphql_ide = graphql_ide_ + _GraphQLController.multipart_uploads_enabled = multipart_uploads_enabled return _GraphQLController diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index f9db21a01d..5aafcec514 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -61,9 +61,11 @@ def __init__( graphiql: Optional[bool] = None, graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, + multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index edb30075f6..b62a63ba65 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -102,11 +102,13 @@ def __init__( allow_queries_via_get: bool = True, json_encoder: Optional[Type[json.JSONEncoder]] = None, json_dumps_params: Optional[Dict[str, Any]] = None, + multipart_uploads_enabled: bool = False, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get self.json_encoder = json_encoder self.json_dumps_params = json_dumps_params + self.multipart_uploads_enabled = multipart_uploads_enabled if self.json_encoder is not None: # pragma: no cover warnings.warn( diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index cd552e877c..4979d7d480 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -72,6 +72,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): view = GraphQLView( schema=schema, @@ -79,6 +80,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, keep_alive=False, + multipart_uploads_enabled=multipart_uploads_enabled, ) view.result_override = result_override diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 72d9e95aa6..5734c8df16 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -74,6 +74,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): view = GraphQLView( schema, @@ -81,6 +82,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, keep_alive=False, + multipart_uploads_enabled=multipart_uploads_enabled, ) view.result_override = result_override diff --git a/tests/http/clients/async_django.py b/tests/http/clients/async_django.py index 0e8bfca0ed..5f5caf4dd7 100644 --- a/tests/http/clients/async_django.py +++ b/tests/http/clients/async_django.py @@ -43,6 +43,7 @@ async def _do_request(self, request: RequestFactory) -> Response: graphql_ide=self.graphql_ide, allow_queries_via_get=self.allow_queries_via_get, result_override=self.result_override, + multipart_uploads_enabled=self.multipart_uploads_enabled, ) try: diff --git a/tests/http/clients/async_flask.py b/tests/http/clients/async_flask.py index 86e8e5b7f9..4db37c8d1b 100644 --- a/tests/http/clients/async_flask.py +++ b/tests/http/clients/async_flask.py @@ -52,6 +52,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.app = Flask(__name__) self.app.debug = True @@ -63,6 +64,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + multipart_uploads_enabled=multipart_uploads_enabled, ) self.app.add_url_rule( diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index daf1106359..ff31e4111e 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -103,6 +103,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): ... @abc.abstractmethod diff --git a/tests/http/clients/chalice.py b/tests/http/clients/chalice.py index eddb7d8ada..2e01c1f400 100644 --- a/tests/http/clients/chalice.py +++ b/tests/http/clients/chalice.py @@ -50,6 +50,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.app = Chalice(app_name="TheStackBadger") diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index da981403bb..14abd5e4af 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -139,6 +139,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi( schema=schema, @@ -151,6 +152,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + multipart_uploads_enabled=multipart_uploads_enabled, ) def create_app(self, **kwargs: Any) -> None: @@ -260,6 +262,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.http_app = DebuggableSyncGraphQLHTTPConsumer.as_asgi( schema=schema, @@ -267,6 +270,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + multipart_uploads_enabled=multipart_uploads_enabled, ) diff --git a/tests/http/clients/django.py b/tests/http/clients/django.py index 75efa2825c..a871d7b63b 100644 --- a/tests/http/clients/django.py +++ b/tests/http/clients/django.py @@ -48,11 +48,13 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.graphiql = graphiql self.graphql_ide = graphql_ide self.allow_queries_via_get = allow_queries_via_get self.result_override = result_override + self.multipart_uploads_enabled = multipart_uploads_enabled def _get_header_name(self, key: str) -> str: return f"HTTP_{key.upper().replace('-', '_')}" @@ -75,6 +77,7 @@ async def _do_request(self, request: RequestFactory) -> Response: graphql_ide=self.graphql_ide, allow_queries_via_get=self.allow_queries_via_get, result_override=self.result_override, + multipart_uploads_enabled=self.multipart_uploads_enabled, ) try: diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index 1a8148c136..cddc43032f 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -86,6 +86,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.app = FastAPI() @@ -97,6 +98,7 @@ def __init__( root_value_getter=get_root_value, allow_queries_via_get=allow_queries_via_get, keep_alive=False, + multipart_uploads_enabled=multipart_uploads_enabled, ) graphql_app.result_override = result_override self.app.include_router(graphql_app, prefix="/graphql") diff --git a/tests/http/clients/flask.py b/tests/http/clients/flask.py index abc0e3cec4..7da42bbaff 100644 --- a/tests/http/clients/flask.py +++ b/tests/http/clients/flask.py @@ -61,6 +61,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.app = Flask(__name__) self.app.debug = True @@ -72,6 +73,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + multipart_uploads_enabled=multipart_uploads_enabled, ) self.app.add_url_rule( diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index ccf9999f7f..0b99b43729 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -59,12 +59,14 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.create_app( graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + multipart_uploads_enabled=multipart_uploads_enabled, ) def create_app(self, result_override: ResultOverrideFunction = None, **kwargs: Any): diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py index 60bc14b8c2..d9a184dfd4 100644 --- a/tests/http/clients/quart.py +++ b/tests/http/clients/quart.py @@ -54,6 +54,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.app = Quart(__name__) self.app.debug = True @@ -65,6 +66,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + multipart_uploads_enabled=multipart_uploads_enabled, ) self.app.add_url_rule( diff --git a/tests/http/clients/sanic.py b/tests/http/clients/sanic.py index 449aa316e8..8edc351db9 100644 --- a/tests/http/clients/sanic.py +++ b/tests/http/clients/sanic.py @@ -53,6 +53,7 @@ def __init__( graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, result_override: ResultOverrideFunction = None, + multipart_uploads_enabled: bool = False, ): self.app = Sanic( f"test_{int(randint(0, 1000))}", # noqa: S311 @@ -63,6 +64,7 @@ def __init__( graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, result_override=result_override, + multipart_uploads_enabled=multipart_uploads_enabled, ) self.app.add_route( view, diff --git a/tests/http/test_upload.py b/tests/http/test_upload.py index 86871c9f2a..7a991db846 100644 --- a/tests/http/test_upload.py +++ b/tests/http/test_upload.py @@ -20,7 +20,18 @@ def http_client(http_client_class: Type[HttpClient]) -> HttpClient: return http_client_class() -async def test_upload(http_client: HttpClient): +@pytest.fixture() +def enabled_http_client(http_client_class: Type[HttpClient]) -> HttpClient: + with contextlib.suppress(ImportError): + from .clients.chalice import ChaliceHttpClient + + if http_client_class is ChaliceHttpClient: + pytest.xfail(reason="Chalice does not support uploads") + + return http_client_class(multipart_uploads_enabled=True) + + +async def test_multipart_uploads_are_disabled_by_default(http_client: HttpClient): f = BytesIO(b"strawberry") query = """ @@ -35,16 +46,35 @@ async def test_upload(http_client: HttpClient): files={"textFile": f}, ) + assert response.status_code == 400 + assert response.data == b"Unsupported content type" + + +async def test_upload(enabled_http_client: HttpClient): + f = BytesIO(b"strawberry") + + query = """ + mutation($textFile: Upload!) { + readText(textFile: $textFile) + } + """ + + response = await enabled_http_client.query( + query, + variables={"textFile": None}, + files={"textFile": f}, + ) + assert response.json.get("errors") is None assert response.json["data"] == {"readText": "strawberry"} -async def test_file_list_upload(http_client: HttpClient): +async def test_file_list_upload(enabled_http_client: HttpClient): query = "mutation($files: [Upload!]!) { readFiles(files: $files) }" file1 = BytesIO(b"strawberry1") file2 = BytesIO(b"strawberry2") - response = await http_client.query( + response = await enabled_http_client.query( query=query, variables={"files": [None, None]}, files={"file1": file1, "file2": file2}, @@ -57,12 +87,12 @@ async def test_file_list_upload(http_client: HttpClient): assert data["readFiles"][1] == "strawberry2" -async def test_nested_file_list(http_client: HttpClient): +async def test_nested_file_list(enabled_http_client: HttpClient): query = "mutation($folder: FolderInput!) { readFolder(folder: $folder) }" file1 = BytesIO(b"strawberry1") file2 = BytesIO(b"strawberry2") - response = await http_client.query( + response = await enabled_http_client.query( query=query, variables={"folder": {"files": [None, None]}}, files={"file1": file1, "file2": file2}, @@ -74,7 +104,7 @@ async def test_nested_file_list(http_client: HttpClient): assert data["readFolder"][1] == "strawberry2" -async def test_upload_single_and_list_file_together(http_client: HttpClient): +async def test_upload_single_and_list_file_together(enabled_http_client: HttpClient): query = """ mutation($files: [Upload!]!, $textFile: Upload!) { readFiles(files: $files) @@ -85,7 +115,7 @@ async def test_upload_single_and_list_file_together(http_client: HttpClient): file2 = BytesIO(b"strawberry2") file3 = BytesIO(b"strawberry3") - response = await http_client.query( + response = await enabled_http_client.query( query=query, variables={"files": [None, None], "textFile": None}, files={"file1": file1, "file2": file2, "textFile": file3}, @@ -98,7 +128,7 @@ async def test_upload_single_and_list_file_together(http_client: HttpClient): assert data["readText"] == "strawberry3" -async def test_upload_invalid_query(http_client: HttpClient): +async def test_upload_invalid_query(enabled_http_client: HttpClient): f = BytesIO(b"strawberry") query = """ @@ -106,7 +136,7 @@ async def test_upload_invalid_query(http_client: HttpClient): readT """ - response = await http_client.query( + response = await enabled_http_client.query( query, variables={"textFile": None}, files={"textFile": f}, @@ -122,7 +152,7 @@ async def test_upload_invalid_query(http_client: HttpClient): ] -async def test_upload_missing_file(http_client: HttpClient): +async def test_upload_missing_file(enabled_http_client: HttpClient): f = BytesIO(b"strawberry") query = """ @@ -131,7 +161,7 @@ async def test_upload_missing_file(http_client: HttpClient): } """ - response = await http_client.query( + response = await enabled_http_client.query( query, variables={"textFile": None}, # using the wrong name to simulate a missing file @@ -155,7 +185,7 @@ def value(self) -> bytes: return self.buffer.getvalue() -async def test_extra_form_data_fields_are_ignored(http_client: HttpClient): +async def test_extra_form_data_fields_are_ignored(enabled_http_client: HttpClient): query = """mutation($textFile: Upload!) { readText(textFile: $textFile) }""" @@ -175,7 +205,7 @@ async def test_extra_form_data_fields_are_ignored(http_client: HttpClient): data, header = encode_multipart_formdata(fields) - response = await http_client.post( + response = await enabled_http_client.post( url="/graphql", data=data, headers={ @@ -188,9 +218,9 @@ async def test_extra_form_data_fields_are_ignored(http_client: HttpClient): assert response.json["data"] == {"readText": "strawberry"} -async def test_sending_invalid_form_data(http_client: HttpClient): +async def test_sending_invalid_form_data(enabled_http_client: HttpClient): headers = {"content-type": "multipart/form-data; boundary=----fake"} - response = await http_client.post("/graphql", headers=headers) + response = await enabled_http_client.post("/graphql", headers=headers) assert response.status_code == 400 # TODO: consolidate this, it seems only AIOHTTP returns the second error @@ -202,7 +232,7 @@ async def test_sending_invalid_form_data(http_client: HttpClient): @pytest.mark.aiohttp -async def test_sending_invalid_json_body(http_client: HttpClient): +async def test_sending_invalid_json_body(enabled_http_client: HttpClient): f = BytesIO(b"strawberry") operations = "}" file_map = json.dumps({"textFile": ["variables.textFile"]}) @@ -215,7 +245,7 @@ async def test_sending_invalid_json_body(http_client: HttpClient): data, header = encode_multipart_formdata(fields) - response = await http_client.post( + response = await enabled_http_client.post( "/graphql", data=data, headers={ From 2f210673ee9e2b64128ce9eb47569aa644a61b96 Mon Sep 17 00:00:00 2001 From: Botberry Date: Wed, 25 Sep 2024 16:10:27 +0000 Subject: [PATCH 03/31] =?UTF-8?q?Release=20=F0=9F=8D=93=200.243.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 12 ++++++++++++ RELEASE.md | 7 ------- pyproject.toml | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 4943c0113a..6ac89bb427 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,18 @@ CHANGELOG ========= +0.243.0 - 2024-09-25 +-------------------- + +Starting with this release, multipart uploads are disabled by default and Strawberry Django view is no longer implicitly exempted from Django's CSRF protection. +Both changes relieve users from implicit security implications inherited from the GraphQL multipart request specification which was enabled in Strawberry by default. + +These are breaking changes if you are using multipart uploads OR the Strawberry Django view. +Migrations guides including further information are available on the Strawberry website. + +Contributed by [Jonathan Ehwald](https://github.com/DoctorJohn) via [PR #3645](https://github.com/strawberry-graphql/strawberry/pull/3645/) + + 0.242.0 - 2024-09-19 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index 0d039486fc..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,7 +0,0 @@ -Release type: minor - -Starting with this release, multipart uploads are disabled by default and Strawberry Django view is no longer implicitly exempted from Django's CSRF protection. -Both changes relieve users from implicit security implications inherited from the GraphQL multipart request specification which was enabled in Strawberry by default. - -These are breaking changes if you are using multipart uploads OR the Strawberry Django view. -Migrations guides including further information are available on the Strawberry website. diff --git a/pyproject.toml b/pyproject.toml index ef445cf7ac..8251617067 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.242.0" +version = "0.243.0" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From d803c408f3e9af4ba580f4dec6c2c6a90e14f5f0 Mon Sep 17 00:00:00 2001 From: Strawberry GraphQL Bot Date: Wed, 25 Sep 2024 16:13:30 +0000 Subject: [PATCH 04/31] Remove TWEET.md --- TWEET.md | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 TWEET.md diff --git a/TWEET.md b/TWEET.md deleted file mode 100644 index ec081269f6..0000000000 --- a/TWEET.md +++ /dev/null @@ -1,8 +0,0 @@ -πŸ†• Release $version is out! Thanks to $contributor πŸ‘ - -We've made some important security changes regarding file uploads and CSRF in -Django. - -Check out our migration guides if you're using multipart or Django view. - -πŸ‘‡ $release_url From b270fe9bbca37bfbdffbdcb936352a929bec061c Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Wed, 25 Sep 2024 18:25:01 +0200 Subject: [PATCH 05/31] Fix version typo --- docs/breaking-changes/0.243.0.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/breaking-changes/0.243.0.md b/docs/breaking-changes/0.243.0.md index 0eec0b9633..a56f09f637 100644 --- a/docs/breaking-changes/0.243.0.md +++ b/docs/breaking-changes/0.243.0.md @@ -3,7 +3,7 @@ title: 0.243.0 Breaking Changes slug: breaking-changes/0.243.0 --- -# v0.240.0 Breaking Changes +# v0.243.0 Breaking Changes Release v0.240.0 comes with two breaking changes regarding multipart file uploads and Django CSRF protection. From 8e92e2b9527a8899683369e16034200d28f450af Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Wed, 25 Sep 2024 18:25:27 +0200 Subject: [PATCH 06/31] Fix another typo --- docs/breaking-changes/0.243.0.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/breaking-changes/0.243.0.md b/docs/breaking-changes/0.243.0.md index a56f09f637..4eca524f94 100644 --- a/docs/breaking-changes/0.243.0.md +++ b/docs/breaking-changes/0.243.0.md @@ -5,7 +5,7 @@ slug: breaking-changes/0.243.0 # v0.243.0 Breaking Changes -Release v0.240.0 comes with two breaking changes regarding multipart file +Release v0.243.0 comes with two breaking changes regarding multipart file uploads and Django CSRF protection. ## Multipart uploads disabled by default From 3eb3b20157bdadb48645fc2ba4cd4e4ac2c0fb98 Mon Sep 17 00:00:00 2001 From: Krisque <63082743+chrisemke@users.noreply.github.com> Date: Thu, 26 Sep 2024 09:22:21 -0300 Subject: [PATCH 07/31] fix: pydantic ^2.9.0 needs 2 extra fields on to_argument (#3632) * fix: pydantic ^2.9.0 needs 2 extra fields on to_argument * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * docs: add RELEASE.md file * Update .alexrc to allow crash * Update RELEASE.md --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Patrick Arminio --- .alexrc | 3 ++- RELEASE.md | 3 +++ strawberry/ext/mypy_plugin.py | 9 ++++++++- 3 files changed, 13 insertions(+), 2 deletions(-) create mode 100644 RELEASE.md diff --git a/.alexrc b/.alexrc index db2914316a..587b769682 100644 --- a/.alexrc +++ b/.alexrc @@ -11,6 +11,7 @@ "execution", "special", "primitive", - "invalid" + "invalid", + "crash", ] } diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..9e2deb1dac --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: patch + +This releases adds support for Pydantic 2.9.0's Mypy plugin diff --git a/strawberry/ext/mypy_plugin.py b/strawberry/ext/mypy_plugin.py index 5d5c3d08ef..77a8a406f0 100644 --- a/strawberry/ext/mypy_plugin.py +++ b/strawberry/ext/mypy_plugin.py @@ -481,7 +481,14 @@ def strawberry_pydantic_class_callback(ctx: ClassDefContext) -> None: # Based on pydantic's default value # https://github.com/pydantic/pydantic/pull/9606/files#diff-469037bbe55bbf9aa359480a16040d368c676adad736e133fb07e5e20d6ac523R1066 extra["force_typevars_invariant"] = False - + if PYDANTIC_VERSION >= (2, 9, 0): + extra["model_strict"] = model_type.type.metadata[ + PYDANTIC_METADATA_KEY + ]["config"].get("strict", False) + extra["is_root_model_root"] = any( + "pydantic.root_model.RootModel" in base.fullname + for base in model_type.type.mro[:-1] + ) add_method( ctx, "to_pydantic", From 7c41804c49ccefa5119dcf930453fe56b87c5783 Mon Sep 17 00:00:00 2001 From: Botberry Date: Thu, 26 Sep 2024 12:23:44 +0000 Subject: [PATCH 08/31] =?UTF-8?q?Release=20=F0=9F=8D=93=200.243.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 8 ++++++++ RELEASE.md | 3 --- pyproject.toml | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ac89bb427..375a4c9206 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,14 @@ CHANGELOG ========= +0.243.1 - 2024-09-26 +-------------------- + +This releases adds support for Pydantic 2.9.0's Mypy plugin + +Contributed by [Krisque](https://github.com/chrisemke) via [PR #3632](https://github.com/strawberry-graphql/strawberry/pull/3632/) + + 0.243.0 - 2024-09-25 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index 9e2deb1dac..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,3 +0,0 @@ -Release type: patch - -This releases adds support for Pydantic 2.9.0's Mypy plugin diff --git a/pyproject.toml b/pyproject.toml index 8251617067..2159916a19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.243.0" +version = "0.243.1" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From 92ae942ab073eaafe728624403f9374ed5dc442f Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Thu, 26 Sep 2024 17:57:10 +0200 Subject: [PATCH 09/31] update doc for Flask and Quart (#3647) the graphiql -> graphql_ide had not been made for those views --- docs/integrations/flask.md | 5 +++-- docs/integrations/quart.md | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/integrations/flask.md b/docs/integrations/flask.md index b44e166efa..2e9c5eb389 100644 --- a/docs/integrations/flask.md +++ b/docs/integrations/flask.md @@ -37,8 +37,9 @@ from strawberry.flask.views import AsyncGraphQLView The `GraphQLView` accepts the following options at the moment: - `schema`: mandatory, the schema created by `strawberry.Schema`. -- `graphiql:` optional, defaults to `True`, whether to enable the GraphiQL - interface. +- `graphql_ide`: optional, defaults to `"graphiql"`, allows to choose the + GraphQL IDE interface (one of `graphiql`, `apollo-sandbox` or `pathfinder`) or + to disable it by passing `None`. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests - `multipart_uploads_enabled`: optional, defaults to `False`, controls whether diff --git a/docs/integrations/quart.md b/docs/integrations/quart.md index fa9b071990..a6d7061ba8 100644 --- a/docs/integrations/quart.md +++ b/docs/integrations/quart.md @@ -29,8 +29,9 @@ if __name__ == "__main__": The `GraphQLView` accepts the following options at the moment: - `schema`: mandatory, the schema created by `strawberry.Schema`. -- `graphiql:` optional, defaults to `True`, whether to enable the GraphiQL - interface. +- `graphql_ide`: optional, defaults to `"graphiql"`, allows to choose the + GraphQL IDE interface (one of `graphiql`, `apollo-sandbox` or `pathfinder`) or + to disable it by passing `None`. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests - `multipart_uploads_enabled`: optional, defaults to `False`, controls whether From 36bd99ea38925e955b2893db7b7a12d5931ad349 Mon Sep 17 00:00:00 2001 From: topher-g Date: Sat, 28 Sep 2024 04:38:05 -0500 Subject: [PATCH 10/31] Updated schema-configurations.md to correct the last code example (#3651) Co-authored-by: Chris Gill --- docs/types/schema-configurations.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/types/schema-configurations.md b/docs/types/schema-configurations.md index 5ad12c9c6c..396524938c 100644 --- a/docs/types/schema-configurations.md +++ b/docs/types/schema-configurations.md @@ -123,5 +123,5 @@ class CustomInfo(Info): return self.context["response"].headers -schema = strawberry.Schema(query=Query, info_class=CustomInfo) +schema = strawberry.Schema(query=Query, config=StrawberryConfig(info_class=CustomInfo)) ``` From 943736fcd651a9aa6994588c92e6e7380a34af39 Mon Sep 17 00:00:00 2001 From: Bingdom Date: Wed, 2 Oct 2024 01:15:30 +1000 Subject: [PATCH 11/31] Added subscription unsubscribing info (#3654) * Added subscription unsubscribing info * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typos * Follow python naming convention --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jonathan Ehwald --- docs/general/subscriptions.md | 49 +++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/docs/general/subscriptions.md b/docs/general/subscriptions.md index 7d4161bdc4..6363fc3051 100644 --- a/docs/general/subscriptions.md +++ b/docs/general/subscriptions.md @@ -254,6 +254,55 @@ schema = strawberry.Schema(query=Query, subscription=Subscription) [pep-525]: https://www.python.org/dev/peps/pep-0525/ +## Unsubscribing subscriptions + +In GraphQL, it is possible to unsubscribe from a subscription. Strawberry +supports this behaviour, and is done using a `try...except` block. + +In Apollo-client, closing a subscription can be achieved like the following: + +```javascript +const client = useApolloClient(); +const subscriber = client.subscribe({query: ...}).subscribe({...}) +// ... +// done with subscription. now unsubscribe +subscriber.unsubscribe(); +``` + +Strawberry can easily capture when a subscriber unsubscribes using an +`asyncio.CancelledError` exception. + +```python +import asyncio +from typing import AsyncGenerator +from uuid import uuid4 + +import strawberry + +# track active subscribers +event_messages = {} + + +@strawberry.type +class Subscription: + @strawberry.subscription + async def message(self) -> AsyncGenerator[int, None]: + try: + subscription_id = uuid4() + + event_messages[subscription_id] = [] + + while True: + if len(event_messages[subscription_id]) > 0: + yield event_messages[subscription_id] + event_messages[subscription_id].clear() + + await asyncio.sleep(1) + except asyncio.CancelledError: + # stop listening to events + del event_messages[subscription_id] +``` + ## GraphQL over WebSocket protocols Strawberry support both the legacy From 1b33547c9e8d0430b172b93faea8bafef68d67be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Oct 2024 10:01:07 +0200 Subject: [PATCH 12/31] chore: upgrade ruff version (#3644) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.5 β†’ v0.6.8](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.5...v0.6.8) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a5f0843ca3..556ad1132b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.5 + rev: v0.6.8 hooks: - id: ruff-format exclude: ^tests/\w+/snapshots/ From 40a05049732de6c0eea5cf6473f82798436908f5 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Sat, 5 Oct 2024 22:19:59 +0200 Subject: [PATCH 13/31] Integrate websockets into the async base view (#3638) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Patrick Arminio --- RELEASE.md | 4 + TWEET.md | 7 + strawberry/aiohttp/handlers/__init__.py | 6 - .../handlers/graphql_transport_ws_handler.py | 62 ------- .../aiohttp/handlers/graphql_ws_handler.py | 69 ------- strawberry/aiohttp/views.py | 102 ++++++----- strawberry/asgi/__init__.py | 123 ++++++------- strawberry/asgi/handlers/__init__.py | 6 - .../handlers/graphql_transport_ws_handler.py | 66 ------- .../asgi/handlers/graphql_ws_handler.py | 71 -------- strawberry/channels/__init__.py | 7 +- strawberry/channels/handlers/base.py | 4 +- .../handlers/graphql_transport_ws_handler.py | 62 ------- .../channels/handlers/graphql_ws_handler.py | 72 -------- strawberry/channels/handlers/http_handler.py | 19 +- strawberry/channels/handlers/ws_handler.py | 171 ++++++++++++------ strawberry/django/views.py | 20 +- strawberry/fastapi/handlers/__init__.py | 6 - .../handlers/graphql_transport_ws_handler.py | 20 -- .../fastapi/handlers/graphql_ws_handler.py | 18 -- strawberry/fastapi/router.py | 72 +++----- strawberry/flask/views.py | 16 +- strawberry/http/async_base_view.py | 152 ++++++++++++++-- strawberry/http/exceptions.py | 4 + strawberry/http/typevars.py | 12 +- strawberry/litestar/controller.py | 163 ++++++++--------- strawberry/litestar/handlers/__init__.py | 0 .../handlers/graphql_transport_ws_handler.py | 60 ------ .../litestar/handlers/graphql_ws_handler.py | 66 ------- strawberry/quart/views.py | 16 +- strawberry/sanic/views.py | 22 ++- .../graphql_transport_ws/handlers.py | 81 +++++---- .../protocols/graphql_ws/handlers.py | 91 +++++----- tests/aiohttp/__init__.py | 0 tests/aiohttp/app.py | 46 ----- tests/aiohttp/test_websockets.py | 110 ----------- tests/asgi/app.py | 24 --- tests/asgi/test_websockets.py | 94 ---------- tests/channels/test_layers.py | 4 +- tests/channels/test_testing.py | 4 +- tests/channels/test_ws_handler.py | 54 ------ tests/fastapi/test_websockets.py | 125 ------------- tests/http/clients/aiohttp.py | 22 +-- tests/http/clients/asgi.py | 26 ++- tests/http/clients/base.py | 58 ++++-- tests/http/clients/channels.py | 58 +++--- tests/http/clients/fastapi.py | 15 +- tests/http/clients/litestar.py | 22 +-- tests/litestar/test_websockets.py | 89 --------- tests/websockets/test_graphql_transport_ws.py | 15 +- tests/websockets/test_graphql_ws.py | 11 ++ tests/websockets/test_websockets.py | 82 +++++++++ 52 files changed, 891 insertions(+), 1638 deletions(-) create mode 100644 RELEASE.md create mode 100644 TWEET.md delete mode 100644 strawberry/aiohttp/handlers/__init__.py delete mode 100644 strawberry/aiohttp/handlers/graphql_transport_ws_handler.py delete mode 100644 strawberry/aiohttp/handlers/graphql_ws_handler.py delete mode 100644 strawberry/asgi/handlers/__init__.py delete mode 100644 strawberry/asgi/handlers/graphql_transport_ws_handler.py delete mode 100644 strawberry/asgi/handlers/graphql_ws_handler.py delete mode 100644 strawberry/channels/handlers/graphql_transport_ws_handler.py delete mode 100644 strawberry/channels/handlers/graphql_ws_handler.py delete mode 100644 strawberry/fastapi/handlers/__init__.py delete mode 100644 strawberry/fastapi/handlers/graphql_transport_ws_handler.py delete mode 100644 strawberry/fastapi/handlers/graphql_ws_handler.py delete mode 100644 strawberry/litestar/handlers/__init__.py delete mode 100644 strawberry/litestar/handlers/graphql_transport_ws_handler.py delete mode 100644 strawberry/litestar/handlers/graphql_ws_handler.py delete mode 100644 tests/aiohttp/__init__.py delete mode 100644 tests/aiohttp/app.py delete mode 100644 tests/aiohttp/test_websockets.py delete mode 100644 tests/asgi/app.py delete mode 100644 tests/asgi/test_websockets.py delete mode 100644 tests/channels/test_ws_handler.py delete mode 100644 tests/fastapi/test_websockets.py delete mode 100644 tests/litestar/test_websockets.py create mode 100644 tests/websockets/test_websockets.py diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..7f194a4b85 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,4 @@ +Release type: minor + +Starting with this release, WebSocket logic now lives in the base class shared between all HTTP integrations. +This makes the behaviour of WebSockets much more consistent between integrations and easier to maintain. diff --git a/TWEET.md b/TWEET.md new file mode 100644 index 0000000000..0437a68073 --- /dev/null +++ b/TWEET.md @@ -0,0 +1,7 @@ +πŸš€ Starting with Strawberry $version, WebSocket logic now lives in the base +class shared across all HTTP integrations. More consistent behavior and easier +maintenance for WebSockets across integrations. πŸŽ‰ + +Thanks to $contributor for the PR πŸ‘ + +$release_url diff --git a/strawberry/aiohttp/handlers/__init__.py b/strawberry/aiohttp/handlers/__init__.py deleted file mode 100644 index c769c4eec7..0000000000 --- a/strawberry/aiohttp/handlers/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from strawberry.aiohttp.handlers.graphql_transport_ws_handler import ( - GraphQLTransportWSHandler, -) -from strawberry.aiohttp.handlers.graphql_ws_handler import GraphQLWSHandler - -__all__ = ["GraphQLTransportWSHandler", "GraphQLWSHandler"] diff --git a/strawberry/aiohttp/handlers/graphql_transport_ws_handler.py b/strawberry/aiohttp/handlers/graphql_transport_ws_handler.py deleted file mode 100644 index 52350199f7..0000000000 --- a/strawberry/aiohttp/handlers/graphql_transport_ws_handler.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable, Dict - -from aiohttp import http, web -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( - BaseGraphQLTransportWSHandler, -) - -if TYPE_CHECKING: - from datetime import timedelta - - from strawberry.schema import BaseSchema - - -class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - connection_init_wait_timeout: timedelta, - get_context: Callable[..., Dict[str, Any]], - get_root_value: Any, - request: web.Request, - ) -> None: - super().__init__(schema, debug, connection_init_wait_timeout) - self._get_context = get_context - self._get_root_value = get_root_value - self._request = request - self._ws = web.WebSocketResponse(protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]) - - async def get_context(self) -> Any: - return await self._get_context(request=self._request, response=self._ws) # type: ignore - - async def get_root_value(self) -> Any: - return await self._get_root_value(request=self._request) - - async def send_json(self, data: dict) -> None: - await self._ws.send_json(data) - - async def close(self, code: int, reason: str) -> None: - await self._ws.close(code=code, message=reason.encode()) - - async def handle_request(self) -> web.StreamResponse: - await self._ws.prepare(self._request) - self.on_request_accepted() - - try: - async for ws_message in self._ws: # type: http.WSMessage - if ws_message.type == http.WSMsgType.TEXT: - await self.handle_message(ws_message.json()) - else: - error_message = "WebSocket message type must be text" - await self.handle_invalid_message(error_message) - finally: - await self.shutdown() - - return self._ws - - -__all__ = ["GraphQLTransportWSHandler"] diff --git a/strawberry/aiohttp/handlers/graphql_ws_handler.py b/strawberry/aiohttp/handlers/graphql_ws_handler.py deleted file mode 100644 index 677dd34884..0000000000 --- a/strawberry/aiohttp/handlers/graphql_ws_handler.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import annotations - -from contextlib import suppress -from typing import TYPE_CHECKING, Any, Callable, Optional - -from aiohttp import http, web -from strawberry.subscriptions import GRAPHQL_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler - -if TYPE_CHECKING: - from strawberry.schema import BaseSchema - from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage - - -class GraphQLWSHandler(BaseGraphQLWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - keep_alive: bool, - keep_alive_interval: float, - get_context: Callable, - get_root_value: Callable, - request: web.Request, - ) -> None: - super().__init__(schema, debug, keep_alive, keep_alive_interval) - self._get_context = get_context - self._get_root_value = get_root_value - self._request = request - self._ws = web.WebSocketResponse(protocols=[GRAPHQL_WS_PROTOCOL]) - - async def get_context(self) -> Any: - return await self._get_context(request=self._request, response=self._ws) - - async def get_root_value(self) -> Any: - return await self._get_root_value(request=self._request) - - async def send_json(self, data: OperationMessage) -> None: - await self._ws.send_json(data) - - async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: - message = reason.encode() if reason else b"" - await self._ws.close(code=code, message=message) - - async def handle_request(self) -> Any: - await self._ws.prepare(self._request) - - try: - async for ws_message in self._ws: # type: http.WSMessage - if ws_message.type == http.WSMsgType.TEXT: - message: OperationMessage = ws_message.json() - await self.handle_message(message) - else: - await self.close( - code=1002, reason="WebSocket message type must be text" - ) - finally: - if self.keep_alive_task: - self.keep_alive_task.cancel() - with suppress(BaseException): - await self.keep_alive_task - - for operation_id in list(self.subscriptions.keys()): - await self.cleanup_operation(operation_id) - - return self._ws - - -__all__ = ["GraphQLWSHandler"] diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index f2154309f6..884264dcb0 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -4,6 +4,7 @@ import warnings from datetime import timedelta from io import BytesIO +from json.decoder import JSONDecodeError from typing import ( TYPE_CHECKING, Any, @@ -16,15 +17,16 @@ Union, cast, ) +from typing_extensions import TypeGuard -from aiohttp import web +from aiohttp import http, web from aiohttp.multipart import BodyPartReader -from strawberry.aiohttp.handlers import ( - GraphQLTransportWSHandler, - GraphQLWSHandler, +from strawberry.http.async_base_view import ( + AsyncBaseHTTPView, + AsyncHTTPRequestAdapter, + AsyncWebSocketAdapter, ) -from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter -from strawberry.http.exceptions import HTTPException +from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived from strawberry.http.types import FormData, HTTPMethod, QueryParams from strawberry.http.typevars import ( Context, @@ -79,11 +81,36 @@ def content_type(self) -> Optional[str]: return self.headers.get("content-type") +class AioHTTPWebSocketAdapter(AsyncWebSocketAdapter): + def __init__(self, request: web.Request, ws: web.WebSocketResponse) -> None: + self.request = request + self.ws = ws + + async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: + async for ws_message in self.ws: + if ws_message.type == http.WSMsgType.TEXT: + try: + yield ws_message.json() + except JSONDecodeError: + raise NonJsonMessageReceived() + + elif ws_message.type == http.WSMsgType.BINARY: + raise NonJsonMessageReceived() + + async def send_json(self, message: Mapping[str, object]) -> None: + await self.ws.send_json(message) + + async def close(self, code: int, reason: str) -> None: + await self.ws.close(code=code, message=reason.encode()) + + class GraphQLView( AsyncBaseHTTPView[ web.Request, Union[web.Response, web.StreamResponse], web.Response, + web.Request, + web.WebSocketResponse, Context, RootValue, ] @@ -92,10 +119,9 @@ class GraphQLView( # bare handler function. _is_coroutine = asyncio.coroutines._is_coroutine # type: ignore[attr-defined] - graphql_transport_ws_handler_class = GraphQLTransportWSHandler - graphql_ws_handler_class = GraphQLWSHandler allow_queries_via_get = True request_adapter_class = AioHTTPRequestAdapter + websocket_adapter_class = AioHTTPWebSocketAdapter def __init__( self, @@ -138,48 +164,36 @@ async def render_graphql_ide(self, request: web.Request) -> web.Response: async def get_sub_response(self, request: web.Request) -> web.Response: return web.Response() - async def __call__(self, request: web.Request) -> web.StreamResponse: + def is_websocket_request(self, request: web.Request) -> TypeGuard[web.Request]: ws = web.WebSocketResponse(protocols=self.subscription_protocols) - ws_test = ws.can_prepare(request) - - if not ws_test.ok: - try: - return await self.run(request=request) - except HTTPException as e: - return web.Response( - body=e.reason, - status=e.status_code, - ) - - if ws_test.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: - return await self.graphql_transport_ws_handler_class( - schema=self.schema, - debug=self.debug, - connection_init_wait_timeout=self.connection_init_wait_timeout, - get_context=self.get_context, # type: ignore - get_root_value=self.get_root_value, - request=request, - ).handle() - elif ws_test.protocol == GRAPHQL_WS_PROTOCOL: - return await self.graphql_ws_handler_class( - schema=self.schema, - debug=self.debug, - keep_alive=self.keep_alive, - keep_alive_interval=self.keep_alive_interval, - get_context=self.get_context, - get_root_value=self.get_root_value, - request=request, - ).handle() - else: - await ws.prepare(request) - await ws.close(code=4406, message=b"Subprotocol not acceptable") - return ws + return ws.can_prepare(request).ok + + async def pick_websocket_subprotocol(self, request: web.Request) -> Optional[str]: + ws = web.WebSocketResponse(protocols=self.subscription_protocols) + return ws.can_prepare(request).protocol + + async def create_websocket_response( + self, request: web.Request, subprotocol: Optional[str] + ) -> web.WebSocketResponse: + protocols = [subprotocol] if subprotocol else [] + ws = web.WebSocketResponse(protocols=protocols) + await ws.prepare(request) + return ws + + async def __call__(self, request: web.Request) -> web.StreamResponse: + try: + return await self.run(request=request) + except HTTPException as e: + return web.Response( + body=e.reason, + status=e.status_code, + ) async def get_root_value(self, request: web.Request) -> Optional[RootValue]: return None async def get_context( - self, request: web.Request, response: web.Response + self, request: web.Request, response: Union[web.Response, web.WebSocketResponse] ) -> Context: return {"request": request, "response": response} # type: ignore diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index d2647c6ee6..d10d207987 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -2,9 +2,11 @@ import warnings from datetime import timedelta +from json import JSONDecodeError from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, AsyncIterator, Callable, Dict, @@ -14,6 +16,7 @@ Union, cast, ) +from typing_extensions import TypeGuard from starlette import status from starlette.requests import Request @@ -23,14 +26,14 @@ Response, StreamingResponse, ) -from starlette.websockets import WebSocket +from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState -from strawberry.asgi.handlers import ( - GraphQLTransportWSHandler, - GraphQLWSHandler, +from strawberry.http.async_base_view import ( + AsyncBaseHTTPView, + AsyncHTTPRequestAdapter, + AsyncWebSocketAdapter, ) -from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter -from strawberry.http.exceptions import HTTPException +from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived from strawberry.http.types import FormData, HTTPMethod, QueryParams from strawberry.http.typevars import ( Context, @@ -78,19 +81,41 @@ async def get_form_data(self) -> FormData: ) +class ASGIWebSocketAdapter(AsyncWebSocketAdapter): + def __init__(self, request: WebSocket, response: WebSocket) -> None: + self.ws = response + + async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: + try: + try: + while self.ws.application_state != WebSocketState.DISCONNECTED: + yield await self.ws.receive_json() + except (KeyError, JSONDecodeError): + raise NonJsonMessageReceived() + except WebSocketDisconnect: # pragma: no cover + pass + + async def send_json(self, message: Mapping[str, object]) -> None: + await self.ws.send_json(message) + + async def close(self, code: int, reason: str) -> None: + await self.ws.close(code=code, reason=reason) + + class GraphQL( AsyncBaseHTTPView[ - Union[Request, WebSocket], + Request, Response, Response, + WebSocket, + WebSocket, Context, RootValue, ] ): - graphql_transport_ws_handler_class = GraphQLTransportWSHandler - graphql_ws_handler_class = GraphQLWSHandler allow_queries_via_get = True - request_adapter_class = ASGIRequestAdapter # pyright: ignore + request_adapter_class = ASGIRequestAdapter + websocket_adapter_class = ASGIWebSocketAdapter def __init__( self, @@ -129,51 +154,25 @@ def __init__( async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "http": - return await self.handle_http(scope, receive, send) + http_request = Request(scope=scope, receive=receive) - elif scope["type"] == "websocket": - ws = WebSocket(scope, receive=receive, send=send) - preferred_protocol = self.pick_preferred_protocol(ws) - - if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: - await self.graphql_transport_ws_handler_class( - schema=self.schema, - debug=self.debug, - connection_init_wait_timeout=self.connection_init_wait_timeout, - get_context=self.get_context, - get_root_value=self.get_root_value, - ws=ws, - ).handle() - - elif preferred_protocol == GRAPHQL_WS_PROTOCOL: - await self.graphql_ws_handler_class( - schema=self.schema, - debug=self.debug, - keep_alive=self.keep_alive, - keep_alive_interval=self.keep_alive_interval, - get_context=self.get_context, - get_root_value=self.get_root_value, - ws=ws, - ).handle() - - else: - # Subprotocol not acceptable - await ws.close(code=4406) + try: + response = await self.run(http_request) + except HTTPException as e: + response = PlainTextResponse(e.reason, status_code=e.status_code) + await response(scope, receive, send) + elif scope["type"] == "websocket": + ws_request = WebSocket(scope, receive=receive, send=send) + await self.run(ws_request) else: # pragma: no cover raise ValueError("Unknown scope type: {!r}".format(scope["type"])) - def pick_preferred_protocol(self, ws: WebSocket) -> Optional[str]: - protocols = ws["subprotocols"] - intersection = set(protocols) & set(self.protocols) - sorted_intersection = sorted(intersection, key=protocols.index) - return next(iter(sorted_intersection), None) - async def get_root_value(self, request: Union[Request, WebSocket]) -> Optional[Any]: return None async def get_context( - self, request: Union[Request, WebSocket], response: Response + self, request: Union[Request, WebSocket], response: Union[Response, WebSocket] ) -> Context: return {"request": request, "response": response} # type: ignore @@ -187,21 +186,6 @@ async def get_sub_response( return sub_response - async def handle_http( - self, - scope: Scope, - receive: Receive, - send: Send, - ) -> None: - request = Request(scope=scope, receive=receive) - - try: - response = await self.run(request) - except HTTPException as e: - response = PlainTextResponse(e.reason, status_code=e.status_code) # pyright: ignore - - await response(scope, receive, send) - async def render_graphql_ide(self, request: Union[Request, WebSocket]) -> Response: return HTMLResponse(self.graphql_ide_html) @@ -239,3 +223,20 @@ async def create_streaming_response( **headers, }, ) + + def is_websocket_request( + self, request: Union[Request, WebSocket] + ) -> TypeGuard[WebSocket]: + return request.scope["type"] == "websocket" + + async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]: + protocols = request["subprotocols"] + intersection = set(protocols) & set(self.protocols) + sorted_intersection = sorted(intersection, key=protocols.index) + return next(iter(sorted_intersection), None) + + async def create_websocket_response( + self, request: WebSocket, subprotocol: Optional[str] + ) -> WebSocket: + await request.accept(subprotocol=subprotocol) + return request diff --git a/strawberry/asgi/handlers/__init__.py b/strawberry/asgi/handlers/__init__.py deleted file mode 100644 index 1891a06ad0..0000000000 --- a/strawberry/asgi/handlers/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from strawberry.asgi.handlers.graphql_transport_ws_handler import ( - GraphQLTransportWSHandler, -) -from strawberry.asgi.handlers.graphql_ws_handler import GraphQLWSHandler - -__all__ = ["GraphQLTransportWSHandler", "GraphQLWSHandler"] diff --git a/strawberry/asgi/handlers/graphql_transport_ws_handler.py b/strawberry/asgi/handlers/graphql_transport_ws_handler.py deleted file mode 100644 index 7cec132ffd..0000000000 --- a/strawberry/asgi/handlers/graphql_transport_ws_handler.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable - -from starlette.websockets import WebSocketDisconnect, WebSocketState - -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( - BaseGraphQLTransportWSHandler, -) - -if TYPE_CHECKING: - from datetime import timedelta - - from starlette.websockets import WebSocket - - from strawberry.schema import BaseSchema - - -class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - connection_init_wait_timeout: timedelta, - get_context: Callable, - get_root_value: Callable, - ws: WebSocket, - ) -> None: - super().__init__(schema, debug, connection_init_wait_timeout) - self._get_context = get_context - self._get_root_value = get_root_value - self._ws = ws - - async def get_context(self) -> Any: - return await self._get_context(request=self._ws, response=None) - - async def get_root_value(self) -> Any: - return await self._get_root_value(request=self._ws) - - async def send_json(self, data: dict) -> None: - await self._ws.send_json(data) - - async def close(self, code: int, reason: str) -> None: - await self._ws.close(code=code, reason=reason) - - async def handle_request(self) -> None: - await self._ws.accept(subprotocol=GRAPHQL_TRANSPORT_WS_PROTOCOL) - self.on_request_accepted() - - try: - while self._ws.application_state != WebSocketState.DISCONNECTED: - try: - message = await self._ws.receive_json() - except KeyError: # noqa: PERF203 - error_message = "WebSocket message type must be text" - await self.handle_invalid_message(error_message) - else: - await self.handle_message(message) - except WebSocketDisconnect: # pragma: no cover - pass - finally: - await self.shutdown() - - -__all__ = ["GraphQLTransportWSHandler"] diff --git a/strawberry/asgi/handlers/graphql_ws_handler.py b/strawberry/asgi/handlers/graphql_ws_handler.py deleted file mode 100644 index 00a314bbd0..0000000000 --- a/strawberry/asgi/handlers/graphql_ws_handler.py +++ /dev/null @@ -1,71 +0,0 @@ -from __future__ import annotations - -from contextlib import suppress -from typing import TYPE_CHECKING, Any, Callable, Optional - -from starlette.websockets import WebSocketDisconnect, WebSocketState - -from strawberry.subscriptions import GRAPHQL_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler - -if TYPE_CHECKING: - from starlette.websockets import WebSocket - - from strawberry.schema import BaseSchema - from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage - - -class GraphQLWSHandler(BaseGraphQLWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - keep_alive: bool, - keep_alive_interval: float, - get_context: Callable, - get_root_value: Callable, - ws: WebSocket, - ) -> None: - super().__init__(schema, debug, keep_alive, keep_alive_interval) - self._get_context = get_context - self._get_root_value = get_root_value - self._ws = ws - - async def get_context(self) -> Any: - return await self._get_context(request=self._ws, response=None) - - async def get_root_value(self) -> Any: - return await self._get_root_value(request=self._ws) - - async def send_json(self, data: OperationMessage) -> None: - await self._ws.send_json(data) - - async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: - await self._ws.close(code=code, reason=reason) - - async def handle_request(self) -> Any: - await self._ws.accept(subprotocol=GRAPHQL_WS_PROTOCOL) - - try: - while self._ws.application_state != WebSocketState.DISCONNECTED: - try: - message = await self._ws.receive_json() - except KeyError: # noqa: PERF203 - await self.close( - code=1002, reason="WebSocket message type must be text" - ) - else: - await self.handle_message(message) - except WebSocketDisconnect: # pragma: no cover - pass - finally: - if self.keep_alive_task: - self.keep_alive_task.cancel() - with suppress(BaseException): - await self.keep_alive_task - - for operation_id in list(self.subscriptions.keys()): - await self.cleanup_operation(operation_id) - - -__all__ = ["GraphQLWSHandler"] diff --git a/strawberry/channels/__init__.py b/strawberry/channels/__init__.py index 455513babb..f67fb25a82 100644 --- a/strawberry/channels/__init__.py +++ b/strawberry/channels/__init__.py @@ -1,6 +1,4 @@ -from .handlers.base import ChannelsConsumer, ChannelsWSConsumer -from .handlers.graphql_transport_ws_handler import GraphQLTransportWSHandler -from .handlers.graphql_ws_handler import GraphQLWSHandler +from .handlers.base import ChannelsConsumer from .handlers.http_handler import ( ChannelsRequest, GraphQLHTTPConsumer, @@ -12,10 +10,7 @@ __all__ = [ "ChannelsConsumer", "ChannelsRequest", - "ChannelsWSConsumer", "GraphQLProtocolTypeRouter", - "GraphQLWSHandler", - "GraphQLTransportWSHandler", "GraphQLHTTPConsumer", "GraphQLWSConsumer", "SyncGraphQLHTTPConsumer", diff --git a/strawberry/channels/handlers/base.py b/strawberry/channels/handlers/base.py index ec2ffe6b2c..769ec569e5 100644 --- a/strawberry/channels/handlers/base.py +++ b/strawberry/channels/handlers/base.py @@ -16,7 +16,7 @@ from weakref import WeakSet from channels.consumer import AsyncConsumer -from channels.generic.websocket import AsyncJsonWebsocketConsumer +from channels.generic.websocket import AsyncWebsocketConsumer class ChannelsMessage(TypedDict, total=False): @@ -210,7 +210,7 @@ async def _listen_to_channel_generator( return -class ChannelsWSConsumer(ChannelsConsumer, AsyncJsonWebsocketConsumer): +class ChannelsWSConsumer(ChannelsConsumer, AsyncWebsocketConsumer): """Base channels websocket async consumer.""" diff --git a/strawberry/channels/handlers/graphql_transport_ws_handler.py b/strawberry/channels/handlers/graphql_transport_ws_handler.py deleted file mode 100644 index db290f4ef8..0000000000 --- a/strawberry/channels/handlers/graphql_transport_ws_handler.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable, Optional - -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( - BaseGraphQLTransportWSHandler, -) - -if TYPE_CHECKING: - from datetime import timedelta - - from strawberry.channels.handlers.base import ChannelsWSConsumer - from strawberry.schema import BaseSchema - - -class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - connection_init_wait_timeout: timedelta, - get_context: Callable, - get_root_value: Callable, - ws: ChannelsWSConsumer, - ) -> None: - super().__init__(schema, debug, connection_init_wait_timeout) - self._get_context = get_context - self._get_root_value = get_root_value - self._ws = ws - - async def get_context(self) -> Any: - return await self._get_context( - request=self._ws, connection_params=self.connection_params - ) - - async def get_root_value(self) -> Any: - return await self._get_root_value(request=self._ws) - - async def send_json(self, data: dict) -> None: - await self._ws.send_json(data) - - async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: - # TODO: We are using `self._ws.base_send` directly instead of `self._ws.close` - # because the later doesn't accept the `reason` argument. - await self._ws.base_send( - { - "type": "websocket.close", - "code": code, - "reason": reason or "", - } - ) - - async def handle_request(self) -> Any: - await self._ws.accept(subprotocol=GRAPHQL_TRANSPORT_WS_PROTOCOL) - self.on_request_accepted() - - async def handle_disconnect(self, code: int) -> None: - await self.shutdown() - - -__all__ = ["GraphQLTransportWSHandler"] diff --git a/strawberry/channels/handlers/graphql_ws_handler.py b/strawberry/channels/handlers/graphql_ws_handler.py deleted file mode 100644 index 6d967a1d15..0000000000 --- a/strawberry/channels/handlers/graphql_ws_handler.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations - -from contextlib import suppress -from typing import TYPE_CHECKING, Any, Callable, Optional - -from strawberry.subscriptions import GRAPHQL_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler - -if TYPE_CHECKING: - from strawberry.channels.handlers.base import ChannelsWSConsumer - from strawberry.schema import BaseSchema - from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage - - -class GraphQLWSHandler(BaseGraphQLWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - keep_alive: bool, - keep_alive_interval: float, - get_context: Callable, - get_root_value: Callable, - ws: ChannelsWSConsumer, - ) -> None: - super().__init__(schema, debug, keep_alive, keep_alive_interval) - self._get_context = get_context - self._get_root_value = get_root_value - self._ws = ws - - async def get_context(self) -> Any: - return await self._get_context( - request=self._ws, connection_params=self.connection_params - ) - - async def get_root_value(self) -> Any: - return await self._get_root_value(request=self._ws) - - async def send_json(self, data: OperationMessage) -> None: - await self._ws.send_json(data) - - async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: - # TODO: We are using `self._ws.base_send` directly instead of `self._ws.close` - # because the latler doesn't accept the `reason` argument. - await self._ws.base_send( - { - "type": "websocket.close", - "code": code, - "reason": reason or "", - } - ) - - async def handle_request(self) -> Any: - await self._ws.accept(subprotocol=GRAPHQL_WS_PROTOCOL) - - async def handle_disconnect(self, code: int) -> None: - if self.keep_alive_task: - self.keep_alive_task.cancel() - with suppress(BaseException): - await self.keep_alive_task - - for operation_id in list(self.subscriptions.keys()): - await self.cleanup_operation(operation_id) - - async def handle_invalid_message(self, error_message: str) -> None: - # This is not part of the BaseGraphQLWSHandler's interface, but the - # channels integration is a high level wrapper that forwards this to - # both us and the BaseGraphQLTransportWSHandler. - await self.close(code=1002, reason=error_message) - - -__all__ = ["GraphQLWSHandler"] diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index c7264a45b3..8d682eea74 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -15,7 +15,7 @@ Optional, Union, ) -from typing_extensions import assert_never +from typing_extensions import TypeGuard, assert_never from urllib.parse import parse_qs from django.conf import settings @@ -233,6 +233,8 @@ class GraphQLHTTPConsumer( ChannelsRequest, Union[ChannelsResponse, MultipartChannelsResponse], TemporalResponse, + ChannelsRequest, + TemporalResponse, Context, RootValue, ], @@ -298,6 +300,21 @@ async def render_graphql_ide(self, request: ChannelsRequest) -> ChannelsResponse content=self.graphql_ide_html.encode(), content_type="text/html" ) + def is_websocket_request( + self, request: ChannelsRequest + ) -> TypeGuard[ChannelsRequest]: + return False + + async def pick_websocket_subprotocol( + self, request: ChannelsRequest + ) -> Optional[str]: + return None + + async def create_websocket_response( + self, request: ChannelsRequest, subprotocol: Optional[str] + ) -> TemporalResponse: + raise NotImplementedError + class SyncGraphQLHTTPConsumer( BaseGraphQLHTTPConsumer, diff --git a/strawberry/channels/handlers/ws_handler.py b/strawberry/channels/handlers/ws_handler.py index 2991059afd..b267f7ea9b 100644 --- a/strawberry/channels/handlers/ws_handler.py +++ b/strawberry/channels/handlers/ws_handler.py @@ -1,20 +1,76 @@ from __future__ import annotations +import asyncio import datetime -from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Union - +import json +from typing import ( + TYPE_CHECKING, + AsyncGenerator, + Dict, + Mapping, + Optional, + Tuple, + TypedDict, + Union, +) +from typing_extensions import TypeGuard + +from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter +from strawberry.http.exceptions import NonJsonMessageReceived +from strawberry.http.typevars import Context, RootValue from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL -from .base import ChannelsConsumer, ChannelsWSConsumer -from .graphql_transport_ws_handler import GraphQLTransportWSHandler -from .graphql_ws_handler import GraphQLWSHandler +from .base import ChannelsWSConsumer if TYPE_CHECKING: - from strawberry.http.typevars import Context, RootValue + from strawberry.http import GraphQLHTTPResponse from strawberry.schema import BaseSchema -class GraphQLWSConsumer(ChannelsWSConsumer): +class ChannelsWebSocketAdapter(AsyncWebSocketAdapter): + def __init__(self, request: GraphQLWSConsumer, response: GraphQLWSConsumer) -> None: + self.ws_consumer = response + + async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: + while True: + message = await self.ws_consumer.message_queue.get() + + if message["disconnected"]: + break + + if message["message"] is None: + raise NonJsonMessageReceived() + + try: + yield json.loads(message["message"]) + except json.JSONDecodeError: + raise NonJsonMessageReceived() + + async def send_json(self, message: Mapping[str, object]) -> None: + serialized_message = json.dumps(message) + await self.ws_consumer.send(serialized_message) + + async def close(self, code: int, reason: str) -> None: + await self.ws_consumer.close(code=code, reason=reason) + + +class MessageQueueData(TypedDict): + message: Union[str, None] + disconnected: bool + + +class GraphQLWSConsumer( + ChannelsWSConsumer, + AsyncBaseHTTPView[ + "GraphQLWSConsumer", + "GraphQLWSConsumer", + "GraphQLWSConsumer", + "GraphQLWSConsumer", + "GraphQLWSConsumer", + Context, + RootValue, + ], +): """A channels websocket consumer for GraphQL. This handles the connections, then hands off to the appropriate @@ -39,9 +95,7 @@ class GraphQLWSConsumer(ChannelsWSConsumer): ``` """ - graphql_transport_ws_handler_class = GraphQLTransportWSHandler - graphql_ws_handler_class = GraphQLWSHandler - _handler: Union[GraphQLWSHandler, GraphQLTransportWSHandler] + websocket_adapter_class = ChannelsWebSocketAdapter def __init__( self, @@ -63,70 +117,71 @@ def __init__( self.keep_alive_interval = keep_alive_interval self.debug = debug self.protocols = subscription_protocols + self.message_queue: asyncio.Queue[MessageQueueData] = asyncio.Queue() + self.run_task: Optional[asyncio.Task] = None super().__init__() - def pick_preferred_protocol( - self, accepted_subprotocols: Sequence[str] - ) -> Optional[str]: - intersection = set(accepted_subprotocols) & set(self.protocols) - sorted_intersection = sorted(intersection, key=accepted_subprotocols.index) - return next(iter(sorted_intersection), None) - async def connect(self) -> None: - preferred_protocol = self.pick_preferred_protocol(self.scope["subprotocols"]) - - if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: - self._handler = self.graphql_transport_ws_handler_class( - schema=self.schema, - debug=self.debug, - connection_init_wait_timeout=self.connection_init_wait_timeout, - get_context=self.get_context, - get_root_value=self.get_root_value, - ws=self, - ) - elif preferred_protocol == GRAPHQL_WS_PROTOCOL: - self._handler = self.graphql_ws_handler_class( - schema=self.schema, - debug=self.debug, - keep_alive=self.keep_alive, - keep_alive_interval=self.keep_alive_interval, - get_context=self.get_context, - get_root_value=self.get_root_value, - ws=self, - ) - else: - # Subprotocol not acceptable - return await self.close(code=4406) + self.run_task = asyncio.create_task(self.run(self)) - await self._handler.handle() - return None - - async def receive(self, *args: str, **kwargs: Any) -> None: - # Overriding this so that we can pass the errors to handle_invalid_message - try: - await super().receive(*args, **kwargs) - except ValueError: - reason = "WebSocket message type must be text" - await self._handler.handle_invalid_message(reason) - - async def receive_json(self, content: Any, **kwargs: Any) -> None: - await self._handler.handle_message(content) + async def receive( + self, text_data: Optional[str] = None, bytes_data: Optional[bytes] = None + ) -> None: + if text_data: + self.message_queue.put_nowait({"message": text_data, "disconnected": False}) + else: + self.message_queue.put_nowait({"message": None, "disconnected": False}) async def disconnect(self, code: int) -> None: - await self._handler.handle_disconnect(code) + self.message_queue.put_nowait({"message": None, "disconnected": True}) + assert self.run_task + await self.run_task - async def get_root_value(self, request: ChannelsConsumer) -> Optional[RootValue]: + async def get_root_value(self, request: GraphQLWSConsumer) -> Optional[RootValue]: return None async def get_context( - self, request: ChannelsConsumer, connection_params: Any + self, request: GraphQLWSConsumer, response: GraphQLWSConsumer ) -> Context: return { "request": request, - "connection_params": connection_params, "ws": request, } # type: ignore + @property + def allow_queries_via_get(self) -> bool: + return False + + async def get_sub_response(self, request: GraphQLWSConsumer) -> GraphQLWSConsumer: + raise NotImplementedError + + def create_response( + self, response_data: GraphQLHTTPResponse, sub_response: GraphQLWSConsumer + ) -> GraphQLWSConsumer: + raise NotImplementedError + + async def render_graphql_ide(self, request: GraphQLWSConsumer) -> GraphQLWSConsumer: + raise NotImplementedError + + def is_websocket_request( + self, request: GraphQLWSConsumer + ) -> TypeGuard[GraphQLWSConsumer]: + return True + + async def pick_websocket_subprotocol( + self, request: GraphQLWSConsumer + ) -> Optional[str]: + protocols = request.scope["subprotocols"] + intersection = set(protocols) & set(self.protocols) + sorted_intersection = sorted(intersection, key=protocols.index) + return next(iter(sorted_intersection), None) + + async def create_websocket_response( + self, request: GraphQLWSConsumer, subprotocol: Optional[str] + ) -> GraphQLWSConsumer: + await request.accept(subprotocol=subprotocol) + return request + __all__ = ["GraphQLWSConsumer"] diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 132c822f78..457314d93b 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -13,6 +13,7 @@ Union, cast, ) +from typing_extensions import TypeGuard from asgiref.sync import markcoroutinefunction from django.core.serializers.json import DjangoJSONEncoder @@ -258,7 +259,13 @@ def render_graphql_ide(self, request: HttpRequest) -> HttpResponse: class AsyncGraphQLView( BaseView, AsyncBaseHTTPView[ - HttpRequest, HttpResponseBase, TemporalHttpResponse, Context, RootValue + HttpRequest, + HttpResponseBase, + TemporalHttpResponse, + HttpRequest, + TemporalHttpResponse, + Context, + RootValue, ], View, ): @@ -312,5 +319,16 @@ async def render_graphql_ide(self, request: HttpRequest) -> HttpResponse: return response + def is_websocket_request(self, request: HttpRequest) -> TypeGuard[HttpRequest]: + return False + + async def pick_websocket_subprotocol(self, request: HttpRequest) -> Optional[str]: + raise NotImplementedError + + async def create_websocket_response( + self, request: HttpRequest, subprotocol: Optional[str] + ) -> TemporalHttpResponse: + raise NotImplementedError + __all__ = ["GraphQLView", "AsyncGraphQLView"] diff --git a/strawberry/fastapi/handlers/__init__.py b/strawberry/fastapi/handlers/__init__.py deleted file mode 100644 index 20f336f5ff..0000000000 --- a/strawberry/fastapi/handlers/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from strawberry.fastapi.handlers.graphql_transport_ws_handler import ( - GraphQLTransportWSHandler, -) -from strawberry.fastapi.handlers.graphql_ws_handler import GraphQLWSHandler - -__all__ = ["GraphQLTransportWSHandler", "GraphQLWSHandler"] diff --git a/strawberry/fastapi/handlers/graphql_transport_ws_handler.py b/strawberry/fastapi/handlers/graphql_transport_ws_handler.py deleted file mode 100644 index 817f6996ac..0000000000 --- a/strawberry/fastapi/handlers/graphql_transport_ws_handler.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Any - -from strawberry.asgi.handlers import ( - GraphQLTransportWSHandler as BaseGraphQLTransportWSHandler, -) -from strawberry.fastapi.context import BaseContext - - -class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): - async def get_context(self) -> Any: - context = await self._get_context() - if isinstance(context, BaseContext): - context.connection_params = self.connection_params - return context - - async def get_root_value(self) -> Any: - return await self._get_root_value() - - -__all__ = ["GraphQLTransportWSHandler"] diff --git a/strawberry/fastapi/handlers/graphql_ws_handler.py b/strawberry/fastapi/handlers/graphql_ws_handler.py deleted file mode 100644 index 0c43bbbd6e..0000000000 --- a/strawberry/fastapi/handlers/graphql_ws_handler.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Any - -from strawberry.asgi.handlers import GraphQLWSHandler as BaseGraphQLWSHandler -from strawberry.fastapi.context import BaseContext - - -class GraphQLWSHandler(BaseGraphQLWSHandler): - async def get_context(self) -> Any: - context = await self._get_context() - if isinstance(context, BaseContext): - context.connection_params = self.connection_params - return context - - async def get_root_value(self) -> Any: - return await self._get_root_value() - - -__all__ = ["GraphQLWSHandler"] diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index badcfa33e0..3ed8e6a4a0 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -17,6 +17,7 @@ Union, cast, ) +from typing_extensions import TypeGuard from starlette import status from starlette.background import BackgroundTasks # noqa: TCH002 @@ -34,10 +35,9 @@ from fastapi.datastructures import Default from fastapi.routing import APIRoute from fastapi.utils import generate_unique_id -from strawberry.asgi import ASGIRequestAdapter +from strawberry.asgi import ASGIRequestAdapter, ASGIWebSocketAdapter from strawberry.exceptions import InvalidCustomContext from strawberry.fastapi.context import BaseContext, CustomContext -from strawberry.fastapi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.http import process_result from strawberry.http.async_base_view import AsyncBaseHTTPView from strawberry.http.exceptions import HTTPException @@ -58,12 +58,14 @@ class GraphQLRouter( - AsyncBaseHTTPView[Request, Response, Response, Context, RootValue], APIRouter + AsyncBaseHTTPView[ + Request, Response, Response, WebSocket, WebSocket, Context, RootValue + ], + APIRouter, ): - graphql_ws_handler_class = GraphQLWSHandler - graphql_transport_ws_handler_class = GraphQLTransportWSHandler allow_queries_via_get = True request_adapter_class = ASGIRequestAdapter + websocket_adapter_class = ASGIWebSocketAdapter @staticmethod async def __get_root_value() -> None: @@ -261,44 +263,7 @@ async def websocket_endpoint( # pyright: ignore context: Context = Depends(self.context_getter), root_value: RootValue = Depends(self.root_value_getter), ) -> None: - async def _get_context() -> Context: - return context - - async def _get_root_value() -> RootValue: - return root_value - - preferred_protocol = self.pick_preferred_protocol(websocket) - if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: - await self.graphql_transport_ws_handler_class( - schema=self.schema, - debug=self.debug, - connection_init_wait_timeout=self.connection_init_wait_timeout, - get_context=_get_context, - get_root_value=_get_root_value, - ws=websocket, - ).handle() - elif preferred_protocol == GRAPHQL_WS_PROTOCOL: - await self.graphql_ws_handler_class( - schema=self.schema, - debug=self.debug, - keep_alive=self.keep_alive, - keep_alive_interval=self.keep_alive_interval, - get_context=_get_context, - get_root_value=_get_root_value, - ws=websocket, - ).handle() - else: - # Code 4406 is "Subprotocol not acceptable" - await websocket.close(code=4406) - - def pick_preferred_protocol(self, ws: WebSocket) -> Optional[str]: - protocols = ws["subprotocols"] - intersection = set(protocols) & set(self.protocols) - return min( - intersection, - key=lambda i: protocols.index(i), - default=None, - ) + await self.run(request=websocket, context=context, root_value=root_value) async def render_graphql_ide(self, request: Request) -> HTMLResponse: return HTMLResponse(self.graphql_ide_html) @@ -309,12 +274,12 @@ async def process_result( return process_result(result) async def get_context( - self, request: Request, response: Response + self, request: Union[Request, WebSocket], response: Union[Response, WebSocket] ) -> Context: # pragma: no cover raise ValueError("`get_context` is not used by FastAPI GraphQL Router") async def get_root_value( - self, request: Request + self, request: Union[Request, WebSocket] ) -> Optional[RootValue]: # pragma: no cover raise ValueError("`get_root_value` is not used by FastAPI GraphQL Router") @@ -350,5 +315,22 @@ async def create_streaming_response( }, ) + def is_websocket_request( + self, request: Union[Request, WebSocket] + ) -> TypeGuard[WebSocket]: + return request.scope["type"] == "websocket" + + async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]: + protocols = request["subprotocols"] + intersection = set(protocols) & set(self.protocols) + sorted_intersection = sorted(intersection, key=protocols.index) + return next(iter(sorted_intersection), None) + + async def create_websocket_response( + self, request: WebSocket, subprotocol: Optional[str] + ) -> WebSocket: + await request.accept(subprotocol=subprotocol) + return request + __all__ = ["GraphQLRouter"] diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index d952eb6aa9..2dc15d6d6c 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -9,6 +9,7 @@ Union, cast, ) +from typing_extensions import TypeGuard from flask import Request, Response, render_template_string, request from flask.views import View @@ -159,7 +160,9 @@ async def get_form_data(self) -> FormData: class AsyncGraphQLView( BaseGraphQLView, - AsyncBaseHTTPView[Request, Response, Response, Context, RootValue], + AsyncBaseHTTPView[ + Request, Response, Response, Request, Response, Context, RootValue + ], View, ): methods = ["GET", "POST"] @@ -187,6 +190,17 @@ async def dispatch_request(self) -> ResponseReturnValue: # type: ignore async def render_graphql_ide(self, request: Request) -> Response: return render_template_string(self.graphql_ide_html) # type: ignore + def is_websocket_request(self, request: Request) -> TypeGuard[Request]: + return False + + async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]: + raise NotImplementedError + + async def create_websocket_response( + self, request: Request, subprotocol: Optional[str] + ) -> Response: + raise NotImplementedError + __all__ = [ "GraphQLView", diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index c9f1e6ae49..a7666018ef 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -2,6 +2,7 @@ import asyncio import contextlib import json +from datetime import timedelta from typing import ( Any, AsyncGenerator, @@ -13,8 +14,10 @@ Optional, Tuple, Union, + cast, + overload, ) -from typing_extensions import Literal +from typing_extensions import Literal, TypeGuard from graphql import GraphQLError @@ -29,6 +32,11 @@ from strawberry.http.ides import GraphQL_IDE from strawberry.schema.base import BaseSchema from strawberry.schema.exceptions import InvalidOperationTypeError +from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL +from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( + BaseGraphQLTransportWSHandler, +) +from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler from strawberry.types import ExecutionResult, SubscriptionExecutionResult from strawberry.types.graphql import OperationType @@ -36,7 +44,15 @@ from .exceptions import HTTPException from .parse_content_type import parse_content_type from .types import FormData, HTTPMethod, QueryParams -from .typevars import Context, Request, Response, RootValue, SubResponse +from .typevars import ( + Context, + Request, + Response, + RootValue, + SubResponse, + WebSocketRequest, + WebSocketResponse, +) class AsyncHTTPRequestAdapter(abc.ABC): @@ -63,14 +79,42 @@ async def get_body(self) -> Union[str, bytes]: ... async def get_form_data(self) -> FormData: ... +class AsyncWebSocketAdapter(abc.ABC): + @abc.abstractmethod + def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: ... + + @abc.abstractmethod + async def send_json(self, message: Mapping[str, object]) -> None: ... + + @abc.abstractmethod + async def close(self, code: int, reason: str) -> None: ... + + class AsyncBaseHTTPView( abc.ABC, BaseView[Request], - Generic[Request, Response, SubResponse, Context, RootValue], + Generic[ + Request, + Response, + SubResponse, + WebSocketRequest, + WebSocketResponse, + Context, + RootValue, + ], ): schema: BaseSchema graphql_ide: Optional[GraphQL_IDE] + debug: bool + keep_alive = False + keep_alive_interval: Optional[float] = None + connection_init_wait_timeout: timedelta = timedelta(minutes=1) request_adapter_class: Callable[[Request], AsyncHTTPRequestAdapter] + websocket_adapter_class: Callable[ + [WebSocketRequest, WebSocketResponse], AsyncWebSocketAdapter + ] + graphql_transport_ws_handler_class = BaseGraphQLTransportWSHandler + graphql_ws_handler_class = BaseGraphQLWSHandler @property @abc.abstractmethod @@ -80,10 +124,16 @@ def allow_queries_via_get(self) -> bool: ... async def get_sub_response(self, request: Request) -> SubResponse: ... @abc.abstractmethod - async def get_context(self, request: Request, response: SubResponse) -> Context: ... + async def get_context( + self, + request: Union[Request, WebSocketRequest], + response: Union[SubResponse, WebSocketResponse], + ) -> Context: ... @abc.abstractmethod - async def get_root_value(self, request: Request) -> Optional[RootValue]: ... + async def get_root_value( + self, request: Union[Request, WebSocketRequest] + ) -> Optional[RootValue]: ... @abc.abstractmethod def create_response( @@ -102,6 +152,21 @@ async def create_streaming_response( ) -> Response: raise ValueError("Multipart responses are not supported") + @abc.abstractmethod + def is_websocket_request( + self, request: Union[Request, WebSocketRequest] + ) -> TypeGuard[WebSocketRequest]: ... + + @abc.abstractmethod + async def pick_websocket_subprotocol( + self, request: WebSocketRequest + ) -> Optional[str]: ... + + @abc.abstractmethod + async def create_websocket_response( + self, request: WebSocketRequest, subprotocol: Optional[str] + ) -> WebSocketResponse: ... + async def execute_operation( self, request: Request, context: Context, root_value: Optional[RootValue] ) -> Union[ExecutionResult, SubscriptionExecutionResult]: @@ -167,35 +232,90 @@ def _handle_errors( ) -> None: """Hook to allow custom handling of errors, used by the Sentry Integration.""" + @overload async def run( self, request: Request, context: Optional[Context] = UNSET, root_value: Optional[RootValue] = UNSET, - ) -> Response: - request_adapter = self.request_adapter_class(request) + ) -> Response: ... - if not self.is_request_allowed(request_adapter): - raise HTTPException(405, "GraphQL only supports GET and POST requests.") + @overload + async def run( + self, + request: WebSocketRequest, + context: Optional[Context] = UNSET, + root_value: Optional[RootValue] = UNSET, + ) -> WebSocketResponse: ... - if self.should_render_graphql_ide(request_adapter): - if self.graphql_ide: - return await self.render_graphql_ide(request) + async def run( + self, + request: Union[Request, WebSocketRequest], + context: Optional[Context] = UNSET, + root_value: Optional[RootValue] = UNSET, + ) -> Union[Response, WebSocketResponse]: + root_value = ( + await self.get_root_value(request) if root_value is UNSET else root_value + ) + + if self.is_websocket_request(request): + websocket_subprotocol = await self.pick_websocket_subprotocol(request) + websocket_response = await self.create_websocket_response( + request, websocket_subprotocol + ) + websocket = self.websocket_adapter_class(request, websocket_response) + + context = ( + await self.get_context(request, response=websocket_response) + if context is UNSET + else context + ) + + if websocket_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: + await self.graphql_transport_ws_handler_class( + websocket=websocket, + context=context, + root_value=root_value, + schema=self.schema, + debug=self.debug, + connection_init_wait_timeout=self.connection_init_wait_timeout, + ).handle() + elif websocket_subprotocol == GRAPHQL_WS_PROTOCOL: + await self.graphql_ws_handler_class( + websocket=websocket, + context=context, + root_value=root_value, + schema=self.schema, + debug=self.debug, + keep_alive=self.keep_alive, + keep_alive_interval=self.keep_alive_interval, + ).handle() else: - raise HTTPException(404, "Not Found") + await websocket.close(4406, "Subprotocol not acceptable") + + return websocket_response + else: + request = cast(Request, request) + request_adapter = self.request_adapter_class(request) sub_response = await self.get_sub_response(request) context = ( await self.get_context(request, response=sub_response) if context is UNSET else context ) - root_value = ( - await self.get_root_value(request) if root_value is UNSET else root_value - ) assert context + if not self.is_request_allowed(request_adapter): + raise HTTPException(405, "GraphQL only supports GET and POST requests.") + + if self.should_render_graphql_ide(request_adapter): + if self.graphql_ide: + return await self.render_graphql_ide(request) + else: + raise HTTPException(404, "Not Found") + try: result = await self.execute_operation( request=request, context=context, root_value=root_value diff --git a/strawberry/http/exceptions.py b/strawberry/http/exceptions.py index d934696806..feddf77631 100644 --- a/strawberry/http/exceptions.py +++ b/strawberry/http/exceptions.py @@ -4,4 +4,8 @@ def __init__(self, status_code: int, reason: str) -> None: self.reason = reason +class NonJsonMessageReceived(Exception): + pass + + __all__ = ["HTTPException"] diff --git a/strawberry/http/typevars.py b/strawberry/http/typevars.py index a48cba848e..53a5d5ac33 100644 --- a/strawberry/http/typevars.py +++ b/strawberry/http/typevars.py @@ -3,8 +3,18 @@ Request = TypeVar("Request", contravariant=True) Response = TypeVar("Response") SubResponse = TypeVar("SubResponse") +WebSocketRequest = TypeVar("WebSocketRequest") +WebSocketResponse = TypeVar("WebSocketResponse") Context = TypeVar("Context") RootValue = TypeVar("RootValue") -__all__ = ["Request", "Response", "SubResponse", "Context", "RootValue"] +__all__ = [ + "Request", + "Response", + "SubResponse", + "WebSocketRequest", + "WebSocketResponse", + "Context", + "RootValue", +] diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index e5c27ffe87..fed3d2d45f 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -7,19 +7,19 @@ from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, AsyncIterator, Callable, Dict, FrozenSet, - List, Optional, - Set, Tuple, Type, TypedDict, Union, cast, ) +from typing_extensions import TypeGuard from msgspec import Struct @@ -35,23 +35,24 @@ ) from litestar.background_tasks import BackgroundTasks from litestar.di import Provide -from litestar.exceptions import NotFoundException, ValidationException +from litestar.exceptions import ( + NotFoundException, + SerializationException, + ValidationException, + WebSocketDisconnect, +) from litestar.response.streaming import Stream from litestar.status_codes import HTTP_200_OK from strawberry.exceptions import InvalidCustomContext -from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter -from strawberry.http.exceptions import HTTPException +from strawberry.http.async_base_view import ( + AsyncBaseHTTPView, + AsyncHTTPRequestAdapter, + AsyncWebSocketAdapter, +) +from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived from strawberry.http.types import FormData, HTTPMethod, QueryParams from strawberry.http.typevars import Context, RootValue from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_transport_ws import ( - WS_4406_PROTOCOL_NOT_ACCEPTABLE, -) - -from .handlers.graphql_transport_ws_handler import ( - GraphQLTransportWSHandler as BaseGraphQLTransportWSHandler, -) -from .handlers.graphql_ws_handler import GraphQLWSHandler as BaseGraphQLWSHandler if TYPE_CHECKING: from collections.abc import Mapping @@ -152,22 +153,6 @@ class GraphQLResource(Struct): extensions: Optional[dict[str, object]] -class GraphQLWSHandler(BaseGraphQLWSHandler): - async def get_context(self) -> Any: - return await self._get_context() - - async def get_root_value(self) -> Any: - return await self._get_root_value() - - -class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): - async def get_context(self) -> Any: - return await self._get_context() - - async def get_root_value(self) -> Any: - return await self._get_root_value() - - class LitestarRequestAdapter(AsyncHTTPRequestAdapter): def __init__(self, request: Request[Any, Any, Any]) -> None: self.request = request @@ -203,10 +188,37 @@ async def get_form_data(self) -> FormData: return FormData(form=multipart_data, files=multipart_data) +class LitestarWebSocketAdapter(AsyncWebSocketAdapter): + def __init__(self, request: WebSocket, response: WebSocket) -> None: + self.ws = response + + async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: + try: + try: + while self.ws.connection_state != "disconnect": + yield await self.ws.receive_json() + except (SerializationException, ValueError): + raise NonJsonMessageReceived() + except WebSocketDisconnect: + pass + + async def send_json(self, message: Mapping[str, object]) -> None: + await self.ws.send_json(message) + + async def close(self, code: int, reason: str) -> None: + await self.ws.close(code=code, reason=reason) + + class GraphQLController( Controller, AsyncBaseHTTPView[ - Request[Any, Any, Any], Response[Any], Response[Any], Context, RootValue + Request[Any, Any, Any], + Response[Any], + Response[Any], + WebSocket, + WebSocket, + Context, + RootValue, ], ): path: str = "" @@ -219,10 +231,7 @@ class GraphQLController( } request_adapter_class = LitestarRequestAdapter - graphql_ws_handler_class: Type[GraphQLWSHandler] = GraphQLWSHandler - graphql_transport_ws_handler_class: Type[GraphQLTransportWSHandler] = ( - GraphQLTransportWSHandler - ) + websocket_adapter_class = LitestarWebSocketAdapter allow_queries_via_get: bool = True graphiql_allowed_accept: FrozenSet[str] = frozenset({"text/html", "*/*"}) @@ -236,6 +245,23 @@ class GraphQLController( keep_alive: bool = False keep_alive_interval: float = 1 + def is_websocket_request( + self, request: Union[Request, WebSocket] + ) -> TypeGuard[WebSocket]: + return isinstance(request, WebSocket) + + async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]: + subprotocols = request.scope["subprotocols"] + intersection = set(subprotocols) & set(self.protocols) + sorted_intersection = sorted(intersection, key=subprotocols.index) + return next(iter(sorted_intersection), None) + + async def create_websocket_response( + self, request: WebSocket, subprotocol: Optional[str] + ) -> WebSocket: + await request.accept(subprotocols=subprotocol) + return request + async def execute_request( self, request: Request[Any, Any, Any], @@ -245,8 +271,6 @@ async def execute_request( try: return await self.run( request, - # TODO: check the dependency, above, can we make it so that - # we don't need to type ignore here? context=context, root_value=root_value, ) @@ -328,14 +352,29 @@ async def handle_http_post( root_value=root_value, ) + @websocket() + async def websocket_endpoint( + self, + socket: WebSocket, + context_ws: Any, + root_value: Any, + ) -> None: + await self.run( + request=socket, + context=context_ws, + root_value=root_value, + ) + async def get_context( - self, request: Request[Any, Any, Any], response: Response + self, + request: Union[Request[Any, Any, Any], WebSocket], + response: Union[Response, WebSocket], ) -> Context: # pragma: no cover msg = "`get_context` is not used by Litestar's controller" raise ValueError(msg) async def get_root_value( - self, request: Request[Any, Any, Any] + self, request: Union[Request[Any, Any, Any], WebSocket] ) -> RootValue | None: # pragma: no cover msg = "`get_root_value` is not used by Litestar's controller" raise ValueError(msg) @@ -343,54 +382,6 @@ async def get_root_value( async def get_sub_response(self, request: Request[Any, Any, Any]) -> Response: return self.temporal_response - @websocket() - async def websocket_endpoint( - self, - socket: WebSocket, - context_ws: Any, - root_value: Any, - ) -> None: - async def _get_context() -> Any: - return context_ws - - async def _get_root_value() -> Any: - return root_value - - preferred_protocol = self.pick_preferred_protocol(socket) - if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL: - await self.graphql_transport_ws_handler_class( - schema=self.schema, - debug=self.debug, - connection_init_wait_timeout=self.connection_init_wait_timeout, - get_context=_get_context, - get_root_value=_get_root_value, - ws=socket, - ).handle() - elif preferred_protocol == GRAPHQL_WS_PROTOCOL: - await self.graphql_ws_handler_class( - schema=self.schema, - debug=self.debug, - keep_alive=self.keep_alive, - keep_alive_interval=self.keep_alive_interval, - get_context=_get_context, - get_root_value=_get_root_value, - ws=socket, - ).handle() - else: - await socket.close(code=WS_4406_PROTOCOL_NOT_ACCEPTABLE) - - def pick_preferred_protocol(self, socket: WebSocket) -> str | None: - protocols: List[str] = socket.scope["subprotocols"] - intersection: Set[str] = set(protocols) & set(self.protocols) - return ( - min( - intersection, - key=lambda i: protocols.index(i) if i else "", - default=None, - ) - or None - ) - def make_graphql_controller( schema: BaseSchema, diff --git a/strawberry/litestar/handlers/__init__.py b/strawberry/litestar/handlers/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/strawberry/litestar/handlers/graphql_transport_ws_handler.py b/strawberry/litestar/handlers/graphql_transport_ws_handler.py deleted file mode 100644 index b5aa915d08..0000000000 --- a/strawberry/litestar/handlers/graphql_transport_ws_handler.py +++ /dev/null @@ -1,60 +0,0 @@ -from collections.abc import Callable -from datetime import timedelta -from typing import Any - -from litestar import WebSocket -from litestar.exceptions import SerializationException, WebSocketDisconnect -from strawberry.schema import BaseSchema -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( - BaseGraphQLTransportWSHandler, -) - - -class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - connection_init_wait_timeout: timedelta, - get_context: Callable, - get_root_value: Callable, - ws: WebSocket, - ) -> None: - super().__init__(schema, debug, connection_init_wait_timeout) - self._get_context = get_context - self._get_root_value = get_root_value - self._ws = ws - - async def get_context(self) -> Any: - return await self._get_context() - - async def get_root_value(self) -> Any: - return await self._get_root_value() - - async def send_json(self, data: dict) -> None: - await self._ws.send_json(data) - - async def close(self, code: int, reason: str) -> None: - await self._ws.close(code=code, reason=reason) - - async def handle_request(self) -> None: - await self._ws.accept(subprotocols=GRAPHQL_TRANSPORT_WS_PROTOCOL) - self.on_request_accepted() - - try: - while self._ws.connection_state != "disconnect": - try: - message = await self._ws.receive_json() - except (SerializationException, ValueError): # noqa: PERF203 - error_message = "WebSocket message type must be text" - await self.handle_invalid_message(error_message) - else: - await self.handle_message(message) - except WebSocketDisconnect: # pragma: no cover - pass - finally: - await self.shutdown() - - -__all__ = ["GraphQLTransportWSHandler"] diff --git a/strawberry/litestar/handlers/graphql_ws_handler.py b/strawberry/litestar/handlers/graphql_ws_handler.py deleted file mode 100644 index ada421922f..0000000000 --- a/strawberry/litestar/handlers/graphql_ws_handler.py +++ /dev/null @@ -1,66 +0,0 @@ -from collections.abc import Callable -from contextlib import suppress -from typing import Any, Optional - -from litestar import WebSocket -from litestar.exceptions import SerializationException, WebSocketDisconnect -from strawberry.schema import BaseSchema -from strawberry.subscriptions import GRAPHQL_WS_PROTOCOL -from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler -from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage - - -class GraphQLWSHandler(BaseGraphQLWSHandler): - def __init__( - self, - schema: BaseSchema, - debug: bool, - keep_alive: bool, - keep_alive_interval: float, - get_context: Callable, - get_root_value: Callable, - ws: WebSocket, - ) -> None: - super().__init__(schema, debug, keep_alive, keep_alive_interval) - self._get_context = get_context - self._get_root_value = get_root_value - self._ws = ws - - async def get_context(self) -> Any: - return await self._get_context() - - async def get_root_value(self) -> Any: - return await self._get_root_value() - - async def send_json(self, data: OperationMessage) -> None: - await self._ws.send_json(data) - - async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: - await self._ws.close(code=code, reason=reason) - - async def handle_request(self) -> Any: - await self._ws.accept(subprotocols=GRAPHQL_WS_PROTOCOL) - - try: - while self._ws.connection_state != "disconnect": - try: - message = await self._ws.receive_json() - except (SerializationException, ValueError): # noqa: PERF203 - await self.close( - code=1002, reason="WebSocket message type must be text" - ) - else: - await self.handle_message(message) - except WebSocketDisconnect: # pragma: no cover - pass - finally: - if self.keep_alive_task: - self.keep_alive_task.cancel() - with suppress(BaseException): - await self.keep_alive_task - - for operation_id in list(self.subscriptions.keys()): - await self.cleanup_operation(operation_id) - - -__all__ = ["GraphQLWSHandler"] diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index 5aafcec514..c7dc1257fd 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -1,6 +1,7 @@ import warnings from collections.abc import Mapping from typing import TYPE_CHECKING, AsyncGenerator, Callable, Dict, Optional, cast +from typing_extensions import TypeGuard from quart import Request, Response, request from quart.views import View @@ -46,7 +47,9 @@ async def get_form_data(self) -> FormData: class GraphQLView( - AsyncBaseHTTPView[Request, Response, Response, Context, RootValue], + AsyncBaseHTTPView[ + Request, Response, Response, Request, Response, Context, RootValue + ], View, ): _ide_subscription_enabled = False @@ -121,5 +124,16 @@ async def create_streaming_response( }, ) + def is_websocket_request(self, request: Request) -> TypeGuard[Request]: + return False + + async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]: + raise NotImplementedError + + async def create_websocket_response( + self, request: Request, subprotocol: Optional[str] + ) -> Response: + raise NotImplementedError + __all__ = ["GraphQLView"] diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index b62a63ba65..ee76d2e946 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -13,6 +13,7 @@ Type, cast, ) +from typing_extensions import TypeGuard from sanic.request import Request from sanic.response import HTTPResponse, html @@ -71,7 +72,15 @@ async def get_form_data(self) -> FormData: class GraphQLView( - AsyncBaseHTTPView[Request, HTTPResponse, TemporalResponse, Context, RootValue], + AsyncBaseHTTPView[ + Request, + HTTPResponse, + TemporalResponse, + Request, + TemporalResponse, + Context, + RootValue, + ], HTTPMethodView, ): """Class based view to handle GraphQL HTTP Requests. @@ -206,5 +215,16 @@ async def create_streaming_response( # corner case return None # type: ignore + def is_websocket_request(self, request: Request) -> TypeGuard[Request]: + return False + + async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]: + raise NotImplementedError + + async def create_websocket_response( + self, request: Request, subprotocol: Optional[str] + ) -> TemporalResponse: + raise NotImplementedError + __all__ = ["GraphQLView"] diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py index 7d19db8e98..f74e9def3d 100644 --- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py @@ -2,7 +2,6 @@ import asyncio import logging -from abc import ABC, abstractmethod from contextlib import suppress from typing import ( TYPE_CHECKING, @@ -16,6 +15,7 @@ from graphql import GraphQLError, GraphQLSyntaxError, parse +from strawberry.http.exceptions import NonJsonMessageReceived from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( CompleteMessage, ConnectionAckMessage, @@ -38,6 +38,7 @@ if TYPE_CHECKING: from datetime import timedelta + from strawberry.http.async_base_view import AsyncWebSocketAdapter from strawberry.schema import BaseSchema from strawberry.schema.subscribe import SubscriptionResult from strawberry.subscriptions.protocols.graphql_transport_ws.types import ( @@ -45,15 +46,21 @@ ) -class BaseGraphQLTransportWSHandler(ABC): +class BaseGraphQLTransportWSHandler: task_logger: logging.Logger = logging.getLogger("strawberry.ws.task") def __init__( self, + websocket: AsyncWebSocketAdapter, + context: object, + root_value: object, schema: BaseSchema, debug: bool, connection_init_wait_timeout: timedelta, ) -> None: + self.websocket = websocket + self.context = context + self.root_value = root_value self.schema = schema self.debug = debug self.connection_init_wait_timeout = connection_init_wait_timeout @@ -65,28 +72,16 @@ def __init__( self.completed_tasks: List[asyncio.Task] = [] self.connection_params: Optional[Dict[str, Any]] = None - @abstractmethod - async def get_context(self) -> Any: - """Return the operations context.""" - - @abstractmethod - async def get_root_value(self) -> Any: - """Return the schemas root value.""" - - @abstractmethod - async def send_json(self, data: dict) -> None: - """Send the data JSON encoded to the WebSocket client.""" - - @abstractmethod - async def close(self, code: int, reason: str) -> None: - """Close the WebSocket with the passed code and reason.""" - - @abstractmethod - async def handle_request(self) -> Any: - """Handle the request this instance was created for.""" - async def handle(self) -> Any: - return await self.handle_request() + self.on_request_accepted() + + try: + async for message in self.websocket.iter_json(): + await self.handle_message(message) + except NonJsonMessageReceived: + await self.handle_invalid_message("WebSocket message type must be text") + finally: + await self.shutdown() async def shutdown(self) -> None: if self.connection_init_timeout_task: @@ -118,7 +113,7 @@ async def handle_connection_init_timeout(self) -> None: self.connection_timed_out = True reason = "Connection initialisation timeout" - await self.close(code=4408, reason=reason) + await self.websocket.close(code=4408, reason=reason) except Exception as error: await self.handle_task_exception(error) # pragma: no cover finally: @@ -189,14 +184,16 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None: ) if not isinstance(payload, dict): - await self.close(code=4400, reason="Invalid connection init payload") + await self.websocket.close( + code=4400, reason="Invalid connection init payload" + ) return self.connection_params = payload if self.connection_init_received: reason = "Too many initialisation requests" - await self.close(code=4429, reason=reason) + await self.websocket.close(code=4429, reason=reason) return self.connection_init_received = True @@ -211,13 +208,13 @@ async def handle_pong(self, message: PongMessage) -> None: async def handle_subscribe(self, message: SubscribeMessage) -> None: if not self.connection_acknowledged: - await self.close(code=4401, reason="Unauthorized") + await self.websocket.close(code=4401, reason="Unauthorized") return try: graphql_document = parse(message.payload.query) except GraphQLSyntaxError as exc: - await self.close(code=4400, reason=exc.message) + await self.websocket.close(code=4400, reason=exc.message) return try: @@ -225,12 +222,14 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: graphql_document, message.payload.operationName ) except RuntimeError: - await self.close(code=4400, reason="Can't get GraphQL operation type") + await self.websocket.close( + code=4400, reason="Can't get GraphQL operation type" + ) return if message.id in self.operations: reason = f"Subscriber for {message.id} already exists" - await self.close(code=4409, reason=reason) + await self.websocket.close(code=4409, reason=reason) return if self.debug: # pragma: no cover @@ -240,26 +239,28 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None: message.payload.variables, ) - context = await self.get_context() - if isinstance(context, dict): - context["connection_params"] = self.connection_params - root_value = await self.get_root_value() + if isinstance(self.context, dict): + self.context["connection_params"] = self.connection_params + elif hasattr(self.context, "connection_params"): + self.context.connection_params = self.connection_params + result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult] + # Get an AsyncGenerator yielding the results if operation_type == OperationType.SUBSCRIPTION: result_source = self.schema.subscribe( query=message.payload.query, variable_values=message.payload.variables, operation_name=message.payload.operationName, - context_value=context, - root_value=root_value, + context_value=self.context, + root_value=self.root_value, ) else: result_source = self.schema.execute( query=message.payload.query, variable_values=message.payload.variables, - context_value=context, - root_value=root_value, + context_value=self.context, + root_value=self.root_value, operation_name=message.payload.operationName, ) @@ -312,11 +313,11 @@ async def handle_complete(self, message: CompleteMessage) -> None: await self.cleanup_operation(operation_id=message.id) async def handle_invalid_message(self, error_message: str) -> None: - await self.close(code=4400, reason=error_message) + await self.websocket.close(code=4400, reason=error_message) async def send_message(self, message: GraphQLTransportMessage) -> None: data = message.as_dict() - await self.send_json(data) + await self.websocket.send_json(data) async def cleanup_operation(self, operation_id: str) -> None: if operation_id not in self.operations: diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py index 0451e0934b..fda3db829f 100644 --- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py +++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py @@ -1,11 +1,9 @@ from __future__ import annotations import asyncio -from abc import ABC, abstractmethod from contextlib import suppress from typing import ( TYPE_CHECKING, - Any, AsyncGenerator, Awaitable, Dict, @@ -13,6 +11,7 @@ cast, ) +from strawberry.http.exceptions import NonJsonMessageReceived from strawberry.subscriptions.protocols.graphql_ws import ( GQL_COMPLETE, GQL_CONNECTION_ACK, @@ -25,29 +24,36 @@ GQL_START, GQL_STOP, ) +from strawberry.subscriptions.protocols.graphql_ws.types import ( + ConnectionInitPayload, + DataPayload, + OperationMessage, + OperationMessagePayload, + StartPayload, +) from strawberry.types.execution import ExecutionResult, PreExecutionError from strawberry.utils.debug import pretty_print_graphql_operation if TYPE_CHECKING: + from strawberry.http.async_base_view import AsyncWebSocketAdapter from strawberry.schema import BaseSchema from strawberry.schema.subscribe import SubscriptionResult - from strawberry.subscriptions.protocols.graphql_ws.types import ( - ConnectionInitPayload, - DataPayload, - OperationMessage, - OperationMessagePayload, - StartPayload, - ) -class BaseGraphQLWSHandler(ABC): +class BaseGraphQLWSHandler: def __init__( self, + websocket: AsyncWebSocketAdapter, + context: object, + root_value: object, schema: BaseSchema, debug: bool, keep_alive: bool, - keep_alive_interval: float, + keep_alive_interval: Optional[float], ) -> None: + self.websocket = websocket + self.context = context + self.root_value = root_value self.schema = schema self.debug = debug self.keep_alive = keep_alive @@ -57,28 +63,22 @@ def __init__( self.tasks: Dict[str, asyncio.Task] = {} self.connection_params: Optional[ConnectionInitPayload] = None - @abstractmethod - async def get_context(self) -> Any: - """Return the operations context.""" - - @abstractmethod - async def get_root_value(self) -> Any: - """Return the schemas root value.""" - - @abstractmethod - async def send_json(self, data: OperationMessage) -> None: - """Send the data JSON encoded to the WebSocket client.""" - - @abstractmethod - async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: - """Close the WebSocket with the passed code and reason.""" - - @abstractmethod - async def handle_request(self) -> Any: - """Handle the request this instance was created for.""" - - async def handle(self) -> Any: - return await self.handle_request() + async def handle(self) -> None: + try: + async for message in self.websocket.iter_json(): + await self.handle_message(cast(OperationMessage, message)) + except NonJsonMessageReceived: + await self.websocket.close( + code=1002, reason="WebSocket message type must be text" + ) + finally: + if self.keep_alive_task: + self.keep_alive_task.cancel() + with suppress(BaseException): + await self.keep_alive_task + + for operation_id in list(self.subscriptions.keys()): + await self.cleanup_operation(operation_id) async def handle_message( self, @@ -99,22 +99,22 @@ async def handle_connection_init(self, message: OperationMessage) -> None: payload = message.get("payload") if payload is not None and not isinstance(payload, dict): error_message: OperationMessage = {"type": GQL_CONNECTION_ERROR} - await self.send_json(error_message) - await self.close() + await self.websocket.send_json(error_message) + await self.websocket.close(code=1000, reason="") return payload = cast(Optional["ConnectionInitPayload"], payload) self.connection_params = payload acknowledge_message: OperationMessage = {"type": GQL_CONNECTION_ACK} - await self.send_json(acknowledge_message) + await self.websocket.send_json(acknowledge_message) if self.keep_alive: keep_alive_handler = self.handle_keep_alive() self.keep_alive_task = asyncio.create_task(keep_alive_handler) async def handle_connection_terminate(self, message: OperationMessage) -> None: - await self.close() + await self.websocket.close(code=1000, reason="") async def handle_start(self, message: OperationMessage) -> None: operation_id = message["id"] @@ -123,10 +123,10 @@ async def handle_start(self, message: OperationMessage) -> None: operation_name = payload.get("operationName") variables = payload.get("variables") - context = await self.get_context() - if isinstance(context, dict): - context["connection_params"] = self.connection_params - root_value = await self.get_root_value() + if isinstance(self.context, dict): + self.context["connection_params"] = self.connection_params + elif hasattr(self.context, "connection_params"): + self.context.connection_params = self.connection_params if self.debug: pretty_print_graphql_operation(operation_name, query, variables) @@ -135,8 +135,8 @@ async def handle_start(self, message: OperationMessage) -> None: query=query, variable_values=variables, operation_name=operation_name, - context_value=context, - root_value=root_value, + context_value=self.context, + root_value=self.root_value, ) result_handler = self.handle_async_results(result_source, operation_id) @@ -147,9 +147,10 @@ async def handle_stop(self, message: OperationMessage) -> None: await self.cleanup_operation(operation_id) async def handle_keep_alive(self) -> None: + assert self.keep_alive_interval while True: data: OperationMessage = {"type": GQL_CONNECTION_KEEP_ALIVE} - await self.send_json(data) + await self.websocket.send_json(data) await asyncio.sleep(self.keep_alive_interval) async def handle_async_results( @@ -191,7 +192,7 @@ async def send_message( data: OperationMessage = {"type": type_, "id": operation_id} if payload is not None: data["payload"] = payload - await self.send_json(data) + await self.websocket.send_json(data) async def send_data( self, execution_result: ExecutionResult, operation_id: str diff --git a/tests/aiohttp/__init__.py b/tests/aiohttp/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/aiohttp/app.py b/tests/aiohttp/app.py deleted file mode 100644 index ba43cea8fc..0000000000 --- a/tests/aiohttp/app.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Any - -from aiohttp import web -from strawberry.aiohttp.handlers import GraphQLTransportWSHandler, GraphQLWSHandler -from strawberry.aiohttp.views import GraphQLView -from tests.views.schema import Query, schema - - -class DebuggableGraphQLTransportWSHandler(GraphQLTransportWSHandler): - def get_tasks(self) -> list: - return [op.task for op in self.operations.values()] - - async def get_context(self) -> object: - context = await super().get_context() - context["ws"] = self._ws - context["get_tasks"] = self.get_tasks - context["connectionInitTimeoutTask"] = self.connection_init_timeout_task - return context - - -class DebuggableGraphQLWSHandler(GraphQLWSHandler): - def get_tasks(self) -> list: - return list(self.tasks.values()) - - async def get_context(self) -> object: - context = await super().get_context() - context["ws"] = self._ws - context["get_tasks"] = self.get_tasks - context["connectionInitTimeoutTask"] = None - return context - - -class MyGraphQLView(GraphQLView): - graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler - graphql_ws_handler_class = DebuggableGraphQLWSHandler - - async def get_root_value(self, request: web.Request) -> Query: - await super().get_root_value(request) # for coverage - return Query() - - -def create_app(**kwargs: Any) -> web.Application: - app = web.Application() - app.router.add_route("*", "/graphql", MyGraphQLView(schema=schema, **kwargs)) - - return app diff --git a/tests/aiohttp/test_websockets.py b/tests/aiohttp/test_websockets.py deleted file mode 100644 index 82d6a5db00..0000000000 --- a/tests/aiohttp/test_websockets.py +++ /dev/null @@ -1,110 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Awaitable, Callable - -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL - -if TYPE_CHECKING: - from aiohttp.test_utils import TestClient - - -async def test_turning_off_graphql_ws( - aiohttp_client: Callable[..., Awaitable[TestClient]], -) -> None: - from .app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]) - aiohttp_app_client = await aiohttp_client(app) - - async with aiohttp_app_client.ws_connect( - "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] - ) as ws: - data = await ws.receive(timeout=2) - assert ws.protocol is None - assert ws.closed - assert ws.close_code == 4406 - assert data.extra == "Subprotocol not acceptable" - - -async def test_turning_off_graphql_transport_ws( - aiohttp_client: Callable[..., Awaitable[TestClient]], -): - from .app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL]) - aiohttp_app_client = await aiohttp_client(app) - - async with aiohttp_app_client.ws_connect( - "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) as ws: - data = await ws.receive(timeout=2) - assert ws.protocol is None - assert ws.closed - assert ws.close_code == 4406 - assert data.extra == "Subprotocol not acceptable" - - -async def test_turning_off_all_ws_protocols( - aiohttp_client: Callable[..., Awaitable[TestClient]], -): - from .app import create_app - - app = create_app(subscription_protocols=[]) - aiohttp_app_client = await aiohttp_client(app) - - async with aiohttp_app_client.ws_connect( - "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) as ws: - data = await ws.receive(timeout=2) - assert ws.protocol is None - assert ws.closed - assert ws.close_code == 4406 - assert data.extra == "Subprotocol not acceptable" - - async with aiohttp_app_client.ws_connect( - "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] - ) as ws: - data = await ws.receive(timeout=2) - assert ws.protocol is None - assert ws.closed - assert ws.close_code == 4406 - assert data.extra == "Subprotocol not acceptable" - - -async def test_unsupported_ws_protocol( - aiohttp_client: Callable[..., Awaitable[TestClient]], -): - from .app import create_app - - app = create_app(subscription_protocols=[]) - aiohttp_app_client = await aiohttp_client(app) - - async with aiohttp_app_client.ws_connect( - "/graphql", protocols=["imaginary-protocol"] - ) as ws: - data = await ws.receive(timeout=2) - assert ws.protocol is None - assert ws.closed - assert ws.close_code == 4406 - assert data.extra == "Subprotocol not acceptable" - - -async def test_clients_can_prefer_protocols( - aiohttp_client: Callable[..., Awaitable[TestClient]], -): - from .app import create_app - - app = create_app( - subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) - aiohttp_app_client = await aiohttp_client(app) - - async with aiohttp_app_client.ws_connect( - "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL] - ) as ws: - assert ws.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL - - async with aiohttp_app_client.ws_connect( - "/graphql", protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) as ws: - assert ws.protocol == GRAPHQL_WS_PROTOCOL diff --git a/tests/asgi/app.py b/tests/asgi/app.py deleted file mode 100644 index a179a3210f..0000000000 --- a/tests/asgi/app.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Any, Dict, Optional, Union - -from starlette.requests import Request -from starlette.responses import Response -from starlette.websockets import WebSocket - -from strawberry.asgi import GraphQL as BaseGraphQL -from tests.views.schema import Query, schema - - -class GraphQL(BaseGraphQL): - async def get_root_value(self, request) -> Query: - return Query() - - async def get_context( - self, - request: Union[Request, WebSocket], - response: Optional[Response] = None, - ) -> Dict[str, Union[Request, WebSocket, Response, str, None]]: - return {"request": request, "response": response, "custom_value": "Hi"} - - -def create_app(**kwargs: Any) -> GraphQL: - return GraphQL(schema, **kwargs) diff --git a/tests/asgi/test_websockets.py b/tests/asgi/test_websockets.py deleted file mode 100644 index 511661358a..0000000000 --- a/tests/asgi/test_websockets.py +++ /dev/null @@ -1,94 +0,0 @@ -import pytest - -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL - - -def test_turning_off_graphql_ws(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.asgi.app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/", [GRAPHQL_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_turning_off_graphql_transport_ws(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.asgi.app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/", [GRAPHQL_TRANSPORT_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_turning_off_all_ws_protocols(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.asgi.app import create_app - - app = create_app(subscription_protocols=[]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/", [GRAPHQL_TRANSPORT_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/", [GRAPHQL_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_unsupported_ws_protocol(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.asgi.app import create_app - - app = create_app(subscription_protocols=[]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/", ["imaginary-protocol"]): - pass - - assert exc.value.code == 4406 - - -def test_clients_can_prefer_protocols(): - from starlette.testclient import TestClient - - from tests.asgi.app import create_app - - app = create_app( - subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) - test_client = TestClient(app) - - with test_client.websocket_connect( - "/", [GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL] - ) as ws: - assert ws.accepted_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL - - with test_client.websocket_connect( - "/", [GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) as ws: - assert ws.accepted_subprotocol == GRAPHQL_WS_PROTOCOL diff --git a/tests/channels/test_layers.py b/tests/channels/test_layers.py index 40a3eed7b6..8db1205fdf 100644 --- a/tests/channels/test_layers.py +++ b/tests/channels/test_layers.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, Generator +from typing import TYPE_CHECKING, AsyncGenerator import pytest @@ -21,7 +21,7 @@ @pytest.fixture -async def ws() -> Generator[WebsocketCommunicator, None, None]: +async def ws() -> AsyncGenerator[WebsocketCommunicator, None]: from channels.testing import WebsocketCommunicator from strawberry.channels import GraphQLWSConsumer diff --git a/tests/channels/test_testing.py b/tests/channels/test_testing.py index 921d8abf35..99aa9dd6c8 100644 --- a/tests/channels/test_testing.py +++ b/tests/channels/test_testing.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generator +from typing import TYPE_CHECKING, Any, AsyncGenerator import pytest @@ -14,7 +14,7 @@ @pytest.fixture(params=[GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL]) async def communicator( request: Any, -) -> Generator[GraphQLWebsocketCommunicator, None, None]: +) -> AsyncGenerator[GraphQLWebsocketCommunicator, None]: from strawberry.channels import GraphQLWSConsumer from strawberry.channels.testing import GraphQLWebsocketCommunicator diff --git a/tests/channels/test_ws_handler.py b/tests/channels/test_ws_handler.py deleted file mode 100644 index 88310bb617..0000000000 --- a/tests/channels/test_ws_handler.py +++ /dev/null @@ -1,54 +0,0 @@ -import pytest - -from tests.views.schema import schema - -try: - from channels.testing.websocket import WebsocketCommunicator - from strawberry.channels.handlers.graphql_transport_ws_handler import ( - GraphQLTransportWSHandler, - ) - from strawberry.channels.handlers.graphql_ws_handler import GraphQLWSHandler - from strawberry.channels.handlers.ws_handler import GraphQLWSConsumer -except ImportError: - pytestmark = pytest.mark.skip("Channels is not installed") - GraphQLWSHandler = None - GraphQLTransportWSHandler = None - -from strawberry.subscriptions import ( - GRAPHQL_TRANSPORT_WS_PROTOCOL, - GRAPHQL_WS_PROTOCOL, -) - - -async def test_wrong_protocol(): - GraphQLWSConsumer.as_asgi(schema=schema) - client = WebsocketCommunicator( - GraphQLWSConsumer.as_asgi(schema=schema), - "/graphql", - subprotocols=[ - "non-existing", - ], - ) - res = await client.connect() - assert res == (False, 4406) - - -@pytest.mark.parametrize( - ("protocol", "handler"), - [ - (GRAPHQL_TRANSPORT_WS_PROTOCOL, GraphQLTransportWSHandler), - (GRAPHQL_WS_PROTOCOL, GraphQLWSHandler), - ], -) -async def test_correct_protocol(protocol, handler): - consumer = GraphQLWSConsumer(schema=schema) - client = WebsocketCommunicator( - consumer, - "/graphql", - subprotocols=[ - protocol, - ], - ) - res = await client.connect() - assert res == (True, protocol) - assert isinstance(consumer._handler, handler) diff --git a/tests/fastapi/test_websockets.py b/tests/fastapi/test_websockets.py deleted file mode 100644 index de729c9f32..0000000000 --- a/tests/fastapi/test_websockets.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import Any - -import pytest - -import strawberry -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL - - -def test_turning_off_graphql_ws(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.fastapi.app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_turning_off_graphql_transport_ws(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.fastapi.app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_turning_off_all_ws_protocols(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.fastapi.app import create_app - - app = create_app(subscription_protocols=[]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_unsupported_ws_protocol(): - from starlette.testclient import TestClient - from starlette.websockets import WebSocketDisconnect - - from tests.fastapi.app import create_app - - app = create_app(subscription_protocols=[]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", ["imaginary-protocol"]): - pass - - assert exc.value.code == 4406 - - -def test_clients_can_prefer_protocols(): - from starlette.testclient import TestClient - - from tests.fastapi.app import create_app - - app = create_app( - subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) - test_client = TestClient(app) - - with test_client.websocket_connect( - "/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL] - ) as ws: - assert ws.accepted_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL - - with test_client.websocket_connect( - "/graphql", [GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) as ws: - assert ws.accepted_subprotocol == GRAPHQL_WS_PROTOCOL - - -def test_with_custom_encode_json(): - from starlette.testclient import TestClient - - from fastapi import FastAPI - from strawberry.fastapi.router import GraphQLRouter - - @strawberry.type - class Query: - @strawberry.field - def abc(self) -> str: - return "abc" - - class MyRouter(GraphQLRouter[None, None]): - def encode_json(self, response_data: Any): - return '"custom"' - - app = FastAPI() - schema = strawberry.Schema(query=Query) - graphql_app = MyRouter(schema=schema) - app.include_router(graphql_app, prefix="/graphql") - - test_client = TestClient(app) - response = test_client.post("/graphql", json={"query": "{ abc }"}) - - assert response.status_code == 200 - assert response.json() == "custom" diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index 4979d7d480..dcf7abb5bd 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -10,7 +10,6 @@ from aiohttp.client_ws import ClientWebSocketResponse from aiohttp.http_websocket import WSMsgType from aiohttp.test_utils import TestClient, TestServer -from strawberry.aiohttp.handlers import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.aiohttp.views import GraphQLView as BaseGraphQLView from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE @@ -20,8 +19,8 @@ from ..context import get_context from .base import ( JSON, - DebuggableGraphQLTransportWSMixin, - DebuggableGraphQLWSMixin, + DebuggableGraphQLTransportWSHandler, + DebuggableGraphQLWSHandler, HttpClient, Message, Response, @@ -30,16 +29,6 @@ ) -class DebuggableGraphQLTransportWSHandler( - DebuggableGraphQLTransportWSMixin, GraphQLTransportWSHandler -): - pass - - -class DebuggableGraphQLWSHandler(DebuggableGraphQLWSMixin, GraphQLWSHandler): - pass - - class GraphQLView(BaseGraphQLView): result_override: ResultOverrideFunction = None graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler @@ -194,6 +183,9 @@ def __init__(self, ws: ClientWebSocketResponse): self.ws = ws self._reason: Optional[str] = None + async def send_text(self, payload: str) -> None: + await self.ws.send_str(payload) + async def send_json(self, payload: Dict[str, Any]) -> None: await self.ws.send_json(payload) @@ -213,6 +205,10 @@ async def receive_json(self, timeout: Optional[float] = None) -> Any: async def close(self) -> None: await self.ws.close() + @property + def accepted_subprotocol(self) -> Optional[str]: + return self.ws.protocol + @property def closed(self) -> bool: return self.ws.closed diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index 5734c8df16..b8f51bbab1 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -8,11 +8,10 @@ from starlette.requests import Request from starlette.responses import Response as StarletteResponse -from starlette.testclient import TestClient +from starlette.testclient import TestClient, WebSocketTestSession from starlette.websockets import WebSocket, WebSocketDisconnect from strawberry.asgi import GraphQL as BaseGraphQLView -from strawberry.asgi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult @@ -21,8 +20,8 @@ from ..context import get_context from .base import ( JSON, - DebuggableGraphQLTransportWSMixin, - DebuggableGraphQLWSMixin, + DebuggableGraphQLTransportWSHandler, + DebuggableGraphQLWSHandler, HttpClient, Message, Response, @@ -31,16 +30,6 @@ ) -class DebuggableGraphQLTransportWSHandler( - DebuggableGraphQLTransportWSMixin, GraphQLTransportWSHandler -): - pass - - -class DebuggableGraphQLWSHandler(DebuggableGraphQLWSMixin, GraphQLWSHandler): - pass - - class GraphQLView(BaseGraphQLView): result_override: ResultOverrideFunction = None graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler @@ -181,7 +170,7 @@ async def ws_connect( class AsgiWebSocketClient(WebSocketClient): - def __init__(self, ws: Any): + def __init__(self, ws: WebSocketTestSession): self.ws = ws self._closed: bool = False self._close_code: Optional[int] = None @@ -192,6 +181,9 @@ def handle_disconnect(self, exc: WebSocketDisconnect) -> None: self._close_code = exc.code self._close_reason = exc.reason + async def send_text(self, payload: str) -> None: + self.ws.send_text(payload) + async def send_json(self, payload: Dict[str, Any]) -> None: self.ws.send_json(payload) @@ -224,6 +216,10 @@ async def close(self) -> None: self.ws.close() self._closed = True + @property + def accepted_subprotocol(self) -> Optional[str]: + return self.ws.accepted_subprotocol + @property def closed(self) -> bool: return self._closed diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index ff31e4111e..c1156fc3a6 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -21,6 +21,10 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE +from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import ( + BaseGraphQLTransportWSHandler, +) +from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler from strawberry.types import ExecutionResult logger = logging.getLogger("strawberry.test.http_client") @@ -237,7 +241,7 @@ def create_app(self, **kwargs: Any) -> None: """For use by websocket tests.""" raise NotImplementedError - async def ws_connect( + def ws_connect( self, url: str, *, @@ -260,6 +264,9 @@ class WebSocketClient(abc.ABC): def name(self) -> str: return "" + @abc.abstractmethod + async def send_text(self, payload: str) -> None: ... + @abc.abstractmethod async def send_json(self, payload: Dict[str, Any]) -> None: ... @@ -274,6 +281,10 @@ async def receive_json(self, timeout: Optional[float] = None) -> Any: ... @abc.abstractmethod async def close(self) -> None: ... + @property + @abc.abstractmethod + def accepted_subprotocol(self) -> Optional[str]: ... + @property @abc.abstractmethod def closed(self) -> bool: ... @@ -290,7 +301,7 @@ async def __aiter__(self) -> AsyncGenerator[Message, None]: yield await self.receive() -class DebuggableGraphQLTransportWSMixin: +class DebuggableGraphQLTransportWSHandler(BaseGraphQLTransportWSHandler): def on_init(self) -> None: """This method can be patched by unit tests to get the instance of the transport handler when it is initialized. @@ -298,26 +309,41 @@ def on_init(self) -> None: def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) - DebuggableGraphQLTransportWSMixin.on_init(self) + self.original_context = kwargs.get("context", {}) + DebuggableGraphQLTransportWSHandler.on_init(self) def get_tasks(self) -> List: return [op.task for op in self.operations.values()] - async def get_context(self) -> object: - context = await super().get_context() - context["ws"] = self._ws - context["get_tasks"] = self.get_tasks - context["connectionInitTimeoutTask"] = self.connection_init_timeout_task - return context + @property + def context(self): + self.original_context["ws"] = self.websocket + self.original_context["get_tasks"] = self.get_tasks + self.original_context["connectionInitTimeoutTask"] = ( + self.connection_init_timeout_task + ) + return self.original_context + + @context.setter + def context(self, value): + self.original_context = value + +class DebuggableGraphQLWSHandler(BaseGraphQLWSHandler): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self.original_context = self.context -class DebuggableGraphQLWSMixin: def get_tasks(self) -> List: return list(self.tasks.values()) - async def get_context(self) -> object: - context = await super().get_context() - context["ws"] = self._ws - context["get_tasks"] = self.get_tasks - context["connectionInitTimeoutTask"] = None - return context + @property + def context(self): + self.original_context["ws"] = self.websocket + self.original_context["get_tasks"] = self.get_tasks + self.original_context["connectionInitTimeoutTask"] = None + return self.original_context + + @context.setter + def context(self, value): + self.original_context = value diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 14abd5e4af..0e53ce8f82 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -23,6 +23,8 @@ from ..context import get_context from .base import ( JSON, + DebuggableGraphQLTransportWSHandler, + DebuggableGraphQLWSHandler, HttpClient, Message, Response, @@ -63,25 +65,6 @@ def create_multipart_request_body( return headers, request_body -class DebuggableGraphQLTransportWSConsumer(GraphQLWSConsumer): - def get_tasks(self) -> List[Any]: - if hasattr(self._handler, "operations"): - return [op.task for op in self._handler.operations.values()] - else: - return list(self._handler.tasks.values()) - - async def get_context(self, *args: str, **kwargs: Any) -> object: - context = await super().get_context(*args, **kwargs) - context["ws"] = self._handler._ws - context["get_tasks"] = self.get_tasks - context["connectionInitTimeoutTask"] = getattr( - self._handler, "connection_init_timeout_task", None - ) - for key, val in get_context({}).items(): - context[key] = val - return context - - class DebuggableGraphQLHTTPConsumer(GraphQLHTTPConsumer): result_override: ResultOverrideFunction = None @@ -130,6 +113,16 @@ def process_result( return super().process_result(request, result) +class DebuggableGraphQLWSConsumer(GraphQLWSConsumer): + graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler + graphql_ws_handler_class = DebuggableGraphQLWSHandler + + async def get_context(self, request, response): + context = await super().get_context(request, response) + + return get_context(context) + + class ChannelsHttpClient(HttpClient): """A client to test websockets over channels.""" @@ -141,7 +134,7 @@ def __init__( result_override: ResultOverrideFunction = None, multipart_uploads_enabled: bool = False, ): - self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi( + self.ws_app = DebuggableGraphQLWSConsumer.as_asgi( schema=schema, keep_alive=False, ) @@ -156,9 +149,7 @@ def __init__( ) def create_app(self, **kwargs: Any) -> None: - self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi( - schema=schema, **kwargs - ) + self.ws_app = DebuggableGraphQLWSConsumer.as_asgi(schema=schema, **kwargs) async def _graphql_request( self, @@ -247,10 +238,13 @@ async def ws_connect( ) -> AsyncGenerator[WebSocketClient, None]: client = WebsocketCommunicator(self.ws_app, url, subprotocols=protocols) - res = await client.connect() - assert res == (True, protocols[0]) + connected, subprotocol_or_close_code = await client.connect() + assert connected + try: - yield ChannelsWebSocketClient(client) + yield ChannelsWebSocketClient( + client, accepted_subprotocol=subprotocol_or_close_code + ) finally: await client.disconnect() @@ -275,15 +269,21 @@ def __init__( class ChannelsWebSocketClient(WebSocketClient): - def __init__(self, client: WebsocketCommunicator): + def __init__( + self, client: WebsocketCommunicator, accepted_subprotocol: Optional[str] + ): self.ws = client self._closed: bool = False self._close_code: Optional[int] = None self._close_reason: Optional[str] = None + self._accepted_subprotocol = accepted_subprotocol def name(self) -> str: return "channels" + async def send_text(self, payload: str) -> None: + await self.ws.send_to(text_data=payload) + async def send_json(self, payload: Dict[str, Any]) -> None: await self.ws.send_json_to(payload) @@ -311,6 +311,10 @@ async def close(self) -> None: await self.ws.disconnect() self._closed = True + @property + def accepted_subprotocol(self) -> Optional[str]: + return self._accepted_subprotocol + @property def closed(self) -> bool: return self._closed diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py index cddc43032f..b1b80625fa 100644 --- a/tests/http/clients/fastapi.py +++ b/tests/http/clients/fastapi.py @@ -11,7 +11,6 @@ from fastapi import BackgroundTasks, Depends, FastAPI, Request, WebSocket from fastapi.testclient import TestClient from strawberry.fastapi import GraphQLRouter as BaseGraphQLRouter -from strawberry.fastapi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.types import ExecutionResult @@ -21,8 +20,8 @@ from .asgi import AsgiWebSocketClient from .base import ( JSON, - DebuggableGraphQLTransportWSMixin, - DebuggableGraphQLWSMixin, + DebuggableGraphQLTransportWSHandler, + DebuggableGraphQLWSHandler, HttpClient, Response, ResultOverrideFunction, @@ -30,16 +29,6 @@ ) -class DebuggableGraphQLTransportWSHandler( - DebuggableGraphQLTransportWSMixin, GraphQLTransportWSHandler -): - pass - - -class DebuggableGraphQLWSHandler(DebuggableGraphQLWSMixin, GraphQLWSHandler): - pass - - def custom_context_dependency() -> str: return "Hi!" diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index 0b99b43729..065b04f395 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -13,15 +13,14 @@ from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE from strawberry.litestar import make_graphql_controller -from strawberry.litestar.controller import GraphQLTransportWSHandler, GraphQLWSHandler from strawberry.types import ExecutionResult from tests.views.schema import Query, schema from ..context import get_context from .base import ( JSON, - DebuggableGraphQLTransportWSMixin, - DebuggableGraphQLWSMixin, + DebuggableGraphQLTransportWSHandler, + DebuggableGraphQLWSHandler, HttpClient, Message, Response, @@ -42,16 +41,6 @@ async def get_root_value(request: Request = None): return Query() -class DebuggableGraphQLTransportWSHandler( - DebuggableGraphQLTransportWSMixin, GraphQLTransportWSHandler -): - pass - - -class DebuggableGraphQLWSHandler(DebuggableGraphQLWSMixin, GraphQLWSHandler): - pass - - class LitestarHttpClient(HttpClient): def __init__( self, @@ -190,6 +179,9 @@ def handle_disconnect(self, exc: WebSocketDisconnect) -> None: self._closed = True self._close_code = exc.code + async def send_text(self, payload: str) -> None: + self.ws.send_text(payload) + async def send_json(self, payload: Dict[str, Any]) -> None: self.ws.send_json(payload) @@ -229,6 +221,10 @@ async def close(self) -> None: self.ws.close() self._closed = True + @property + def accepted_subprotocol(self) -> Optional[str]: + return self.ws.accepted_subprotocol + @property def closed(self) -> bool: return self._closed diff --git a/tests/litestar/test_websockets.py b/tests/litestar/test_websockets.py deleted file mode 100644 index b5554cb264..0000000000 --- a/tests/litestar/test_websockets.py +++ /dev/null @@ -1,89 +0,0 @@ -import pytest - -from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL - - -def test_turning_off_graphql_ws(): - from litestar.exceptions import WebSocketDisconnect - from litestar.testing import TestClient - from tests.litestar.app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_turning_off_graphql_transport_ws(): - from litestar.exceptions import WebSocketDisconnect - from litestar.testing import TestClient - from tests.litestar.app import create_app - - app = create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_turning_off_all_ws_protocols(): - from litestar.exceptions import WebSocketDisconnect - from litestar.testing import TestClient - from tests.litestar.app import create_app - - app = create_app(subscription_protocols=[]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", [GRAPHQL_WS_PROTOCOL]): - pass - - assert exc.value.code == 4406 - - -def test_unsupported_ws_protocol(): - from litestar.exceptions import WebSocketDisconnect - from litestar.testing import TestClient - from tests.litestar.app import create_app - - app = create_app(subscription_protocols=[]) - test_client = TestClient(app) - - with pytest.raises(WebSocketDisconnect) as exc: - with test_client.websocket_connect("/graphql", ["imaginary-protocol"]): - pass - - assert exc.value.code == 4406 - - -def test_clients_can_prefer_protocols(): - from litestar.testing import TestClient - from tests.litestar.app import create_app - - app = create_app( - subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) - test_client = TestClient(app) - - with test_client.websocket_connect( - "/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL] - ) as ws: - assert ws.accepted_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL - - with test_client.websocket_connect( - "/graphql", [GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] - ) as ws: - assert ws.accepted_subprotocol == GRAPHQL_WS_PROTOCOL diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 02f8366852..49e9c7ce32 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -24,7 +24,7 @@ SubscribeMessage, SubscribeMessagePayload, ) -from tests.http.clients.base import DebuggableGraphQLTransportWSMixin +from tests.http.clients.base import DebuggableGraphQLTransportWSHandler from tests.views.schema import MyExtension, Schema if TYPE_CHECKING: @@ -123,6 +123,17 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient): ws.assert_reason("WebSocket message type must be text") +async def test_ws_messages_must_be_json(ws_raw: WebSocketClient): + ws = ws_raw + + await ws.send_text("not valid json") + + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 4400 + ws.assert_reason("WebSocket message type must be text") + + async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): ws = ws_raw @@ -879,7 +890,7 @@ def on_init(_handler): # cause an attribute error in the timeout task handler.connection_init_wait_timeout = None - with patch.object(DebuggableGraphQLTransportWSMixin, "on_init", on_init): + with patch.object(DebuggableGraphQLTransportWSHandler, "on_init", on_init): async with http_client.ws_connect( "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] ) as ws: diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 6752eaa7f8..1ad38c6d19 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -292,6 +292,17 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient): ws.assert_reason("WebSocket message type must be text") +async def test_ws_messages_must_be_json(ws_raw: WebSocketClient): + ws = ws_raw + + await ws.send_text("not valid json") + + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 1002 + ws.assert_reason("WebSocket message type must be text") + + async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): ws = ws_raw diff --git a/tests/websockets/test_websockets.py b/tests/websockets/test_websockets.py new file mode 100644 index 0000000000..767617d727 --- /dev/null +++ b/tests/websockets/test_websockets.py @@ -0,0 +1,82 @@ +from typing import Type + +from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL +from tests.http.clients.base import HttpClient + + +async def test_turning_off_graphql_ws(http_client_class: Type[HttpClient]): + http_client = http_client_class() + http_client.create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]) + + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] + ) as ws: + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 4406 + ws.assert_reason("Subprotocol not acceptable") + + +async def test_turning_off_graphql_transport_ws(http_client_class: Type[HttpClient]): + http_client = http_client_class() + http_client.create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL]) + + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) as ws: + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 4406 + ws.assert_reason("Subprotocol not acceptable") + + +async def test_turning_off_all_subprotocols(http_client_class: Type[HttpClient]): + http_client = http_client_class() + http_client.create_app(subscription_protocols=[]) + + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) as ws: + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 4406 + ws.assert_reason("Subprotocol not acceptable") + + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] + ) as ws: + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 4406 + ws.assert_reason("Subprotocol not acceptable") + + +async def test_generally_unsupported_subprotocols_are_rejected(http_client: HttpClient): + async with http_client.ws_connect( + "/graphql", protocols=["imaginary-protocol"] + ) as ws: + await ws.receive(timeout=2) + assert ws.closed + assert ws.close_code == 4406 + ws.assert_reason("Subprotocol not acceptable") + + +async def test_clients_can_prefer_subprotocols(http_client_class: Type[HttpClient]): + http_client = http_client_class() + http_client.create_app( + subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) + + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL] + ) as ws: + assert ws.accepted_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL + await ws.close() + assert ws.closed + + async with http_client.ws_connect( + "/graphql", protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL] + ) as ws: + assert ws.accepted_subprotocol == GRAPHQL_WS_PROTOCOL + await ws.close() + assert ws.closed From 2ad6bc3eac3e6713c4925a82a2facb4b1d1dd69f Mon Sep 17 00:00:00 2001 From: Botberry Date: Sat, 5 Oct 2024 20:20:45 +0000 Subject: [PATCH 14/31] =?UTF-8?q?Release=20=F0=9F=8D=93=200.244.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 9 +++++++++ RELEASE.md | 4 ---- pyproject.toml | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 375a4c9206..15006fcc18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,15 @@ CHANGELOG ========= +0.244.0 - 2024-10-05 +-------------------- + +Starting with this release, WebSocket logic now lives in the base class shared between all HTTP integrations. +This makes the behaviour of WebSockets much more consistent between integrations and easier to maintain. + +Contributed by [Jonathan Ehwald](https://github.com/DoctorJohn) via [PR #3638](https://github.com/strawberry-graphql/strawberry/pull/3638/) + + 0.243.1 - 2024-09-26 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index 7f194a4b85..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,4 +0,0 @@ -Release type: minor - -Starting with this release, WebSocket logic now lives in the base class shared between all HTTP integrations. -This makes the behaviour of WebSockets much more consistent between integrations and easier to maintain. diff --git a/pyproject.toml b/pyproject.toml index 2159916a19..b1d1c54544 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.243.1" +version = "0.244.0" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From 34b40d4032b94d6c02b755bc36de03a19f7b4989 Mon Sep 17 00:00:00 2001 From: Strawberry GraphQL Bot Date: Sat, 5 Oct 2024 20:24:25 +0000 Subject: [PATCH 15/31] Remove TWEET.md --- TWEET.md | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 TWEET.md diff --git a/TWEET.md b/TWEET.md deleted file mode 100644 index 0437a68073..0000000000 --- a/TWEET.md +++ /dev/null @@ -1,7 +0,0 @@ -πŸš€ Starting with Strawberry $version, WebSocket logic now lives in the base -class shared across all HTTP integrations. More consistent behavior and easier -maintenance for WebSockets across integrations. πŸŽ‰ - -Thanks to $contributor for the PR πŸ‘ - -$release_url From ffdd85b8cfbb666ca7d274dc67dc64542b6e4311 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Sun, 6 Oct 2024 16:56:44 +0200 Subject: [PATCH 16/31] Add more advice on file upload security (#3657) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/guides/file-upload.md | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/docs/guides/file-upload.md b/docs/guides/file-upload.md index c9c556b9bc..e2a5ce44eb 100644 --- a/docs/guides/file-upload.md +++ b/docs/guides/file-upload.md @@ -8,6 +8,21 @@ All Strawberry integrations support multipart uploads as described in the [GraphQL multipart request specification](https://github.com/jaydenseric/graphql-multipart-request-spec). This includes support for uploading single files as well as lists of files. +## Security + +Note that multipart file upload support is disabled by default in all +integrations. Before enabling multipart file upload support, make sure you +address the +[security implications outlined in the specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security). +Usually, this entails enabling CSRF protection in your server framework (e.g., +the `CsrfViewMiddleware` middleware in Django). + +To enable file upload support, pass `multipart_uploads_enabled=True` to your +integration's view class. Refer to the integration-specific documentation for +more details on how to do this. + +## Upload Scalar + Uploads can be used in mutations via the `Upload` scalar. The type passed at runtime depends on the integration: @@ -15,10 +30,10 @@ runtime depends on the integration: | ----------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------- | | [AIOHTTP](/docs/integrations/aiohttp) | [`io.BytesIO`](https://docs.python.org/3/library/io.html#io.BytesIO) | | [ASGI](/docs/integrations/asgi) | [`starlette.datastructures.UploadFile`](https://www.starlette.io/requests/#request-files) | -| [Channels](/docs/integrations/channels) | [`django.core.files.uploadedfile.UploadedFile`](https://docs.djangoproject.com/en/3.2/ref/files/uploads/#django.core.files.uploadedfile.UploadedFile) | -| [Django](/docs/integrations/django) | [`django.core.files.uploadedfile.UploadedFile`](https://docs.djangoproject.com/en/3.2/ref/files/uploads/#django.core.files.uploadedfile.UploadedFile) | +| [Channels](/docs/integrations/channels) | [`django.core.files.uploadedfile.UploadedFile`](https://docs.djangoproject.com/en/dev/ref/files/uploads/#django.core.files.uploadedfile.UploadedFile) | +| [Django](/docs/integrations/django) | [`django.core.files.uploadedfile.UploadedFile`](https://docs.djangoproject.com/en/dev/ref/files/uploads/#django.core.files.uploadedfile.UploadedFile) | | [FastAPI](/docs/integrations/fastapi) | [`fastapi.UploadFile`](https://fastapi.tiangolo.com/tutorial/request-files/#file-parameters-with-uploadfile) | -| [Flask](/docs/integrations/flask) | [`werkzeug.datastructures.FileStorage`](https://werkzeug.palletsprojects.com/en/2.0.x/datastructures/#werkzeug.datastructures.FileStorage) | +| [Flask](/docs/integrations/flask) | [`werkzeug.datastructures.FileStorage`](https://werkzeug.palletsprojects.com/en/latest/datastructures/#werkzeug.datastructures.FileStorage) | | [Quart](/docs/integrations/quart) | [`quart.datastructures.FileStorage`](https://github.com/pallets/quart/blob/main/src/quart/datastructures.py) | | [Sanic](/docs/integrations/sanic) | [`sanic.request.File`](https://sanic.readthedocs.io/en/stable/sanic/api/core.html#sanic.request.File) | | [Starlette](/docs/integrations/starlette) | [`starlette.datastructures.UploadFile`](https://www.starlette.io/requests/#request-files) | From 5c0f0b0d1ac9ce9f085aaf3e87b37b07bb358253 Mon Sep 17 00:00:00 2001 From: Jacob Allen Date: Sun, 6 Oct 2024 11:32:30 -0600 Subject: [PATCH 17/31] Fix codegen crash on nullable lists of non-scalars (#3653) Co-authored-by: Jacob Allen --- RELEASE.md | 3 +++ strawberry/codegen/query_codegen.py | 14 ++++++++------ tests/codegen/conftest.py | 1 + .../queries/nullable_list_of_non_scalars.graphql | 6 ++++++ .../python/nullable_list_of_non_scalars.py | 8 ++++++++ .../typescript/nullable_list_of_non_scalars.ts | 8 ++++++++ 6 files changed, 34 insertions(+), 6 deletions(-) create mode 100644 RELEASE.md create mode 100644 tests/codegen/queries/nullable_list_of_non_scalars.graphql create mode 100644 tests/codegen/snapshots/python/nullable_list_of_non_scalars.py create mode 100644 tests/codegen/snapshots/typescript/nullable_list_of_non_scalars.ts diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..087ca107ed --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: patch + +Fixes an issue where the codegen tool would crash when working with a nullable list of types. diff --git a/strawberry/codegen/query_codegen.py b/strawberry/codegen/query_codegen.py index 92d93d98db..583f7fbe39 100644 --- a/strawberry/codegen/query_codegen.py +++ b/strawberry/codegen/query_codegen.py @@ -645,20 +645,22 @@ def _unwrap_type( ) -> Tuple[ Union[type, StrawberryType], Optional[Callable[[GraphQLType], GraphQLType]] ]: - wrapper = None + wrapper: Optional[Callable[[GraphQLType], GraphQLType]] = None if isinstance(type_, StrawberryOptional): - type_, wrapper = self._unwrap_type(type_.of_type) + type_, previous_wrapper = self._unwrap_type(type_.of_type) wrapper = ( GraphQLOptional - if wrapper is None - else lambda t: GraphQLOptional(wrapper(t)) # type: ignore[misc] + if previous_wrapper is None + else lambda t: GraphQLOptional(previous_wrapper(t)) # type: ignore[misc] ) elif isinstance(type_, StrawberryList): - type_, wrapper = self._unwrap_type(type_.of_type) + type_, previous_wrapper = self._unwrap_type(type_.of_type) wrapper = ( - GraphQLList if wrapper is None else lambda t: GraphQLList(wrapper(t)) + GraphQLList + if previous_wrapper is None + else lambda t: GraphQLList(previous_wrapper(t)) ) elif isinstance(type_, LazyType): diff --git a/tests/codegen/conftest.py b/tests/codegen/conftest.py index 1b7f80e467..d776a474fd 100644 --- a/tests/codegen/conftest.py +++ b/tests/codegen/conftest.py @@ -101,6 +101,7 @@ class Query: person: Person optional_person: Optional[Person] list_of_people: List[Person] + optional_list_of_people: Optional[List[Person]] enum: Color json: JSON union: PersonOrAnimal diff --git a/tests/codegen/queries/nullable_list_of_non_scalars.graphql b/tests/codegen/queries/nullable_list_of_non_scalars.graphql new file mode 100644 index 0000000000..8bdedbe6a2 --- /dev/null +++ b/tests/codegen/queries/nullable_list_of_non_scalars.graphql @@ -0,0 +1,6 @@ +query OperationName { + optionalListOfPeople { + name + age + } +} diff --git a/tests/codegen/snapshots/python/nullable_list_of_non_scalars.py b/tests/codegen/snapshots/python/nullable_list_of_non_scalars.py new file mode 100644 index 0000000000..f5efcaa7f8 --- /dev/null +++ b/tests/codegen/snapshots/python/nullable_list_of_non_scalars.py @@ -0,0 +1,8 @@ +from typing import List, Optional + +class OperationNameResultOptionalListOfPeople: + name: str + age: int + +class OperationNameResult: + optional_list_of_people: Optional[List[OperationNameResultOptionalListOfPeople]] diff --git a/tests/codegen/snapshots/typescript/nullable_list_of_non_scalars.ts b/tests/codegen/snapshots/typescript/nullable_list_of_non_scalars.ts new file mode 100644 index 0000000000..7498cac29d --- /dev/null +++ b/tests/codegen/snapshots/typescript/nullable_list_of_non_scalars.ts @@ -0,0 +1,8 @@ +type OperationNameResultOptionalListOfPeople = { + name: string + age: number +} + +type OperationNameResult = { + optional_list_of_people: OperationNameResultOptionalListOfPeople[] | undefined +} From d7e8a9b78945f9e14e306b39ab1cff4db3b8e811 Mon Sep 17 00:00:00 2001 From: Botberry Date: Sun, 6 Oct 2024 17:33:16 +0000 Subject: [PATCH 18/31] =?UTF-8?q?Release=20=F0=9F=8D=93=200.244.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 8 ++++++++ RELEASE.md | 3 --- pyproject.toml | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 15006fcc18..88cf43815c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,14 @@ CHANGELOG ========= +0.244.1 - 2024-10-06 +-------------------- + +Fixes an issue where the codegen tool would crash when working with a nullable list of types. + +Contributed by [Jacob Allen](https://github.com/enoua5) via [PR #3653](https://github.com/strawberry-graphql/strawberry/pull/3653/) + + 0.244.0 - 2024-10-05 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index 087ca107ed..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,3 +0,0 @@ -Release type: patch - -Fixes an issue where the codegen tool would crash when working with a nullable list of types. diff --git a/pyproject.toml b/pyproject.toml index b1d1c54544..171b9a5393 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.244.0" +version = "0.244.1" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From 348448077c59b3de1c8f27f0b7a6b6332910da58 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Sun, 6 Oct 2024 20:56:10 +0200 Subject: [PATCH 19/31] Make it easier to debug close reason assertions (#3659) --- tests/http/clients/aiohttp.py | 5 +-- tests/http/clients/asgi.py | 5 +-- tests/http/clients/base.py | 3 +- tests/http/clients/channels.py | 5 +-- tests/http/clients/litestar.py | 5 +-- tests/websockets/test_graphql_transport_ws.py | 36 +++++++++---------- tests/websockets/test_graphql_ws.py | 6 ++-- tests/websockets/test_websockets.py | 10 +++--- 8 files changed, 40 insertions(+), 35 deletions(-) diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py index dcf7abb5bd..89b0c718e8 100644 --- a/tests/http/clients/aiohttp.py +++ b/tests/http/clients/aiohttp.py @@ -218,5 +218,6 @@ def close_code(self) -> int: assert self.ws.close_code is not None return self.ws.close_code - def assert_reason(self, reason: str) -> None: - assert self._reason == reason + @property + def close_reason(self) -> Optional[str]: + return self._reason diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py index b8f51bbab1..7910e02f73 100644 --- a/tests/http/clients/asgi.py +++ b/tests/http/clients/asgi.py @@ -229,5 +229,6 @@ def close_code(self) -> int: assert self._close_code is not None return self._close_code - def assert_reason(self, reason: str) -> None: - assert self._close_reason == reason + @property + def close_reason(self) -> Optional[str]: + return self._close_reason diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index c1156fc3a6..c85b2efe00 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -293,8 +293,9 @@ def closed(self) -> bool: ... @abc.abstractmethod def close_code(self) -> int: ... + @property @abc.abstractmethod - def assert_reason(self, reason: str) -> None: ... + def close_reason(self) -> Optional[str]: ... async def __aiter__(self) -> AsyncGenerator[Message, None]: while not self.closed: diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py index 0e53ce8f82..bde2364128 100644 --- a/tests/http/clients/channels.py +++ b/tests/http/clients/channels.py @@ -324,5 +324,6 @@ def close_code(self) -> int: assert self._close_code is not None return self._close_code - def assert_reason(self, reason: str) -> None: - assert self._close_reason == reason + @property + def close_reason(self) -> Optional[str]: + return self._close_reason diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py index 065b04f395..2548dc563c 100644 --- a/tests/http/clients/litestar.py +++ b/tests/http/clients/litestar.py @@ -234,5 +234,6 @@ def close_code(self) -> int: assert self._close_code is not None return self._close_code - def assert_reason(self, reason: str) -> None: - assert self._close_reason == reason + @property + def close_reason(self) -> Optional[str]: + return self._close_reason diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py index 49e9c7ce32..4dbea524f4 100644 --- a/tests/websockets/test_graphql_transport_ws.py +++ b/tests/websockets/test_graphql_transport_ws.py @@ -76,7 +76,7 @@ async def test_unknown_message_type(ws_raw: WebSocketClient): data = await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 - ws.assert_reason("Unknown message type: NOT_A_MESSAGE_TYPE") + assert ws.close_reason == "Unknown message type: NOT_A_MESSAGE_TYPE" async def test_missing_message_type(ws_raw: WebSocketClient): @@ -87,7 +87,7 @@ async def test_missing_message_type(ws_raw: WebSocketClient): data = await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 - ws.assert_reason("Failed to parse message") + assert ws.close_reason == "Failed to parse message" async def test_parsing_an_invalid_message(ws_raw: WebSocketClient): @@ -98,7 +98,7 @@ async def test_parsing_an_invalid_message(ws_raw: WebSocketClient): data = await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 - ws.assert_reason("Failed to parse message") + assert ws.close_reason == "Failed to parse message" async def test_parsing_an_invalid_payload(ws_raw: WebSocketClient): @@ -109,7 +109,7 @@ async def test_parsing_an_invalid_payload(ws_raw: WebSocketClient): data = await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 - ws.assert_reason("Failed to parse message") + assert ws.close_reason == "Failed to parse message" async def test_ws_messages_must_be_text(ws_raw: WebSocketClient): @@ -120,7 +120,7 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient): await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 - ws.assert_reason("WebSocket message type must be text") + assert ws.close_reason == "WebSocket message type must be text" async def test_ws_messages_must_be_json(ws_raw: WebSocketClient): @@ -131,7 +131,7 @@ async def test_ws_messages_must_be_json(ws_raw: WebSocketClient): await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 - ws.assert_reason("WebSocket message type must be text") + assert ws.close_reason == "WebSocket message type must be text" async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): @@ -156,7 +156,7 @@ async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 - ws.assert_reason("WebSocket message type must be text") + assert ws.close_reason == "WebSocket message type must be text" async def test_connection_init_timeout( @@ -180,7 +180,7 @@ async def test_connection_init_timeout( data = await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4408 - ws.assert_reason("Connection initialisation timeout") + assert ws.close_reason == "Connection initialisation timeout" @pytest.mark.flaky @@ -240,7 +240,7 @@ async def test_close_twice( await ws.receive(timeout=0.5) assert ws.closed assert ws.close_code == 4400 - ws.assert_reason("Invalid connection init payload") + assert ws.close_reason == "Invalid connection init payload" transport_close.assert_not_called() @@ -249,7 +249,7 @@ async def test_too_many_initialisation_requests(ws: WebSocketClient): data = await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4429 - ws.assert_reason("Too many initialisation requests") + assert ws.close_reason == "Too many initialisation requests" async def test_ping_pong(ws: WebSocketClient): @@ -320,7 +320,7 @@ async def test_unauthorized_subscriptions(ws_raw: WebSocketClient): data = await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4401 - ws.assert_reason("Unauthorized") + assert ws.close_reason == "Unauthorized" async def test_duplicated_operation_ids(ws: WebSocketClient): @@ -345,7 +345,7 @@ async def test_duplicated_operation_ids(ws: WebSocketClient): data = await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4409 - ws.assert_reason("Subscriber for sub1 already exists") + assert ws.close_reason == "Subscriber for sub1 already exists" async def test_reused_operation_ids(ws: WebSocketClient): @@ -409,7 +409,7 @@ async def test_subscription_syntax_error(ws: WebSocketClient): data = await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 - ws.assert_reason("Syntax Error: Expected Name, found .") + assert ws.close_reason == "Syntax Error: Expected Name, found ." async def test_subscription_field_errors(ws: WebSocketClient): @@ -663,7 +663,7 @@ async def test_single_result_invalid_operation_selection(ws: WebSocketClient): data = await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 - ws.assert_reason("Can't get GraphQL operation type") + assert ws.close_reason == "Can't get GraphQL operation type" async def test_single_result_execution_error(ws: WebSocketClient): @@ -743,7 +743,7 @@ async def test_single_result_duplicate_ids_sub(ws: WebSocketClient): data = await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4409 - ws.assert_reason("Subscriber for sub1 already exists") + assert ws.close_reason == "Subscriber for sub1 already exists" async def test_single_result_duplicate_ids_query(ws: WebSocketClient): @@ -774,7 +774,7 @@ async def test_single_result_duplicate_ids_query(ws: WebSocketClient): data = await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4409 - ws.assert_reason("Subscriber for sub1 already exists") + assert ws.close_reason == "Subscriber for sub1 already exists" async def test_injects_connection_params(ws_raw: WebSocketClient): @@ -804,7 +804,7 @@ async def test_rejects_connection_params_not_dict(ws_raw: WebSocketClient): data = await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 - ws.assert_reason("Invalid connection init payload") + assert ws.close_reason == "Invalid connection init payload" @pytest.mark.parametrize( @@ -820,7 +820,7 @@ async def test_rejects_connection_params_with_wrong_type( data = await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4400 - ws.assert_reason("Invalid connection init payload") + assert ws.close_reason == "Invalid connection init payload" # timings can sometimes fail currently. Until this test is rewritten when diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py index 1ad38c6d19..daf56f90fb 100644 --- a/tests/websockets/test_graphql_ws.py +++ b/tests/websockets/test_graphql_ws.py @@ -289,7 +289,7 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient): await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 1002 - ws.assert_reason("WebSocket message type must be text") + assert ws.close_reason == "WebSocket message type must be text" async def test_ws_messages_must_be_json(ws_raw: WebSocketClient): @@ -300,7 +300,7 @@ async def test_ws_messages_must_be_json(ws_raw: WebSocketClient): await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 1002 - ws.assert_reason("WebSocket message type must be text") + assert ws.close_reason == "WebSocket message type must be text" async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): @@ -326,7 +326,7 @@ async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient): await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 1002 - ws.assert_reason("WebSocket message type must be text") + assert ws.close_reason == "WebSocket message type must be text" async def test_unknown_protocol_messages_are_ignored(ws_raw: WebSocketClient): diff --git a/tests/websockets/test_websockets.py b/tests/websockets/test_websockets.py index 767617d727..2018316be0 100644 --- a/tests/websockets/test_websockets.py +++ b/tests/websockets/test_websockets.py @@ -14,7 +14,7 @@ async def test_turning_off_graphql_ws(http_client_class: Type[HttpClient]): await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4406 - ws.assert_reason("Subprotocol not acceptable") + assert ws.close_reason == "Subprotocol not acceptable" async def test_turning_off_graphql_transport_ws(http_client_class: Type[HttpClient]): @@ -27,7 +27,7 @@ async def test_turning_off_graphql_transport_ws(http_client_class: Type[HttpClie await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4406 - ws.assert_reason("Subprotocol not acceptable") + assert ws.close_reason == "Subprotocol not acceptable" async def test_turning_off_all_subprotocols(http_client_class: Type[HttpClient]): @@ -40,7 +40,7 @@ async def test_turning_off_all_subprotocols(http_client_class: Type[HttpClient]) await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4406 - ws.assert_reason("Subprotocol not acceptable") + assert ws.close_reason == "Subprotocol not acceptable" async with http_client.ws_connect( "/graphql", protocols=[GRAPHQL_WS_PROTOCOL] @@ -48,7 +48,7 @@ async def test_turning_off_all_subprotocols(http_client_class: Type[HttpClient]) await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4406 - ws.assert_reason("Subprotocol not acceptable") + assert ws.close_reason == "Subprotocol not acceptable" async def test_generally_unsupported_subprotocols_are_rejected(http_client: HttpClient): @@ -58,7 +58,7 @@ async def test_generally_unsupported_subprotocols_are_rejected(http_client: Http await ws.receive(timeout=2) assert ws.closed assert ws.close_code == 4406 - ws.assert_reason("Subprotocol not acceptable") + assert ws.close_reason == "Subprotocol not acceptable" async def test_clients_can_prefer_subprotocols(http_client_class: Type[HttpClient]): From 1228f466fac9be09a5df1f5cf299e865591dd634 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Mon, 7 Oct 2024 02:06:27 +0200 Subject: [PATCH 20/31] Stop micromanaging WS support in IDEs (#3660) * Stop conditionally disabling WS support in IDEs * Add release file * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Move import into type-checking block --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- RELEASE.md | 4 +++ docs/integrations/channels.md | 2 -- docs/integrations/django.md | 4 --- strawberry/chalice/views.py | 1 - strawberry/channels/handlers/http_handler.py | 3 -- strawberry/django/views.py | 31 +++++--------------- strawberry/flask/views.py | 1 - strawberry/http/base.py | 10 +------ strawberry/http/ides.py | 8 ----- strawberry/quart/views.py | 2 -- strawberry/static/graphiql.html | 5 +--- 11 files changed, 14 insertions(+), 57 deletions(-) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..ae000d7c0e --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,4 @@ +Release type: minor + +This release removes the dated `subscriptions_enabled` setting from the Django and Channels integrations. +Instead, WebSocket support is now enabled by default in all GraphQL IDEs. diff --git a/docs/integrations/channels.md b/docs/integrations/channels.md index ce4fc70fb6..910e4a1ffc 100644 --- a/docs/integrations/channels.md +++ b/docs/integrations/channels.md @@ -522,8 +522,6 @@ GraphQLWebsocketCommunicator( to disable it by passing `None`. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests -- `subscriptions_enabled`: optional boolean paramenter enabling subscriptions in - the GraphiQL interface, defaults to `True` - `multipart_uploads_enabled`: optional, defaults to `False`, controls whether to enable multipart uploads. Please make sure to consider the [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) diff --git a/docs/integrations/django.md b/docs/integrations/django.md index 2909aba6f9..6a127f055a 100644 --- a/docs/integrations/django.md +++ b/docs/integrations/django.md @@ -39,8 +39,6 @@ The `GraphQLView` accepts the following arguments: to disable it by passing `None`. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests -- `subscriptions_enabled`: optional boolean paramenter enabling subscriptions in - the GraphiQL interface, defaults to `False`. - `multipart_uploads_enabled`: optional, defaults to `False`, controls whether to enable multipart uploads. Please make sure to consider the [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) @@ -182,8 +180,6 @@ The `AsyncGraphQLView` accepts the following arguments: to disable it by passing `None`. - `allow_queries_via_get`: optional, defaults to `True`, whether to enable queries via `GET` requests -- `subscriptions_enabled`: optional boolean paramenter enabling subscriptions in - the GraphiQL interface, defaults to `False`. ## Extending the view diff --git a/strawberry/chalice/views.py b/strawberry/chalice/views.py index 3da2838eaa..9c131f202d 100644 --- a/strawberry/chalice/views.py +++ b/strawberry/chalice/views.py @@ -54,7 +54,6 @@ class GraphQLView( ): allow_queries_via_get: bool = True request_adapter_class = ChaliceHTTPRequestAdapter - _ide_subscription_enabled = False def __init__( self, diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index 8d682eea74..a60bf2789e 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -167,14 +167,11 @@ def __init__( graphiql: Optional[bool] = None, graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, - subscriptions_enabled: bool = True, multipart_uploads_enabled: bool = False, **kwargs: Any, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get - self.subscriptions_enabled = subscriptions_enabled - self._ide_subscriptions_enabled = subscriptions_enabled self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: diff --git a/strawberry/django/views.py b/strawberry/django/views.py index 457314d93b..23fa07e886 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -25,10 +25,8 @@ StreamingHttpResponse, ) from django.http.response import HttpResponseBase -from django.template import RequestContext, Template from django.template.exceptions import TemplateDoesNotExist from django.template.loader import render_to_string -from django.template.response import TemplateResponse from django.utils.decorators import classonlymethod from django.views.generic import View @@ -44,6 +42,8 @@ from .context import StrawberryDjangoContext if TYPE_CHECKING: + from django.template.response import TemplateResponse + from strawberry.http import GraphQLHTTPResponse from strawberry.http.ides import GraphQL_IDE @@ -137,7 +137,6 @@ async def get_form_data(self) -> FormData: class BaseView: - _ide_replace_variables = False graphql_ide_html: str def __init__( @@ -146,13 +145,11 @@ def __init__( graphiql: Optional[str] = None, graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, - subscriptions_enabled: bool = False, multipart_uploads_enabled: bool = False, **kwargs: Any, ) -> None: self.schema = schema self.allow_queries_via_get = allow_queries_via_get - self.subscriptions_enabled = subscriptions_enabled self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: @@ -215,7 +212,6 @@ class GraphQLView( ], View, ): - subscriptions_enabled = False graphiql: Optional[bool] = None graphql_ide: Optional[GraphQL_IDE] = "graphiql" allow_queries_via_get = True @@ -244,16 +240,11 @@ def dispatch( def render_graphql_ide(self, request: HttpRequest) -> HttpResponse: try: - template = Template(render_to_string("graphql/graphiql.html")) + content = render_to_string("graphql/graphiql.html") except TemplateDoesNotExist: - template = Template(self.graphql_ide_html) - - context = {"SUBSCRIPTION_ENABLED": json.dumps(self.subscriptions_enabled)} + content = self.graphql_ide_html - response = TemplateResponse(request=request, template=None, context=context) - response.content = template.render(RequestContext(request, context)) - - return response + return HttpResponse(content) class AsyncGraphQLView( @@ -269,7 +260,6 @@ class AsyncGraphQLView( ], View, ): - subscriptions_enabled = False graphiql: Optional[bool] = None graphql_ide: Optional[GraphQL_IDE] = "graphiql" allow_queries_via_get = True @@ -308,16 +298,11 @@ async def dispatch( # pyright: ignore async def render_graphql_ide(self, request: HttpRequest) -> HttpResponse: try: - template = Template(render_to_string("graphql/graphiql.html")) + content = render_to_string("graphql/graphiql.html") except TemplateDoesNotExist: - template = Template(self.graphql_ide_html) + content = self.graphql_ide_html - context = {"SUBSCRIPTION_ENABLED": json.dumps(self.subscriptions_enabled)} - - response = TemplateResponse(request=request, template=None, context=context) - response.content = template.render(RequestContext(request, context)) - - return response + return HttpResponse(content=content) def is_websocket_request(self, request: HttpRequest) -> TypeGuard[HttpRequest]: return False diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index 2dc15d6d6c..053003ac91 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -63,7 +63,6 @@ def content_type(self) -> Optional[str]: class BaseGraphQLView: - _ide_subscription_enabled = False graphql_ide: Optional[GraphQL_IDE] def __init__( diff --git a/strawberry/http/base.py b/strawberry/http/base.py index 5ab57ef65d..a906849ddf 100644 --- a/strawberry/http/base.py +++ b/strawberry/http/base.py @@ -25,10 +25,6 @@ class BaseView(Generic[Request]): graphql_ide: Optional[GraphQL_IDE] multipart_uploads_enabled: bool = False - # TODO: we might remove this in future :) - _ide_replace_variables: bool = True - _ide_subscription_enabled: bool = True - def should_render_graphql_ide(self, request: BaseRequestProtocol) -> bool: return ( request.method == "GET" @@ -64,11 +60,7 @@ def parse_query_params(self, params: QueryParams) -> Dict[str, Any]: @property def graphql_ide_html(self) -> str: - return get_graphql_ide_html( - subscription_enabled=self._ide_subscription_enabled, - replace_variables=self._ide_replace_variables, - graphql_ide=self.graphql_ide, - ) + return get_graphql_ide_html(graphql_ide=self.graphql_ide) def _is_multipart_subscriptions( self, content_type: str, params: Dict[str, str] diff --git a/strawberry/http/ides.py b/strawberry/http/ides.py index d9c52fb716..9680a0277a 100644 --- a/strawberry/http/ides.py +++ b/strawberry/http/ides.py @@ -1,4 +1,3 @@ -import json import pathlib from typing import Optional from typing_extensions import Literal @@ -7,8 +6,6 @@ def get_graphql_ide_html( - subscription_enabled: bool = True, - replace_variables: bool = True, graphql_ide: Optional[GraphQL_IDE] = "graphiql", ) -> str: here = pathlib.Path(__file__).parents[1] @@ -22,11 +19,6 @@ def get_graphql_ide_html( template = path.read_text(encoding="utf-8") - if replace_variables: - template = template.replace( - "{{ SUBSCRIPTION_ENABLED }}", json.dumps(subscription_enabled) - ) - return template diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index c7dc1257fd..528a987abc 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -52,8 +52,6 @@ class GraphQLView( ], View, ): - _ide_subscription_enabled = False - methods = ["GET", "POST"] allow_queries_via_get: bool = True request_adapter_class = QuartHTTPRequestAdapter diff --git a/strawberry/static/graphiql.html b/strawberry/static/graphiql.html index 95e34c1709..b66082a97f 100644 --- a/strawberry/static/graphiql.html +++ b/strawberry/static/graphiql.html @@ -131,10 +131,7 @@ headers["x-csrftoken"] = csrfToken; } - const subscriptionsEnabled = JSON.parse("{{ SUBSCRIPTION_ENABLED }}"); - const subscriptionUrl = subscriptionsEnabled - ? httpUrlToWebSockeUrl(fetchURL) - : null; + const subscriptionUrl = httpUrlToWebSockeUrl(fetchURL); const fetcher = GraphiQL.createFetcher({ url: fetchURL, From b30f5f19fcbb86265930e021d272b0bfedbe4d16 Mon Sep 17 00:00:00 2001 From: Botberry Date: Mon, 7 Oct 2024 00:07:15 +0000 Subject: [PATCH 21/31] =?UTF-8?q?Release=20=F0=9F=8D=93=200.245.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 9 +++++++++ RELEASE.md | 4 ---- pyproject.toml | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 88cf43815c..994c2f7ef3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,15 @@ CHANGELOG ========= +0.245.0 - 2024-10-07 +-------------------- + +This release removes the dated `subscriptions_enabled` setting from the Django and Channels integrations. +Instead, WebSocket support is now enabled by default in all GraphQL IDEs. + +Contributed by [Jonathan Ehwald](https://github.com/DoctorJohn) via [PR #3660](https://github.com/strawberry-graphql/strawberry/pull/3660/) + + 0.244.1 - 2024-10-06 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index ae000d7c0e..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,4 +0,0 @@ -Release type: minor - -This release removes the dated `subscriptions_enabled` setting from the Django and Channels integrations. -Instead, WebSocket support is now enabled by default in all GraphQL IDEs. diff --git a/pyproject.toml b/pyproject.toml index 171b9a5393..d0a8182c53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.244.1" +version = "0.245.0" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From c25bb39dad9f4c04d8e0129bb63d79f654c0ff7f Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Mon, 7 Oct 2024 11:03:32 +0200 Subject: [PATCH 22/31] Rename misleading test client argument name (#3661) --- RELEASE.md | 4 ++ strawberry/aiohttp/test/client.py | 18 ++++++- strawberry/test/client.py | 18 ++++++- tests/django/test_graphql_test_client.py | 7 --- tests/test/__init__.py | 0 tests/test/conftest.py | 66 ++++++++++++++++++++++++ tests/test/test_client.py | 34 ++++++++++++ 7 files changed, 136 insertions(+), 11 deletions(-) create mode 100644 RELEASE.md delete mode 100644 tests/django/test_graphql_test_client.py create mode 100644 tests/test/__init__.py create mode 100644 tests/test/conftest.py create mode 100644 tests/test/test_client.py diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..ec1563a653 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,4 @@ +Release type: minor + +The AIOHTTP, ASGI, and Django test clients' `asserts_errors` option has been renamed to `assert_no_errors` to better reflect its purpose. +This change is backwards-compatible, but the old option name will raise a deprecation warning. diff --git a/strawberry/aiohttp/test/client.py b/strawberry/aiohttp/test/client.py index 6f30b7e7aa..0d25f4043a 100644 --- a/strawberry/aiohttp/test/client.py +++ b/strawberry/aiohttp/test/client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import ( Any, Dict, @@ -16,8 +17,9 @@ async def query( query: str, variables: Optional[Dict[str, Mapping]] = None, headers: Optional[Dict[str, object]] = None, - asserts_errors: Optional[bool] = True, + asserts_errors: Optional[bool] = None, files: Optional[Dict[str, object]] = None, + assert_no_errors: Optional[bool] = True, ) -> Response: body = self._build_body(query, variables, files) @@ -29,7 +31,19 @@ async def query( data=data.get("data"), extensions=data.get("extensions"), ) - if asserts_errors: + + if asserts_errors is not None: + warnings.warn( + "The `asserts_errors` argument has been renamed to `assert_no_errors`", + DeprecationWarning, + stacklevel=2, + ) + + assert_no_errors = ( + assert_no_errors if asserts_errors is None else asserts_errors + ) + + if assert_no_errors: assert resp.status == 200 assert response.errors is None diff --git a/strawberry/test/client.py b/strawberry/test/client.py index 9127799763..243ac3aadb 100644 --- a/strawberry/test/client.py +++ b/strawberry/test/client.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import warnings from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Coroutine, Dict, List, Mapping, Optional, Union @@ -36,8 +37,9 @@ def query( query: str, variables: Optional[Dict[str, Mapping]] = None, headers: Optional[Dict[str, object]] = None, - asserts_errors: Optional[bool] = True, + asserts_errors: Optional[bool] = None, files: Optional[Dict[str, object]] = None, + assert_no_errors: Optional[bool] = True, ) -> Union[Coroutine[Any, Any, Response], Response]: body = self._build_body(query, variables, files) @@ -49,7 +51,19 @@ def query( data=data.get("data"), extensions=data.get("extensions"), ) - if asserts_errors: + + if asserts_errors is not None: + warnings.warn( + "The `asserts_errors` argument has been renamed to `assert_no_errors`", + DeprecationWarning, + stacklevel=2, + ) + + assert_no_errors = ( + assert_no_errors if asserts_errors is None else asserts_errors + ) + + if assert_no_errors: assert response.errors is None return response diff --git a/tests/django/test_graphql_test_client.py b/tests/django/test_graphql_test_client.py deleted file mode 100644 index 9a839b71b2..0000000000 --- a/tests/django/test_graphql_test_client.py +++ /dev/null @@ -1,7 +0,0 @@ -def test_assertion_error_not_raised_when_asserts_errors_is_false(graphql_client): - query = "{ }" - - try: - graphql_client.query(query, asserts_errors=False) - except AssertionError: - raise AssertionError diff --git a/tests/test/__init__.py b/tests/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test/conftest.py b/tests/test/conftest.py new file mode 100644 index 0000000000..30ecd326bd --- /dev/null +++ b/tests/test/conftest.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, AsyncGenerator + +import pytest + +from tests.views.schema import schema + +if TYPE_CHECKING: + from strawberry.test import BaseGraphQLTestClient + + +@asynccontextmanager +async def aiohttp_graphql_client() -> AsyncGenerator[BaseGraphQLTestClient]: + try: + from aiohttp import web + from aiohttp.test_utils import TestClient, TestServer + from strawberry.aiohttp.test import GraphQLTestClient + from strawberry.aiohttp.views import GraphQLView + except ImportError: + pytest.skip("Aiohttp not installed") + + view = GraphQLView(schema=schema) + app = web.Application() + app.router.add_route("*", "/graphql/", view) + + async with TestClient(TestServer(app)) as client: + yield GraphQLTestClient(client) + + +@asynccontextmanager +async def asgi_graphql_client() -> AsyncGenerator[BaseGraphQLTestClient]: + try: + from starlette.testclient import TestClient + + from strawberry.asgi import GraphQL + from strawberry.asgi.test import GraphQLTestClient + except ImportError: + pytest.skip("Starlette not installed") + + yield GraphQLTestClient(TestClient(GraphQL(schema))) + + +@asynccontextmanager +async def django_graphql_client() -> AsyncGenerator[BaseGraphQLTestClient]: + try: + from django.test.client import Client + + from strawberry.django.test import GraphQLTestClient + except ImportError: + pytest.skip("Django not installed") + + yield GraphQLTestClient(Client()) + + +@pytest.fixture( + params=[ + pytest.param(aiohttp_graphql_client, marks=[pytest.mark.aiohttp]), + pytest.param(asgi_graphql_client, marks=[pytest.mark.asgi]), + pytest.param(django_graphql_client, marks=[pytest.mark.django]), + ] +) +async def graphql_client(request) -> AsyncGenerator[BaseGraphQLTestClient]: + async with request.param() as graphql_client: + yield graphql_client diff --git a/tests/test/test_client.py b/tests/test/test_client.py new file mode 100644 index 0000000000..0e03d12f77 --- /dev/null +++ b/tests/test/test_client.py @@ -0,0 +1,34 @@ +from contextlib import nullcontext + +import pytest + +from strawberry.utils.await_maybe import await_maybe + + +@pytest.mark.parametrize("asserts_errors", [True, False]) +async def test_query_asserts_errors_option_is_deprecated( + graphql_client, asserts_errors +): + with pytest.warns( + DeprecationWarning, + match="The `asserts_errors` argument has been renamed to `assert_no_errors`", + ): + await await_maybe( + graphql_client.query("{ hello }", asserts_errors=asserts_errors) + ) + + +@pytest.mark.parametrize("option_name", ["asserts_errors", "assert_no_errors"]) +@pytest.mark.parametrize( + ("assert_no_errors", "expectation"), + [(True, pytest.raises(AssertionError)), (False, nullcontext())], +) +async def test_query_with_assert_no_errors_option( + graphql_client, option_name, assert_no_errors, expectation +): + query = "{ ThisIsNotAValidQuery }" + + with expectation: + await await_maybe( + graphql_client.query(query, **{option_name: assert_no_errors}) + ) From 690c61ca29a31f7d0fe11726c9d752ebe7817897 Mon Sep 17 00:00:00 2001 From: Botberry Date: Mon, 7 Oct 2024 09:04:21 +0000 Subject: [PATCH 23/31] =?UTF-8?q?Release=20=F0=9F=8D=93=200.246.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 9 +++++++++ RELEASE.md | 4 ---- pyproject.toml | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 994c2f7ef3..ddaccdc0ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,15 @@ CHANGELOG ========= +0.246.0 - 2024-10-07 +-------------------- + +The AIOHTTP, ASGI, and Django test clients' `asserts_errors` option has been renamed to `assert_no_errors` to better reflect its purpose. +This change is backwards-compatible, but the old option name will raise a deprecation warning. + +Contributed by [Jonathan Ehwald](https://github.com/DoctorJohn) via [PR #3661](https://github.com/strawberry-graphql/strawberry/pull/3661/) + + 0.245.0 - 2024-10-07 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index ec1563a653..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,4 +0,0 @@ -Release type: minor - -The AIOHTTP, ASGI, and Django test clients' `asserts_errors` option has been renamed to `assert_no_errors` to better reflect its purpose. -This change is backwards-compatible, but the old option name will raise a deprecation warning. diff --git a/pyproject.toml b/pyproject.toml index d0a8182c53..9d34ba1396 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.245.0" +version = "0.246.0" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From 7feeea9cf8c0941f392ea39df468b5ff1d8cef3e Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Mon, 7 Oct 2024 17:50:41 +0200 Subject: [PATCH 24/31] Fix heading levels of options/subclassing docs (#3662) --- docs/integrations/aiohttp.md | 8 ++++---- docs/integrations/asgi.md | 12 ++++++------ docs/integrations/chalice.md | 8 ++++---- docs/integrations/django.md | 8 ++++---- docs/integrations/fastapi.md | 10 +++++----- docs/integrations/flask.md | 8 ++++---- docs/integrations/litestar.md | 6 +++--- docs/integrations/quart.md | 8 ++++---- docs/integrations/sanic.md | 8 ++++---- 9 files changed, 38 insertions(+), 38 deletions(-) diff --git a/docs/integrations/aiohttp.md b/docs/integrations/aiohttp.md index e968e43bf2..207f51b9b9 100644 --- a/docs/integrations/aiohttp.md +++ b/docs/integrations/aiohttp.md @@ -52,7 +52,7 @@ methods: - `async process_result(self, request: aiohttp.web.Request, result: ExecutionResult) -> GraphQLHTTPResponse` - `def encode_json(self, data: GraphQLHTTPResponse) -> str` -## get_context +### get_context By overriding `GraphQLView.get_context` you can provide a custom context object for your resolvers. You can return anything here; by default GraphQLView returns @@ -83,7 +83,7 @@ called `"example"`. Then we can use the context in a resolver. In this case the resolver will return `1`. -## get_root_value +### get_root_value By overriding `GraphQLView.get_root_value` you can provide a custom root value for your schema. This is probably not used a lot but it might be useful in @@ -110,7 +110,7 @@ class Query: Here we configure a Query where requesting the `name` field will return `"Patrick"` through the custom root value. -## process_result +### process_result By overriding `GraphQLView.process_result` you can customize and/or process results before they are sent to a client. This can be useful for logging errors, @@ -141,7 +141,7 @@ class MyGraphQLView(GraphQLView): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. -## encode_json +### encode_json `encode_json` allows to customize the encoding of the JSON response. By default we use `json.dumps` but you can override this method to use a different encoder. diff --git a/docs/integrations/asgi.md b/docs/integrations/asgi.md index dcb51643b3..3c9d78bd50 100644 --- a/docs/integrations/asgi.md +++ b/docs/integrations/asgi.md @@ -51,7 +51,7 @@ We allow to extend the base `GraphQL` app, by overriding the following methods: - `async process_result(self, request: Request, result: ExecutionResult) -> GraphQLHTTPResponse` - `def encode_json(self, response_data: GraphQLHTTPResponse) -> str` -## get_context +### get_context `get_context` allows to provide a custom context object that can be used in your resolver. You can return anything here, by default we return a dictionary with @@ -78,7 +78,7 @@ called "example". Then we use the context in a resolver, the resolver will return "1" in this case. -### Setting response headers +#### Setting response headers It is possible to use `get_context` to set response headers. A common use case might be cookie-based user authentication, where your login mutation resolver @@ -97,7 +97,7 @@ class Mutation: return True ``` -### Setting background tasks +#### Setting background tasks Similarly, [background tasks](https://www.starlette.io/background/) can be set on the response via the context: @@ -116,7 +116,7 @@ class Mutation: info.context["response"].background = BackgroundTask(notify_new_flavour, name) ``` -## get_root_value +### get_root_value `get_root_value` allows to provide a custom root value for your schema, this is probably not used a lot but it might be useful in certain situations. @@ -137,7 +137,7 @@ class Query: Here we are returning a Query where the name is "Patrick", so we when requesting the field name we'll return "Patrick" in this case. -## process_result +### process_result `process_result` allows to customize and/or process results before they are sent to the clients. This can be useful logging errors or hiding them (for example to @@ -166,7 +166,7 @@ class MyGraphQL(GraphQL): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. -## encode_json +### encode_json `encode_json` allows to customize the encoding of the JSON response. By default we use `json.dumps` but you can override this method to use a different encoder. diff --git a/docs/integrations/chalice.md b/docs/integrations/chalice.md index 8c5671c33d..5608eff948 100644 --- a/docs/integrations/chalice.md +++ b/docs/integrations/chalice.md @@ -77,7 +77,7 @@ We allow to extend the base `GraphQLView`, by overriding the following methods: - `process_result(self, request: Request, result: ExecutionResult) -> GraphQLHTTPResponse` - `encode_json(self, response_data: GraphQLHTTPResponse) -> str` -## get_context +### get_context `get_context` allows to provide a custom context object that can be used in your resolver. You can return anything here, by default we return a dictionary with @@ -103,7 +103,7 @@ called "example". Then we use the context in a resolver, the resolver will return "1" in this case. -## get_root_value +### get_root_value `get_root_value` allows to provide a custom root value for your schema, this is probably not used a lot but it might be useful in certain situations. @@ -124,7 +124,7 @@ class Query: Here we are returning a Query where the name is "Patrick", so we when requesting the field name we'll return "Patrick" in this case. -## process_result +### process_result `process_result` allows to customize and/or process results before they are sent to the clients. This can be useful logging errors or hiding them (for example to @@ -151,7 +151,7 @@ class MyGraphQLView(GraphQLView): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. -## encode_json +### encode_json `encode_json` allows to customize the encoding of the JSON response. By default we use `json.dumps` but you can override this method to use a different encoder. diff --git a/docs/integrations/django.md b/docs/integrations/django.md index 6a127f055a..281b40c34a 100644 --- a/docs/integrations/django.md +++ b/docs/integrations/django.md @@ -191,7 +191,7 @@ methods: - `async process_result(self, request: HttpRequest, result: ExecutionResult) -> GraphQLHTTPResponse` - `def encode_json(self, data: GraphQLHTTPResponse) -> str` -## get_context +### get_context `get_context` allows to provide a custom context object that can be used in your resolver. You can return anything here, by default we return a dictionary with @@ -216,7 +216,7 @@ called "example". Then we use the context in a resolver, the resolver will return "1" in this case. -## get_root_value +### get_root_value `get_root_value` allows to provide a custom root value for your schema, this is probably not used a lot but it might be useful in certain situations. @@ -237,7 +237,7 @@ class Query: Here we are returning a Query where the name is "Patrick", so we when requesting the field name we'll return "Patrick" in this case. -## process_result +### process_result `process_result` allows to customize and/or process results before they are sent to the clients. This can be useful logging errors or hiding them (for example to @@ -266,7 +266,7 @@ class MyGraphQLView(AsyncGraphQLView): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. -## encode_json +### encode_json `encode_json` allows to customize the encoding of the JSON response. By default we use `json.dumps` but you can override this method to use a different encoder. diff --git a/docs/integrations/fastapi.md b/docs/integrations/fastapi.md index 18b7dae5b6..4c25ca651b 100644 --- a/docs/integrations/fastapi.md +++ b/docs/integrations/fastapi.md @@ -59,7 +59,7 @@ The `GraphQLRouter` accepts the following options: [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) when enabling this feature. -## context_getter +### context_getter The `context_getter` option allows you to provide a custom context object that can be used in your resolver. `context_getter` is a @@ -178,7 +178,7 @@ requires `.request` indexing. Then we use the context in a resolver. The resolver will return β€œHello John, you rock!” in this case. -### Setting background tasks +#### Setting background tasks Similarly, [background tasks](https://fastapi.tiangolo.com/tutorial/background-tasks/?h=background) @@ -221,7 +221,7 @@ app.include_router(graphql_app, prefix="/graphql") If using a custom context class, then background tasks should be stored within the class object as `.background_tasks`. -## root_value_getter +### root_value_getter The `root_value_getter` option allows you to provide a custom root value for your schema. This is most likely a rare usecase but might be useful in certain @@ -259,7 +259,7 @@ app.include_router(graphql_app, prefix="/graphql") Here we are returning a Query where the name is "Patrick", so when we request the field name we'll return "Patrick". -## process_result +### process_result The `process_result` option allows you to customize and/or process results before they are sent to the clients. This can be useful for logging errors or @@ -290,7 +290,7 @@ class MyGraphQLRouter(GraphQLRouter): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. -## encode_json +### encode_json `encode_json` allows to customize the encoding of the JSON response. By default we use `json.dumps` but you can override this method to use a different encoder. diff --git a/docs/integrations/flask.md b/docs/integrations/flask.md index 2e9c5eb389..69576e535c 100644 --- a/docs/integrations/flask.md +++ b/docs/integrations/flask.md @@ -64,7 +64,7 @@ async functions. -## get_context +### get_context `get_context` allows to provide a custom context object that can be used in your resolver. You can return anything here, by default we return a dictionary with @@ -90,7 +90,7 @@ called "example". Then we use the context in a resolver, the resolver will return "1" in this case. -## get_root_value +### get_root_value `get_root_value` allows to provide a custom root value for your schema, this is probably not used a lot but it might be useful in certain situations. @@ -111,7 +111,7 @@ class Query: Here we are returning a Query where the name is "Patrick", so we when requesting the field name we'll return "Patrick" in this case. -## process_result +### process_result `process_result` allows to customize and/or process results before they are sent to the clients. This can be useful logging errors or hiding them (for example to @@ -138,7 +138,7 @@ class MyGraphQLView(GraphQLView): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. -## encode_json +### encode_json `encode_json` allows to customize the encoding of the JSON response. By default we use `json.dumps` but you can override this method to use a different encoder. diff --git a/docs/integrations/litestar.md b/docs/integrations/litestar.md index b505096626..54ea4e239d 100644 --- a/docs/integrations/litestar.md +++ b/docs/integrations/litestar.md @@ -66,7 +66,7 @@ The `make_graphql_controller` function accepts the following options: [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security) when enabling this feature. -## context_getter +### context_getter The `context_getter` option allows you to provide a Litestar dependency that return a custom context object that can be used in your resolver. @@ -188,7 +188,7 @@ GraphQLController = make_graphql_controller( app = Litestar(route_handlers=[GraphQLController]) ``` -### Context typing +#### Context typing In our previous example using class based context, the actual runtime context a `CustomContext` type. Because it inherits from `BaseContext`, the `request`, @@ -283,7 +283,7 @@ GraphQLController = make_graphql_controller( app = Litestar(route_handlers=[GraphQLController]) ``` -## root_value_getter +### root_value_getter The `root_value_getter` option allows you to provide a custom root value that can be used in your resolver diff --git a/docs/integrations/quart.md b/docs/integrations/quart.md index a6d7061ba8..91b8b2a934 100644 --- a/docs/integrations/quart.md +++ b/docs/integrations/quart.md @@ -48,7 +48,7 @@ We allow to extend the base `GraphQLView`, by overriding the following methods: - `process_result(self, result: ExecutionResult) -> GraphQLHTTPResponse` - `encode_json(self, response_data: GraphQLHTTPResponse) -> str` -## get_context +### get_context `get_context` allows to provide a custom context object that can be used in your resolver. You can return anything here, by default we return a dictionary with @@ -74,7 +74,7 @@ called "example". Then we use the context in a resolver, the resolver will return "1" in this case. -## get_root_value +### get_root_value `get_root_value` allows to provide a custom root value for your schema, this is probably not used a lot but it might be useful in certain situations. @@ -95,7 +95,7 @@ class Query: Here we are returning a Query where the name is "Patrick", so we when requesting the field name we'll return "Patrick" in this case. -## process_result +### process_result `process_result` allows to customize and/or process results before they are sent to the clients. This can be useful logging errors or hiding them (for example to @@ -122,7 +122,7 @@ class MyGraphQLView(GraphQLView): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. -## encode_json +### encode_json `encode_json` allows to customize the encoding of the JSON response. By default we use `json.dumps` but you can override this method to use a different encoder. diff --git a/docs/integrations/sanic.md b/docs/integrations/sanic.md index 0a7d885c7d..dfff4c0502 100644 --- a/docs/integrations/sanic.md +++ b/docs/integrations/sanic.md @@ -44,7 +44,7 @@ methods: - `async get_root_value(self, request: Request) -> Any` - `async process_result(self, result: ExecutionResult) -> GraphQLHTTPResponse` -## get_context +### get_context By overriding `GraphQLView.get_context` you can provide a custom context object for your resolvers. You can return anything here; by default GraphQLView returns @@ -69,7 +69,7 @@ called `"example"`. Then we can use the context in a resolver. In this case the resolver will return `1`. -## get_root_value +### get_root_value By overriding `GraphQLView.get_root_value` you can provide a custom root value for your schema. This is probably not used a lot but it might be useful in @@ -91,7 +91,7 @@ class Query: Here we configure a Query where requesting the `name` field will return `"Patrick"` through the custom root value. -## process_result +### process_result By overriding `GraphQLView.process_result` you can customize and/or process results before they are sent to a client. This can be useful for logging errors, @@ -120,7 +120,7 @@ class MyGraphQLView(GraphQLView): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. -## encode_json +### encode_json `encode_json` allows to customize the encoding of the JSON response. By default we use `json.dumps` but you can override this method to use a different encoder. From 0bc97735368c23f84c8909283ccfacd66fb150b2 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Mon, 7 Oct 2024 19:41:18 +0200 Subject: [PATCH 25/31] Document how to render a custom graphql IDE (#3664) --- docs/integrations/aiohttp.md | 18 +++++++++++++ docs/integrations/asgi.md | 18 +++++++++++++ docs/integrations/chalice.md | 18 +++++++++++++ docs/integrations/channels.md | 21 +++++++++++++-- docs/integrations/django.md | 50 ++++++++++++++++++++++++++++++----- docs/integrations/fastapi.md | 17 ++++++++++++ docs/integrations/flask.md | 26 +++++++++++++++--- docs/integrations/litestar.md | 40 ++++++++++++++++++++++++++++ docs/integrations/quart.md | 26 +++++++++++++++--- docs/integrations/sanic.md | 19 +++++++++++++ strawberry/asgi/__init__.py | 2 +- 11 files changed, 238 insertions(+), 17 deletions(-) diff --git a/docs/integrations/aiohttp.md b/docs/integrations/aiohttp.md index 207f51b9b9..937bb31889 100644 --- a/docs/integrations/aiohttp.md +++ b/docs/integrations/aiohttp.md @@ -51,6 +51,7 @@ methods: - `async get_root_value(self, request: aiohttp.web.Request) -> object` - `async process_result(self, request: aiohttp.web.Request, result: ExecutionResult) -> GraphQLHTTPResponse` - `def encode_json(self, data: GraphQLHTTPResponse) -> str` +- `async def render_graphql_ide(self, request: aiohttp.web.Request) -> aiohttp.web.Response` ### get_context @@ -151,3 +152,20 @@ class MyGraphQLView(GraphQLView): def encode_json(self, data: GraphQLHTTPResponse) -> str: return json.dumps(data, indent=2) ``` + +### render_graphql_ide + +In case you need more control over the rendering of the GraphQL IDE than the +`graphql_ide` option provides, you can override the `render_graphql_ide` method. + +```python +from aiohttp import web +from strawberry.aiohttp.views import GraphQLView + + +class MyGraphQLView(GraphQLView): + async def render_graphql_ide(self, request: web.Request) -> web.Response: + custom_html = """

Custom GraphQL IDE

""" + + return web.Response(text=custom_html, content_type="text/html") +``` diff --git a/docs/integrations/asgi.md b/docs/integrations/asgi.md index 3c9d78bd50..1d5097e8d4 100644 --- a/docs/integrations/asgi.md +++ b/docs/integrations/asgi.md @@ -50,6 +50,7 @@ We allow to extend the base `GraphQL` app, by overriding the following methods: - `async get_root_value(self, request: Request) -> Any` - `async process_result(self, request: Request, result: ExecutionResult) -> GraphQLHTTPResponse` - `def encode_json(self, response_data: GraphQLHTTPResponse) -> str` +- `async def render_graphql_ide(self, request: Request) -> Response` ### get_context @@ -176,3 +177,20 @@ class MyGraphQLView(GraphQL): def encode_json(self, data: GraphQLHTTPResponse) -> str: return json.dumps(data, indent=2) ``` + +### render_graphql_ide + +In case you need more control over the rendering of the GraphQL IDE than the +`graphql_ide` option provides, you can override the `render_graphql_ide` method. + +```python +from strawberry.asgi import GraphQL +from starlette.responses import HTMLResponse, Response + + +class MyGraphQL(GraphQL): + async def render_graphql_ide(self, request: Request) -> Response: + custom_html = """

Custom GraphQL IDE

""" + + return HTMLResponse(custom_html) +``` diff --git a/docs/integrations/chalice.md b/docs/integrations/chalice.md index 5608eff948..04b8c2e23d 100644 --- a/docs/integrations/chalice.md +++ b/docs/integrations/chalice.md @@ -76,6 +76,7 @@ We allow to extend the base `GraphQLView`, by overriding the following methods: - `get_root_value(self, request: Request) -> Any` - `process_result(self, request: Request, result: ExecutionResult) -> GraphQLHTTPResponse` - `encode_json(self, response_data: GraphQLHTTPResponse) -> str` +- `def render_graphql_ide(self, request: Request) -> Response` ### get_context @@ -161,3 +162,20 @@ class MyGraphQLView(GraphQLView): def encode_json(self, data: GraphQLHTTPResponse) -> str: return json.dumps(data, indent=2) ``` + +### render_graphql_ide + +In case you need more control over the rendering of the GraphQL IDE than the +`graphql_ide` option provides, you can override the `render_graphql_ide` method. + +```python +from strawberry.chalice.views import GraphQLView +from chalice.app import Request, Response + + +class MyGraphQLView(GraphQLView): + def render_graphql_ide(self, request: Request) -> Response: + custom_html = """

Custom GraphQL IDE

""" + + return Response(custom_html, headers={"Content-Type": "text/html"}) +``` diff --git a/docs/integrations/channels.md b/docs/integrations/channels.md index 910e4a1ffc..af495a1ed9 100644 --- a/docs/integrations/channels.md +++ b/docs/integrations/channels.md @@ -533,9 +533,10 @@ We allow to extend `GraphQLHTTPConsumer`, by overriding the following methods: - `async def get_context(self, request: ChannelsRequest, response: TemporalResponse) -> Context` - `async def get_root_value(self, request: ChannelsRequest) -> Optional[RootValue]` -- `async def process_result(self, request: Request, result: ExecutionResult) -> GraphQLHTTPResponse:`. +- `async def process_result(self, request: Request, result: ExecutionResult) -> GraphQLHTTPResponse`. +- `async def render_graphql_ide(self, request: ChannelsRequest) -> ChannelsResponse` -### Context +#### Context The default context returned by `get_context()` is a `dict` that includes the following keys by default: @@ -552,6 +553,22 @@ following keys by default: errors (defaults to `200`) - `headers`: Any additional headers that should be send with the response +#### render_graphql_ide + +In case you need more control over the rendering of the GraphQL IDE than the +`graphql_ide` option provides, you can override the `render_graphql_ide` method. + +```python +from strawberry.channels import GraphQLHTTPConsumer, ChannelsRequest, ChannelsResponse + + +class MyGraphQLHTTPConsumer(GraphQLHTTPConsumer): + async def render_graphql_ide(self, request: ChannelsRequest) -> ChannelsResponse: + custom_html = """

Custom GraphQL IDE

""" + + return ChannelsResponse(content=custom_html, content_type="text/html") +``` + ## GraphQLWSConsumer (WebSockets / Subscriptions) ### Options diff --git a/docs/integrations/django.md b/docs/integrations/django.md index 281b40c34a..ee582c4797 100644 --- a/docs/integrations/django.md +++ b/docs/integrations/django.md @@ -62,11 +62,12 @@ encoding process. We allow to extend the base `GraphQLView`, by overriding the following methods: -- `get_context(self, request: HttpRequest, response: HttpResponse) -> Any` -- `get_root_value(self, request: HttpRequest) -> Any` -- `process_result(self, request: HttpRequest, result: ExecutionResult) -> GraphQLHTTPResponse` +- `def get_context(self, request: HttpRequest, response: HttpResponse) -> Any` +- `def get_root_value(self, request: HttpRequest) -> Any` +- `def process_result(self, request: HttpRequest, result: ExecutionResult) -> GraphQLHTTPResponse` +- `def render_graphql_ide(self, request: HttpRequest) -> HttpResponse` -## get_context +### get_context `get_context` allows to provide a custom context object that can be used in your resolver. You can return anything here, by default we return a @@ -101,7 +102,7 @@ called "example". Then we use the context in a resolver, the resolver will return "1" in this case. -## get_root_value +### get_root_value `get_root_value` allows to provide a custom root value for your schema, this is probably not used a lot but it might be useful in certain situations. @@ -122,7 +123,7 @@ class Query: Here we are returning a Query where the name is "Patrick", so we when requesting the field name we'll return "Patrick" in this case. -## process_result +### process_result `process_result` allows to customize and/or process results before they are sent to the clients. This can be useful logging errors or hiding them (for example to @@ -151,6 +152,24 @@ class MyGraphQLView(GraphQLView): In this case we are doing the default processing of the result, but it can be tweaked based on your needs. +### render_graphql_ide + +In case you need more control over the rendering of the GraphQL IDE than the +`graphql_ide` option provides, you can override the `render_graphql_ide` method. + +```python +from strawberry.django.views import GraphQLView +from django.http import HttpResponse +from django.template.loader import render_to_string + + +class MyGraphQLView(GraphQLView): + def render_graphql_ide(self, request: HttpRequest) -> HttpResponse: + content = render_to_string("myapp/my_graphql_ide_template.html") + + return HttpResponse(content) +``` + # Async Django Strawberry also provides an async view that you can use with Django 3.1+ @@ -190,6 +209,7 @@ methods: - `async get_root_value(self, request: HttpRequest) -> Any` - `async process_result(self, request: HttpRequest, result: ExecutionResult) -> GraphQLHTTPResponse` - `def encode_json(self, data: GraphQLHTTPResponse) -> str` +- `async def render_graphql_ide(self, request: HttpRequest) -> HttpResponse` ### get_context @@ -277,6 +297,24 @@ class MyGraphQLView(AsyncGraphQLView): return json.dumps(data, indent=2) ``` +### render_graphql_ide + +In case you need more control over the rendering of the GraphQL IDE than the +`graphql_ide` option provides, you can override the `render_graphql_ide` method. + +```python +from strawberry.django.views import AsyncGraphQLView +from django.http import HttpResponse +from django.template.loader import render_to_string + + +class MyGraphQLView(AsyncGraphQLView): + async def render_graphql_ide(self, request: HttpRequest) -> HttpResponse: + content = render_to_string("myapp/my_graphql_ide_template.html") + + return HttpResponse(content) +``` + ## Subscriptions Subscriptions run over websockets and thus depend on diff --git a/docs/integrations/fastapi.md b/docs/integrations/fastapi.md index 4c25ca651b..f9ffc1de3f 100644 --- a/docs/integrations/fastapi.md +++ b/docs/integrations/fastapi.md @@ -301,3 +301,20 @@ class MyGraphQLRouter(GraphQLRouter): def encode_json(self, data: GraphQLHTTPResponse) -> bytes: return orjson.dumps(data) ``` + +### render_graphql_ide + +In case you need more control over the rendering of the GraphQL IDE than the +`graphql_ide` option provides, you can override the `render_graphql_ide` method. + +```python +from strawberry.fastapi import GraphQLRouter +from starlette.responses import HTMLResponse, Response + + +class MyGraphQLRouter(GraphQLRouter): + async def render_graphql_ide(self, request: Request) -> HTMLResponse: + custom_html = """

Custom GraphQL IDE

""" + + return HTMLResponse(custom_html) +``` diff --git a/docs/integrations/flask.md b/docs/integrations/flask.md index 69576e535c..a8dde13b8a 100644 --- a/docs/integrations/flask.md +++ b/docs/integrations/flask.md @@ -51,10 +51,11 @@ The `GraphQLView` accepts the following options at the moment: We allow to extend the base `GraphQLView`, by overriding the following methods: -- `get_context(self, request: Request, response: Response) -> Any` -- `get_root_value(self, request: Request) -> Any` -- `process_result(self, result: ExecutionResult) -> GraphQLHTTPResponse` -- `encode_json(self, response_data: GraphQLHTTPResponse) -> str` +- `def get_context(self, request: Request, response: Response) -> Any` +- `def get_root_value(self, request: Request) -> Any` +- `def process_result(self, result: ExecutionResult) -> GraphQLHTTPResponse` +- `def encode_json(self, response_data: GraphQLHTTPResponse) -> str` +- `def render_graphql_ide(self, request: Request) -> Response` @@ -148,3 +149,20 @@ class MyGraphQLView(GraphQLView): def encode_json(self, data: GraphQLHTTPResponse) -> str: return json.dumps(data, indent=2) ``` + +### render_graphql_ide + +In case you need more control over the rendering of the GraphQL IDE than the +`graphql_ide` option provides, you can override the `render_graphql_ide` method. + +```python +from strawberry.flask.views import GraphQLView +from flask import Request, Response + + +class MyGraphQLView(GraphQLView): + def render_graphql_ide(self, request: Request) -> Response: + custom_html = """

Custom GraphQL IDE

""" + + return Response(custom_html, status=200, content_type="text/html") +``` diff --git a/docs/integrations/litestar.md b/docs/integrations/litestar.md index 54ea4e239d..86ee0bc179 100644 --- a/docs/integrations/litestar.md +++ b/docs/integrations/litestar.md @@ -317,3 +317,43 @@ GraphQLController = make_graphql_controller( app = Litestar(route_handlers=[GraphQLController]) ``` + +## Extending the controller + +The `make_graphql_controller` function returns a `GraphQLController` class that +can be extended by overriding the following methods: + +1. `async def render_graphql_ide(self, request: Request) -> Response` + +### render_graphql_ide + +In case you need more control over the rendering of the GraphQL IDE than the +`graphql_ide` option provides, you can override the `render_graphql_ide` method. + +```python +import strawberry +from strawberry.litestar import make_graphql_controller +from litestar import MediaType, Request, Response + + +@strawberry.type +class Query: + @strawberry.field + def hello(self) -> str: + return "world" + + +schema = strawberry.Schema(Query) + +GraphQLController = make_graphql_controller( + schema, + path="/graphql", +) + + +class MyGraphQLController(GraphQLController): + async def render_graphql_ide(self, request: Request) -> Response: + custom_html = """

Custom GraphQL IDE

""" + + return Response(custom_html, media_type=MediaType.HTML) +``` diff --git a/docs/integrations/quart.md b/docs/integrations/quart.md index 91b8b2a934..d863777b04 100644 --- a/docs/integrations/quart.md +++ b/docs/integrations/quart.md @@ -43,10 +43,11 @@ The `GraphQLView` accepts the following options at the moment: We allow to extend the base `GraphQLView`, by overriding the following methods: -- `get_context(self, request: Request, response: Response) -> Any` -- `get_root_value(self, request: Request) -> Any` -- `process_result(self, result: ExecutionResult) -> GraphQLHTTPResponse` -- `encode_json(self, response_data: GraphQLHTTPResponse) -> str` +- `async def get_context(self, request: Request, response: Response) -> Any` +- `async def get_root_value(self, request: Request) -> Any` +- `async def process_result(self, result: ExecutionResult) -> GraphQLHTTPResponse` +- `def encode_json(self, response_data: GraphQLHTTPResponse) -> str` +- `async def render_graphql_ide(self, request: Request) -> Response` ### get_context @@ -132,3 +133,20 @@ class MyGraphQLView(GraphQLView): def encode_json(self, data: GraphQLHTTPResponse) -> str: return json.dumps(data, indent=2) ``` + +### render_graphql_ide + +In case you need more control over the rendering of the GraphQL IDE than the +`graphql_ide` option provides, you can override the `render_graphql_ide` method. + +```python +from strawberry.quart.views import GraphQLView +from quart import Request, Response + + +class MyGraphQLView(GraphQLView): + async def render_graphql_ide(self, request: Request) -> Response: + custom_html = """

Custom GraphQL IDE

""" + + return Response(self.graphql_ide_html) +``` diff --git a/docs/integrations/sanic.md b/docs/integrations/sanic.md index dfff4c0502..2cae238b30 100644 --- a/docs/integrations/sanic.md +++ b/docs/integrations/sanic.md @@ -43,6 +43,7 @@ methods: - `async get_context(self, request: Request, response: Response) -> Any` - `async get_root_value(self, request: Request) -> Any` - `async process_result(self, result: ExecutionResult) -> GraphQLHTTPResponse` +- `async def render_graphql_ide(self, request: Request) -> HTTPResponse` ### get_context @@ -130,3 +131,21 @@ class MyGraphQLView(GraphQLView): def encode_json(self, data: GraphQLHTTPResponse) -> str: return json.dumps(data, indent=2) ``` + +### render_graphql_ide + +In case you need more control over the rendering of the GraphQL IDE than the +`graphql_ide` option provides, you can override the `render_graphql_ide` method. + +```python +from strawberry.sanic.views import GraphQLView +from sanic.request import Request +from sanic.response import HTTPResponse, html + + +class MyGraphQLView(GraphQLView): + async def render_graphql_ide(self, request: Request) -> HTTPResponse: + custom_html = """

Custom GraphQL IDE

""" + + return html(custom_html) +``` diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index d10d207987..5a3f01203d 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -186,7 +186,7 @@ async def get_sub_response( return sub_response - async def render_graphql_ide(self, request: Union[Request, WebSocket]) -> Response: + async def render_graphql_ide(self, request: Request) -> Response: return HTMLResponse(self.graphql_ide_html) def create_response( From 596461b13c3267960ac432b815fbe7bdaf68fba7 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Wed, 9 Oct 2024 09:19:09 +0100 Subject: [PATCH 26/31] Bump 3.13 to non dev (#3665) --- .github/workflows/test.yml | 2 +- noxfile.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0374bbe3d4..570700cf10 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -57,7 +57,7 @@ jobs: 3.10 3.11 3.12 - 3.13-dev + 3.13 - run: pip install poetry nox nox-poetry uv - run: nox -r -t tests -s "${{ matrix.session.session }}" diff --git a/noxfile.py b/noxfile.py index 1a36339397..7e817ea3cb 100644 --- a/noxfile.py +++ b/noxfile.py @@ -80,7 +80,7 @@ def tests(session: Session, gql_core: str) -> None: ) -@session(python=["3.11", "3.12"], name="Django tests", tags=["tests"]) +@session(python=["3.12"], name="Django tests", tags=["tests"]) @with_gql_core_parametrize("django", ["4.2.0", "4.1.0", "4.0.0", "3.2.0"]) def tests_django(session: Session, django: str, gql_core: str) -> None: session.run_always("poetry", "install", external=True) From 0d880b7bb11906302daf1d0b6c479df656e8d364 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D7=A0=D7=99=D7=A8?= <88795475+nrbnlulu@users.noreply.github.com> Date: Wed, 9 Oct 2024 13:34:40 +0300 Subject: [PATCH 27/31] Add support for using raw Python enum types in schema (#3639) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Patrick Arminio --- .alexrc | 2 +- .pre-commit-config.yaml | 6 +- RELEASE.md | 25 +++++++ docs/errors/not-a-strawberry-enum.md | 62 ---------------- docs/general/subscriptions.md | 2 +- docs/types/enums.md | 11 +++ strawberry/annotation.py | 4 +- strawberry/exceptions/__init__.py | 2 - .../exceptions/not_a_strawberry_enum.py | 38 ---------- tests/enums/test_enum.py | 71 +++++++++++++------ 10 files changed, 94 insertions(+), 129 deletions(-) create mode 100644 RELEASE.md delete mode 100644 docs/errors/not-a-strawberry-enum.md delete mode 100644 strawberry/exceptions/not_a_strawberry_enum.py diff --git a/.alexrc b/.alexrc index 587b769682..ea3756ce5a 100644 --- a/.alexrc +++ b/.alexrc @@ -12,6 +12,6 @@ "special", "primitive", "invalid", - "crash", + "crash" ] } diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 556ad1132b..6c4163edbd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.8 + rev: v0.6.9 hooks: - id: ruff-format exclude: ^tests/\w+/snapshots/ @@ -20,7 +20,7 @@ repos: files: '^docs/.*\.mdx?$' - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: trailing-whitespace - id: check-merge-conflict @@ -31,7 +31,7 @@ repos: args: ["--branch", "main"] - repo: https://github.com/adamchainz/blacken-docs - rev: 1.18.0 + rev: 1.19.0 hooks: - id: blacken-docs args: [--skip-errors] diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..fedbf0c263 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,25 @@ +Release type: patch + +This release adds support for using raw Python enum types in your schema +(enums that are not decorated with `@strawberry.enum`) + +This is useful if you have enum types from other places in your code +that you want to use in strawberry. +i.e +```py +# somewhere.py +from enum import Enum + + +class AnimalKind(Enum): + AXOLOTL, CAPYBARA = range(2) + + +# gql/animals +from somewhere import AnimalKind + + +@strawberry.type +class AnimalType: + kind: AnimalKind +``` diff --git a/docs/errors/not-a-strawberry-enum.md b/docs/errors/not-a-strawberry-enum.md deleted file mode 100644 index ceaf00c53d..0000000000 --- a/docs/errors/not-a-strawberry-enum.md +++ /dev/null @@ -1,62 +0,0 @@ ---- -title: Not a Strawberry Enum Error ---- - -# Not a Strawberry Enum Error - -## Description - -This error is thrown when trying to use an enum that is not a Strawberry enum, -for example the following code will throw this error: - -```python -import strawberry - - -# note the lack of @strawberry.enum here: -class IceCreamFlavour(Enum): - VANILLA = strawberry.enum_value("vanilla") - STRAWBERRY = strawberry.enum_value( - "strawberry", - description="Our favourite", - ) - CHOCOLATE = "chocolate" - - -@strawberry.type -class Query: - field: IceCreamFlavour - - -schema = strawberry.Schema(query=Query) -``` - -This happens because Strawberry expects all enums to be decorated with -`@strawberry.enum`. - -## How to fix this error - -You can fix this error by making sure the enum you're using is decorated with -`@strawberry.enum`. For example, the following code will fix this error: - -```python -import strawberry - - -@strawberry.enum -class IceCreamFlavour(Enum): - VANILLA = strawberry.enum_value("vanilla") - STRAWBERRY = strawberry.enum_value( - "strawberry", - description="Our favourite", - ) - CHOCOLATE = "chocolate" - - -@strawberry.type -class Query: - field: IceCreamFlavour - - -schema = strawberry.Schema(query=Query) -``` diff --git a/docs/general/subscriptions.md b/docs/general/subscriptions.md index 6363fc3051..25675c08d0 100644 --- a/docs/general/subscriptions.md +++ b/docs/general/subscriptions.md @@ -269,7 +269,7 @@ const subscriber = client.subscribe({query: ...}).subscribe({...}) subscriber.unsubscribe(); ``` -Strawberry can easily capture when a subscriber unsubscribes using an +Strawberry can capture when a subscriber unsubscribes using an `asyncio.CancelledError` exception. ```python diff --git a/docs/types/enums.md b/docs/types/enums.md index cb2e99d216..743e91f80b 100644 --- a/docs/types/enums.md +++ b/docs/types/enums.md @@ -42,6 +42,17 @@ class IceCreamFlavour(Enum): CHOCOLATE = "chocolate" ``` + + +In some cases you already have an enum defined elsewhere in your code. You can +safely use it in your schema and strawberry will generate a default graphql +implementation of it. + +The only drawback is that it is not currently possible to configure it +(documentation / renaming or using `strawberry.enum_value` on it). + + + Let's see how we can use Enums in our schema. ```python diff --git a/strawberry/annotation.py b/strawberry/annotation.py index 86d06d73e5..dff708a1a1 100644 --- a/strawberry/annotation.py +++ b/strawberry/annotation.py @@ -20,7 +20,6 @@ ) from typing_extensions import Annotated, Self, get_args, get_origin -from strawberry.exceptions.not_a_strawberry_enum import NotAStrawberryEnumError from strawberry.types.base import ( StrawberryList, StrawberryObjectDefinition, @@ -30,6 +29,7 @@ has_object_definition, ) from strawberry.types.enum import EnumDefinition +from strawberry.types.enum import enum as strawberry_enum from strawberry.types.lazy_type import LazyType from strawberry.types.private import is_private from strawberry.types.scalar import ScalarDefinition @@ -187,7 +187,7 @@ def create_enum(self, evaled_type: Any) -> EnumDefinition: try: return evaled_type._enum_definition except AttributeError: - raise NotAStrawberryEnumError(evaled_type) + return strawberry_enum(evaled_type)._enum_definition def create_list(self, evaled_type: Any) -> StrawberryList: item_type, *_ = get_args(evaled_type) diff --git a/strawberry/exceptions/__init__.py b/strawberry/exceptions/__init__.py index 7d3b3ccb1b..ee331af721 100644 --- a/strawberry/exceptions/__init__.py +++ b/strawberry/exceptions/__init__.py @@ -15,7 +15,6 @@ from .missing_dependencies import MissingOptionalDependenciesError from .missing_field_annotation import MissingFieldAnnotationError from .missing_return_annotation import MissingReturnAnnotationError -from .not_a_strawberry_enum import NotAStrawberryEnumError from .object_is_not_a_class import ObjectIsNotClassError from .object_is_not_an_enum import ObjectIsNotAnEnumError from .private_strawberry_field import PrivateStrawberryFieldError @@ -174,7 +173,6 @@ class StrawberryGraphQLError(GraphQLError): "UnresolvedFieldTypeError", "PrivateStrawberryFieldError", "MultipleStrawberryArgumentsError", - "NotAStrawberryEnumError", "ScalarAlreadyRegisteredError", "WrongNumberOfResultsReturned", "FieldWithResolverAndDefaultValueError", diff --git a/strawberry/exceptions/not_a_strawberry_enum.py b/strawberry/exceptions/not_a_strawberry_enum.py deleted file mode 100644 index dd0c94d48f..0000000000 --- a/strawberry/exceptions/not_a_strawberry_enum.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import annotations - -from functools import cached_property -from typing import TYPE_CHECKING, Optional - -from .exception import StrawberryException -from .utils.source_finder import SourceFinder - -if TYPE_CHECKING: - from enum import EnumMeta - - from .exception_source import ExceptionSource - - -class NotAStrawberryEnumError(StrawberryException): - def __init__(self, enum: EnumMeta) -> None: - self.enum = enum - - self.message = f'Enum "{enum.__name__}" is not a Strawberry enum.' - self.rich_message = ( - f"Enum `[underline]{enum.__name__}[/]` is not a Strawberry enum." - ) - self.suggestion = ( - "To fix this error you can declare the enum using `@strawberry.enum`." - ) - - self.annotation_message = "enum defined here" - - super().__init__(self.message) - - @cached_property - def exception_source(self) -> Optional[ExceptionSource]: - if self.enum is None: - return None # pragma: no cover - - source_finder = SourceFinder() - - return source_finder.find_class_from_object(self.enum) diff --git a/tests/enums/test_enum.py b/tests/enums/test_enum.py index 2189fba3d6..c0dea47813 100644 --- a/tests/enums/test_enum.py +++ b/tests/enums/test_enum.py @@ -4,7 +4,7 @@ import strawberry from strawberry.exceptions import ObjectIsNotAnEnumError -from strawberry.exceptions.not_a_strawberry_enum import NotAStrawberryEnumError +from strawberry.types.base import get_object_definition from strawberry.types.enum import EnumDefinition @@ -120,25 +120,6 @@ class IceCreamFlavour(Enum): assert definition.values[2].description is None -@pytest.mark.raises_strawberry_exception( - NotAStrawberryEnumError, match='Enum "IceCreamFlavour" is not a Strawberry enum' -) -def test_raises_error_when_using_enum_not_decorated(): - class IceCreamFlavour(Enum): - VANILLA = strawberry.enum_value("vanilla") - STRAWBERRY = strawberry.enum_value( - "strawberry", - description="Our favourite", - ) - CHOCOLATE = "chocolate" - - @strawberry.type - class Query: - flavour: IceCreamFlavour - - strawberry.Schema(query=Query) - - def test_can_use_enum_values(): @strawberry.enum class TestEnum(Enum): @@ -169,3 +150,53 @@ class TestEnum(IntEnum): assert TestEnum.D.value == 4 assert [x.value for x in TestEnum.__members__.values()] == [1, 2, 3, 4] + + +def test_default_enum_implementation() -> None: + class Foo(Enum): + BAR = "bar" + BAZ = "baz" + + @strawberry.type + class Query: + @strawberry.field + def foo(self, foo: Foo) -> Foo: + return foo + + schema = strawberry.Schema(Query) + res = schema.execute_sync("{ foo(foo: BAR) }") + assert not res.errors + assert res.data + assert res.data["foo"] == "BAR" + + +def test_default_enum_reuse() -> None: + class Foo(Enum): + BAR = "bar" + BAZ = "baz" + + @strawberry.type + class SomeType: + foo: Foo + bar: Foo + + definition = get_object_definition(SomeType, strict=True) + assert definition.fields[1].type is definition.fields[1].type + + +def test_default_enum_with_enum_value() -> None: + class Foo(Enum): + BAR = "bar" + BAZ = strawberry.enum_value("baz") + + @strawberry.type + class Query: + @strawberry.field + def foo(self, foo: Foo) -> str: + return foo.value + + schema = strawberry.Schema(Query) + res = schema.execute_sync("{ foo(foo: BAZ) }") + assert not res.errors + assert res.data + assert res.data["foo"] == "baz" From d5044283a853176df478d82dd126cc3f23980bde Mon Sep 17 00:00:00 2001 From: Botberry Date: Wed, 9 Oct 2024 10:36:32 +0000 Subject: [PATCH 28/31] =?UTF-8?q?Release=20=F0=9F=8D=93=200.246.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 30 ++++++++++++++++++++++++++++++ RELEASE.md | 25 ------------------------- pyproject.toml | 2 +- 3 files changed, 31 insertions(+), 26 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index ddaccdc0ad..3bd92715e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,36 @@ CHANGELOG ========= +0.246.1 - 2024-10-09 +-------------------- + +This release adds support for using raw Python enum types in your schema +(enums that are not decorated with `@strawberry.enum`) + +This is useful if you have enum types from other places in your code +that you want to use in strawberry. +i.e +```py +# somewhere.py +from enum import Enum + + +class AnimalKind(Enum): + AXOLOTL, CAPYBARA = range(2) + + +# gql/animals +from somewhere import AnimalKind + + +@strawberry.type +class AnimalType: + kind: AnimalKind +``` + +Contributed by [Χ Χ™Χ¨](https://github.com/nrbnlulu) via [PR #3639](https://github.com/strawberry-graphql/strawberry/pull/3639/) + + 0.246.0 - 2024-10-07 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index fedbf0c263..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,25 +0,0 @@ -Release type: patch - -This release adds support for using raw Python enum types in your schema -(enums that are not decorated with `@strawberry.enum`) - -This is useful if you have enum types from other places in your code -that you want to use in strawberry. -i.e -```py -# somewhere.py -from enum import Enum - - -class AnimalKind(Enum): - AXOLOTL, CAPYBARA = range(2) - - -# gql/animals -from somewhere import AnimalKind - - -@strawberry.type -class AnimalType: - kind: AnimalKind -``` diff --git a/pyproject.toml b/pyproject.toml index 9d34ba1396..fe2177b5d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.246.0" +version = "0.246.1" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT" From daec2360c0fffec10c60e7de27334170aa230937 Mon Sep 17 00:00:00 2001 From: Luis Gustavo Date: Sat, 12 Oct 2024 14:02:02 -0300 Subject: [PATCH 29/31] Fix CI with new Ubuntu LTS (#3667) * fix ci * fix precommit * change pip to uv * fix uv * fix pre-commit --- .github/workflows/test.yml | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 570700cf10..e45bbec64a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,10 +24,14 @@ jobs: sessions: ${{ steps.set-matrix.outputs.sessions }} steps: - uses: actions/checkout@v4 - - run: pip install poetry nox nox-poetry + - name: Install uv + uses: astral-sh/setup-uv@v3 + - run: uv venv + - run: uv pip install poetry nox nox-poetry - id: set-matrix shell: bash run: | + . .venv/bin/activate echo sessions=$( nox --json -t tests -l | jq 'map( @@ -82,7 +86,9 @@ jobs: benchmarks: name: πŸ“ˆ Benchmarks - runs-on: ubuntu-latest + + # Using this version because CodSpeed doesn't support Ubuntu 24.04 LTS yet + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 @@ -99,7 +105,7 @@ jobs: if: steps.setup-python.outputs.cache-hit != 'true' - name: Run benchmarks - uses: CodSpeedHQ/action@v2 + uses: CodSpeedHQ/action@v3 with: token: ${{ secrets.CODSPEED_TOKEN }} run: poetry run pytest tests/benchmarks --codspeed From 153da5e3d2904614d26fdc24916affb1dc0257e6 Mon Sep 17 00:00:00 2001 From: Jonathan Ehwald Date: Sat, 12 Oct 2024 19:19:28 +0200 Subject: [PATCH 30/31] Use explicit response code to avoid type error (#3666) --- RELEASE.md | 3 +++ strawberry/flask/views.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..6902fd7d51 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: patch + +This release tweaks the Flask integration's `render_graphql_ide` method to be stricter typed internally, making type checkers ever so slightly happier. diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py index 053003ac91..d93f480b1e 100644 --- a/strawberry/flask/views.py +++ b/strawberry/flask/views.py @@ -187,7 +187,8 @@ async def dispatch_request(self) -> ResponseReturnValue: # type: ignore ) async def render_graphql_ide(self, request: Request) -> Response: - return render_template_string(self.graphql_ide_html) # type: ignore + content = render_template_string(self.graphql_ide_html) + return Response(content, status=200, content_type="text/html") def is_websocket_request(self, request: Request) -> TypeGuard[Request]: return False From 56172dc6ccd8d0a0b3ba6e13c9363d7b42afe33f Mon Sep 17 00:00:00 2001 From: Botberry Date: Sat, 12 Oct 2024 17:20:09 +0000 Subject: [PATCH 31/31] =?UTF-8?q?Release=20=F0=9F=8D=93=200.246.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 8 ++++++++ RELEASE.md | 3 --- pyproject.toml | 2 +- 3 files changed, 9 insertions(+), 4 deletions(-) delete mode 100644 RELEASE.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 3bd92715e2..54aa166998 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,14 @@ CHANGELOG ========= +0.246.2 - 2024-10-12 +-------------------- + +This release tweaks the Flask integration's `render_graphql_ide` method to be stricter typed internally, making type checkers ever so slightly happier. + +Contributed by [Jonathan Ehwald](https://github.com/DoctorJohn) via [PR #3666](https://github.com/strawberry-graphql/strawberry/pull/3666/) + + 0.246.1 - 2024-10-09 -------------------- diff --git a/RELEASE.md b/RELEASE.md deleted file mode 100644 index 6902fd7d51..0000000000 --- a/RELEASE.md +++ /dev/null @@ -1,3 +0,0 @@ -Release type: patch - -This release tweaks the Flask integration's `render_graphql_ide` method to be stricter typed internally, making type checkers ever so slightly happier. diff --git a/pyproject.toml b/pyproject.toml index fe2177b5d0..29e173ab7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "strawberry-graphql" packages = [ { include = "strawberry" } ] -version = "0.246.1" +version = "0.246.2" description = "A library for creating GraphQL APIs" authors = ["Patrick Arminio "] license = "MIT"