diff --git a/.alexrc b/.alexrc
index db2914316a..ea3756ce5a 100644
--- a/.alexrc
+++ b/.alexrc
@@ -11,6 +11,7 @@
"execution",
"special",
"primitive",
- "invalid"
+ "invalid",
+ "crash"
]
}
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 0a3c02a29c..e45bbec64a 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -24,10 +24,14 @@ jobs:
sessions: ${{ steps.set-matrix.outputs.sessions }}
steps:
- uses: actions/checkout@v4
- - run: pip install poetry nox nox-poetry
+ - name: Install uv
+ uses: astral-sh/setup-uv@v3
+ - run: uv venv
+ - run: uv pip install poetry nox nox-poetry
- id: set-matrix
shell: bash
run: |
+ . .venv/bin/activate
echo sessions=$(
nox --json -t tests -l |
jq 'map(
@@ -57,21 +61,7 @@ jobs:
3.10
3.11
3.12
- 3.13-dev
-
- - name: Pip and nox cache
- id: cache
- uses: actions/cache@v4
- with:
- path: |
- ~/.cache
- ~/.nox
- .nox
- key:
- ${{ runner.os }}-nox-${{ matrix.session.session }}-${{ env.pythonLocation }}-${{
- hashFiles('**/poetry.lock') }}-${{ hashFiles('**/noxfile.py') }}
- restore-keys: |
- ${{ runner.os }}-nox-${{ matrix.session.session }}-${{ env.pythonLocation }}
+ 3.13
- run: pip install poetry nox nox-poetry uv
- run: nox -r -t tests -s "${{ matrix.session.session }}"
@@ -96,7 +86,9 @@ jobs:
benchmarks:
name: 📈 Benchmarks
- runs-on: ubuntu-latest
+
+ # Using this version because CodSpeed doesn't support Ubuntu 24.04 LTS yet
+ runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
@@ -113,7 +105,7 @@ jobs:
if: steps.setup-python.outputs.cache-hit != 'true'
- name: Run benchmarks
- uses: CodSpeedHQ/action@v2
+ uses: CodSpeedHQ/action@v3
with:
token: ${{ secrets.CODSPEED_TOKEN }}
run: poetry run pytest tests/benchmarks --codspeed
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 8b8a2d6a19..743ab03501 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.6.5
+ rev: v0.6.9
hooks:
- id: ruff-format
exclude: ^tests/\w+/snapshots/
@@ -14,7 +14,7 @@ repos:
exclude: (CHANGELOG|TWEET).md
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.6.0
+ rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-merge-conflict
@@ -25,7 +25,7 @@ repos:
args: ["--branch", "main"]
- repo: https://github.com/adamchainz/blacken-docs
- rev: 1.18.0
+ rev: 1.19.0
hooks:
- id: blacken-docs
args: [--skip-errors]
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 4943c0113a..54aa166998 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,6 +1,99 @@
CHANGELOG
=========
+0.246.2 - 2024-10-12
+--------------------
+
+This release tweaks the Flask integration's `render_graphql_ide` method to be stricter typed internally, making type checkers ever so slightly happier.
+
+Contributed by [Jonathan Ehwald](https://github.com/DoctorJohn) via [PR #3666](https://github.com/strawberry-graphql/strawberry/pull/3666/)
+
+
+0.246.1 - 2024-10-09
+--------------------
+
+This release adds support for using raw Python enum types in your schema
+(enums that are not decorated with `@strawberry.enum`)
+
+This is useful if you have enum types from other places in your code
+that you want to use in strawberry.
+i.e
+```py
+# somewhere.py
+from enum import Enum
+
+
+class AnimalKind(Enum):
+ AXOLOTL, CAPYBARA = range(2)
+
+
+# gql/animals
+from somewhere import AnimalKind
+
+
+@strawberry.type
+class AnimalType:
+ kind: AnimalKind
+```
+
+Contributed by [ניר](https://github.com/nrbnlulu) via [PR #3639](https://github.com/strawberry-graphql/strawberry/pull/3639/)
+
+
+0.246.0 - 2024-10-07
+--------------------
+
+The AIOHTTP, ASGI, and Django test clients' `asserts_errors` option has been renamed to `assert_no_errors` to better reflect its purpose.
+This change is backwards-compatible, but the old option name will raise a deprecation warning.
+
+Contributed by [Jonathan Ehwald](https://github.com/DoctorJohn) via [PR #3661](https://github.com/strawberry-graphql/strawberry/pull/3661/)
+
+
+0.245.0 - 2024-10-07
+--------------------
+
+This release removes the dated `subscriptions_enabled` setting from the Django and Channels integrations.
+Instead, WebSocket support is now enabled by default in all GraphQL IDEs.
+
+Contributed by [Jonathan Ehwald](https://github.com/DoctorJohn) via [PR #3660](https://github.com/strawberry-graphql/strawberry/pull/3660/)
+
+
+0.244.1 - 2024-10-06
+--------------------
+
+Fixes an issue where the codegen tool would crash when working with a nullable list of types.
+
+Contributed by [Jacob Allen](https://github.com/enoua5) via [PR #3653](https://github.com/strawberry-graphql/strawberry/pull/3653/)
+
+
+0.244.0 - 2024-10-05
+--------------------
+
+Starting with this release, WebSocket logic now lives in the base class shared between all HTTP integrations.
+This makes the behaviour of WebSockets much more consistent between integrations and easier to maintain.
+
+Contributed by [Jonathan Ehwald](https://github.com/DoctorJohn) via [PR #3638](https://github.com/strawberry-graphql/strawberry/pull/3638/)
+
+
+0.243.1 - 2024-09-26
+--------------------
+
+This releases adds support for Pydantic 2.9.0's Mypy plugin
+
+Contributed by [Krisque](https://github.com/chrisemke) via [PR #3632](https://github.com/strawberry-graphql/strawberry/pull/3632/)
+
+
+0.243.0 - 2024-09-25
+--------------------
+
+Starting with this release, multipart uploads are disabled by default and Strawberry Django view is no longer implicitly exempted from Django's CSRF protection.
+Both changes relieve users from implicit security implications inherited from the GraphQL multipart request specification which was enabled in Strawberry by default.
+
+These are breaking changes if you are using multipart uploads OR the Strawberry Django view.
+Migrations guides including further information are available on the Strawberry website.
+
+Contributed by [Jonathan Ehwald](https://github.com/DoctorJohn) via [PR #3645](https://github.com/strawberry-graphql/strawberry/pull/3645/)
+
+
0.242.0 - 2024-09-19
--------------------
diff --git a/docs/breaking-changes.md b/docs/breaking-changes.md
index 9f0ed20a80..3df11e2b23 100644
--- a/docs/breaking-changes.md
+++ b/docs/breaking-changes.md
@@ -4,6 +4,7 @@ title: List of breaking changes and deprecations
# List of breaking changes and deprecations
+- [Version 0.243.0 - 25 September 2024](./breaking-changes/0.243.0.md)
- [Version 0.240.0 - 10 September 2024](./breaking-changes/0.240.0.md)
- [Version 0.236.0 - 17 July 2024](./breaking-changes/0.236.0.md)
- [Version 0.233.0 - 29 May 2024](./breaking-changes/0.233.0.md)
diff --git a/docs/breaking-changes/0.243.0.md b/docs/breaking-changes/0.243.0.md
new file mode 100644
index 0000000000..4eca524f94
--- /dev/null
+++ b/docs/breaking-changes/0.243.0.md
@@ -0,0 +1,53 @@
+---
+title: 0.243.0 Breaking Changes
+slug: breaking-changes/0.243.0
+---
+
+# v0.243.0 Breaking Changes
+
+Release v0.243.0 comes with two breaking changes regarding multipart file
+uploads and Django CSRF protection.
+
+## Multipart uploads disabled by default
+
+Previously, support for uploads via the
+[GraphQL multipart request specification](https://github.com/jaydenseric/graphql-multipart-request-spec)
+was enabled by default. This implicitly required Strawberry users to consider
+the
+[security implications outlined in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security).
+Given that most Strawberry users were likely not aware of this, we're making
+multipart file upload support stictly opt-in via a new
+`multipart_uploads_enabled` view settings.
+
+To enable multipart upload support for your Strawberry view integration, please
+follow the updated integration guides and enable appropriate security
+measurements for your server.
+
+## Django CSRF protection enabled
+
+Previously, the Strawberry Django view integration was internally exempted from
+Django's built-in CSRF protection (i.e, the `CsrfViewMiddleware` middleware).
+While this is how many GraphQL APIs operate, implicitly addded exemptions can
+lead to security vulnerabilities. Instead, we delegate the decision of adding an
+CSRF exemption to users now.
+
+Note that having the CSRF protection enabled on your Strawberry Django view
+potentially requires all your clients to send an CSRF token with every request.
+You can learn more about this in the official Django
+[Cross Site Request Forgery protection documentation](https://docs.djangoproject.com/en/dev/ref/csrf/).
+
+To restore the behaviour of the integration before this release, you can add the
+`csrf_exempt` decorator provided by Django yourself:
+
+```python
+from django.urls import path
+from django.views.decorators.csrf import csrf_exempt
+
+from strawberry.django.views import GraphQLView
+
+from api.schema import schema
+
+urlpatterns = [
+ path("graphql/", csrf_exempt(GraphQLView.as_view(schema=schema))),
+]
+```
diff --git a/docs/errors/not-a-strawberry-enum.md b/docs/errors/not-a-strawberry-enum.md
deleted file mode 100644
index ceaf00c53d..0000000000
--- a/docs/errors/not-a-strawberry-enum.md
+++ /dev/null
@@ -1,62 +0,0 @@
----
-title: Not a Strawberry Enum Error
----
-
-# Not a Strawberry Enum Error
-
-## Description
-
-This error is thrown when trying to use an enum that is not a Strawberry enum,
-for example the following code will throw this error:
-
-```python
-import strawberry
-
-
-# note the lack of @strawberry.enum here:
-class IceCreamFlavour(Enum):
- VANILLA = strawberry.enum_value("vanilla")
- STRAWBERRY = strawberry.enum_value(
- "strawberry",
- description="Our favourite",
- )
- CHOCOLATE = "chocolate"
-
-
-@strawberry.type
-class Query:
- field: IceCreamFlavour
-
-
-schema = strawberry.Schema(query=Query)
-```
-
-This happens because Strawberry expects all enums to be decorated with
-`@strawberry.enum`.
-
-## How to fix this error
-
-You can fix this error by making sure the enum you're using is decorated with
-`@strawberry.enum`. For example, the following code will fix this error:
-
-```python
-import strawberry
-
-
-@strawberry.enum
-class IceCreamFlavour(Enum):
- VANILLA = strawberry.enum_value("vanilla")
- STRAWBERRY = strawberry.enum_value(
- "strawberry",
- description="Our favourite",
- )
- CHOCOLATE = "chocolate"
-
-
-@strawberry.type
-class Query:
- field: IceCreamFlavour
-
-
-schema = strawberry.Schema(query=Query)
-```
diff --git a/docs/general/subscriptions.md b/docs/general/subscriptions.md
index 7d4161bdc4..25675c08d0 100644
--- a/docs/general/subscriptions.md
+++ b/docs/general/subscriptions.md
@@ -254,6 +254,55 @@ schema = strawberry.Schema(query=Query, subscription=Subscription)
[pep-525]: https://www.python.org/dev/peps/pep-0525/
+## Unsubscribing subscriptions
+
+In GraphQL, it is possible to unsubscribe from a subscription. Strawberry
+supports this behaviour, and is done using a `try...except` block.
+
+In Apollo-client, closing a subscription can be achieved like the following:
+
+```javascript
+const client = useApolloClient();
+const subscriber = client.subscribe({query: ...}).subscribe({...})
+// ...
+// done with subscription. now unsubscribe
+subscriber.unsubscribe();
+```
+
+Strawberry can capture when a subscriber unsubscribes using an
+`asyncio.CancelledError` exception.
+
+```python
+import asyncio
+from typing import AsyncGenerator
+from uuid import uuid4
+
+import strawberry
+
+# track active subscribers
+event_messages = {}
+
+
+@strawberry.type
+class Subscription:
+ @strawberry.subscription
+ async def message(self) -> AsyncGenerator[int, None]:
+ try:
+ subscription_id = uuid4()
+
+ event_messages[subscription_id] = []
+
+ while True:
+ if len(event_messages[subscription_id]) > 0:
+ yield event_messages[subscription_id]
+ event_messages[subscription_id].clear()
+
+ await asyncio.sleep(1)
+ except asyncio.CancelledError:
+ # stop listening to events
+ del event_messages[subscription_id]
+```
+
## GraphQL over WebSocket protocols
Strawberry support both the legacy
diff --git a/docs/guides/file-upload.md b/docs/guides/file-upload.md
index c9c556b9bc..e2a5ce44eb 100644
--- a/docs/guides/file-upload.md
+++ b/docs/guides/file-upload.md
@@ -8,6 +8,21 @@ All Strawberry integrations support multipart uploads as described in the
[GraphQL multipart request specification](https://github.com/jaydenseric/graphql-multipart-request-spec).
This includes support for uploading single files as well as lists of files.
+## Security
+
+Note that multipart file upload support is disabled by default in all
+integrations. Before enabling multipart file upload support, make sure you
+address the
+[security implications outlined in the specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security).
+Usually, this entails enabling CSRF protection in your server framework (e.g.,
+the `CsrfViewMiddleware` middleware in Django).
+
+To enable file upload support, pass `multipart_uploads_enabled=True` to your
+integration's view class. Refer to the integration-specific documentation for
+more details on how to do this.
+
+## Upload Scalar
+
Uploads can be used in mutations via the `Upload` scalar. The type passed at
runtime depends on the integration:
@@ -15,10 +30,10 @@ runtime depends on the integration:
| ----------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------- |
| [AIOHTTP](/docs/integrations/aiohttp) | [`io.BytesIO`](https://docs.python.org/3/library/io.html#io.BytesIO) |
| [ASGI](/docs/integrations/asgi) | [`starlette.datastructures.UploadFile`](https://www.starlette.io/requests/#request-files) |
-| [Channels](/docs/integrations/channels) | [`django.core.files.uploadedfile.UploadedFile`](https://docs.djangoproject.com/en/3.2/ref/files/uploads/#django.core.files.uploadedfile.UploadedFile) |
-| [Django](/docs/integrations/django) | [`django.core.files.uploadedfile.UploadedFile`](https://docs.djangoproject.com/en/3.2/ref/files/uploads/#django.core.files.uploadedfile.UploadedFile) |
+| [Channels](/docs/integrations/channels) | [`django.core.files.uploadedfile.UploadedFile`](https://docs.djangoproject.com/en/dev/ref/files/uploads/#django.core.files.uploadedfile.UploadedFile) |
+| [Django](/docs/integrations/django) | [`django.core.files.uploadedfile.UploadedFile`](https://docs.djangoproject.com/en/dev/ref/files/uploads/#django.core.files.uploadedfile.UploadedFile) |
| [FastAPI](/docs/integrations/fastapi) | [`fastapi.UploadFile`](https://fastapi.tiangolo.com/tutorial/request-files/#file-parameters-with-uploadfile) |
-| [Flask](/docs/integrations/flask) | [`werkzeug.datastructures.FileStorage`](https://werkzeug.palletsprojects.com/en/2.0.x/datastructures/#werkzeug.datastructures.FileStorage) |
+| [Flask](/docs/integrations/flask) | [`werkzeug.datastructures.FileStorage`](https://werkzeug.palletsprojects.com/en/latest/datastructures/#werkzeug.datastructures.FileStorage) |
| [Quart](/docs/integrations/quart) | [`quart.datastructures.FileStorage`](https://github.com/pallets/quart/blob/main/src/quart/datastructures.py) |
| [Sanic](/docs/integrations/sanic) | [`sanic.request.File`](https://sanic.readthedocs.io/en/stable/sanic/api/core.html#sanic.request.File) |
| [Starlette](/docs/integrations/starlette) | [`starlette.datastructures.UploadFile`](https://www.starlette.io/requests/#request-files) |
diff --git a/docs/integrations/aiohttp.md b/docs/integrations/aiohttp.md
index cb622234fd..937bb31889 100644
--- a/docs/integrations/aiohttp.md
+++ b/docs/integrations/aiohttp.md
@@ -29,7 +29,7 @@ app.router.add_route("*", "/graphql", GraphQLView(schema=schema))
## Options
-The `GraphQLView` accepts two options at the moment:
+The `GraphQLView` accepts the following options at the moment:
- `schema`: mandatory, the schema created by `strawberry.Schema`.
- `graphql_ide`: optional, defaults to `"graphiql"`, allows to choose the
@@ -37,6 +37,10 @@ The `GraphQLView` accepts two options at the moment:
to disable it by passing `None`.
- `allow_queries_via_get`: optional, defaults to `True`, whether to enable
queries via `GET` requests
+- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether
+ to enable multipart uploads. Please make sure to consider the
+ [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security)
+ when enabling this feature.
## Extending the view
@@ -47,8 +51,9 @@ methods:
- `async get_root_value(self, request: aiohttp.web.Request) -> object`
- `async process_result(self, request: aiohttp.web.Request, result: ExecutionResult) -> GraphQLHTTPResponse`
- `def encode_json(self, data: GraphQLHTTPResponse) -> str`
+- `async def render_graphql_ide(self, request: aiohttp.web.Request) -> aiohttp.web.Response`
-## get_context
+### get_context
By overriding `GraphQLView.get_context` you can provide a custom context object
for your resolvers. You can return anything here; by default GraphQLView returns
@@ -79,7 +84,7 @@ called `"example"`.
Then we can use the context in a resolver. In this case the resolver will return
`1`.
-## get_root_value
+### get_root_value
By overriding `GraphQLView.get_root_value` you can provide a custom root value
for your schema. This is probably not used a lot but it might be useful in
@@ -106,7 +111,7 @@ class Query:
Here we configure a Query where requesting the `name` field will return
`"Patrick"` through the custom root value.
-## process_result
+### process_result
By overriding `GraphQLView.process_result` you can customize and/or process
results before they are sent to a client. This can be useful for logging errors,
@@ -137,7 +142,7 @@ class MyGraphQLView(GraphQLView):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.
-## encode_json
+### encode_json
`encode_json` allows to customize the encoding of the JSON response. By default
we use `json.dumps` but you can override this method to use a different encoder.
@@ -147,3 +152,20 @@ class MyGraphQLView(GraphQLView):
def encode_json(self, data: GraphQLHTTPResponse) -> str:
return json.dumps(data, indent=2)
```
+
+### render_graphql_ide
+
+In case you need more control over the rendering of the GraphQL IDE than the
+`graphql_ide` option provides, you can override the `render_graphql_ide` method.
+
+```python
+from aiohttp import web
+from strawberry.aiohttp.views import GraphQLView
+
+
+class MyGraphQLView(GraphQLView):
+ async def render_graphql_ide(self, request: web.Request) -> web.Response:
+ custom_html = """
Custom GraphQL IDE
"""
+
+ return web.Response(text=custom_html, content_type="text/html")
+```
diff --git a/docs/integrations/asgi.md b/docs/integrations/asgi.md
index eaad8a52ff..1d5097e8d4 100644
--- a/docs/integrations/asgi.md
+++ b/docs/integrations/asgi.md
@@ -29,7 +29,7 @@ app with `uvicorn server:app`
## Options
-The `GraphQL` app accepts two options at the moment:
+The `GraphQL` app accepts the following options at the moment:
- `schema`: mandatory, the schema created by `strawberry.Schema`.
- `graphql_ide`: optional, defaults to `"graphiql"`, allows to choose the
@@ -37,6 +37,10 @@ The `GraphQL` app accepts two options at the moment:
to disable it by passing `None`.
- `allow_queries_via_get`: optional, defaults to `True`, whether to enable
queries via `GET` requests
+- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether
+ to enable multipart uploads. Please make sure to consider the
+ [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security)
+ when enabling this feature.
## Extending the view
@@ -46,8 +50,9 @@ We allow to extend the base `GraphQL` app, by overriding the following methods:
- `async get_root_value(self, request: Request) -> Any`
- `async process_result(self, request: Request, result: ExecutionResult) -> GraphQLHTTPResponse`
- `def encode_json(self, response_data: GraphQLHTTPResponse) -> str`
+- `async def render_graphql_ide(self, request: Request) -> Response`
-## get_context
+### get_context
`get_context` allows to provide a custom context object that can be used in your
resolver. You can return anything here, by default we return a dictionary with
@@ -74,7 +79,7 @@ called "example".
Then we use the context in a resolver, the resolver will return "1" in this
case.
-### Setting response headers
+#### Setting response headers
It is possible to use `get_context` to set response headers. A common use case
might be cookie-based user authentication, where your login mutation resolver
@@ -93,7 +98,7 @@ class Mutation:
return True
```
-### Setting background tasks
+#### Setting background tasks
Similarly, [background tasks](https://www.starlette.io/background/) can be set
on the response via the context:
@@ -112,7 +117,7 @@ class Mutation:
info.context["response"].background = BackgroundTask(notify_new_flavour, name)
```
-## get_root_value
+### get_root_value
`get_root_value` allows to provide a custom root value for your schema, this is
probably not used a lot but it might be useful in certain situations.
@@ -133,7 +138,7 @@ class Query:
Here we are returning a Query where the name is "Patrick", so we when requesting
the field name we'll return "Patrick" in this case.
-## process_result
+### process_result
`process_result` allows to customize and/or process results before they are sent
to the clients. This can be useful logging errors or hiding them (for example to
@@ -162,7 +167,7 @@ class MyGraphQL(GraphQL):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.
-## encode_json
+### encode_json
`encode_json` allows to customize the encoding of the JSON response. By default
we use `json.dumps` but you can override this method to use a different encoder.
@@ -172,3 +177,20 @@ class MyGraphQLView(GraphQL):
def encode_json(self, data: GraphQLHTTPResponse) -> str:
return json.dumps(data, indent=2)
```
+
+### render_graphql_ide
+
+In case you need more control over the rendering of the GraphQL IDE than the
+`graphql_ide` option provides, you can override the `render_graphql_ide` method.
+
+```python
+from strawberry.asgi import GraphQL
+from starlette.responses import HTMLResponse, Response
+
+
+class MyGraphQL(GraphQL):
+ async def render_graphql_ide(self, request: Request) -> Response:
+ custom_html = """Custom GraphQL IDE
"""
+
+ return HTMLResponse(custom_html)
+```
diff --git a/docs/integrations/chalice.md b/docs/integrations/chalice.md
index 8c5671c33d..04b8c2e23d 100644
--- a/docs/integrations/chalice.md
+++ b/docs/integrations/chalice.md
@@ -76,8 +76,9 @@ We allow to extend the base `GraphQLView`, by overriding the following methods:
- `get_root_value(self, request: Request) -> Any`
- `process_result(self, request: Request, result: ExecutionResult) -> GraphQLHTTPResponse`
- `encode_json(self, response_data: GraphQLHTTPResponse) -> str`
+- `def render_graphql_ide(self, request: Request) -> Response`
-## get_context
+### get_context
`get_context` allows to provide a custom context object that can be used in your
resolver. You can return anything here, by default we return a dictionary with
@@ -103,7 +104,7 @@ called "example".
Then we use the context in a resolver, the resolver will return "1" in this
case.
-## get_root_value
+### get_root_value
`get_root_value` allows to provide a custom root value for your schema, this is
probably not used a lot but it might be useful in certain situations.
@@ -124,7 +125,7 @@ class Query:
Here we are returning a Query where the name is "Patrick", so we when requesting
the field name we'll return "Patrick" in this case.
-## process_result
+### process_result
`process_result` allows to customize and/or process results before they are sent
to the clients. This can be useful logging errors or hiding them (for example to
@@ -151,7 +152,7 @@ class MyGraphQLView(GraphQLView):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.
-## encode_json
+### encode_json
`encode_json` allows to customize the encoding of the JSON response. By default
we use `json.dumps` but you can override this method to use a different encoder.
@@ -161,3 +162,20 @@ class MyGraphQLView(GraphQLView):
def encode_json(self, data: GraphQLHTTPResponse) -> str:
return json.dumps(data, indent=2)
```
+
+### render_graphql_ide
+
+In case you need more control over the rendering of the GraphQL IDE than the
+`graphql_ide` option provides, you can override the `render_graphql_ide` method.
+
+```python
+from strawberry.chalice.views import GraphQLView
+from chalice.app import Request, Response
+
+
+class MyGraphQLView(GraphQLView):
+ def render_graphql_ide(self, request: Request) -> Response:
+ custom_html = """Custom GraphQL IDE
"""
+
+ return Response(custom_html, headers={"Content-Type": "text/html"})
+```
diff --git a/docs/integrations/channels.md b/docs/integrations/channels.md
index 4fa4ae716f..af495a1ed9 100644
--- a/docs/integrations/channels.md
+++ b/docs/integrations/channels.md
@@ -522,8 +522,10 @@ GraphQLWebsocketCommunicator(
to disable it by passing `None`.
- `allow_queries_via_get`: optional, defaults to `True`, whether to enable
queries via `GET` requests
-- `subscriptions_enabled`: optional boolean paramenter enabling subscriptions in
- the GraphiQL interface, defaults to `True`
+- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether
+ to enable multipart uploads. Please make sure to consider the
+ [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security)
+ when enabling this feature.
### Extending the consumer
@@ -531,9 +533,10 @@ We allow to extend `GraphQLHTTPConsumer`, by overriding the following methods:
- `async def get_context(self, request: ChannelsRequest, response: TemporalResponse) -> Context`
- `async def get_root_value(self, request: ChannelsRequest) -> Optional[RootValue]`
-- `async def process_result(self, request: Request, result: ExecutionResult) -> GraphQLHTTPResponse:`.
+- `async def process_result(self, request: Request, result: ExecutionResult) -> GraphQLHTTPResponse`.
+- `async def render_graphql_ide(self, request: ChannelsRequest) -> ChannelsResponse`
-### Context
+#### Context
The default context returned by `get_context()` is a `dict` that includes the
following keys by default:
@@ -550,6 +553,22 @@ following keys by default:
errors (defaults to `200`)
- `headers`: Any additional headers that should be send with the response
+#### render_graphql_ide
+
+In case you need more control over the rendering of the GraphQL IDE than the
+`graphql_ide` option provides, you can override the `render_graphql_ide` method.
+
+```python
+from strawberry.channels import GraphQLHTTPConsumer, ChannelsRequest, ChannelsResponse
+
+
+class MyGraphQLHTTPConsumer(GraphQLHTTPConsumer):
+ async def render_graphql_ide(self, request: ChannelsRequest) -> ChannelsResponse:
+ custom_html = """Custom GraphQL IDE
"""
+
+ return ChannelsResponse(content=custom_html, content_type="text/html")
+```
+
## GraphQLWSConsumer (WebSockets / Subscriptions)
### Options
diff --git a/docs/integrations/django.md b/docs/integrations/django.md
index 2a4e75266a..ee582c4797 100644
--- a/docs/integrations/django.md
+++ b/docs/integrations/django.md
@@ -10,13 +10,14 @@ It provides a view that you can use to serve your GraphQL schema:
```python
from django.urls import path
+from django.views.decorators.csrf import csrf_exempt
from strawberry.django.views import GraphQLView
from api.schema import schema
urlpatterns = [
- path("graphql/", GraphQLView.as_view(schema=schema)),
+ path("graphql/", csrf_exempt(GraphQLView.as_view(schema=schema))),
]
```
@@ -38,8 +39,10 @@ The `GraphQLView` accepts the following arguments:
to disable it by passing `None`.
- `allow_queries_via_get`: optional, defaults to `True`, whether to enable
queries via `GET` requests
-- `subscriptions_enabled`: optional boolean paramenter enabling subscriptions in
- the GraphiQL interface, defaults to `False`.
+- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether
+ to enable multipart uploads. Please make sure to consider the
+ [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security)
+ when enabling this feature.
## Deprecated options
@@ -59,11 +62,12 @@ encoding process.
We allow to extend the base `GraphQLView`, by overriding the following methods:
-- `get_context(self, request: HttpRequest, response: HttpResponse) -> Any`
-- `get_root_value(self, request: HttpRequest) -> Any`
-- `process_result(self, request: HttpRequest, result: ExecutionResult) -> GraphQLHTTPResponse`
+- `def get_context(self, request: HttpRequest, response: HttpResponse) -> Any`
+- `def get_root_value(self, request: HttpRequest) -> Any`
+- `def process_result(self, request: HttpRequest, result: ExecutionResult) -> GraphQLHTTPResponse`
+- `def render_graphql_ide(self, request: HttpRequest) -> HttpResponse`
-## get_context
+### get_context
`get_context` allows to provide a custom context object that can be used in your
resolver. You can return anything here, by default we return a
@@ -98,7 +102,7 @@ called "example".
Then we use the context in a resolver, the resolver will return "1" in this
case.
-## get_root_value
+### get_root_value
`get_root_value` allows to provide a custom root value for your schema, this is
probably not used a lot but it might be useful in certain situations.
@@ -119,7 +123,7 @@ class Query:
Here we are returning a Query where the name is "Patrick", so we when requesting
the field name we'll return "Patrick" in this case.
-## process_result
+### process_result
`process_result` allows to customize and/or process results before they are sent
to the clients. This can be useful logging errors or hiding them (for example to
@@ -148,6 +152,24 @@ class MyGraphQLView(GraphQLView):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.
+### render_graphql_ide
+
+In case you need more control over the rendering of the GraphQL IDE than the
+`graphql_ide` option provides, you can override the `render_graphql_ide` method.
+
+```python
+from strawberry.django.views import GraphQLView
+from django.http import HttpResponse
+from django.template.loader import render_to_string
+
+
+class MyGraphQLView(GraphQLView):
+ def render_graphql_ide(self, request: HttpRequest) -> HttpResponse:
+ content = render_to_string("myapp/my_graphql_ide_template.html")
+
+ return HttpResponse(content)
+```
+
# Async Django
Strawberry also provides an async view that you can use with Django 3.1+
@@ -177,8 +199,6 @@ The `AsyncGraphQLView` accepts the following arguments:
to disable it by passing `None`.
- `allow_queries_via_get`: optional, defaults to `True`, whether to enable
queries via `GET` requests
-- `subscriptions_enabled`: optional boolean paramenter enabling subscriptions in
- the GraphiQL interface, defaults to `False`.
## Extending the view
@@ -189,8 +209,9 @@ methods:
- `async get_root_value(self, request: HttpRequest) -> Any`
- `async process_result(self, request: HttpRequest, result: ExecutionResult) -> GraphQLHTTPResponse`
- `def encode_json(self, data: GraphQLHTTPResponse) -> str`
+- `async def render_graphql_ide(self, request: HttpRequest) -> HttpResponse`
-## get_context
+### get_context
`get_context` allows to provide a custom context object that can be used in your
resolver. You can return anything here, by default we return a dictionary with
@@ -215,7 +236,7 @@ called "example".
Then we use the context in a resolver, the resolver will return "1" in this
case.
-## get_root_value
+### get_root_value
`get_root_value` allows to provide a custom root value for your schema, this is
probably not used a lot but it might be useful in certain situations.
@@ -236,7 +257,7 @@ class Query:
Here we are returning a Query where the name is "Patrick", so we when requesting
the field name we'll return "Patrick" in this case.
-## process_result
+### process_result
`process_result` allows to customize and/or process results before they are sent
to the clients. This can be useful logging errors or hiding them (for example to
@@ -265,7 +286,7 @@ class MyGraphQLView(AsyncGraphQLView):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.
-## encode_json
+### encode_json
`encode_json` allows to customize the encoding of the JSON response. By default
we use `json.dumps` but you can override this method to use a different encoder.
@@ -276,6 +297,24 @@ class MyGraphQLView(AsyncGraphQLView):
return json.dumps(data, indent=2)
```
+### render_graphql_ide
+
+In case you need more control over the rendering of the GraphQL IDE than the
+`graphql_ide` option provides, you can override the `render_graphql_ide` method.
+
+```python
+from strawberry.django.views import AsyncGraphQLView
+from django.http import HttpResponse
+from django.template.loader import render_to_string
+
+
+class MyGraphQLView(AsyncGraphQLView):
+ async def render_graphql_ide(self, request: HttpRequest) -> HttpResponse:
+ content = render_to_string("myapp/my_graphql_ide_template.html")
+
+ return HttpResponse(content)
+```
+
## Subscriptions
Subscriptions run over websockets and thus depend on
diff --git a/docs/integrations/fastapi.md b/docs/integrations/fastapi.md
index 1fa144b974..f9ffc1de3f 100644
--- a/docs/integrations/fastapi.md
+++ b/docs/integrations/fastapi.md
@@ -54,8 +54,12 @@ The `GraphQLRouter` accepts the following options:
value.
- `root_value_getter`: optional FastAPI dependency for providing custom root
value.
+- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether
+ to enable multipart uploads. Please make sure to consider the
+ [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security)
+ when enabling this feature.
-## context_getter
+### context_getter
The `context_getter` option allows you to provide a custom context object that
can be used in your resolver. `context_getter` is a
@@ -174,7 +178,7 @@ requires `.request` indexing.
Then we use the context in a resolver. The resolver will return “Hello John, you
rock!” in this case.
-### Setting background tasks
+#### Setting background tasks
Similarly,
[background tasks](https://fastapi.tiangolo.com/tutorial/background-tasks/?h=background)
@@ -217,7 +221,7 @@ app.include_router(graphql_app, prefix="/graphql")
If using a custom context class, then background tasks should be stored within
the class object as `.background_tasks`.
-## root_value_getter
+### root_value_getter
The `root_value_getter` option allows you to provide a custom root value for
your schema. This is most likely a rare usecase but might be useful in certain
@@ -255,7 +259,7 @@ app.include_router(graphql_app, prefix="/graphql")
Here we are returning a Query where the name is "Patrick", so when we request
the field name we'll return "Patrick".
-## process_result
+### process_result
The `process_result` option allows you to customize and/or process results
before they are sent to the clients. This can be useful for logging errors or
@@ -286,7 +290,7 @@ class MyGraphQLRouter(GraphQLRouter):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.
-## encode_json
+### encode_json
`encode_json` allows to customize the encoding of the JSON response. By default
we use `json.dumps` but you can override this method to use a different encoder.
@@ -297,3 +301,20 @@ class MyGraphQLRouter(GraphQLRouter):
def encode_json(self, data: GraphQLHTTPResponse) -> bytes:
return orjson.dumps(data)
```
+
+### render_graphql_ide
+
+In case you need more control over the rendering of the GraphQL IDE than the
+`graphql_ide` option provides, you can override the `render_graphql_ide` method.
+
+```python
+from strawberry.fastapi import GraphQLRouter
+from starlette.responses import HTMLResponse, Response
+
+
+class MyGraphQLRouter(GraphQLRouter):
+ async def render_graphql_ide(self, request: Request) -> HTMLResponse:
+ custom_html = """Custom GraphQL IDE
"""
+
+ return HTMLResponse(custom_html)
+```
diff --git a/docs/integrations/flask.md b/docs/integrations/flask.md
index 9fe4e17ac3..a8dde13b8a 100644
--- a/docs/integrations/flask.md
+++ b/docs/integrations/flask.md
@@ -34,22 +34,28 @@ from strawberry.flask.views import AsyncGraphQLView
## Options
-The `GraphQLView` accepts two options at the moment:
+The `GraphQLView` accepts the following options at the moment:
- `schema`: mandatory, the schema created by `strawberry.Schema`.
-- `graphiql:` optional, defaults to `True`, whether to enable the GraphiQL
- interface.
+- `graphql_ide`: optional, defaults to `"graphiql"`, allows to choose the
+ GraphQL IDE interface (one of `graphiql`, `apollo-sandbox` or `pathfinder`) or
+ to disable it by passing `None`.
- `allow_queries_via_get`: optional, defaults to `True`, whether to enable
queries via `GET` requests
+- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether
+ to enable multipart uploads. Please make sure to consider the
+ [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security)
+ when enabling this feature.
## Extending the view
We allow to extend the base `GraphQLView`, by overriding the following methods:
-- `get_context(self, request: Request, response: Response) -> Any`
-- `get_root_value(self, request: Request) -> Any`
-- `process_result(self, result: ExecutionResult) -> GraphQLHTTPResponse`
-- `encode_json(self, response_data: GraphQLHTTPResponse) -> str`
+- `def get_context(self, request: Request, response: Response) -> Any`
+- `def get_root_value(self, request: Request) -> Any`
+- `def process_result(self, result: ExecutionResult) -> GraphQLHTTPResponse`
+- `def encode_json(self, response_data: GraphQLHTTPResponse) -> str`
+- `def render_graphql_ide(self, request: Request) -> Response`
@@ -59,7 +65,7 @@ async functions.
-## get_context
+### get_context
`get_context` allows to provide a custom context object that can be used in your
resolver. You can return anything here, by default we return a dictionary with
@@ -85,7 +91,7 @@ called "example".
Then we use the context in a resolver, the resolver will return "1" in this
case.
-## get_root_value
+### get_root_value
`get_root_value` allows to provide a custom root value for your schema, this is
probably not used a lot but it might be useful in certain situations.
@@ -106,7 +112,7 @@ class Query:
Here we are returning a Query where the name is "Patrick", so we when requesting
the field name we'll return "Patrick" in this case.
-## process_result
+### process_result
`process_result` allows to customize and/or process results before they are sent
to the clients. This can be useful logging errors or hiding them (for example to
@@ -133,7 +139,7 @@ class MyGraphQLView(GraphQLView):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.
-## encode_json
+### encode_json
`encode_json` allows to customize the encoding of the JSON response. By default
we use `json.dumps` but you can override this method to use a different encoder.
@@ -143,3 +149,20 @@ class MyGraphQLView(GraphQLView):
def encode_json(self, data: GraphQLHTTPResponse) -> str:
return json.dumps(data, indent=2)
```
+
+### render_graphql_ide
+
+In case you need more control over the rendering of the GraphQL IDE than the
+`graphql_ide` option provides, you can override the `render_graphql_ide` method.
+
+```python
+from strawberry.flask.views import GraphQLView
+from flask import Request, Response
+
+
+class MyGraphQLView(GraphQLView):
+ def render_graphql_ide(self, request: Request) -> Response:
+ custom_html = """Custom GraphQL IDE
"""
+
+ return Response(custom_html, status=200, content_type="text/html")
+```
diff --git a/docs/integrations/litestar.md b/docs/integrations/litestar.md
index 0ba7c1ded8..86ee0bc179 100644
--- a/docs/integrations/litestar.md
+++ b/docs/integrations/litestar.md
@@ -61,8 +61,12 @@ The `make_graphql_controller` function accepts the following options:
the maximum time to wait for the connection initialization message when using
`graphql-transport-ws`
[protocol](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#connectioninit)
+- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether
+ to enable multipart uploads. Please make sure to consider the
+ [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security)
+ when enabling this feature.
-## context_getter
+### context_getter
The `context_getter` option allows you to provide a Litestar dependency that
return a custom context object that can be used in your resolver.
@@ -184,7 +188,7 @@ GraphQLController = make_graphql_controller(
app = Litestar(route_handlers=[GraphQLController])
```
-### Context typing
+#### Context typing
In our previous example using class based context, the actual runtime context a
`CustomContext` type. Because it inherits from `BaseContext`, the `request`,
@@ -279,7 +283,7 @@ GraphQLController = make_graphql_controller(
app = Litestar(route_handlers=[GraphQLController])
```
-## root_value_getter
+### root_value_getter
The `root_value_getter` option allows you to provide a custom root value that
can be used in your resolver
@@ -313,3 +317,43 @@ GraphQLController = make_graphql_controller(
app = Litestar(route_handlers=[GraphQLController])
```
+
+## Extending the controller
+
+The `make_graphql_controller` function returns a `GraphQLController` class that
+can be extended by overriding the following methods:
+
+1. `async def render_graphql_ide(self, request: Request) -> Response`
+
+### render_graphql_ide
+
+In case you need more control over the rendering of the GraphQL IDE than the
+`graphql_ide` option provides, you can override the `render_graphql_ide` method.
+
+```python
+import strawberry
+from strawberry.litestar import make_graphql_controller
+from litestar import MediaType, Request, Response
+
+
+@strawberry.type
+class Query:
+ @strawberry.field
+ def hello(self) -> str:
+ return "world"
+
+
+schema = strawberry.Schema(Query)
+
+GraphQLController = make_graphql_controller(
+ schema,
+ path="/graphql",
+)
+
+
+class MyGraphQLController(GraphQLController):
+ async def render_graphql_ide(self, request: Request) -> Response:
+ custom_html = """Custom GraphQL IDE
"""
+
+ return Response(custom_html, media_type=MediaType.HTML)
+```
diff --git a/docs/integrations/quart.md b/docs/integrations/quart.md
index 05eb0cd994..d863777b04 100644
--- a/docs/integrations/quart.md
+++ b/docs/integrations/quart.md
@@ -26,24 +26,30 @@ if __name__ == "__main__":
## Options
-The `GraphQLView` accepts two options at the moment:
+The `GraphQLView` accepts the following options at the moment:
- `schema`: mandatory, the schema created by `strawberry.Schema`.
-- `graphiql:` optional, defaults to `True`, whether to enable the GraphiQL
- interface.
+- `graphql_ide`: optional, defaults to `"graphiql"`, allows to choose the
+ GraphQL IDE interface (one of `graphiql`, `apollo-sandbox` or `pathfinder`) or
+ to disable it by passing `None`.
- `allow_queries_via_get`: optional, defaults to `True`, whether to enable
queries via `GET` requests
+- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether
+ to enable multipart uploads. Please make sure to consider the
+ [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security)
+ when enabling this feature.
## Extending the view
We allow to extend the base `GraphQLView`, by overriding the following methods:
-- `get_context(self, request: Request, response: Response) -> Any`
-- `get_root_value(self, request: Request) -> Any`
-- `process_result(self, result: ExecutionResult) -> GraphQLHTTPResponse`
-- `encode_json(self, response_data: GraphQLHTTPResponse) -> str`
+- `async def get_context(self, request: Request, response: Response) -> Any`
+- `async def get_root_value(self, request: Request) -> Any`
+- `async def process_result(self, result: ExecutionResult) -> GraphQLHTTPResponse`
+- `def encode_json(self, response_data: GraphQLHTTPResponse) -> str`
+- `async def render_graphql_ide(self, request: Request) -> Response`
-## get_context
+### get_context
`get_context` allows to provide a custom context object that can be used in your
resolver. You can return anything here, by default we return a dictionary with
@@ -69,7 +75,7 @@ called "example".
Then we use the context in a resolver, the resolver will return "1" in this
case.
-## get_root_value
+### get_root_value
`get_root_value` allows to provide a custom root value for your schema, this is
probably not used a lot but it might be useful in certain situations.
@@ -90,7 +96,7 @@ class Query:
Here we are returning a Query where the name is "Patrick", so we when requesting
the field name we'll return "Patrick" in this case.
-## process_result
+### process_result
`process_result` allows to customize and/or process results before they are sent
to the clients. This can be useful logging errors or hiding them (for example to
@@ -117,7 +123,7 @@ class MyGraphQLView(GraphQLView):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.
-## encode_json
+### encode_json
`encode_json` allows to customize the encoding of the JSON response. By default
we use `json.dumps` but you can override this method to use a different encoder.
@@ -127,3 +133,20 @@ class MyGraphQLView(GraphQLView):
def encode_json(self, data: GraphQLHTTPResponse) -> str:
return json.dumps(data, indent=2)
```
+
+### render_graphql_ide
+
+In case you need more control over the rendering of the GraphQL IDE than the
+`graphql_ide` option provides, you can override the `render_graphql_ide` method.
+
+```python
+from strawberry.quart.views import GraphQLView
+from quart import Request, Response
+
+
+class MyGraphQLView(GraphQLView):
+ async def render_graphql_ide(self, request: Request) -> Response:
+ custom_html = """Custom GraphQL IDE
"""
+
+ return Response(self.graphql_ide_html)
+```
diff --git a/docs/integrations/sanic.md b/docs/integrations/sanic.md
index 1701a74a83..2cae238b30 100644
--- a/docs/integrations/sanic.md
+++ b/docs/integrations/sanic.md
@@ -22,7 +22,7 @@ app.add_route(
## Options
-The `GraphQLView` accepts two options at the moment:
+The `GraphQLView` accepts the following options at the moment:
- `schema`: mandatory, the schema created by `strawberry.Schema`.
- `graphql_ide`: optional, defaults to `"graphiql"`, allows to choose the
@@ -30,7 +30,10 @@ The `GraphQLView` accepts two options at the moment:
to disable it by passing `None`.
- `allow_queries_via_get`: optional, defaults to `True`, whether to enable
queries via `GET` requests
-- `def encode_json(self, data: GraphQLHTTPResponse) -> str`
+- `multipart_uploads_enabled`: optional, defaults to `False`, controls whether
+ to enable multipart uploads. Please make sure to consider the
+ [security implications mentioned in the GraphQL Multipart Request Specification](https://github.com/jaydenseric/graphql-multipart-request-spec/blob/master/readme.md#security)
+ when enabling this feature.
## Extending the view
@@ -40,8 +43,9 @@ methods:
- `async get_context(self, request: Request, response: Response) -> Any`
- `async get_root_value(self, request: Request) -> Any`
- `async process_result(self, result: ExecutionResult) -> GraphQLHTTPResponse`
+- `async def render_graphql_ide(self, request: Request) -> HTTPResponse`
-## get_context
+### get_context
By overriding `GraphQLView.get_context` you can provide a custom context object
for your resolvers. You can return anything here; by default GraphQLView returns
@@ -66,7 +70,7 @@ called `"example"`.
Then we can use the context in a resolver. In this case the resolver will return
`1`.
-## get_root_value
+### get_root_value
By overriding `GraphQLView.get_root_value` you can provide a custom root value
for your schema. This is probably not used a lot but it might be useful in
@@ -88,7 +92,7 @@ class Query:
Here we configure a Query where requesting the `name` field will return
`"Patrick"` through the custom root value.
-## process_result
+### process_result
By overriding `GraphQLView.process_result` you can customize and/or process
results before they are sent to a client. This can be useful for logging errors,
@@ -117,7 +121,7 @@ class MyGraphQLView(GraphQLView):
In this case we are doing the default processing of the result, but it can be
tweaked based on your needs.
-## encode_json
+### encode_json
`encode_json` allows to customize the encoding of the JSON response. By default
we use `json.dumps` but you can override this method to use a different encoder.
@@ -127,3 +131,21 @@ class MyGraphQLView(GraphQLView):
def encode_json(self, data: GraphQLHTTPResponse) -> str:
return json.dumps(data, indent=2)
```
+
+### render_graphql_ide
+
+In case you need more control over the rendering of the GraphQL IDE than the
+`graphql_ide` option provides, you can override the `render_graphql_ide` method.
+
+```python
+from strawberry.sanic.views import GraphQLView
+from sanic.request import Request
+from sanic.response import HTTPResponse, html
+
+
+class MyGraphQLView(GraphQLView):
+ async def render_graphql_ide(self, request: Request) -> HTTPResponse:
+ custom_html = """Custom GraphQL IDE
"""
+
+ return html(custom_html)
+```
diff --git a/docs/types/enums.md b/docs/types/enums.md
index cb2e99d216..743e91f80b 100644
--- a/docs/types/enums.md
+++ b/docs/types/enums.md
@@ -42,6 +42,17 @@ class IceCreamFlavour(Enum):
CHOCOLATE = "chocolate"
```
+
+
+In some cases you already have an enum defined elsewhere in your code. You can
+safely use it in your schema and strawberry will generate a default graphql
+implementation of it.
+
+The only drawback is that it is not currently possible to configure it
+(documentation / renaming or using `strawberry.enum_value` on it).
+
+
+
Let's see how we can use Enums in our schema.
```python
diff --git a/docs/types/schema-configurations.md b/docs/types/schema-configurations.md
index 5ad12c9c6c..396524938c 100644
--- a/docs/types/schema-configurations.md
+++ b/docs/types/schema-configurations.md
@@ -123,5 +123,5 @@ class CustomInfo(Info):
return self.context["response"].headers
-schema = strawberry.Schema(query=Query, info_class=CustomInfo)
+schema = strawberry.Schema(query=Query, config=StrawberryConfig(info_class=CustomInfo))
```
diff --git a/docs/types/schema.md b/docs/types/schema.md
index f82f4fe7cb..a0d14e8b51 100644
--- a/docs/types/schema.md
+++ b/docs/types/schema.md
@@ -91,7 +91,7 @@ class Query:
@strawberry.field
def get_customer(
self, id: strawberry.ID
- ): # -> Customer note we're returning the interface here
+ ) -> Customer: # note we're returning the interface here
if id == "mark":
return Individual(name="Mark", date_of_birth=date(1984, 5, 14))
diff --git a/noxfile.py b/noxfile.py
index 1a36339397..7e817ea3cb 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -80,7 +80,7 @@ def tests(session: Session, gql_core: str) -> None:
)
-@session(python=["3.11", "3.12"], name="Django tests", tags=["tests"])
+@session(python=["3.12"], name="Django tests", tags=["tests"])
@with_gql_core_parametrize("django", ["4.2.0", "4.1.0", "4.0.0", "3.2.0"])
def tests_django(session: Session, django: str, gql_core: str) -> None:
session.run_always("poetry", "install", external=True)
diff --git a/pyproject.toml b/pyproject.toml
index ef445cf7ac..29e173ab7d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
[tool.poetry]
name = "strawberry-graphql"
packages = [ { include = "strawberry" } ]
-version = "0.242.0"
+version = "0.246.2"
description = "A library for creating GraphQL APIs"
authors = ["Patrick Arminio "]
license = "MIT"
diff --git a/strawberry/aiohttp/handlers/__init__.py b/strawberry/aiohttp/handlers/__init__.py
deleted file mode 100644
index c769c4eec7..0000000000
--- a/strawberry/aiohttp/handlers/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from strawberry.aiohttp.handlers.graphql_transport_ws_handler import (
- GraphQLTransportWSHandler,
-)
-from strawberry.aiohttp.handlers.graphql_ws_handler import GraphQLWSHandler
-
-__all__ = ["GraphQLTransportWSHandler", "GraphQLWSHandler"]
diff --git a/strawberry/aiohttp/handlers/graphql_transport_ws_handler.py b/strawberry/aiohttp/handlers/graphql_transport_ws_handler.py
deleted file mode 100644
index 52350199f7..0000000000
--- a/strawberry/aiohttp/handlers/graphql_transport_ws_handler.py
+++ /dev/null
@@ -1,62 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Any, Callable, Dict
-
-from aiohttp import http, web
-from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
-from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import (
- BaseGraphQLTransportWSHandler,
-)
-
-if TYPE_CHECKING:
- from datetime import timedelta
-
- from strawberry.schema import BaseSchema
-
-
-class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
- def __init__(
- self,
- schema: BaseSchema,
- debug: bool,
- connection_init_wait_timeout: timedelta,
- get_context: Callable[..., Dict[str, Any]],
- get_root_value: Any,
- request: web.Request,
- ) -> None:
- super().__init__(schema, debug, connection_init_wait_timeout)
- self._get_context = get_context
- self._get_root_value = get_root_value
- self._request = request
- self._ws = web.WebSocketResponse(protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL])
-
- async def get_context(self) -> Any:
- return await self._get_context(request=self._request, response=self._ws) # type: ignore
-
- async def get_root_value(self) -> Any:
- return await self._get_root_value(request=self._request)
-
- async def send_json(self, data: dict) -> None:
- await self._ws.send_json(data)
-
- async def close(self, code: int, reason: str) -> None:
- await self._ws.close(code=code, message=reason.encode())
-
- async def handle_request(self) -> web.StreamResponse:
- await self._ws.prepare(self._request)
- self.on_request_accepted()
-
- try:
- async for ws_message in self._ws: # type: http.WSMessage
- if ws_message.type == http.WSMsgType.TEXT:
- await self.handle_message(ws_message.json())
- else:
- error_message = "WebSocket message type must be text"
- await self.handle_invalid_message(error_message)
- finally:
- await self.shutdown()
-
- return self._ws
-
-
-__all__ = ["GraphQLTransportWSHandler"]
diff --git a/strawberry/aiohttp/handlers/graphql_ws_handler.py b/strawberry/aiohttp/handlers/graphql_ws_handler.py
deleted file mode 100644
index 677dd34884..0000000000
--- a/strawberry/aiohttp/handlers/graphql_ws_handler.py
+++ /dev/null
@@ -1,69 +0,0 @@
-from __future__ import annotations
-
-from contextlib import suppress
-from typing import TYPE_CHECKING, Any, Callable, Optional
-
-from aiohttp import http, web
-from strawberry.subscriptions import GRAPHQL_WS_PROTOCOL
-from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler
-
-if TYPE_CHECKING:
- from strawberry.schema import BaseSchema
- from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage
-
-
-class GraphQLWSHandler(BaseGraphQLWSHandler):
- def __init__(
- self,
- schema: BaseSchema,
- debug: bool,
- keep_alive: bool,
- keep_alive_interval: float,
- get_context: Callable,
- get_root_value: Callable,
- request: web.Request,
- ) -> None:
- super().__init__(schema, debug, keep_alive, keep_alive_interval)
- self._get_context = get_context
- self._get_root_value = get_root_value
- self._request = request
- self._ws = web.WebSocketResponse(protocols=[GRAPHQL_WS_PROTOCOL])
-
- async def get_context(self) -> Any:
- return await self._get_context(request=self._request, response=self._ws)
-
- async def get_root_value(self) -> Any:
- return await self._get_root_value(request=self._request)
-
- async def send_json(self, data: OperationMessage) -> None:
- await self._ws.send_json(data)
-
- async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
- message = reason.encode() if reason else b""
- await self._ws.close(code=code, message=message)
-
- async def handle_request(self) -> Any:
- await self._ws.prepare(self._request)
-
- try:
- async for ws_message in self._ws: # type: http.WSMessage
- if ws_message.type == http.WSMsgType.TEXT:
- message: OperationMessage = ws_message.json()
- await self.handle_message(message)
- else:
- await self.close(
- code=1002, reason="WebSocket message type must be text"
- )
- finally:
- if self.keep_alive_task:
- self.keep_alive_task.cancel()
- with suppress(BaseException):
- await self.keep_alive_task
-
- for operation_id in list(self.subscriptions.keys()):
- await self.cleanup_operation(operation_id)
-
- return self._ws
-
-
-__all__ = ["GraphQLWSHandler"]
diff --git a/strawberry/aiohttp/test/client.py b/strawberry/aiohttp/test/client.py
index 6f30b7e7aa..0d25f4043a 100644
--- a/strawberry/aiohttp/test/client.py
+++ b/strawberry/aiohttp/test/client.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import warnings
from typing import (
Any,
Dict,
@@ -16,8 +17,9 @@ async def query(
query: str,
variables: Optional[Dict[str, Mapping]] = None,
headers: Optional[Dict[str, object]] = None,
- asserts_errors: Optional[bool] = True,
+ asserts_errors: Optional[bool] = None,
files: Optional[Dict[str, object]] = None,
+ assert_no_errors: Optional[bool] = True,
) -> Response:
body = self._build_body(query, variables, files)
@@ -29,7 +31,19 @@ async def query(
data=data.get("data"),
extensions=data.get("extensions"),
)
- if asserts_errors:
+
+ if asserts_errors is not None:
+ warnings.warn(
+ "The `asserts_errors` argument has been renamed to `assert_no_errors`",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ assert_no_errors = (
+ assert_no_errors if asserts_errors is None else asserts_errors
+ )
+
+ if assert_no_errors:
assert resp.status == 200
assert response.errors is None
diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py
index 0a8143657f..884264dcb0 100644
--- a/strawberry/aiohttp/views.py
+++ b/strawberry/aiohttp/views.py
@@ -4,6 +4,7 @@
import warnings
from datetime import timedelta
from io import BytesIO
+from json.decoder import JSONDecodeError
from typing import (
TYPE_CHECKING,
Any,
@@ -16,15 +17,16 @@
Union,
cast,
)
+from typing_extensions import TypeGuard
-from aiohttp import web
+from aiohttp import http, web
from aiohttp.multipart import BodyPartReader
-from strawberry.aiohttp.handlers import (
- GraphQLTransportWSHandler,
- GraphQLWSHandler,
+from strawberry.http.async_base_view import (
+ AsyncBaseHTTPView,
+ AsyncHTTPRequestAdapter,
+ AsyncWebSocketAdapter,
)
-from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter
-from strawberry.http.exceptions import HTTPException
+from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import (
Context,
@@ -79,11 +81,36 @@ def content_type(self) -> Optional[str]:
return self.headers.get("content-type")
+class AioHTTPWebSocketAdapter(AsyncWebSocketAdapter):
+ def __init__(self, request: web.Request, ws: web.WebSocketResponse) -> None:
+ self.request = request
+ self.ws = ws
+
+ async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
+ async for ws_message in self.ws:
+ if ws_message.type == http.WSMsgType.TEXT:
+ try:
+ yield ws_message.json()
+ except JSONDecodeError:
+ raise NonJsonMessageReceived()
+
+ elif ws_message.type == http.WSMsgType.BINARY:
+ raise NonJsonMessageReceived()
+
+ async def send_json(self, message: Mapping[str, object]) -> None:
+ await self.ws.send_json(message)
+
+ async def close(self, code: int, reason: str) -> None:
+ await self.ws.close(code=code, message=reason.encode())
+
+
class GraphQLView(
AsyncBaseHTTPView[
web.Request,
Union[web.Response, web.StreamResponse],
web.Response,
+ web.Request,
+ web.WebSocketResponse,
Context,
RootValue,
]
@@ -92,10 +119,9 @@ class GraphQLView(
# bare handler function.
_is_coroutine = asyncio.coroutines._is_coroutine # type: ignore[attr-defined]
- graphql_transport_ws_handler_class = GraphQLTransportWSHandler
- graphql_ws_handler_class = GraphQLWSHandler
allow_queries_via_get = True
request_adapter_class = AioHTTPRequestAdapter
+ websocket_adapter_class = AioHTTPWebSocketAdapter
def __init__(
self,
@@ -111,6 +137,7 @@ def __init__(
GRAPHQL_WS_PROTOCOL,
),
connection_init_wait_timeout: timedelta = timedelta(minutes=1),
+ multipart_uploads_enabled: bool = False,
) -> None:
self.schema = schema
self.allow_queries_via_get = allow_queries_via_get
@@ -119,6 +146,7 @@ def __init__(
self.debug = debug
self.subscription_protocols = subscription_protocols
self.connection_init_wait_timeout = connection_init_wait_timeout
+ self.multipart_uploads_enabled = multipart_uploads_enabled
if graphiql is not None:
warnings.warn(
@@ -136,48 +164,36 @@ async def render_graphql_ide(self, request: web.Request) -> web.Response:
async def get_sub_response(self, request: web.Request) -> web.Response:
return web.Response()
- async def __call__(self, request: web.Request) -> web.StreamResponse:
+ def is_websocket_request(self, request: web.Request) -> TypeGuard[web.Request]:
ws = web.WebSocketResponse(protocols=self.subscription_protocols)
- ws_test = ws.can_prepare(request)
-
- if not ws_test.ok:
- try:
- return await self.run(request=request)
- except HTTPException as e:
- return web.Response(
- body=e.reason,
- status=e.status_code,
- )
-
- if ws_test.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
- return await self.graphql_transport_ws_handler_class(
- schema=self.schema,
- debug=self.debug,
- connection_init_wait_timeout=self.connection_init_wait_timeout,
- get_context=self.get_context, # type: ignore
- get_root_value=self.get_root_value,
- request=request,
- ).handle()
- elif ws_test.protocol == GRAPHQL_WS_PROTOCOL:
- return await self.graphql_ws_handler_class(
- schema=self.schema,
- debug=self.debug,
- keep_alive=self.keep_alive,
- keep_alive_interval=self.keep_alive_interval,
- get_context=self.get_context,
- get_root_value=self.get_root_value,
- request=request,
- ).handle()
- else:
- await ws.prepare(request)
- await ws.close(code=4406, message=b"Subprotocol not acceptable")
- return ws
+ return ws.can_prepare(request).ok
+
+ async def pick_websocket_subprotocol(self, request: web.Request) -> Optional[str]:
+ ws = web.WebSocketResponse(protocols=self.subscription_protocols)
+ return ws.can_prepare(request).protocol
+
+ async def create_websocket_response(
+ self, request: web.Request, subprotocol: Optional[str]
+ ) -> web.WebSocketResponse:
+ protocols = [subprotocol] if subprotocol else []
+ ws = web.WebSocketResponse(protocols=protocols)
+ await ws.prepare(request)
+ return ws
+
+ async def __call__(self, request: web.Request) -> web.StreamResponse:
+ try:
+ return await self.run(request=request)
+ except HTTPException as e:
+ return web.Response(
+ body=e.reason,
+ status=e.status_code,
+ )
async def get_root_value(self, request: web.Request) -> Optional[RootValue]:
return None
async def get_context(
- self, request: web.Request, response: web.Response
+ self, request: web.Request, response: Union[web.Response, web.WebSocketResponse]
) -> Context:
return {"request": request, "response": response} # type: ignore
diff --git a/strawberry/annotation.py b/strawberry/annotation.py
index 86d06d73e5..dff708a1a1 100644
--- a/strawberry/annotation.py
+++ b/strawberry/annotation.py
@@ -20,7 +20,6 @@
)
from typing_extensions import Annotated, Self, get_args, get_origin
-from strawberry.exceptions.not_a_strawberry_enum import NotAStrawberryEnumError
from strawberry.types.base import (
StrawberryList,
StrawberryObjectDefinition,
@@ -30,6 +29,7 @@
has_object_definition,
)
from strawberry.types.enum import EnumDefinition
+from strawberry.types.enum import enum as strawberry_enum
from strawberry.types.lazy_type import LazyType
from strawberry.types.private import is_private
from strawberry.types.scalar import ScalarDefinition
@@ -187,7 +187,7 @@ def create_enum(self, evaled_type: Any) -> EnumDefinition:
try:
return evaled_type._enum_definition
except AttributeError:
- raise NotAStrawberryEnumError(evaled_type)
+ return strawberry_enum(evaled_type)._enum_definition
def create_list(self, evaled_type: Any) -> StrawberryList:
item_type, *_ = get_args(evaled_type)
diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py
index d5aae404f6..5a3f01203d 100644
--- a/strawberry/asgi/__init__.py
+++ b/strawberry/asgi/__init__.py
@@ -2,9 +2,11 @@
import warnings
from datetime import timedelta
+from json import JSONDecodeError
from typing import (
TYPE_CHECKING,
Any,
+ AsyncGenerator,
AsyncIterator,
Callable,
Dict,
@@ -14,6 +16,7 @@
Union,
cast,
)
+from typing_extensions import TypeGuard
from starlette import status
from starlette.requests import Request
@@ -23,14 +26,14 @@
Response,
StreamingResponse,
)
-from starlette.websockets import WebSocket
+from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
-from strawberry.asgi.handlers import (
- GraphQLTransportWSHandler,
- GraphQLWSHandler,
+from strawberry.http.async_base_view import (
+ AsyncBaseHTTPView,
+ AsyncHTTPRequestAdapter,
+ AsyncWebSocketAdapter,
)
-from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter
-from strawberry.http.exceptions import HTTPException
+from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import (
Context,
@@ -78,19 +81,41 @@ async def get_form_data(self) -> FormData:
)
+class ASGIWebSocketAdapter(AsyncWebSocketAdapter):
+ def __init__(self, request: WebSocket, response: WebSocket) -> None:
+ self.ws = response
+
+ async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
+ try:
+ try:
+ while self.ws.application_state != WebSocketState.DISCONNECTED:
+ yield await self.ws.receive_json()
+ except (KeyError, JSONDecodeError):
+ raise NonJsonMessageReceived()
+ except WebSocketDisconnect: # pragma: no cover
+ pass
+
+ async def send_json(self, message: Mapping[str, object]) -> None:
+ await self.ws.send_json(message)
+
+ async def close(self, code: int, reason: str) -> None:
+ await self.ws.close(code=code, reason=reason)
+
+
class GraphQL(
AsyncBaseHTTPView[
- Union[Request, WebSocket],
+ Request,
Response,
Response,
+ WebSocket,
+ WebSocket,
Context,
RootValue,
]
):
- graphql_transport_ws_handler_class = GraphQLTransportWSHandler
- graphql_ws_handler_class = GraphQLWSHandler
allow_queries_via_get = True
- request_adapter_class = ASGIRequestAdapter # pyright: ignore
+ request_adapter_class = ASGIRequestAdapter
+ websocket_adapter_class = ASGIWebSocketAdapter
def __init__(
self,
@@ -106,6 +131,7 @@ def __init__(
GRAPHQL_WS_PROTOCOL,
),
connection_init_wait_timeout: timedelta = timedelta(minutes=1),
+ multipart_uploads_enabled: bool = False,
) -> None:
self.schema = schema
self.allow_queries_via_get = allow_queries_via_get
@@ -114,6 +140,7 @@ def __init__(
self.debug = debug
self.protocols = subscription_protocols
self.connection_init_wait_timeout = connection_init_wait_timeout
+ self.multipart_uploads_enabled = multipart_uploads_enabled
if graphiql is not None:
warnings.warn(
@@ -127,51 +154,25 @@ def __init__(
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
- return await self.handle_http(scope, receive, send)
+ http_request = Request(scope=scope, receive=receive)
- elif scope["type"] == "websocket":
- ws = WebSocket(scope, receive=receive, send=send)
- preferred_protocol = self.pick_preferred_protocol(ws)
-
- if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
- await self.graphql_transport_ws_handler_class(
- schema=self.schema,
- debug=self.debug,
- connection_init_wait_timeout=self.connection_init_wait_timeout,
- get_context=self.get_context,
- get_root_value=self.get_root_value,
- ws=ws,
- ).handle()
-
- elif preferred_protocol == GRAPHQL_WS_PROTOCOL:
- await self.graphql_ws_handler_class(
- schema=self.schema,
- debug=self.debug,
- keep_alive=self.keep_alive,
- keep_alive_interval=self.keep_alive_interval,
- get_context=self.get_context,
- get_root_value=self.get_root_value,
- ws=ws,
- ).handle()
-
- else:
- # Subprotocol not acceptable
- await ws.close(code=4406)
+ try:
+ response = await self.run(http_request)
+ except HTTPException as e:
+ response = PlainTextResponse(e.reason, status_code=e.status_code)
+ await response(scope, receive, send)
+ elif scope["type"] == "websocket":
+ ws_request = WebSocket(scope, receive=receive, send=send)
+ await self.run(ws_request)
else: # pragma: no cover
raise ValueError("Unknown scope type: {!r}".format(scope["type"]))
- def pick_preferred_protocol(self, ws: WebSocket) -> Optional[str]:
- protocols = ws["subprotocols"]
- intersection = set(protocols) & set(self.protocols)
- sorted_intersection = sorted(intersection, key=protocols.index)
- return next(iter(sorted_intersection), None)
-
async def get_root_value(self, request: Union[Request, WebSocket]) -> Optional[Any]:
return None
async def get_context(
- self, request: Union[Request, WebSocket], response: Response
+ self, request: Union[Request, WebSocket], response: Union[Response, WebSocket]
) -> Context:
return {"request": request, "response": response} # type: ignore
@@ -185,22 +186,7 @@ async def get_sub_response(
return sub_response
- async def handle_http(
- self,
- scope: Scope,
- receive: Receive,
- send: Send,
- ) -> None:
- request = Request(scope=scope, receive=receive)
-
- try:
- response = await self.run(request)
- except HTTPException as e:
- response = PlainTextResponse(e.reason, status_code=e.status_code) # pyright: ignore
-
- await response(scope, receive, send)
-
- async def render_graphql_ide(self, request: Union[Request, WebSocket]) -> Response:
+ async def render_graphql_ide(self, request: Request) -> Response:
return HTMLResponse(self.graphql_ide_html)
def create_response(
@@ -237,3 +223,20 @@ async def create_streaming_response(
**headers,
},
)
+
+ def is_websocket_request(
+ self, request: Union[Request, WebSocket]
+ ) -> TypeGuard[WebSocket]:
+ return request.scope["type"] == "websocket"
+
+ async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]:
+ protocols = request["subprotocols"]
+ intersection = set(protocols) & set(self.protocols)
+ sorted_intersection = sorted(intersection, key=protocols.index)
+ return next(iter(sorted_intersection), None)
+
+ async def create_websocket_response(
+ self, request: WebSocket, subprotocol: Optional[str]
+ ) -> WebSocket:
+ await request.accept(subprotocol=subprotocol)
+ return request
diff --git a/strawberry/asgi/handlers/__init__.py b/strawberry/asgi/handlers/__init__.py
deleted file mode 100644
index 1891a06ad0..0000000000
--- a/strawberry/asgi/handlers/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from strawberry.asgi.handlers.graphql_transport_ws_handler import (
- GraphQLTransportWSHandler,
-)
-from strawberry.asgi.handlers.graphql_ws_handler import GraphQLWSHandler
-
-__all__ = ["GraphQLTransportWSHandler", "GraphQLWSHandler"]
diff --git a/strawberry/asgi/handlers/graphql_transport_ws_handler.py b/strawberry/asgi/handlers/graphql_transport_ws_handler.py
deleted file mode 100644
index 7cec132ffd..0000000000
--- a/strawberry/asgi/handlers/graphql_transport_ws_handler.py
+++ /dev/null
@@ -1,66 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Any, Callable
-
-from starlette.websockets import WebSocketDisconnect, WebSocketState
-
-from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
-from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import (
- BaseGraphQLTransportWSHandler,
-)
-
-if TYPE_CHECKING:
- from datetime import timedelta
-
- from starlette.websockets import WebSocket
-
- from strawberry.schema import BaseSchema
-
-
-class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
- def __init__(
- self,
- schema: BaseSchema,
- debug: bool,
- connection_init_wait_timeout: timedelta,
- get_context: Callable,
- get_root_value: Callable,
- ws: WebSocket,
- ) -> None:
- super().__init__(schema, debug, connection_init_wait_timeout)
- self._get_context = get_context
- self._get_root_value = get_root_value
- self._ws = ws
-
- async def get_context(self) -> Any:
- return await self._get_context(request=self._ws, response=None)
-
- async def get_root_value(self) -> Any:
- return await self._get_root_value(request=self._ws)
-
- async def send_json(self, data: dict) -> None:
- await self._ws.send_json(data)
-
- async def close(self, code: int, reason: str) -> None:
- await self._ws.close(code=code, reason=reason)
-
- async def handle_request(self) -> None:
- await self._ws.accept(subprotocol=GRAPHQL_TRANSPORT_WS_PROTOCOL)
- self.on_request_accepted()
-
- try:
- while self._ws.application_state != WebSocketState.DISCONNECTED:
- try:
- message = await self._ws.receive_json()
- except KeyError: # noqa: PERF203
- error_message = "WebSocket message type must be text"
- await self.handle_invalid_message(error_message)
- else:
- await self.handle_message(message)
- except WebSocketDisconnect: # pragma: no cover
- pass
- finally:
- await self.shutdown()
-
-
-__all__ = ["GraphQLTransportWSHandler"]
diff --git a/strawberry/asgi/handlers/graphql_ws_handler.py b/strawberry/asgi/handlers/graphql_ws_handler.py
deleted file mode 100644
index 00a314bbd0..0000000000
--- a/strawberry/asgi/handlers/graphql_ws_handler.py
+++ /dev/null
@@ -1,71 +0,0 @@
-from __future__ import annotations
-
-from contextlib import suppress
-from typing import TYPE_CHECKING, Any, Callable, Optional
-
-from starlette.websockets import WebSocketDisconnect, WebSocketState
-
-from strawberry.subscriptions import GRAPHQL_WS_PROTOCOL
-from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler
-
-if TYPE_CHECKING:
- from starlette.websockets import WebSocket
-
- from strawberry.schema import BaseSchema
- from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage
-
-
-class GraphQLWSHandler(BaseGraphQLWSHandler):
- def __init__(
- self,
- schema: BaseSchema,
- debug: bool,
- keep_alive: bool,
- keep_alive_interval: float,
- get_context: Callable,
- get_root_value: Callable,
- ws: WebSocket,
- ) -> None:
- super().__init__(schema, debug, keep_alive, keep_alive_interval)
- self._get_context = get_context
- self._get_root_value = get_root_value
- self._ws = ws
-
- async def get_context(self) -> Any:
- return await self._get_context(request=self._ws, response=None)
-
- async def get_root_value(self) -> Any:
- return await self._get_root_value(request=self._ws)
-
- async def send_json(self, data: OperationMessage) -> None:
- await self._ws.send_json(data)
-
- async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
- await self._ws.close(code=code, reason=reason)
-
- async def handle_request(self) -> Any:
- await self._ws.accept(subprotocol=GRAPHQL_WS_PROTOCOL)
-
- try:
- while self._ws.application_state != WebSocketState.DISCONNECTED:
- try:
- message = await self._ws.receive_json()
- except KeyError: # noqa: PERF203
- await self.close(
- code=1002, reason="WebSocket message type must be text"
- )
- else:
- await self.handle_message(message)
- except WebSocketDisconnect: # pragma: no cover
- pass
- finally:
- if self.keep_alive_task:
- self.keep_alive_task.cancel()
- with suppress(BaseException):
- await self.keep_alive_task
-
- for operation_id in list(self.subscriptions.keys()):
- await self.cleanup_operation(operation_id)
-
-
-__all__ = ["GraphQLWSHandler"]
diff --git a/strawberry/chalice/views.py b/strawberry/chalice/views.py
index 3da2838eaa..9c131f202d 100644
--- a/strawberry/chalice/views.py
+++ b/strawberry/chalice/views.py
@@ -54,7 +54,6 @@ class GraphQLView(
):
allow_queries_via_get: bool = True
request_adapter_class = ChaliceHTTPRequestAdapter
- _ide_subscription_enabled = False
def __init__(
self,
diff --git a/strawberry/channels/__init__.py b/strawberry/channels/__init__.py
index 455513babb..f67fb25a82 100644
--- a/strawberry/channels/__init__.py
+++ b/strawberry/channels/__init__.py
@@ -1,6 +1,4 @@
-from .handlers.base import ChannelsConsumer, ChannelsWSConsumer
-from .handlers.graphql_transport_ws_handler import GraphQLTransportWSHandler
-from .handlers.graphql_ws_handler import GraphQLWSHandler
+from .handlers.base import ChannelsConsumer
from .handlers.http_handler import (
ChannelsRequest,
GraphQLHTTPConsumer,
@@ -12,10 +10,7 @@
__all__ = [
"ChannelsConsumer",
"ChannelsRequest",
- "ChannelsWSConsumer",
"GraphQLProtocolTypeRouter",
- "GraphQLWSHandler",
- "GraphQLTransportWSHandler",
"GraphQLHTTPConsumer",
"GraphQLWSConsumer",
"SyncGraphQLHTTPConsumer",
diff --git a/strawberry/channels/handlers/base.py b/strawberry/channels/handlers/base.py
index ec2ffe6b2c..769ec569e5 100644
--- a/strawberry/channels/handlers/base.py
+++ b/strawberry/channels/handlers/base.py
@@ -16,7 +16,7 @@
from weakref import WeakSet
from channels.consumer import AsyncConsumer
-from channels.generic.websocket import AsyncJsonWebsocketConsumer
+from channels.generic.websocket import AsyncWebsocketConsumer
class ChannelsMessage(TypedDict, total=False):
@@ -210,7 +210,7 @@ async def _listen_to_channel_generator(
return
-class ChannelsWSConsumer(ChannelsConsumer, AsyncJsonWebsocketConsumer):
+class ChannelsWSConsumer(ChannelsConsumer, AsyncWebsocketConsumer):
"""Base channels websocket async consumer."""
diff --git a/strawberry/channels/handlers/graphql_transport_ws_handler.py b/strawberry/channels/handlers/graphql_transport_ws_handler.py
deleted file mode 100644
index db290f4ef8..0000000000
--- a/strawberry/channels/handlers/graphql_transport_ws_handler.py
+++ /dev/null
@@ -1,62 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Any, Callable, Optional
-
-from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
-from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import (
- BaseGraphQLTransportWSHandler,
-)
-
-if TYPE_CHECKING:
- from datetime import timedelta
-
- from strawberry.channels.handlers.base import ChannelsWSConsumer
- from strawberry.schema import BaseSchema
-
-
-class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
- def __init__(
- self,
- schema: BaseSchema,
- debug: bool,
- connection_init_wait_timeout: timedelta,
- get_context: Callable,
- get_root_value: Callable,
- ws: ChannelsWSConsumer,
- ) -> None:
- super().__init__(schema, debug, connection_init_wait_timeout)
- self._get_context = get_context
- self._get_root_value = get_root_value
- self._ws = ws
-
- async def get_context(self) -> Any:
- return await self._get_context(
- request=self._ws, connection_params=self.connection_params
- )
-
- async def get_root_value(self) -> Any:
- return await self._get_root_value(request=self._ws)
-
- async def send_json(self, data: dict) -> None:
- await self._ws.send_json(data)
-
- async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
- # TODO: We are using `self._ws.base_send` directly instead of `self._ws.close`
- # because the later doesn't accept the `reason` argument.
- await self._ws.base_send(
- {
- "type": "websocket.close",
- "code": code,
- "reason": reason or "",
- }
- )
-
- async def handle_request(self) -> Any:
- await self._ws.accept(subprotocol=GRAPHQL_TRANSPORT_WS_PROTOCOL)
- self.on_request_accepted()
-
- async def handle_disconnect(self, code: int) -> None:
- await self.shutdown()
-
-
-__all__ = ["GraphQLTransportWSHandler"]
diff --git a/strawberry/channels/handlers/graphql_ws_handler.py b/strawberry/channels/handlers/graphql_ws_handler.py
deleted file mode 100644
index 6d967a1d15..0000000000
--- a/strawberry/channels/handlers/graphql_ws_handler.py
+++ /dev/null
@@ -1,72 +0,0 @@
-from __future__ import annotations
-
-from contextlib import suppress
-from typing import TYPE_CHECKING, Any, Callable, Optional
-
-from strawberry.subscriptions import GRAPHQL_WS_PROTOCOL
-from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler
-
-if TYPE_CHECKING:
- from strawberry.channels.handlers.base import ChannelsWSConsumer
- from strawberry.schema import BaseSchema
- from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage
-
-
-class GraphQLWSHandler(BaseGraphQLWSHandler):
- def __init__(
- self,
- schema: BaseSchema,
- debug: bool,
- keep_alive: bool,
- keep_alive_interval: float,
- get_context: Callable,
- get_root_value: Callable,
- ws: ChannelsWSConsumer,
- ) -> None:
- super().__init__(schema, debug, keep_alive, keep_alive_interval)
- self._get_context = get_context
- self._get_root_value = get_root_value
- self._ws = ws
-
- async def get_context(self) -> Any:
- return await self._get_context(
- request=self._ws, connection_params=self.connection_params
- )
-
- async def get_root_value(self) -> Any:
- return await self._get_root_value(request=self._ws)
-
- async def send_json(self, data: OperationMessage) -> None:
- await self._ws.send_json(data)
-
- async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
- # TODO: We are using `self._ws.base_send` directly instead of `self._ws.close`
- # because the latler doesn't accept the `reason` argument.
- await self._ws.base_send(
- {
- "type": "websocket.close",
- "code": code,
- "reason": reason or "",
- }
- )
-
- async def handle_request(self) -> Any:
- await self._ws.accept(subprotocol=GRAPHQL_WS_PROTOCOL)
-
- async def handle_disconnect(self, code: int) -> None:
- if self.keep_alive_task:
- self.keep_alive_task.cancel()
- with suppress(BaseException):
- await self.keep_alive_task
-
- for operation_id in list(self.subscriptions.keys()):
- await self.cleanup_operation(operation_id)
-
- async def handle_invalid_message(self, error_message: str) -> None:
- # This is not part of the BaseGraphQLWSHandler's interface, but the
- # channels integration is a high level wrapper that forwards this to
- # both us and the BaseGraphQLTransportWSHandler.
- await self.close(code=1002, reason=error_message)
-
-
-__all__ = ["GraphQLWSHandler"]
diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py
index 9169265cbb..a60bf2789e 100644
--- a/strawberry/channels/handlers/http_handler.py
+++ b/strawberry/channels/handlers/http_handler.py
@@ -15,7 +15,7 @@
Optional,
Union,
)
-from typing_extensions import assert_never
+from typing_extensions import TypeGuard, assert_never
from urllib.parse import parse_qs
from django.conf import settings
@@ -167,13 +167,12 @@ def __init__(
graphiql: Optional[bool] = None,
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
- subscriptions_enabled: bool = True,
+ multipart_uploads_enabled: bool = False,
**kwargs: Any,
) -> None:
self.schema = schema
self.allow_queries_via_get = allow_queries_via_get
- self.subscriptions_enabled = subscriptions_enabled
- self._ide_subscriptions_enabled = subscriptions_enabled
+ self.multipart_uploads_enabled = multipart_uploads_enabled
if graphiql is not None:
warnings.warn(
@@ -231,6 +230,8 @@ class GraphQLHTTPConsumer(
ChannelsRequest,
Union[ChannelsResponse, MultipartChannelsResponse],
TemporalResponse,
+ ChannelsRequest,
+ TemporalResponse,
Context,
RootValue,
],
@@ -296,6 +297,21 @@ async def render_graphql_ide(self, request: ChannelsRequest) -> ChannelsResponse
content=self.graphql_ide_html.encode(), content_type="text/html"
)
+ def is_websocket_request(
+ self, request: ChannelsRequest
+ ) -> TypeGuard[ChannelsRequest]:
+ return False
+
+ async def pick_websocket_subprotocol(
+ self, request: ChannelsRequest
+ ) -> Optional[str]:
+ return None
+
+ async def create_websocket_response(
+ self, request: ChannelsRequest, subprotocol: Optional[str]
+ ) -> TemporalResponse:
+ raise NotImplementedError
+
class SyncGraphQLHTTPConsumer(
BaseGraphQLHTTPConsumer,
diff --git a/strawberry/channels/handlers/ws_handler.py b/strawberry/channels/handlers/ws_handler.py
index 2991059afd..b267f7ea9b 100644
--- a/strawberry/channels/handlers/ws_handler.py
+++ b/strawberry/channels/handlers/ws_handler.py
@@ -1,20 +1,76 @@
from __future__ import annotations
+import asyncio
import datetime
-from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Union
-
+import json
+from typing import (
+ TYPE_CHECKING,
+ AsyncGenerator,
+ Dict,
+ Mapping,
+ Optional,
+ Tuple,
+ TypedDict,
+ Union,
+)
+from typing_extensions import TypeGuard
+
+from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncWebSocketAdapter
+from strawberry.http.exceptions import NonJsonMessageReceived
+from strawberry.http.typevars import Context, RootValue
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
-from .base import ChannelsConsumer, ChannelsWSConsumer
-from .graphql_transport_ws_handler import GraphQLTransportWSHandler
-from .graphql_ws_handler import GraphQLWSHandler
+from .base import ChannelsWSConsumer
if TYPE_CHECKING:
- from strawberry.http.typevars import Context, RootValue
+ from strawberry.http import GraphQLHTTPResponse
from strawberry.schema import BaseSchema
-class GraphQLWSConsumer(ChannelsWSConsumer):
+class ChannelsWebSocketAdapter(AsyncWebSocketAdapter):
+ def __init__(self, request: GraphQLWSConsumer, response: GraphQLWSConsumer) -> None:
+ self.ws_consumer = response
+
+ async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
+ while True:
+ message = await self.ws_consumer.message_queue.get()
+
+ if message["disconnected"]:
+ break
+
+ if message["message"] is None:
+ raise NonJsonMessageReceived()
+
+ try:
+ yield json.loads(message["message"])
+ except json.JSONDecodeError:
+ raise NonJsonMessageReceived()
+
+ async def send_json(self, message: Mapping[str, object]) -> None:
+ serialized_message = json.dumps(message)
+ await self.ws_consumer.send(serialized_message)
+
+ async def close(self, code: int, reason: str) -> None:
+ await self.ws_consumer.close(code=code, reason=reason)
+
+
+class MessageQueueData(TypedDict):
+ message: Union[str, None]
+ disconnected: bool
+
+
+class GraphQLWSConsumer(
+ ChannelsWSConsumer,
+ AsyncBaseHTTPView[
+ "GraphQLWSConsumer",
+ "GraphQLWSConsumer",
+ "GraphQLWSConsumer",
+ "GraphQLWSConsumer",
+ "GraphQLWSConsumer",
+ Context,
+ RootValue,
+ ],
+):
"""A channels websocket consumer for GraphQL.
This handles the connections, then hands off to the appropriate
@@ -39,9 +95,7 @@ class GraphQLWSConsumer(ChannelsWSConsumer):
```
"""
- graphql_transport_ws_handler_class = GraphQLTransportWSHandler
- graphql_ws_handler_class = GraphQLWSHandler
- _handler: Union[GraphQLWSHandler, GraphQLTransportWSHandler]
+ websocket_adapter_class = ChannelsWebSocketAdapter
def __init__(
self,
@@ -63,70 +117,71 @@ def __init__(
self.keep_alive_interval = keep_alive_interval
self.debug = debug
self.protocols = subscription_protocols
+ self.message_queue: asyncio.Queue[MessageQueueData] = asyncio.Queue()
+ self.run_task: Optional[asyncio.Task] = None
super().__init__()
- def pick_preferred_protocol(
- self, accepted_subprotocols: Sequence[str]
- ) -> Optional[str]:
- intersection = set(accepted_subprotocols) & set(self.protocols)
- sorted_intersection = sorted(intersection, key=accepted_subprotocols.index)
- return next(iter(sorted_intersection), None)
-
async def connect(self) -> None:
- preferred_protocol = self.pick_preferred_protocol(self.scope["subprotocols"])
-
- if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
- self._handler = self.graphql_transport_ws_handler_class(
- schema=self.schema,
- debug=self.debug,
- connection_init_wait_timeout=self.connection_init_wait_timeout,
- get_context=self.get_context,
- get_root_value=self.get_root_value,
- ws=self,
- )
- elif preferred_protocol == GRAPHQL_WS_PROTOCOL:
- self._handler = self.graphql_ws_handler_class(
- schema=self.schema,
- debug=self.debug,
- keep_alive=self.keep_alive,
- keep_alive_interval=self.keep_alive_interval,
- get_context=self.get_context,
- get_root_value=self.get_root_value,
- ws=self,
- )
- else:
- # Subprotocol not acceptable
- return await self.close(code=4406)
+ self.run_task = asyncio.create_task(self.run(self))
- await self._handler.handle()
- return None
-
- async def receive(self, *args: str, **kwargs: Any) -> None:
- # Overriding this so that we can pass the errors to handle_invalid_message
- try:
- await super().receive(*args, **kwargs)
- except ValueError:
- reason = "WebSocket message type must be text"
- await self._handler.handle_invalid_message(reason)
-
- async def receive_json(self, content: Any, **kwargs: Any) -> None:
- await self._handler.handle_message(content)
+ async def receive(
+ self, text_data: Optional[str] = None, bytes_data: Optional[bytes] = None
+ ) -> None:
+ if text_data:
+ self.message_queue.put_nowait({"message": text_data, "disconnected": False})
+ else:
+ self.message_queue.put_nowait({"message": None, "disconnected": False})
async def disconnect(self, code: int) -> None:
- await self._handler.handle_disconnect(code)
+ self.message_queue.put_nowait({"message": None, "disconnected": True})
+ assert self.run_task
+ await self.run_task
- async def get_root_value(self, request: ChannelsConsumer) -> Optional[RootValue]:
+ async def get_root_value(self, request: GraphQLWSConsumer) -> Optional[RootValue]:
return None
async def get_context(
- self, request: ChannelsConsumer, connection_params: Any
+ self, request: GraphQLWSConsumer, response: GraphQLWSConsumer
) -> Context:
return {
"request": request,
- "connection_params": connection_params,
"ws": request,
} # type: ignore
+ @property
+ def allow_queries_via_get(self) -> bool:
+ return False
+
+ async def get_sub_response(self, request: GraphQLWSConsumer) -> GraphQLWSConsumer:
+ raise NotImplementedError
+
+ def create_response(
+ self, response_data: GraphQLHTTPResponse, sub_response: GraphQLWSConsumer
+ ) -> GraphQLWSConsumer:
+ raise NotImplementedError
+
+ async def render_graphql_ide(self, request: GraphQLWSConsumer) -> GraphQLWSConsumer:
+ raise NotImplementedError
+
+ def is_websocket_request(
+ self, request: GraphQLWSConsumer
+ ) -> TypeGuard[GraphQLWSConsumer]:
+ return True
+
+ async def pick_websocket_subprotocol(
+ self, request: GraphQLWSConsumer
+ ) -> Optional[str]:
+ protocols = request.scope["subprotocols"]
+ intersection = set(protocols) & set(self.protocols)
+ sorted_intersection = sorted(intersection, key=protocols.index)
+ return next(iter(sorted_intersection), None)
+
+ async def create_websocket_response(
+ self, request: GraphQLWSConsumer, subprotocol: Optional[str]
+ ) -> GraphQLWSConsumer:
+ await request.accept(subprotocol=subprotocol)
+ return request
+
__all__ = ["GraphQLWSConsumer"]
diff --git a/strawberry/codegen/query_codegen.py b/strawberry/codegen/query_codegen.py
index 92d93d98db..583f7fbe39 100644
--- a/strawberry/codegen/query_codegen.py
+++ b/strawberry/codegen/query_codegen.py
@@ -645,20 +645,22 @@ def _unwrap_type(
) -> Tuple[
Union[type, StrawberryType], Optional[Callable[[GraphQLType], GraphQLType]]
]:
- wrapper = None
+ wrapper: Optional[Callable[[GraphQLType], GraphQLType]] = None
if isinstance(type_, StrawberryOptional):
- type_, wrapper = self._unwrap_type(type_.of_type)
+ type_, previous_wrapper = self._unwrap_type(type_.of_type)
wrapper = (
GraphQLOptional
- if wrapper is None
- else lambda t: GraphQLOptional(wrapper(t)) # type: ignore[misc]
+ if previous_wrapper is None
+ else lambda t: GraphQLOptional(previous_wrapper(t)) # type: ignore[misc]
)
elif isinstance(type_, StrawberryList):
- type_, wrapper = self._unwrap_type(type_.of_type)
+ type_, previous_wrapper = self._unwrap_type(type_.of_type)
wrapper = (
- GraphQLList if wrapper is None else lambda t: GraphQLList(wrapper(t))
+ GraphQLList
+ if previous_wrapper is None
+ else lambda t: GraphQLList(previous_wrapper(t))
)
elif isinstance(type_, LazyType):
diff --git a/strawberry/django/views.py b/strawberry/django/views.py
index 0ce5bf920a..23fa07e886 100644
--- a/strawberry/django/views.py
+++ b/strawberry/django/views.py
@@ -13,6 +13,7 @@
Union,
cast,
)
+from typing_extensions import TypeGuard
from asgiref.sync import markcoroutinefunction
from django.core.serializers.json import DjangoJSONEncoder
@@ -24,12 +25,9 @@
StreamingHttpResponse,
)
from django.http.response import HttpResponseBase
-from django.template import RequestContext, Template
from django.template.exceptions import TemplateDoesNotExist
from django.template.loader import render_to_string
-from django.template.response import TemplateResponse
-from django.utils.decorators import classonlymethod, method_decorator
-from django.views.decorators.csrf import csrf_exempt
+from django.utils.decorators import classonlymethod
from django.views.generic import View
from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter
@@ -44,6 +42,8 @@
from .context import StrawberryDjangoContext
if TYPE_CHECKING:
+ from django.template.response import TemplateResponse
+
from strawberry.http import GraphQLHTTPResponse
from strawberry.http.ides import GraphQL_IDE
@@ -137,7 +137,6 @@ async def get_form_data(self) -> FormData:
class BaseView:
- _ide_replace_variables = False
graphql_ide_html: str
def __init__(
@@ -146,12 +145,12 @@ def __init__(
graphiql: Optional[str] = None,
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
- subscriptions_enabled: bool = False,
+ multipart_uploads_enabled: bool = False,
**kwargs: Any,
) -> None:
self.schema = schema
self.allow_queries_via_get = allow_queries_via_get
- self.subscriptions_enabled = subscriptions_enabled
+ self.multipart_uploads_enabled = multipart_uploads_enabled
if graphiql is not None:
warnings.warn(
@@ -213,7 +212,6 @@ class GraphQLView(
],
View,
):
- subscriptions_enabled = False
graphiql: Optional[bool] = None
graphql_ide: Optional[GraphQL_IDE] = "graphiql"
allow_queries_via_get = True
@@ -229,7 +227,6 @@ def get_context(self, request: HttpRequest, response: HttpResponse) -> Any:
def get_sub_response(self, request: HttpRequest) -> TemporalHttpResponse:
return TemporalHttpResponse()
- @method_decorator(csrf_exempt)
def dispatch(
self, request: HttpRequest, *args: Any, **kwargs: Any
) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponseBase]:
@@ -243,26 +240,26 @@ def dispatch(
def render_graphql_ide(self, request: HttpRequest) -> HttpResponse:
try:
- template = Template(render_to_string("graphql/graphiql.html"))
+ content = render_to_string("graphql/graphiql.html")
except TemplateDoesNotExist:
- template = Template(self.graphql_ide_html)
-
- context = {"SUBSCRIPTION_ENABLED": json.dumps(self.subscriptions_enabled)}
+ content = self.graphql_ide_html
- response = TemplateResponse(request=request, template=None, context=context)
- response.content = template.render(RequestContext(request, context))
-
- return response
+ return HttpResponse(content)
class AsyncGraphQLView(
BaseView,
AsyncBaseHTTPView[
- HttpRequest, HttpResponseBase, TemporalHttpResponse, Context, RootValue
+ HttpRequest,
+ HttpResponseBase,
+ TemporalHttpResponse,
+ HttpRequest,
+ TemporalHttpResponse,
+ Context,
+ RootValue,
],
View,
):
- subscriptions_enabled = False
graphiql: Optional[bool] = None
graphql_ide: Optional[GraphQL_IDE] = "graphiql"
allow_queries_via_get = True
@@ -288,7 +285,6 @@ async def get_context(self, request: HttpRequest, response: HttpResponse) -> Any
async def get_sub_response(self, request: HttpRequest) -> TemporalHttpResponse:
return TemporalHttpResponse()
- @method_decorator(csrf_exempt)
async def dispatch( # pyright: ignore
self, request: HttpRequest, *args: Any, **kwargs: Any
) -> Union[HttpResponseNotAllowed, TemplateResponse, HttpResponseBase]:
@@ -302,16 +298,22 @@ async def dispatch( # pyright: ignore
async def render_graphql_ide(self, request: HttpRequest) -> HttpResponse:
try:
- template = Template(render_to_string("graphql/graphiql.html"))
+ content = render_to_string("graphql/graphiql.html")
except TemplateDoesNotExist:
- template = Template(self.graphql_ide_html)
+ content = self.graphql_ide_html
- context = {"SUBSCRIPTION_ENABLED": json.dumps(self.subscriptions_enabled)}
+ return HttpResponse(content=content)
- response = TemplateResponse(request=request, template=None, context=context)
- response.content = template.render(RequestContext(request, context))
+ def is_websocket_request(self, request: HttpRequest) -> TypeGuard[HttpRequest]:
+ return False
- return response
+ async def pick_websocket_subprotocol(self, request: HttpRequest) -> Optional[str]:
+ raise NotImplementedError
+
+ async def create_websocket_response(
+ self, request: HttpRequest, subprotocol: Optional[str]
+ ) -> TemporalHttpResponse:
+ raise NotImplementedError
__all__ = ["GraphQLView", "AsyncGraphQLView"]
diff --git a/strawberry/exceptions/__init__.py b/strawberry/exceptions/__init__.py
index 7d3b3ccb1b..ee331af721 100644
--- a/strawberry/exceptions/__init__.py
+++ b/strawberry/exceptions/__init__.py
@@ -15,7 +15,6 @@
from .missing_dependencies import MissingOptionalDependenciesError
from .missing_field_annotation import MissingFieldAnnotationError
from .missing_return_annotation import MissingReturnAnnotationError
-from .not_a_strawberry_enum import NotAStrawberryEnumError
from .object_is_not_a_class import ObjectIsNotClassError
from .object_is_not_an_enum import ObjectIsNotAnEnumError
from .private_strawberry_field import PrivateStrawberryFieldError
@@ -174,7 +173,6 @@ class StrawberryGraphQLError(GraphQLError):
"UnresolvedFieldTypeError",
"PrivateStrawberryFieldError",
"MultipleStrawberryArgumentsError",
- "NotAStrawberryEnumError",
"ScalarAlreadyRegisteredError",
"WrongNumberOfResultsReturned",
"FieldWithResolverAndDefaultValueError",
diff --git a/strawberry/exceptions/not_a_strawberry_enum.py b/strawberry/exceptions/not_a_strawberry_enum.py
deleted file mode 100644
index dd0c94d48f..0000000000
--- a/strawberry/exceptions/not_a_strawberry_enum.py
+++ /dev/null
@@ -1,38 +0,0 @@
-from __future__ import annotations
-
-from functools import cached_property
-from typing import TYPE_CHECKING, Optional
-
-from .exception import StrawberryException
-from .utils.source_finder import SourceFinder
-
-if TYPE_CHECKING:
- from enum import EnumMeta
-
- from .exception_source import ExceptionSource
-
-
-class NotAStrawberryEnumError(StrawberryException):
- def __init__(self, enum: EnumMeta) -> None:
- self.enum = enum
-
- self.message = f'Enum "{enum.__name__}" is not a Strawberry enum.'
- self.rich_message = (
- f"Enum `[underline]{enum.__name__}[/]` is not a Strawberry enum."
- )
- self.suggestion = (
- "To fix this error you can declare the enum using `@strawberry.enum`."
- )
-
- self.annotation_message = "enum defined here"
-
- super().__init__(self.message)
-
- @cached_property
- def exception_source(self) -> Optional[ExceptionSource]:
- if self.enum is None:
- return None # pragma: no cover
-
- source_finder = SourceFinder()
-
- return source_finder.find_class_from_object(self.enum)
diff --git a/strawberry/ext/mypy_plugin.py b/strawberry/ext/mypy_plugin.py
index 5d5c3d08ef..77a8a406f0 100644
--- a/strawberry/ext/mypy_plugin.py
+++ b/strawberry/ext/mypy_plugin.py
@@ -481,7 +481,14 @@ def strawberry_pydantic_class_callback(ctx: ClassDefContext) -> None:
# Based on pydantic's default value
# https://github.com/pydantic/pydantic/pull/9606/files#diff-469037bbe55bbf9aa359480a16040d368c676adad736e133fb07e5e20d6ac523R1066
extra["force_typevars_invariant"] = False
-
+ if PYDANTIC_VERSION >= (2, 9, 0):
+ extra["model_strict"] = model_type.type.metadata[
+ PYDANTIC_METADATA_KEY
+ ]["config"].get("strict", False)
+ extra["is_root_model_root"] = any(
+ "pydantic.root_model.RootModel" in base.fullname
+ for base in model_type.type.mro[:-1]
+ )
add_method(
ctx,
"to_pydantic",
diff --git a/strawberry/fastapi/handlers/__init__.py b/strawberry/fastapi/handlers/__init__.py
deleted file mode 100644
index 20f336f5ff..0000000000
--- a/strawberry/fastapi/handlers/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from strawberry.fastapi.handlers.graphql_transport_ws_handler import (
- GraphQLTransportWSHandler,
-)
-from strawberry.fastapi.handlers.graphql_ws_handler import GraphQLWSHandler
-
-__all__ = ["GraphQLTransportWSHandler", "GraphQLWSHandler"]
diff --git a/strawberry/fastapi/handlers/graphql_transport_ws_handler.py b/strawberry/fastapi/handlers/graphql_transport_ws_handler.py
deleted file mode 100644
index 817f6996ac..0000000000
--- a/strawberry/fastapi/handlers/graphql_transport_ws_handler.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from typing import Any
-
-from strawberry.asgi.handlers import (
- GraphQLTransportWSHandler as BaseGraphQLTransportWSHandler,
-)
-from strawberry.fastapi.context import BaseContext
-
-
-class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
- async def get_context(self) -> Any:
- context = await self._get_context()
- if isinstance(context, BaseContext):
- context.connection_params = self.connection_params
- return context
-
- async def get_root_value(self) -> Any:
- return await self._get_root_value()
-
-
-__all__ = ["GraphQLTransportWSHandler"]
diff --git a/strawberry/fastapi/handlers/graphql_ws_handler.py b/strawberry/fastapi/handlers/graphql_ws_handler.py
deleted file mode 100644
index 0c43bbbd6e..0000000000
--- a/strawberry/fastapi/handlers/graphql_ws_handler.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from typing import Any
-
-from strawberry.asgi.handlers import GraphQLWSHandler as BaseGraphQLWSHandler
-from strawberry.fastapi.context import BaseContext
-
-
-class GraphQLWSHandler(BaseGraphQLWSHandler):
- async def get_context(self) -> Any:
- context = await self._get_context()
- if isinstance(context, BaseContext):
- context.connection_params = self.connection_params
- return context
-
- async def get_root_value(self) -> Any:
- return await self._get_root_value()
-
-
-__all__ = ["GraphQLWSHandler"]
diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py
index 833b656383..3ed8e6a4a0 100644
--- a/strawberry/fastapi/router.py
+++ b/strawberry/fastapi/router.py
@@ -17,6 +17,7 @@
Union,
cast,
)
+from typing_extensions import TypeGuard
from starlette import status
from starlette.background import BackgroundTasks # noqa: TCH002
@@ -34,10 +35,9 @@
from fastapi.datastructures import Default
from fastapi.routing import APIRoute
from fastapi.utils import generate_unique_id
-from strawberry.asgi import ASGIRequestAdapter
+from strawberry.asgi import ASGIRequestAdapter, ASGIWebSocketAdapter
from strawberry.exceptions import InvalidCustomContext
from strawberry.fastapi.context import BaseContext, CustomContext
-from strawberry.fastapi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler
from strawberry.http import process_result
from strawberry.http.async_base_view import AsyncBaseHTTPView
from strawberry.http.exceptions import HTTPException
@@ -58,12 +58,14 @@
class GraphQLRouter(
- AsyncBaseHTTPView[Request, Response, Response, Context, RootValue], APIRouter
+ AsyncBaseHTTPView[
+ Request, Response, Response, WebSocket, WebSocket, Context, RootValue
+ ],
+ APIRouter,
):
- graphql_ws_handler_class = GraphQLWSHandler
- graphql_transport_ws_handler_class = GraphQLTransportWSHandler
allow_queries_via_get = True
request_adapter_class = ASGIRequestAdapter
+ websocket_adapter_class = ASGIWebSocketAdapter
@staticmethod
async def __get_root_value() -> None:
@@ -156,6 +158,7 @@ def __init__(
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
+ multipart_uploads_enabled: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
@@ -190,6 +193,7 @@ def __init__(
)
self.protocols = subscription_protocols
self.connection_init_wait_timeout = connection_init_wait_timeout
+ self.multipart_uploads_enabled = multipart_uploads_enabled
if graphiql is not None:
warnings.warn(
@@ -259,44 +263,7 @@ async def websocket_endpoint( # pyright: ignore
context: Context = Depends(self.context_getter),
root_value: RootValue = Depends(self.root_value_getter),
) -> None:
- async def _get_context() -> Context:
- return context
-
- async def _get_root_value() -> RootValue:
- return root_value
-
- preferred_protocol = self.pick_preferred_protocol(websocket)
- if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
- await self.graphql_transport_ws_handler_class(
- schema=self.schema,
- debug=self.debug,
- connection_init_wait_timeout=self.connection_init_wait_timeout,
- get_context=_get_context,
- get_root_value=_get_root_value,
- ws=websocket,
- ).handle()
- elif preferred_protocol == GRAPHQL_WS_PROTOCOL:
- await self.graphql_ws_handler_class(
- schema=self.schema,
- debug=self.debug,
- keep_alive=self.keep_alive,
- keep_alive_interval=self.keep_alive_interval,
- get_context=_get_context,
- get_root_value=_get_root_value,
- ws=websocket,
- ).handle()
- else:
- # Code 4406 is "Subprotocol not acceptable"
- await websocket.close(code=4406)
-
- def pick_preferred_protocol(self, ws: WebSocket) -> Optional[str]:
- protocols = ws["subprotocols"]
- intersection = set(protocols) & set(self.protocols)
- return min(
- intersection,
- key=lambda i: protocols.index(i),
- default=None,
- )
+ await self.run(request=websocket, context=context, root_value=root_value)
async def render_graphql_ide(self, request: Request) -> HTMLResponse:
return HTMLResponse(self.graphql_ide_html)
@@ -307,12 +274,12 @@ async def process_result(
return process_result(result)
async def get_context(
- self, request: Request, response: Response
+ self, request: Union[Request, WebSocket], response: Union[Response, WebSocket]
) -> Context: # pragma: no cover
raise ValueError("`get_context` is not used by FastAPI GraphQL Router")
async def get_root_value(
- self, request: Request
+ self, request: Union[Request, WebSocket]
) -> Optional[RootValue]: # pragma: no cover
raise ValueError("`get_root_value` is not used by FastAPI GraphQL Router")
@@ -348,5 +315,22 @@ async def create_streaming_response(
},
)
+ def is_websocket_request(
+ self, request: Union[Request, WebSocket]
+ ) -> TypeGuard[WebSocket]:
+ return request.scope["type"] == "websocket"
+
+ async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]:
+ protocols = request["subprotocols"]
+ intersection = set(protocols) & set(self.protocols)
+ sorted_intersection = sorted(intersection, key=protocols.index)
+ return next(iter(sorted_intersection), None)
+
+ async def create_websocket_response(
+ self, request: WebSocket, subprotocol: Optional[str]
+ ) -> WebSocket:
+ await request.accept(subprotocol=subprotocol)
+ return request
+
__all__ = ["GraphQLRouter"]
diff --git a/strawberry/flask/views.py b/strawberry/flask/views.py
index b855c602ea..d93f480b1e 100644
--- a/strawberry/flask/views.py
+++ b/strawberry/flask/views.py
@@ -9,6 +9,7 @@
Union,
cast,
)
+from typing_extensions import TypeGuard
from flask import Request, Response, render_template_string, request
from flask.views import View
@@ -62,7 +63,6 @@ def content_type(self) -> Optional[str]:
class BaseGraphQLView:
- _ide_subscription_enabled = False
graphql_ide: Optional[GraphQL_IDE]
def __init__(
@@ -71,10 +71,12 @@ def __init__(
graphiql: Optional[bool] = None,
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
+ multipart_uploads_enabled: bool = False,
) -> None:
self.schema = schema
self.graphiql = graphiql
self.allow_queries_via_get = allow_queries_via_get
+ self.multipart_uploads_enabled = multipart_uploads_enabled
if graphiql is not None:
warnings.warn(
@@ -157,7 +159,9 @@ async def get_form_data(self) -> FormData:
class AsyncGraphQLView(
BaseGraphQLView,
- AsyncBaseHTTPView[Request, Response, Response, Context, RootValue],
+ AsyncBaseHTTPView[
+ Request, Response, Response, Request, Response, Context, RootValue
+ ],
View,
):
methods = ["GET", "POST"]
@@ -183,7 +187,19 @@ async def dispatch_request(self) -> ResponseReturnValue: # type: ignore
)
async def render_graphql_ide(self, request: Request) -> Response:
- return render_template_string(self.graphql_ide_html) # type: ignore
+ content = render_template_string(self.graphql_ide_html)
+ return Response(content, status=200, content_type="text/html")
+
+ def is_websocket_request(self, request: Request) -> TypeGuard[Request]:
+ return False
+
+ async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]:
+ raise NotImplementedError
+
+ async def create_websocket_response(
+ self, request: Request, subprotocol: Optional[str]
+ ) -> Response:
+ raise NotImplementedError
__all__ = [
diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py
index 7eef89aa40..a7666018ef 100644
--- a/strawberry/http/async_base_view.py
+++ b/strawberry/http/async_base_view.py
@@ -2,6 +2,7 @@
import asyncio
import contextlib
import json
+from datetime import timedelta
from typing import (
Any,
AsyncGenerator,
@@ -13,8 +14,10 @@
Optional,
Tuple,
Union,
+ cast,
+ overload,
)
-from typing_extensions import Literal
+from typing_extensions import Literal, TypeGuard
from graphql import GraphQLError
@@ -29,6 +32,11 @@
from strawberry.http.ides import GraphQL_IDE
from strawberry.schema.base import BaseSchema
from strawberry.schema.exceptions import InvalidOperationTypeError
+from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
+from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import (
+ BaseGraphQLTransportWSHandler,
+)
+from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler
from strawberry.types import ExecutionResult, SubscriptionExecutionResult
from strawberry.types.graphql import OperationType
@@ -36,7 +44,15 @@
from .exceptions import HTTPException
from .parse_content_type import parse_content_type
from .types import FormData, HTTPMethod, QueryParams
-from .typevars import Context, Request, Response, RootValue, SubResponse
+from .typevars import (
+ Context,
+ Request,
+ Response,
+ RootValue,
+ SubResponse,
+ WebSocketRequest,
+ WebSocketResponse,
+)
class AsyncHTTPRequestAdapter(abc.ABC):
@@ -63,14 +79,42 @@ async def get_body(self) -> Union[str, bytes]: ...
async def get_form_data(self) -> FormData: ...
+class AsyncWebSocketAdapter(abc.ABC):
+ @abc.abstractmethod
+ def iter_json(self) -> AsyncGenerator[Dict[str, object], None]: ...
+
+ @abc.abstractmethod
+ async def send_json(self, message: Mapping[str, object]) -> None: ...
+
+ @abc.abstractmethod
+ async def close(self, code: int, reason: str) -> None: ...
+
+
class AsyncBaseHTTPView(
abc.ABC,
BaseView[Request],
- Generic[Request, Response, SubResponse, Context, RootValue],
+ Generic[
+ Request,
+ Response,
+ SubResponse,
+ WebSocketRequest,
+ WebSocketResponse,
+ Context,
+ RootValue,
+ ],
):
schema: BaseSchema
graphql_ide: Optional[GraphQL_IDE]
+ debug: bool
+ keep_alive = False
+ keep_alive_interval: Optional[float] = None
+ connection_init_wait_timeout: timedelta = timedelta(minutes=1)
request_adapter_class: Callable[[Request], AsyncHTTPRequestAdapter]
+ websocket_adapter_class: Callable[
+ [WebSocketRequest, WebSocketResponse], AsyncWebSocketAdapter
+ ]
+ graphql_transport_ws_handler_class = BaseGraphQLTransportWSHandler
+ graphql_ws_handler_class = BaseGraphQLWSHandler
@property
@abc.abstractmethod
@@ -80,10 +124,16 @@ def allow_queries_via_get(self) -> bool: ...
async def get_sub_response(self, request: Request) -> SubResponse: ...
@abc.abstractmethod
- async def get_context(self, request: Request, response: SubResponse) -> Context: ...
+ async def get_context(
+ self,
+ request: Union[Request, WebSocketRequest],
+ response: Union[SubResponse, WebSocketResponse],
+ ) -> Context: ...
@abc.abstractmethod
- async def get_root_value(self, request: Request) -> Optional[RootValue]: ...
+ async def get_root_value(
+ self, request: Union[Request, WebSocketRequest]
+ ) -> Optional[RootValue]: ...
@abc.abstractmethod
def create_response(
@@ -102,6 +152,21 @@ async def create_streaming_response(
) -> Response:
raise ValueError("Multipart responses are not supported")
+ @abc.abstractmethod
+ def is_websocket_request(
+ self, request: Union[Request, WebSocketRequest]
+ ) -> TypeGuard[WebSocketRequest]: ...
+
+ @abc.abstractmethod
+ async def pick_websocket_subprotocol(
+ self, request: WebSocketRequest
+ ) -> Optional[str]: ...
+
+ @abc.abstractmethod
+ async def create_websocket_response(
+ self, request: WebSocketRequest, subprotocol: Optional[str]
+ ) -> WebSocketResponse: ...
+
async def execute_operation(
self, request: Request, context: Context, root_value: Optional[RootValue]
) -> Union[ExecutionResult, SubscriptionExecutionResult]:
@@ -167,35 +232,90 @@ def _handle_errors(
) -> None:
"""Hook to allow custom handling of errors, used by the Sentry Integration."""
+ @overload
async def run(
self,
request: Request,
context: Optional[Context] = UNSET,
root_value: Optional[RootValue] = UNSET,
- ) -> Response:
- request_adapter = self.request_adapter_class(request)
+ ) -> Response: ...
- if not self.is_request_allowed(request_adapter):
- raise HTTPException(405, "GraphQL only supports GET and POST requests.")
+ @overload
+ async def run(
+ self,
+ request: WebSocketRequest,
+ context: Optional[Context] = UNSET,
+ root_value: Optional[RootValue] = UNSET,
+ ) -> WebSocketResponse: ...
- if self.should_render_graphql_ide(request_adapter):
- if self.graphql_ide:
- return await self.render_graphql_ide(request)
+ async def run(
+ self,
+ request: Union[Request, WebSocketRequest],
+ context: Optional[Context] = UNSET,
+ root_value: Optional[RootValue] = UNSET,
+ ) -> Union[Response, WebSocketResponse]:
+ root_value = (
+ await self.get_root_value(request) if root_value is UNSET else root_value
+ )
+
+ if self.is_websocket_request(request):
+ websocket_subprotocol = await self.pick_websocket_subprotocol(request)
+ websocket_response = await self.create_websocket_response(
+ request, websocket_subprotocol
+ )
+ websocket = self.websocket_adapter_class(request, websocket_response)
+
+ context = (
+ await self.get_context(request, response=websocket_response)
+ if context is UNSET
+ else context
+ )
+
+ if websocket_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
+ await self.graphql_transport_ws_handler_class(
+ websocket=websocket,
+ context=context,
+ root_value=root_value,
+ schema=self.schema,
+ debug=self.debug,
+ connection_init_wait_timeout=self.connection_init_wait_timeout,
+ ).handle()
+ elif websocket_subprotocol == GRAPHQL_WS_PROTOCOL:
+ await self.graphql_ws_handler_class(
+ websocket=websocket,
+ context=context,
+ root_value=root_value,
+ schema=self.schema,
+ debug=self.debug,
+ keep_alive=self.keep_alive,
+ keep_alive_interval=self.keep_alive_interval,
+ ).handle()
else:
- raise HTTPException(404, "Not Found")
+ await websocket.close(4406, "Subprotocol not acceptable")
+
+ return websocket_response
+ else:
+ request = cast(Request, request)
+ request_adapter = self.request_adapter_class(request)
sub_response = await self.get_sub_response(request)
context = (
await self.get_context(request, response=sub_response)
if context is UNSET
else context
)
- root_value = (
- await self.get_root_value(request) if root_value is UNSET else root_value
- )
assert context
+ if not self.is_request_allowed(request_adapter):
+ raise HTTPException(405, "GraphQL only supports GET and POST requests.")
+
+ if self.should_render_graphql_ide(request_adapter):
+ if self.graphql_ide:
+ return await self.render_graphql_ide(request)
+ else:
+ raise HTTPException(404, "Not Found")
+
try:
result = await self.execute_operation(
request=request, context=context, root_value=root_value
@@ -333,7 +453,7 @@ async def parse_http_body(
data = self.parse_query_params(request.query_params)
elif "application/json" in content_type:
data = self.parse_json(await request.get_body())
- elif content_type == "multipart/form-data":
+ elif self.multipart_uploads_enabled and content_type == "multipart/form-data":
data = await self.parse_multipart(request)
else:
raise HTTPException(400, "Unsupported content type")
diff --git a/strawberry/http/base.py b/strawberry/http/base.py
index 7f8e1802bc..a906849ddf 100644
--- a/strawberry/http/base.py
+++ b/strawberry/http/base.py
@@ -23,10 +23,7 @@ def headers(self) -> Mapping[str, str]: ...
class BaseView(Generic[Request]):
graphql_ide: Optional[GraphQL_IDE]
-
- # TODO: we might remove this in future :)
- _ide_replace_variables: bool = True
- _ide_subscription_enabled: bool = True
+ multipart_uploads_enabled: bool = False
def should_render_graphql_ide(self, request: BaseRequestProtocol) -> bool:
return (
@@ -63,11 +60,7 @@ def parse_query_params(self, params: QueryParams) -> Dict[str, Any]:
@property
def graphql_ide_html(self) -> str:
- return get_graphql_ide_html(
- subscription_enabled=self._ide_subscription_enabled,
- replace_variables=self._ide_replace_variables,
- graphql_ide=self.graphql_ide,
- )
+ return get_graphql_ide_html(graphql_ide=self.graphql_ide)
def _is_multipart_subscriptions(
self, content_type: str, params: Dict[str, str]
diff --git a/strawberry/http/exceptions.py b/strawberry/http/exceptions.py
index d934696806..feddf77631 100644
--- a/strawberry/http/exceptions.py
+++ b/strawberry/http/exceptions.py
@@ -4,4 +4,8 @@ def __init__(self, status_code: int, reason: str) -> None:
self.reason = reason
+class NonJsonMessageReceived(Exception):
+ pass
+
+
__all__ = ["HTTPException"]
diff --git a/strawberry/http/ides.py b/strawberry/http/ides.py
index d9c52fb716..9680a0277a 100644
--- a/strawberry/http/ides.py
+++ b/strawberry/http/ides.py
@@ -1,4 +1,3 @@
-import json
import pathlib
from typing import Optional
from typing_extensions import Literal
@@ -7,8 +6,6 @@
def get_graphql_ide_html(
- subscription_enabled: bool = True,
- replace_variables: bool = True,
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
) -> str:
here = pathlib.Path(__file__).parents[1]
@@ -22,11 +19,6 @@ def get_graphql_ide_html(
template = path.read_text(encoding="utf-8")
- if replace_variables:
- template = template.replace(
- "{{ SUBSCRIPTION_ENABLED }}", json.dumps(subscription_enabled)
- )
-
return template
diff --git a/strawberry/http/sync_base_view.py b/strawberry/http/sync_base_view.py
index f1ce7ca19a..df770e0541 100644
--- a/strawberry/http/sync_base_view.py
+++ b/strawberry/http/sync_base_view.py
@@ -143,7 +143,7 @@ def parse_http_body(self, request: SyncHTTPRequestAdapter) -> GraphQLRequestData
elif "application/json" in content_type:
data = self.parse_json(request.body)
# TODO: multipart via get?
- elif content_type == "multipart/form-data":
+ elif self.multipart_uploads_enabled and content_type == "multipart/form-data":
data = self.parse_multipart(request)
elif self._is_multipart_subscriptions(content_type, params):
raise HTTPException(
diff --git a/strawberry/http/typevars.py b/strawberry/http/typevars.py
index a48cba848e..53a5d5ac33 100644
--- a/strawberry/http/typevars.py
+++ b/strawberry/http/typevars.py
@@ -3,8 +3,18 @@
Request = TypeVar("Request", contravariant=True)
Response = TypeVar("Response")
SubResponse = TypeVar("SubResponse")
+WebSocketRequest = TypeVar("WebSocketRequest")
+WebSocketResponse = TypeVar("WebSocketResponse")
Context = TypeVar("Context")
RootValue = TypeVar("RootValue")
-__all__ = ["Request", "Response", "SubResponse", "Context", "RootValue"]
+__all__ = [
+ "Request",
+ "Response",
+ "SubResponse",
+ "WebSocketRequest",
+ "WebSocketResponse",
+ "Context",
+ "RootValue",
+]
diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py
index 7ff68c69ad..fed3d2d45f 100644
--- a/strawberry/litestar/controller.py
+++ b/strawberry/litestar/controller.py
@@ -7,19 +7,19 @@
from typing import (
TYPE_CHECKING,
Any,
+ AsyncGenerator,
AsyncIterator,
Callable,
Dict,
FrozenSet,
- List,
Optional,
- Set,
Tuple,
Type,
TypedDict,
Union,
cast,
)
+from typing_extensions import TypeGuard
from msgspec import Struct
@@ -35,23 +35,24 @@
)
from litestar.background_tasks import BackgroundTasks
from litestar.di import Provide
-from litestar.exceptions import NotFoundException, ValidationException
+from litestar.exceptions import (
+ NotFoundException,
+ SerializationException,
+ ValidationException,
+ WebSocketDisconnect,
+)
from litestar.response.streaming import Stream
from litestar.status_codes import HTTP_200_OK
from strawberry.exceptions import InvalidCustomContext
-from strawberry.http.async_base_view import AsyncBaseHTTPView, AsyncHTTPRequestAdapter
-from strawberry.http.exceptions import HTTPException
+from strawberry.http.async_base_view import (
+ AsyncBaseHTTPView,
+ AsyncHTTPRequestAdapter,
+ AsyncWebSocketAdapter,
+)
+from strawberry.http.exceptions import HTTPException, NonJsonMessageReceived
from strawberry.http.types import FormData, HTTPMethod, QueryParams
from strawberry.http.typevars import Context, RootValue
from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
-from strawberry.subscriptions.protocols.graphql_transport_ws import (
- WS_4406_PROTOCOL_NOT_ACCEPTABLE,
-)
-
-from .handlers.graphql_transport_ws_handler import (
- GraphQLTransportWSHandler as BaseGraphQLTransportWSHandler,
-)
-from .handlers.graphql_ws_handler import GraphQLWSHandler as BaseGraphQLWSHandler
if TYPE_CHECKING:
from collections.abc import Mapping
@@ -152,22 +153,6 @@ class GraphQLResource(Struct):
extensions: Optional[dict[str, object]]
-class GraphQLWSHandler(BaseGraphQLWSHandler):
- async def get_context(self) -> Any:
- return await self._get_context()
-
- async def get_root_value(self) -> Any:
- return await self._get_root_value()
-
-
-class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
- async def get_context(self) -> Any:
- return await self._get_context()
-
- async def get_root_value(self) -> Any:
- return await self._get_root_value()
-
-
class LitestarRequestAdapter(AsyncHTTPRequestAdapter):
def __init__(self, request: Request[Any, Any, Any]) -> None:
self.request = request
@@ -203,10 +188,37 @@ async def get_form_data(self) -> FormData:
return FormData(form=multipart_data, files=multipart_data)
+class LitestarWebSocketAdapter(AsyncWebSocketAdapter):
+ def __init__(self, request: WebSocket, response: WebSocket) -> None:
+ self.ws = response
+
+ async def iter_json(self) -> AsyncGenerator[Dict[str, object], None]:
+ try:
+ try:
+ while self.ws.connection_state != "disconnect":
+ yield await self.ws.receive_json()
+ except (SerializationException, ValueError):
+ raise NonJsonMessageReceived()
+ except WebSocketDisconnect:
+ pass
+
+ async def send_json(self, message: Mapping[str, object]) -> None:
+ await self.ws.send_json(message)
+
+ async def close(self, code: int, reason: str) -> None:
+ await self.ws.close(code=code, reason=reason)
+
+
class GraphQLController(
Controller,
AsyncBaseHTTPView[
- Request[Any, Any, Any], Response[Any], Response[Any], Context, RootValue
+ Request[Any, Any, Any],
+ Response[Any],
+ Response[Any],
+ WebSocket,
+ WebSocket,
+ Context,
+ RootValue,
],
):
path: str = ""
@@ -219,10 +231,7 @@ class GraphQLController(
}
request_adapter_class = LitestarRequestAdapter
- graphql_ws_handler_class: Type[GraphQLWSHandler] = GraphQLWSHandler
- graphql_transport_ws_handler_class: Type[GraphQLTransportWSHandler] = (
- GraphQLTransportWSHandler
- )
+ websocket_adapter_class = LitestarWebSocketAdapter
allow_queries_via_get: bool = True
graphiql_allowed_accept: FrozenSet[str] = frozenset({"text/html", "*/*"})
@@ -236,6 +245,23 @@ class GraphQLController(
keep_alive: bool = False
keep_alive_interval: float = 1
+ def is_websocket_request(
+ self, request: Union[Request, WebSocket]
+ ) -> TypeGuard[WebSocket]:
+ return isinstance(request, WebSocket)
+
+ async def pick_websocket_subprotocol(self, request: WebSocket) -> Optional[str]:
+ subprotocols = request.scope["subprotocols"]
+ intersection = set(subprotocols) & set(self.protocols)
+ sorted_intersection = sorted(intersection, key=subprotocols.index)
+ return next(iter(sorted_intersection), None)
+
+ async def create_websocket_response(
+ self, request: WebSocket, subprotocol: Optional[str]
+ ) -> WebSocket:
+ await request.accept(subprotocols=subprotocol)
+ return request
+
async def execute_request(
self,
request: Request[Any, Any, Any],
@@ -245,8 +271,6 @@ async def execute_request(
try:
return await self.run(
request,
- # TODO: check the dependency, above, can we make it so that
- # we don't need to type ignore here?
context=context,
root_value=root_value,
)
@@ -328,14 +352,29 @@ async def handle_http_post(
root_value=root_value,
)
+ @websocket()
+ async def websocket_endpoint(
+ self,
+ socket: WebSocket,
+ context_ws: Any,
+ root_value: Any,
+ ) -> None:
+ await self.run(
+ request=socket,
+ context=context_ws,
+ root_value=root_value,
+ )
+
async def get_context(
- self, request: Request[Any, Any, Any], response: Response
+ self,
+ request: Union[Request[Any, Any, Any], WebSocket],
+ response: Union[Response, WebSocket],
) -> Context: # pragma: no cover
msg = "`get_context` is not used by Litestar's controller"
raise ValueError(msg)
async def get_root_value(
- self, request: Request[Any, Any, Any]
+ self, request: Union[Request[Any, Any, Any], WebSocket]
) -> RootValue | None: # pragma: no cover
msg = "`get_root_value` is not used by Litestar's controller"
raise ValueError(msg)
@@ -343,54 +382,6 @@ async def get_root_value(
async def get_sub_response(self, request: Request[Any, Any, Any]) -> Response:
return self.temporal_response
- @websocket()
- async def websocket_endpoint(
- self,
- socket: WebSocket,
- context_ws: Any,
- root_value: Any,
- ) -> None:
- async def _get_context() -> Any:
- return context_ws
-
- async def _get_root_value() -> Any:
- return root_value
-
- preferred_protocol = self.pick_preferred_protocol(socket)
- if preferred_protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL:
- await self.graphql_transport_ws_handler_class(
- schema=self.schema,
- debug=self.debug,
- connection_init_wait_timeout=self.connection_init_wait_timeout,
- get_context=_get_context,
- get_root_value=_get_root_value,
- ws=socket,
- ).handle()
- elif preferred_protocol == GRAPHQL_WS_PROTOCOL:
- await self.graphql_ws_handler_class(
- schema=self.schema,
- debug=self.debug,
- keep_alive=self.keep_alive,
- keep_alive_interval=self.keep_alive_interval,
- get_context=_get_context,
- get_root_value=_get_root_value,
- ws=socket,
- ).handle()
- else:
- await socket.close(code=WS_4406_PROTOCOL_NOT_ACCEPTABLE)
-
- def pick_preferred_protocol(self, socket: WebSocket) -> str | None:
- protocols: List[str] = socket.scope["subprotocols"]
- intersection: Set[str] = set(protocols) & set(self.protocols)
- return (
- min(
- intersection,
- key=lambda i: protocols.index(i) if i else "",
- default=None,
- )
- or None
- )
-
def make_graphql_controller(
schema: BaseSchema,
@@ -410,6 +401,7 @@ def make_graphql_controller(
GRAPHQL_WS_PROTOCOL,
),
connection_init_wait_timeout: timedelta = timedelta(minutes=1),
+ multipart_uploads_enabled: bool = False,
) -> Type[GraphQLController]: # sourcery skip: move-assign
if context_getter is None:
custom_context_getter_ = _none_custom_context_getter
@@ -456,6 +448,7 @@ class _GraphQLController(GraphQLController):
_GraphQLController.schema = schema_
_GraphQLController.allow_queries_via_get = allow_queries_via_get_
_GraphQLController.graphql_ide = graphql_ide_
+ _GraphQLController.multipart_uploads_enabled = multipart_uploads_enabled
return _GraphQLController
diff --git a/strawberry/litestar/handlers/graphql_transport_ws_handler.py b/strawberry/litestar/handlers/graphql_transport_ws_handler.py
deleted file mode 100644
index b5aa915d08..0000000000
--- a/strawberry/litestar/handlers/graphql_transport_ws_handler.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from collections.abc import Callable
-from datetime import timedelta
-from typing import Any
-
-from litestar import WebSocket
-from litestar.exceptions import SerializationException, WebSocketDisconnect
-from strawberry.schema import BaseSchema
-from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL
-from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import (
- BaseGraphQLTransportWSHandler,
-)
-
-
-class GraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
- def __init__(
- self,
- schema: BaseSchema,
- debug: bool,
- connection_init_wait_timeout: timedelta,
- get_context: Callable,
- get_root_value: Callable,
- ws: WebSocket,
- ) -> None:
- super().__init__(schema, debug, connection_init_wait_timeout)
- self._get_context = get_context
- self._get_root_value = get_root_value
- self._ws = ws
-
- async def get_context(self) -> Any:
- return await self._get_context()
-
- async def get_root_value(self) -> Any:
- return await self._get_root_value()
-
- async def send_json(self, data: dict) -> None:
- await self._ws.send_json(data)
-
- async def close(self, code: int, reason: str) -> None:
- await self._ws.close(code=code, reason=reason)
-
- async def handle_request(self) -> None:
- await self._ws.accept(subprotocols=GRAPHQL_TRANSPORT_WS_PROTOCOL)
- self.on_request_accepted()
-
- try:
- while self._ws.connection_state != "disconnect":
- try:
- message = await self._ws.receive_json()
- except (SerializationException, ValueError): # noqa: PERF203
- error_message = "WebSocket message type must be text"
- await self.handle_invalid_message(error_message)
- else:
- await self.handle_message(message)
- except WebSocketDisconnect: # pragma: no cover
- pass
- finally:
- await self.shutdown()
-
-
-__all__ = ["GraphQLTransportWSHandler"]
diff --git a/strawberry/litestar/handlers/graphql_ws_handler.py b/strawberry/litestar/handlers/graphql_ws_handler.py
deleted file mode 100644
index ada421922f..0000000000
--- a/strawberry/litestar/handlers/graphql_ws_handler.py
+++ /dev/null
@@ -1,66 +0,0 @@
-from collections.abc import Callable
-from contextlib import suppress
-from typing import Any, Optional
-
-from litestar import WebSocket
-from litestar.exceptions import SerializationException, WebSocketDisconnect
-from strawberry.schema import BaseSchema
-from strawberry.subscriptions import GRAPHQL_WS_PROTOCOL
-from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler
-from strawberry.subscriptions.protocols.graphql_ws.types import OperationMessage
-
-
-class GraphQLWSHandler(BaseGraphQLWSHandler):
- def __init__(
- self,
- schema: BaseSchema,
- debug: bool,
- keep_alive: bool,
- keep_alive_interval: float,
- get_context: Callable,
- get_root_value: Callable,
- ws: WebSocket,
- ) -> None:
- super().__init__(schema, debug, keep_alive, keep_alive_interval)
- self._get_context = get_context
- self._get_root_value = get_root_value
- self._ws = ws
-
- async def get_context(self) -> Any:
- return await self._get_context()
-
- async def get_root_value(self) -> Any:
- return await self._get_root_value()
-
- async def send_json(self, data: OperationMessage) -> None:
- await self._ws.send_json(data)
-
- async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
- await self._ws.close(code=code, reason=reason)
-
- async def handle_request(self) -> Any:
- await self._ws.accept(subprotocols=GRAPHQL_WS_PROTOCOL)
-
- try:
- while self._ws.connection_state != "disconnect":
- try:
- message = await self._ws.receive_json()
- except (SerializationException, ValueError): # noqa: PERF203
- await self.close(
- code=1002, reason="WebSocket message type must be text"
- )
- else:
- await self.handle_message(message)
- except WebSocketDisconnect: # pragma: no cover
- pass
- finally:
- if self.keep_alive_task:
- self.keep_alive_task.cancel()
- with suppress(BaseException):
- await self.keep_alive_task
-
- for operation_id in list(self.subscriptions.keys()):
- await self.cleanup_operation(operation_id)
-
-
-__all__ = ["GraphQLWSHandler"]
diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py
index f9db21a01d..528a987abc 100644
--- a/strawberry/quart/views.py
+++ b/strawberry/quart/views.py
@@ -1,6 +1,7 @@
import warnings
from collections.abc import Mapping
from typing import TYPE_CHECKING, AsyncGenerator, Callable, Dict, Optional, cast
+from typing_extensions import TypeGuard
from quart import Request, Response, request
from quart.views import View
@@ -46,11 +47,11 @@ async def get_form_data(self) -> FormData:
class GraphQLView(
- AsyncBaseHTTPView[Request, Response, Response, Context, RootValue],
+ AsyncBaseHTTPView[
+ Request, Response, Response, Request, Response, Context, RootValue
+ ],
View,
):
- _ide_subscription_enabled = False
-
methods = ["GET", "POST"]
allow_queries_via_get: bool = True
request_adapter_class = QuartHTTPRequestAdapter
@@ -61,9 +62,11 @@ def __init__(
graphiql: Optional[bool] = None,
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
+ multipart_uploads_enabled: bool = False,
) -> None:
self.schema = schema
self.allow_queries_via_get = allow_queries_via_get
+ self.multipart_uploads_enabled = multipart_uploads_enabled
if graphiql is not None:
warnings.warn(
@@ -119,5 +122,16 @@ async def create_streaming_response(
},
)
+ def is_websocket_request(self, request: Request) -> TypeGuard[Request]:
+ return False
+
+ async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]:
+ raise NotImplementedError
+
+ async def create_websocket_response(
+ self, request: Request, subprotocol: Optional[str]
+ ) -> Response:
+ raise NotImplementedError
+
__all__ = ["GraphQLView"]
diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py
index edb30075f6..ee76d2e946 100644
--- a/strawberry/sanic/views.py
+++ b/strawberry/sanic/views.py
@@ -13,6 +13,7 @@
Type,
cast,
)
+from typing_extensions import TypeGuard
from sanic.request import Request
from sanic.response import HTTPResponse, html
@@ -71,7 +72,15 @@ async def get_form_data(self) -> FormData:
class GraphQLView(
- AsyncBaseHTTPView[Request, HTTPResponse, TemporalResponse, Context, RootValue],
+ AsyncBaseHTTPView[
+ Request,
+ HTTPResponse,
+ TemporalResponse,
+ Request,
+ TemporalResponse,
+ Context,
+ RootValue,
+ ],
HTTPMethodView,
):
"""Class based view to handle GraphQL HTTP Requests.
@@ -102,11 +111,13 @@ def __init__(
allow_queries_via_get: bool = True,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
json_dumps_params: Optional[Dict[str, Any]] = None,
+ multipart_uploads_enabled: bool = False,
) -> None:
self.schema = schema
self.allow_queries_via_get = allow_queries_via_get
self.json_encoder = json_encoder
self.json_dumps_params = json_dumps_params
+ self.multipart_uploads_enabled = multipart_uploads_enabled
if self.json_encoder is not None: # pragma: no cover
warnings.warn(
@@ -204,5 +215,16 @@ async def create_streaming_response(
# corner case
return None # type: ignore
+ def is_websocket_request(self, request: Request) -> TypeGuard[Request]:
+ return False
+
+ async def pick_websocket_subprotocol(self, request: Request) -> Optional[str]:
+ raise NotImplementedError
+
+ async def create_websocket_response(
+ self, request: Request, subprotocol: Optional[str]
+ ) -> TemporalResponse:
+ raise NotImplementedError
+
__all__ = ["GraphQLView"]
diff --git a/strawberry/static/graphiql.html b/strawberry/static/graphiql.html
index 95e34c1709..b66082a97f 100644
--- a/strawberry/static/graphiql.html
+++ b/strawberry/static/graphiql.html
@@ -131,10 +131,7 @@
headers["x-csrftoken"] = csrfToken;
}
- const subscriptionsEnabled = JSON.parse("{{ SUBSCRIPTION_ENABLED }}");
- const subscriptionUrl = subscriptionsEnabled
- ? httpUrlToWebSockeUrl(fetchURL)
- : null;
+ const subscriptionUrl = httpUrlToWebSockeUrl(fetchURL);
const fetcher = GraphiQL.createFetcher({
url: fetchURL,
diff --git a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py
index 7d19db8e98..f74e9def3d 100644
--- a/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py
+++ b/strawberry/subscriptions/protocols/graphql_transport_ws/handlers.py
@@ -2,7 +2,6 @@
import asyncio
import logging
-from abc import ABC, abstractmethod
from contextlib import suppress
from typing import (
TYPE_CHECKING,
@@ -16,6 +15,7 @@
from graphql import GraphQLError, GraphQLSyntaxError, parse
+from strawberry.http.exceptions import NonJsonMessageReceived
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
CompleteMessage,
ConnectionAckMessage,
@@ -38,6 +38,7 @@
if TYPE_CHECKING:
from datetime import timedelta
+ from strawberry.http.async_base_view import AsyncWebSocketAdapter
from strawberry.schema import BaseSchema
from strawberry.schema.subscribe import SubscriptionResult
from strawberry.subscriptions.protocols.graphql_transport_ws.types import (
@@ -45,15 +46,21 @@
)
-class BaseGraphQLTransportWSHandler(ABC):
+class BaseGraphQLTransportWSHandler:
task_logger: logging.Logger = logging.getLogger("strawberry.ws.task")
def __init__(
self,
+ websocket: AsyncWebSocketAdapter,
+ context: object,
+ root_value: object,
schema: BaseSchema,
debug: bool,
connection_init_wait_timeout: timedelta,
) -> None:
+ self.websocket = websocket
+ self.context = context
+ self.root_value = root_value
self.schema = schema
self.debug = debug
self.connection_init_wait_timeout = connection_init_wait_timeout
@@ -65,28 +72,16 @@ def __init__(
self.completed_tasks: List[asyncio.Task] = []
self.connection_params: Optional[Dict[str, Any]] = None
- @abstractmethod
- async def get_context(self) -> Any:
- """Return the operations context."""
-
- @abstractmethod
- async def get_root_value(self) -> Any:
- """Return the schemas root value."""
-
- @abstractmethod
- async def send_json(self, data: dict) -> None:
- """Send the data JSON encoded to the WebSocket client."""
-
- @abstractmethod
- async def close(self, code: int, reason: str) -> None:
- """Close the WebSocket with the passed code and reason."""
-
- @abstractmethod
- async def handle_request(self) -> Any:
- """Handle the request this instance was created for."""
-
async def handle(self) -> Any:
- return await self.handle_request()
+ self.on_request_accepted()
+
+ try:
+ async for message in self.websocket.iter_json():
+ await self.handle_message(message)
+ except NonJsonMessageReceived:
+ await self.handle_invalid_message("WebSocket message type must be text")
+ finally:
+ await self.shutdown()
async def shutdown(self) -> None:
if self.connection_init_timeout_task:
@@ -118,7 +113,7 @@ async def handle_connection_init_timeout(self) -> None:
self.connection_timed_out = True
reason = "Connection initialisation timeout"
- await self.close(code=4408, reason=reason)
+ await self.websocket.close(code=4408, reason=reason)
except Exception as error:
await self.handle_task_exception(error) # pragma: no cover
finally:
@@ -189,14 +184,16 @@ async def handle_connection_init(self, message: ConnectionInitMessage) -> None:
)
if not isinstance(payload, dict):
- await self.close(code=4400, reason="Invalid connection init payload")
+ await self.websocket.close(
+ code=4400, reason="Invalid connection init payload"
+ )
return
self.connection_params = payload
if self.connection_init_received:
reason = "Too many initialisation requests"
- await self.close(code=4429, reason=reason)
+ await self.websocket.close(code=4429, reason=reason)
return
self.connection_init_received = True
@@ -211,13 +208,13 @@ async def handle_pong(self, message: PongMessage) -> None:
async def handle_subscribe(self, message: SubscribeMessage) -> None:
if not self.connection_acknowledged:
- await self.close(code=4401, reason="Unauthorized")
+ await self.websocket.close(code=4401, reason="Unauthorized")
return
try:
graphql_document = parse(message.payload.query)
except GraphQLSyntaxError as exc:
- await self.close(code=4400, reason=exc.message)
+ await self.websocket.close(code=4400, reason=exc.message)
return
try:
@@ -225,12 +222,14 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:
graphql_document, message.payload.operationName
)
except RuntimeError:
- await self.close(code=4400, reason="Can't get GraphQL operation type")
+ await self.websocket.close(
+ code=4400, reason="Can't get GraphQL operation type"
+ )
return
if message.id in self.operations:
reason = f"Subscriber for {message.id} already exists"
- await self.close(code=4409, reason=reason)
+ await self.websocket.close(code=4409, reason=reason)
return
if self.debug: # pragma: no cover
@@ -240,26 +239,28 @@ async def handle_subscribe(self, message: SubscribeMessage) -> None:
message.payload.variables,
)
- context = await self.get_context()
- if isinstance(context, dict):
- context["connection_params"] = self.connection_params
- root_value = await self.get_root_value()
+ if isinstance(self.context, dict):
+ self.context["connection_params"] = self.connection_params
+ elif hasattr(self.context, "connection_params"):
+ self.context.connection_params = self.connection_params
+
result_source: Awaitable[ExecutionResult] | Awaitable[SubscriptionResult]
+
# Get an AsyncGenerator yielding the results
if operation_type == OperationType.SUBSCRIPTION:
result_source = self.schema.subscribe(
query=message.payload.query,
variable_values=message.payload.variables,
operation_name=message.payload.operationName,
- context_value=context,
- root_value=root_value,
+ context_value=self.context,
+ root_value=self.root_value,
)
else:
result_source = self.schema.execute(
query=message.payload.query,
variable_values=message.payload.variables,
- context_value=context,
- root_value=root_value,
+ context_value=self.context,
+ root_value=self.root_value,
operation_name=message.payload.operationName,
)
@@ -312,11 +313,11 @@ async def handle_complete(self, message: CompleteMessage) -> None:
await self.cleanup_operation(operation_id=message.id)
async def handle_invalid_message(self, error_message: str) -> None:
- await self.close(code=4400, reason=error_message)
+ await self.websocket.close(code=4400, reason=error_message)
async def send_message(self, message: GraphQLTransportMessage) -> None:
data = message.as_dict()
- await self.send_json(data)
+ await self.websocket.send_json(data)
async def cleanup_operation(self, operation_id: str) -> None:
if operation_id not in self.operations:
diff --git a/strawberry/subscriptions/protocols/graphql_ws/handlers.py b/strawberry/subscriptions/protocols/graphql_ws/handlers.py
index 0451e0934b..fda3db829f 100644
--- a/strawberry/subscriptions/protocols/graphql_ws/handlers.py
+++ b/strawberry/subscriptions/protocols/graphql_ws/handlers.py
@@ -1,11 +1,9 @@
from __future__ import annotations
import asyncio
-from abc import ABC, abstractmethod
from contextlib import suppress
from typing import (
TYPE_CHECKING,
- Any,
AsyncGenerator,
Awaitable,
Dict,
@@ -13,6 +11,7 @@
cast,
)
+from strawberry.http.exceptions import NonJsonMessageReceived
from strawberry.subscriptions.protocols.graphql_ws import (
GQL_COMPLETE,
GQL_CONNECTION_ACK,
@@ -25,29 +24,36 @@
GQL_START,
GQL_STOP,
)
+from strawberry.subscriptions.protocols.graphql_ws.types import (
+ ConnectionInitPayload,
+ DataPayload,
+ OperationMessage,
+ OperationMessagePayload,
+ StartPayload,
+)
from strawberry.types.execution import ExecutionResult, PreExecutionError
from strawberry.utils.debug import pretty_print_graphql_operation
if TYPE_CHECKING:
+ from strawberry.http.async_base_view import AsyncWebSocketAdapter
from strawberry.schema import BaseSchema
from strawberry.schema.subscribe import SubscriptionResult
- from strawberry.subscriptions.protocols.graphql_ws.types import (
- ConnectionInitPayload,
- DataPayload,
- OperationMessage,
- OperationMessagePayload,
- StartPayload,
- )
-class BaseGraphQLWSHandler(ABC):
+class BaseGraphQLWSHandler:
def __init__(
self,
+ websocket: AsyncWebSocketAdapter,
+ context: object,
+ root_value: object,
schema: BaseSchema,
debug: bool,
keep_alive: bool,
- keep_alive_interval: float,
+ keep_alive_interval: Optional[float],
) -> None:
+ self.websocket = websocket
+ self.context = context
+ self.root_value = root_value
self.schema = schema
self.debug = debug
self.keep_alive = keep_alive
@@ -57,28 +63,22 @@ def __init__(
self.tasks: Dict[str, asyncio.Task] = {}
self.connection_params: Optional[ConnectionInitPayload] = None
- @abstractmethod
- async def get_context(self) -> Any:
- """Return the operations context."""
-
- @abstractmethod
- async def get_root_value(self) -> Any:
- """Return the schemas root value."""
-
- @abstractmethod
- async def send_json(self, data: OperationMessage) -> None:
- """Send the data JSON encoded to the WebSocket client."""
-
- @abstractmethod
- async def close(self, code: int = 1000, reason: Optional[str] = None) -> None:
- """Close the WebSocket with the passed code and reason."""
-
- @abstractmethod
- async def handle_request(self) -> Any:
- """Handle the request this instance was created for."""
-
- async def handle(self) -> Any:
- return await self.handle_request()
+ async def handle(self) -> None:
+ try:
+ async for message in self.websocket.iter_json():
+ await self.handle_message(cast(OperationMessage, message))
+ except NonJsonMessageReceived:
+ await self.websocket.close(
+ code=1002, reason="WebSocket message type must be text"
+ )
+ finally:
+ if self.keep_alive_task:
+ self.keep_alive_task.cancel()
+ with suppress(BaseException):
+ await self.keep_alive_task
+
+ for operation_id in list(self.subscriptions.keys()):
+ await self.cleanup_operation(operation_id)
async def handle_message(
self,
@@ -99,22 +99,22 @@ async def handle_connection_init(self, message: OperationMessage) -> None:
payload = message.get("payload")
if payload is not None and not isinstance(payload, dict):
error_message: OperationMessage = {"type": GQL_CONNECTION_ERROR}
- await self.send_json(error_message)
- await self.close()
+ await self.websocket.send_json(error_message)
+ await self.websocket.close(code=1000, reason="")
return
payload = cast(Optional["ConnectionInitPayload"], payload)
self.connection_params = payload
acknowledge_message: OperationMessage = {"type": GQL_CONNECTION_ACK}
- await self.send_json(acknowledge_message)
+ await self.websocket.send_json(acknowledge_message)
if self.keep_alive:
keep_alive_handler = self.handle_keep_alive()
self.keep_alive_task = asyncio.create_task(keep_alive_handler)
async def handle_connection_terminate(self, message: OperationMessage) -> None:
- await self.close()
+ await self.websocket.close(code=1000, reason="")
async def handle_start(self, message: OperationMessage) -> None:
operation_id = message["id"]
@@ -123,10 +123,10 @@ async def handle_start(self, message: OperationMessage) -> None:
operation_name = payload.get("operationName")
variables = payload.get("variables")
- context = await self.get_context()
- if isinstance(context, dict):
- context["connection_params"] = self.connection_params
- root_value = await self.get_root_value()
+ if isinstance(self.context, dict):
+ self.context["connection_params"] = self.connection_params
+ elif hasattr(self.context, "connection_params"):
+ self.context.connection_params = self.connection_params
if self.debug:
pretty_print_graphql_operation(operation_name, query, variables)
@@ -135,8 +135,8 @@ async def handle_start(self, message: OperationMessage) -> None:
query=query,
variable_values=variables,
operation_name=operation_name,
- context_value=context,
- root_value=root_value,
+ context_value=self.context,
+ root_value=self.root_value,
)
result_handler = self.handle_async_results(result_source, operation_id)
@@ -147,9 +147,10 @@ async def handle_stop(self, message: OperationMessage) -> None:
await self.cleanup_operation(operation_id)
async def handle_keep_alive(self) -> None:
+ assert self.keep_alive_interval
while True:
data: OperationMessage = {"type": GQL_CONNECTION_KEEP_ALIVE}
- await self.send_json(data)
+ await self.websocket.send_json(data)
await asyncio.sleep(self.keep_alive_interval)
async def handle_async_results(
@@ -191,7 +192,7 @@ async def send_message(
data: OperationMessage = {"type": type_, "id": operation_id}
if payload is not None:
data["payload"] = payload
- await self.send_json(data)
+ await self.websocket.send_json(data)
async def send_data(
self, execution_result: ExecutionResult, operation_id: str
diff --git a/strawberry/test/client.py b/strawberry/test/client.py
index 9127799763..243ac3aadb 100644
--- a/strawberry/test/client.py
+++ b/strawberry/test/client.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import json
+import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Coroutine, Dict, List, Mapping, Optional, Union
@@ -36,8 +37,9 @@ def query(
query: str,
variables: Optional[Dict[str, Mapping]] = None,
headers: Optional[Dict[str, object]] = None,
- asserts_errors: Optional[bool] = True,
+ asserts_errors: Optional[bool] = None,
files: Optional[Dict[str, object]] = None,
+ assert_no_errors: Optional[bool] = True,
) -> Union[Coroutine[Any, Any, Response], Response]:
body = self._build_body(query, variables, files)
@@ -49,7 +51,19 @@ def query(
data=data.get("data"),
extensions=data.get("extensions"),
)
- if asserts_errors:
+
+ if asserts_errors is not None:
+ warnings.warn(
+ "The `asserts_errors` argument has been renamed to `assert_no_errors`",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ assert_no_errors = (
+ assert_no_errors if asserts_errors is None else asserts_errors
+ )
+
+ if assert_no_errors:
assert response.errors is None
return response
diff --git a/tests/aiohttp/__init__.py b/tests/aiohttp/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/aiohttp/app.py b/tests/aiohttp/app.py
deleted file mode 100644
index ba43cea8fc..0000000000
--- a/tests/aiohttp/app.py
+++ /dev/null
@@ -1,46 +0,0 @@
-from typing import Any
-
-from aiohttp import web
-from strawberry.aiohttp.handlers import GraphQLTransportWSHandler, GraphQLWSHandler
-from strawberry.aiohttp.views import GraphQLView
-from tests.views.schema import Query, schema
-
-
-class DebuggableGraphQLTransportWSHandler(GraphQLTransportWSHandler):
- def get_tasks(self) -> list:
- return [op.task for op in self.operations.values()]
-
- async def get_context(self) -> object:
- context = await super().get_context()
- context["ws"] = self._ws
- context["get_tasks"] = self.get_tasks
- context["connectionInitTimeoutTask"] = self.connection_init_timeout_task
- return context
-
-
-class DebuggableGraphQLWSHandler(GraphQLWSHandler):
- def get_tasks(self) -> list:
- return list(self.tasks.values())
-
- async def get_context(self) -> object:
- context = await super().get_context()
- context["ws"] = self._ws
- context["get_tasks"] = self.get_tasks
- context["connectionInitTimeoutTask"] = None
- return context
-
-
-class MyGraphQLView(GraphQLView):
- graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler
- graphql_ws_handler_class = DebuggableGraphQLWSHandler
-
- async def get_root_value(self, request: web.Request) -> Query:
- await super().get_root_value(request) # for coverage
- return Query()
-
-
-def create_app(**kwargs: Any) -> web.Application:
- app = web.Application()
- app.router.add_route("*", "/graphql", MyGraphQLView(schema=schema, **kwargs))
-
- return app
diff --git a/tests/aiohttp/test_websockets.py b/tests/aiohttp/test_websockets.py
deleted file mode 100644
index 82d6a5db00..0000000000
--- a/tests/aiohttp/test_websockets.py
+++ /dev/null
@@ -1,110 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Awaitable, Callable
-
-from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
-
-if TYPE_CHECKING:
- from aiohttp.test_utils import TestClient
-
-
-async def test_turning_off_graphql_ws(
- aiohttp_client: Callable[..., Awaitable[TestClient]],
-) -> None:
- from .app import create_app
-
- app = create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL])
- aiohttp_app_client = await aiohttp_client(app)
-
- async with aiohttp_app_client.ws_connect(
- "/graphql", protocols=[GRAPHQL_WS_PROTOCOL]
- ) as ws:
- data = await ws.receive(timeout=2)
- assert ws.protocol is None
- assert ws.closed
- assert ws.close_code == 4406
- assert data.extra == "Subprotocol not acceptable"
-
-
-async def test_turning_off_graphql_transport_ws(
- aiohttp_client: Callable[..., Awaitable[TestClient]],
-):
- from .app import create_app
-
- app = create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL])
- aiohttp_app_client = await aiohttp_client(app)
-
- async with aiohttp_app_client.ws_connect(
- "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
- ) as ws:
- data = await ws.receive(timeout=2)
- assert ws.protocol is None
- assert ws.closed
- assert ws.close_code == 4406
- assert data.extra == "Subprotocol not acceptable"
-
-
-async def test_turning_off_all_ws_protocols(
- aiohttp_client: Callable[..., Awaitable[TestClient]],
-):
- from .app import create_app
-
- app = create_app(subscription_protocols=[])
- aiohttp_app_client = await aiohttp_client(app)
-
- async with aiohttp_app_client.ws_connect(
- "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
- ) as ws:
- data = await ws.receive(timeout=2)
- assert ws.protocol is None
- assert ws.closed
- assert ws.close_code == 4406
- assert data.extra == "Subprotocol not acceptable"
-
- async with aiohttp_app_client.ws_connect(
- "/graphql", protocols=[GRAPHQL_WS_PROTOCOL]
- ) as ws:
- data = await ws.receive(timeout=2)
- assert ws.protocol is None
- assert ws.closed
- assert ws.close_code == 4406
- assert data.extra == "Subprotocol not acceptable"
-
-
-async def test_unsupported_ws_protocol(
- aiohttp_client: Callable[..., Awaitable[TestClient]],
-):
- from .app import create_app
-
- app = create_app(subscription_protocols=[])
- aiohttp_app_client = await aiohttp_client(app)
-
- async with aiohttp_app_client.ws_connect(
- "/graphql", protocols=["imaginary-protocol"]
- ) as ws:
- data = await ws.receive(timeout=2)
- assert ws.protocol is None
- assert ws.closed
- assert ws.close_code == 4406
- assert data.extra == "Subprotocol not acceptable"
-
-
-async def test_clients_can_prefer_protocols(
- aiohttp_client: Callable[..., Awaitable[TestClient]],
-):
- from .app import create_app
-
- app = create_app(
- subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL]
- )
- aiohttp_app_client = await aiohttp_client(app)
-
- async with aiohttp_app_client.ws_connect(
- "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL]
- ) as ws:
- assert ws.protocol == GRAPHQL_TRANSPORT_WS_PROTOCOL
-
- async with aiohttp_app_client.ws_connect(
- "/graphql", protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL]
- ) as ws:
- assert ws.protocol == GRAPHQL_WS_PROTOCOL
diff --git a/tests/asgi/app.py b/tests/asgi/app.py
deleted file mode 100644
index a179a3210f..0000000000
--- a/tests/asgi/app.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from typing import Any, Dict, Optional, Union
-
-from starlette.requests import Request
-from starlette.responses import Response
-from starlette.websockets import WebSocket
-
-from strawberry.asgi import GraphQL as BaseGraphQL
-from tests.views.schema import Query, schema
-
-
-class GraphQL(BaseGraphQL):
- async def get_root_value(self, request) -> Query:
- return Query()
-
- async def get_context(
- self,
- request: Union[Request, WebSocket],
- response: Optional[Response] = None,
- ) -> Dict[str, Union[Request, WebSocket, Response, str, None]]:
- return {"request": request, "response": response, "custom_value": "Hi"}
-
-
-def create_app(**kwargs: Any) -> GraphQL:
- return GraphQL(schema, **kwargs)
diff --git a/tests/asgi/test_websockets.py b/tests/asgi/test_websockets.py
deleted file mode 100644
index 511661358a..0000000000
--- a/tests/asgi/test_websockets.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import pytest
-
-from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
-
-
-def test_turning_off_graphql_ws():
- from starlette.testclient import TestClient
- from starlette.websockets import WebSocketDisconnect
-
- from tests.asgi.app import create_app
-
- app = create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL])
- test_client = TestClient(app)
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/", [GRAPHQL_WS_PROTOCOL]):
- pass
-
- assert exc.value.code == 4406
-
-
-def test_turning_off_graphql_transport_ws():
- from starlette.testclient import TestClient
- from starlette.websockets import WebSocketDisconnect
-
- from tests.asgi.app import create_app
-
- app = create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL])
- test_client = TestClient(app)
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/", [GRAPHQL_TRANSPORT_WS_PROTOCOL]):
- pass
-
- assert exc.value.code == 4406
-
-
-def test_turning_off_all_ws_protocols():
- from starlette.testclient import TestClient
- from starlette.websockets import WebSocketDisconnect
-
- from tests.asgi.app import create_app
-
- app = create_app(subscription_protocols=[])
- test_client = TestClient(app)
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/", [GRAPHQL_TRANSPORT_WS_PROTOCOL]):
- pass
-
- assert exc.value.code == 4406
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/", [GRAPHQL_WS_PROTOCOL]):
- pass
-
- assert exc.value.code == 4406
-
-
-def test_unsupported_ws_protocol():
- from starlette.testclient import TestClient
- from starlette.websockets import WebSocketDisconnect
-
- from tests.asgi.app import create_app
-
- app = create_app(subscription_protocols=[])
- test_client = TestClient(app)
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/", ["imaginary-protocol"]):
- pass
-
- assert exc.value.code == 4406
-
-
-def test_clients_can_prefer_protocols():
- from starlette.testclient import TestClient
-
- from tests.asgi.app import create_app
-
- app = create_app(
- subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL]
- )
- test_client = TestClient(app)
-
- with test_client.websocket_connect(
- "/", [GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL]
- ) as ws:
- assert ws.accepted_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL
-
- with test_client.websocket_connect(
- "/", [GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL]
- ) as ws:
- assert ws.accepted_subprotocol == GRAPHQL_WS_PROTOCOL
diff --git a/tests/channels/test_layers.py b/tests/channels/test_layers.py
index 40a3eed7b6..8db1205fdf 100644
--- a/tests/channels/test_layers.py
+++ b/tests/channels/test_layers.py
@@ -1,7 +1,7 @@
from __future__ import annotations
import asyncio
-from typing import TYPE_CHECKING, Generator
+from typing import TYPE_CHECKING, AsyncGenerator
import pytest
@@ -21,7 +21,7 @@
@pytest.fixture
-async def ws() -> Generator[WebsocketCommunicator, None, None]:
+async def ws() -> AsyncGenerator[WebsocketCommunicator, None]:
from channels.testing import WebsocketCommunicator
from strawberry.channels import GraphQLWSConsumer
diff --git a/tests/channels/test_testing.py b/tests/channels/test_testing.py
index 921d8abf35..99aa9dd6c8 100644
--- a/tests/channels/test_testing.py
+++ b/tests/channels/test_testing.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Generator
+from typing import TYPE_CHECKING, Any, AsyncGenerator
import pytest
@@ -14,7 +14,7 @@
@pytest.fixture(params=[GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL])
async def communicator(
request: Any,
-) -> Generator[GraphQLWebsocketCommunicator, None, None]:
+) -> AsyncGenerator[GraphQLWebsocketCommunicator, None]:
from strawberry.channels import GraphQLWSConsumer
from strawberry.channels.testing import GraphQLWebsocketCommunicator
diff --git a/tests/channels/test_ws_handler.py b/tests/channels/test_ws_handler.py
deleted file mode 100644
index 88310bb617..0000000000
--- a/tests/channels/test_ws_handler.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import pytest
-
-from tests.views.schema import schema
-
-try:
- from channels.testing.websocket import WebsocketCommunicator
- from strawberry.channels.handlers.graphql_transport_ws_handler import (
- GraphQLTransportWSHandler,
- )
- from strawberry.channels.handlers.graphql_ws_handler import GraphQLWSHandler
- from strawberry.channels.handlers.ws_handler import GraphQLWSConsumer
-except ImportError:
- pytestmark = pytest.mark.skip("Channels is not installed")
- GraphQLWSHandler = None
- GraphQLTransportWSHandler = None
-
-from strawberry.subscriptions import (
- GRAPHQL_TRANSPORT_WS_PROTOCOL,
- GRAPHQL_WS_PROTOCOL,
-)
-
-
-async def test_wrong_protocol():
- GraphQLWSConsumer.as_asgi(schema=schema)
- client = WebsocketCommunicator(
- GraphQLWSConsumer.as_asgi(schema=schema),
- "/graphql",
- subprotocols=[
- "non-existing",
- ],
- )
- res = await client.connect()
- assert res == (False, 4406)
-
-
-@pytest.mark.parametrize(
- ("protocol", "handler"),
- [
- (GRAPHQL_TRANSPORT_WS_PROTOCOL, GraphQLTransportWSHandler),
- (GRAPHQL_WS_PROTOCOL, GraphQLWSHandler),
- ],
-)
-async def test_correct_protocol(protocol, handler):
- consumer = GraphQLWSConsumer(schema=schema)
- client = WebsocketCommunicator(
- consumer,
- "/graphql",
- subprotocols=[
- protocol,
- ],
- )
- res = await client.connect()
- assert res == (True, protocol)
- assert isinstance(consumer._handler, handler)
diff --git a/tests/codegen/conftest.py b/tests/codegen/conftest.py
index 1b7f80e467..d776a474fd 100644
--- a/tests/codegen/conftest.py
+++ b/tests/codegen/conftest.py
@@ -101,6 +101,7 @@ class Query:
person: Person
optional_person: Optional[Person]
list_of_people: List[Person]
+ optional_list_of_people: Optional[List[Person]]
enum: Color
json: JSON
union: PersonOrAnimal
diff --git a/tests/codegen/queries/nullable_list_of_non_scalars.graphql b/tests/codegen/queries/nullable_list_of_non_scalars.graphql
new file mode 100644
index 0000000000..8bdedbe6a2
--- /dev/null
+++ b/tests/codegen/queries/nullable_list_of_non_scalars.graphql
@@ -0,0 +1,6 @@
+query OperationName {
+ optionalListOfPeople {
+ name
+ age
+ }
+}
diff --git a/tests/codegen/snapshots/python/nullable_list_of_non_scalars.py b/tests/codegen/snapshots/python/nullable_list_of_non_scalars.py
new file mode 100644
index 0000000000..f5efcaa7f8
--- /dev/null
+++ b/tests/codegen/snapshots/python/nullable_list_of_non_scalars.py
@@ -0,0 +1,8 @@
+from typing import List, Optional
+
+class OperationNameResultOptionalListOfPeople:
+ name: str
+ age: int
+
+class OperationNameResult:
+ optional_list_of_people: Optional[List[OperationNameResultOptionalListOfPeople]]
diff --git a/tests/codegen/snapshots/typescript/nullable_list_of_non_scalars.ts b/tests/codegen/snapshots/typescript/nullable_list_of_non_scalars.ts
new file mode 100644
index 0000000000..7498cac29d
--- /dev/null
+++ b/tests/codegen/snapshots/typescript/nullable_list_of_non_scalars.ts
@@ -0,0 +1,8 @@
+type OperationNameResultOptionalListOfPeople = {
+ name: string
+ age: number
+}
+
+type OperationNameResult = {
+ optional_list_of_people: OperationNameResultOptionalListOfPeople[] | undefined
+}
diff --git a/tests/django/test_graphql_test_client.py b/tests/django/test_graphql_test_client.py
deleted file mode 100644
index 9a839b71b2..0000000000
--- a/tests/django/test_graphql_test_client.py
+++ /dev/null
@@ -1,7 +0,0 @@
-def test_assertion_error_not_raised_when_asserts_errors_is_false(graphql_client):
- query = "{ }"
-
- try:
- graphql_client.query(query, asserts_errors=False)
- except AssertionError:
- raise AssertionError
diff --git a/tests/enums/test_enum.py b/tests/enums/test_enum.py
index 2189fba3d6..c0dea47813 100644
--- a/tests/enums/test_enum.py
+++ b/tests/enums/test_enum.py
@@ -4,7 +4,7 @@
import strawberry
from strawberry.exceptions import ObjectIsNotAnEnumError
-from strawberry.exceptions.not_a_strawberry_enum import NotAStrawberryEnumError
+from strawberry.types.base import get_object_definition
from strawberry.types.enum import EnumDefinition
@@ -120,25 +120,6 @@ class IceCreamFlavour(Enum):
assert definition.values[2].description is None
-@pytest.mark.raises_strawberry_exception(
- NotAStrawberryEnumError, match='Enum "IceCreamFlavour" is not a Strawberry enum'
-)
-def test_raises_error_when_using_enum_not_decorated():
- class IceCreamFlavour(Enum):
- VANILLA = strawberry.enum_value("vanilla")
- STRAWBERRY = strawberry.enum_value(
- "strawberry",
- description="Our favourite",
- )
- CHOCOLATE = "chocolate"
-
- @strawberry.type
- class Query:
- flavour: IceCreamFlavour
-
- strawberry.Schema(query=Query)
-
-
def test_can_use_enum_values():
@strawberry.enum
class TestEnum(Enum):
@@ -169,3 +150,53 @@ class TestEnum(IntEnum):
assert TestEnum.D.value == 4
assert [x.value for x in TestEnum.__members__.values()] == [1, 2, 3, 4]
+
+
+def test_default_enum_implementation() -> None:
+ class Foo(Enum):
+ BAR = "bar"
+ BAZ = "baz"
+
+ @strawberry.type
+ class Query:
+ @strawberry.field
+ def foo(self, foo: Foo) -> Foo:
+ return foo
+
+ schema = strawberry.Schema(Query)
+ res = schema.execute_sync("{ foo(foo: BAR) }")
+ assert not res.errors
+ assert res.data
+ assert res.data["foo"] == "BAR"
+
+
+def test_default_enum_reuse() -> None:
+ class Foo(Enum):
+ BAR = "bar"
+ BAZ = "baz"
+
+ @strawberry.type
+ class SomeType:
+ foo: Foo
+ bar: Foo
+
+ definition = get_object_definition(SomeType, strict=True)
+ assert definition.fields[1].type is definition.fields[1].type
+
+
+def test_default_enum_with_enum_value() -> None:
+ class Foo(Enum):
+ BAR = "bar"
+ BAZ = strawberry.enum_value("baz")
+
+ @strawberry.type
+ class Query:
+ @strawberry.field
+ def foo(self, foo: Foo) -> str:
+ return foo.value
+
+ schema = strawberry.Schema(Query)
+ res = schema.execute_sync("{ foo(foo: BAZ) }")
+ assert not res.errors
+ assert res.data
+ assert res.data["foo"] == "baz"
diff --git a/tests/fastapi/test_websockets.py b/tests/fastapi/test_websockets.py
deleted file mode 100644
index de729c9f32..0000000000
--- a/tests/fastapi/test_websockets.py
+++ /dev/null
@@ -1,125 +0,0 @@
-from typing import Any
-
-import pytest
-
-import strawberry
-from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
-
-
-def test_turning_off_graphql_ws():
- from starlette.testclient import TestClient
- from starlette.websockets import WebSocketDisconnect
-
- from tests.fastapi.app import create_app
-
- app = create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL])
- test_client = TestClient(app)
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/graphql", [GRAPHQL_WS_PROTOCOL]):
- pass
-
- assert exc.value.code == 4406
-
-
-def test_turning_off_graphql_transport_ws():
- from starlette.testclient import TestClient
- from starlette.websockets import WebSocketDisconnect
-
- from tests.fastapi.app import create_app
-
- app = create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL])
- test_client = TestClient(app)
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL]):
- pass
-
- assert exc.value.code == 4406
-
-
-def test_turning_off_all_ws_protocols():
- from starlette.testclient import TestClient
- from starlette.websockets import WebSocketDisconnect
-
- from tests.fastapi.app import create_app
-
- app = create_app(subscription_protocols=[])
- test_client = TestClient(app)
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL]):
- pass
-
- assert exc.value.code == 4406
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/graphql", [GRAPHQL_WS_PROTOCOL]):
- pass
-
- assert exc.value.code == 4406
-
-
-def test_unsupported_ws_protocol():
- from starlette.testclient import TestClient
- from starlette.websockets import WebSocketDisconnect
-
- from tests.fastapi.app import create_app
-
- app = create_app(subscription_protocols=[])
- test_client = TestClient(app)
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/graphql", ["imaginary-protocol"]):
- pass
-
- assert exc.value.code == 4406
-
-
-def test_clients_can_prefer_protocols():
- from starlette.testclient import TestClient
-
- from tests.fastapi.app import create_app
-
- app = create_app(
- subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL]
- )
- test_client = TestClient(app)
-
- with test_client.websocket_connect(
- "/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL]
- ) as ws:
- assert ws.accepted_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL
-
- with test_client.websocket_connect(
- "/graphql", [GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL]
- ) as ws:
- assert ws.accepted_subprotocol == GRAPHQL_WS_PROTOCOL
-
-
-def test_with_custom_encode_json():
- from starlette.testclient import TestClient
-
- from fastapi import FastAPI
- from strawberry.fastapi.router import GraphQLRouter
-
- @strawberry.type
- class Query:
- @strawberry.field
- def abc(self) -> str:
- return "abc"
-
- class MyRouter(GraphQLRouter[None, None]):
- def encode_json(self, response_data: Any):
- return '"custom"'
-
- app = FastAPI()
- schema = strawberry.Schema(query=Query)
- graphql_app = MyRouter(schema=schema)
- app.include_router(graphql_app, prefix="/graphql")
-
- test_client = TestClient(app)
- response = test_client.post("/graphql", json={"query": "{ abc }"})
-
- assert response.status_code == 200
- assert response.json() == "custom"
diff --git a/tests/http/clients/aiohttp.py b/tests/http/clients/aiohttp.py
index cd552e877c..89b0c718e8 100644
--- a/tests/http/clients/aiohttp.py
+++ b/tests/http/clients/aiohttp.py
@@ -10,7 +10,6 @@
from aiohttp.client_ws import ClientWebSocketResponse
from aiohttp.http_websocket import WSMsgType
from aiohttp.test_utils import TestClient, TestServer
-from strawberry.aiohttp.handlers import GraphQLTransportWSHandler, GraphQLWSHandler
from strawberry.aiohttp.views import GraphQLView as BaseGraphQLView
from strawberry.http import GraphQLHTTPResponse
from strawberry.http.ides import GraphQL_IDE
@@ -20,8 +19,8 @@
from ..context import get_context
from .base import (
JSON,
- DebuggableGraphQLTransportWSMixin,
- DebuggableGraphQLWSMixin,
+ DebuggableGraphQLTransportWSHandler,
+ DebuggableGraphQLWSHandler,
HttpClient,
Message,
Response,
@@ -30,16 +29,6 @@
)
-class DebuggableGraphQLTransportWSHandler(
- DebuggableGraphQLTransportWSMixin, GraphQLTransportWSHandler
-):
- pass
-
-
-class DebuggableGraphQLWSHandler(DebuggableGraphQLWSMixin, GraphQLWSHandler):
- pass
-
-
class GraphQLView(BaseGraphQLView):
result_override: ResultOverrideFunction = None
graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler
@@ -72,6 +61,7 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
result_override: ResultOverrideFunction = None,
+ multipart_uploads_enabled: bool = False,
):
view = GraphQLView(
schema=schema,
@@ -79,6 +69,7 @@ def __init__(
graphql_ide=graphql_ide,
allow_queries_via_get=allow_queries_via_get,
keep_alive=False,
+ multipart_uploads_enabled=multipart_uploads_enabled,
)
view.result_override = result_override
@@ -192,6 +183,9 @@ def __init__(self, ws: ClientWebSocketResponse):
self.ws = ws
self._reason: Optional[str] = None
+ async def send_text(self, payload: str) -> None:
+ await self.ws.send_str(payload)
+
async def send_json(self, payload: Dict[str, Any]) -> None:
await self.ws.send_json(payload)
@@ -211,6 +205,10 @@ async def receive_json(self, timeout: Optional[float] = None) -> Any:
async def close(self) -> None:
await self.ws.close()
+ @property
+ def accepted_subprotocol(self) -> Optional[str]:
+ return self.ws.protocol
+
@property
def closed(self) -> bool:
return self.ws.closed
@@ -220,5 +218,6 @@ def close_code(self) -> int:
assert self.ws.close_code is not None
return self.ws.close_code
- def assert_reason(self, reason: str) -> None:
- assert self._reason == reason
+ @property
+ def close_reason(self) -> Optional[str]:
+ return self._reason
diff --git a/tests/http/clients/asgi.py b/tests/http/clients/asgi.py
index 72d9e95aa6..7910e02f73 100644
--- a/tests/http/clients/asgi.py
+++ b/tests/http/clients/asgi.py
@@ -8,11 +8,10 @@
from starlette.requests import Request
from starlette.responses import Response as StarletteResponse
-from starlette.testclient import TestClient
+from starlette.testclient import TestClient, WebSocketTestSession
from starlette.websockets import WebSocket, WebSocketDisconnect
from strawberry.asgi import GraphQL as BaseGraphQLView
-from strawberry.asgi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler
from strawberry.http import GraphQLHTTPResponse
from strawberry.http.ides import GraphQL_IDE
from strawberry.types import ExecutionResult
@@ -21,8 +20,8 @@
from ..context import get_context
from .base import (
JSON,
- DebuggableGraphQLTransportWSMixin,
- DebuggableGraphQLWSMixin,
+ DebuggableGraphQLTransportWSHandler,
+ DebuggableGraphQLWSHandler,
HttpClient,
Message,
Response,
@@ -31,16 +30,6 @@
)
-class DebuggableGraphQLTransportWSHandler(
- DebuggableGraphQLTransportWSMixin, GraphQLTransportWSHandler
-):
- pass
-
-
-class DebuggableGraphQLWSHandler(DebuggableGraphQLWSMixin, GraphQLWSHandler):
- pass
-
-
class GraphQLView(BaseGraphQLView):
result_override: ResultOverrideFunction = None
graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler
@@ -74,6 +63,7 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
result_override: ResultOverrideFunction = None,
+ multipart_uploads_enabled: bool = False,
):
view = GraphQLView(
schema,
@@ -81,6 +71,7 @@ def __init__(
graphql_ide=graphql_ide,
allow_queries_via_get=allow_queries_via_get,
keep_alive=False,
+ multipart_uploads_enabled=multipart_uploads_enabled,
)
view.result_override = result_override
@@ -179,7 +170,7 @@ async def ws_connect(
class AsgiWebSocketClient(WebSocketClient):
- def __init__(self, ws: Any):
+ def __init__(self, ws: WebSocketTestSession):
self.ws = ws
self._closed: bool = False
self._close_code: Optional[int] = None
@@ -190,6 +181,9 @@ def handle_disconnect(self, exc: WebSocketDisconnect) -> None:
self._close_code = exc.code
self._close_reason = exc.reason
+ async def send_text(self, payload: str) -> None:
+ self.ws.send_text(payload)
+
async def send_json(self, payload: Dict[str, Any]) -> None:
self.ws.send_json(payload)
@@ -222,6 +216,10 @@ async def close(self) -> None:
self.ws.close()
self._closed = True
+ @property
+ def accepted_subprotocol(self) -> Optional[str]:
+ return self.ws.accepted_subprotocol
+
@property
def closed(self) -> bool:
return self._closed
@@ -231,5 +229,6 @@ def close_code(self) -> int:
assert self._close_code is not None
return self._close_code
- def assert_reason(self, reason: str) -> None:
- assert self._close_reason == reason
+ @property
+ def close_reason(self) -> Optional[str]:
+ return self._close_reason
diff --git a/tests/http/clients/async_django.py b/tests/http/clients/async_django.py
index 0e8bfca0ed..5f5caf4dd7 100644
--- a/tests/http/clients/async_django.py
+++ b/tests/http/clients/async_django.py
@@ -43,6 +43,7 @@ async def _do_request(self, request: RequestFactory) -> Response:
graphql_ide=self.graphql_ide,
allow_queries_via_get=self.allow_queries_via_get,
result_override=self.result_override,
+ multipart_uploads_enabled=self.multipart_uploads_enabled,
)
try:
diff --git a/tests/http/clients/async_flask.py b/tests/http/clients/async_flask.py
index 86e8e5b7f9..4db37c8d1b 100644
--- a/tests/http/clients/async_flask.py
+++ b/tests/http/clients/async_flask.py
@@ -52,6 +52,7 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
result_override: ResultOverrideFunction = None,
+ multipart_uploads_enabled: bool = False,
):
self.app = Flask(__name__)
self.app.debug = True
@@ -63,6 +64,7 @@ def __init__(
graphql_ide=graphql_ide,
allow_queries_via_get=allow_queries_via_get,
result_override=result_override,
+ multipart_uploads_enabled=multipart_uploads_enabled,
)
self.app.add_url_rule(
diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py
index daf1106359..c85b2efe00 100644
--- a/tests/http/clients/base.py
+++ b/tests/http/clients/base.py
@@ -21,6 +21,10 @@
from strawberry.http import GraphQLHTTPResponse
from strawberry.http.ides import GraphQL_IDE
+from strawberry.subscriptions.protocols.graphql_transport_ws.handlers import (
+ BaseGraphQLTransportWSHandler,
+)
+from strawberry.subscriptions.protocols.graphql_ws.handlers import BaseGraphQLWSHandler
from strawberry.types import ExecutionResult
logger = logging.getLogger("strawberry.test.http_client")
@@ -103,6 +107,7 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
result_override: ResultOverrideFunction = None,
+ multipart_uploads_enabled: bool = False,
): ...
@abc.abstractmethod
@@ -236,7 +241,7 @@ def create_app(self, **kwargs: Any) -> None:
"""For use by websocket tests."""
raise NotImplementedError
- async def ws_connect(
+ def ws_connect(
self,
url: str,
*,
@@ -259,6 +264,9 @@ class WebSocketClient(abc.ABC):
def name(self) -> str:
return ""
+ @abc.abstractmethod
+ async def send_text(self, payload: str) -> None: ...
+
@abc.abstractmethod
async def send_json(self, payload: Dict[str, Any]) -> None: ...
@@ -273,6 +281,10 @@ async def receive_json(self, timeout: Optional[float] = None) -> Any: ...
@abc.abstractmethod
async def close(self) -> None: ...
+ @property
+ @abc.abstractmethod
+ def accepted_subprotocol(self) -> Optional[str]: ...
+
@property
@abc.abstractmethod
def closed(self) -> bool: ...
@@ -281,15 +293,16 @@ def closed(self) -> bool: ...
@abc.abstractmethod
def close_code(self) -> int: ...
+ @property
@abc.abstractmethod
- def assert_reason(self, reason: str) -> None: ...
+ def close_reason(self) -> Optional[str]: ...
async def __aiter__(self) -> AsyncGenerator[Message, None]:
while not self.closed:
yield await self.receive()
-class DebuggableGraphQLTransportWSMixin:
+class DebuggableGraphQLTransportWSHandler(BaseGraphQLTransportWSHandler):
def on_init(self) -> None:
"""This method can be patched by unit tests to get the instance of the
transport handler when it is initialized.
@@ -297,26 +310,41 @@ def on_init(self) -> None:
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
- DebuggableGraphQLTransportWSMixin.on_init(self)
+ self.original_context = kwargs.get("context", {})
+ DebuggableGraphQLTransportWSHandler.on_init(self)
def get_tasks(self) -> List:
return [op.task for op in self.operations.values()]
- async def get_context(self) -> object:
- context = await super().get_context()
- context["ws"] = self._ws
- context["get_tasks"] = self.get_tasks
- context["connectionInitTimeoutTask"] = self.connection_init_timeout_task
- return context
+ @property
+ def context(self):
+ self.original_context["ws"] = self.websocket
+ self.original_context["get_tasks"] = self.get_tasks
+ self.original_context["connectionInitTimeoutTask"] = (
+ self.connection_init_timeout_task
+ )
+ return self.original_context
+ @context.setter
+ def context(self, value):
+ self.original_context = value
+
+
+class DebuggableGraphQLWSHandler(BaseGraphQLWSHandler):
+ def __init__(self, *args: Any, **kwargs: Any):
+ super().__init__(*args, **kwargs)
+ self.original_context = self.context
-class DebuggableGraphQLWSMixin:
def get_tasks(self) -> List:
return list(self.tasks.values())
- async def get_context(self) -> object:
- context = await super().get_context()
- context["ws"] = self._ws
- context["get_tasks"] = self.get_tasks
- context["connectionInitTimeoutTask"] = None
- return context
+ @property
+ def context(self):
+ self.original_context["ws"] = self.websocket
+ self.original_context["get_tasks"] = self.get_tasks
+ self.original_context["connectionInitTimeoutTask"] = None
+ return self.original_context
+
+ @context.setter
+ def context(self, value):
+ self.original_context = value
diff --git a/tests/http/clients/chalice.py b/tests/http/clients/chalice.py
index eddb7d8ada..2e01c1f400 100644
--- a/tests/http/clients/chalice.py
+++ b/tests/http/clients/chalice.py
@@ -50,6 +50,7 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
result_override: ResultOverrideFunction = None,
+ multipart_uploads_enabled: bool = False,
):
self.app = Chalice(app_name="TheStackBadger")
diff --git a/tests/http/clients/channels.py b/tests/http/clients/channels.py
index da981403bb..bde2364128 100644
--- a/tests/http/clients/channels.py
+++ b/tests/http/clients/channels.py
@@ -23,6 +23,8 @@
from ..context import get_context
from .base import (
JSON,
+ DebuggableGraphQLTransportWSHandler,
+ DebuggableGraphQLWSHandler,
HttpClient,
Message,
Response,
@@ -63,25 +65,6 @@ def create_multipart_request_body(
return headers, request_body
-class DebuggableGraphQLTransportWSConsumer(GraphQLWSConsumer):
- def get_tasks(self) -> List[Any]:
- if hasattr(self._handler, "operations"):
- return [op.task for op in self._handler.operations.values()]
- else:
- return list(self._handler.tasks.values())
-
- async def get_context(self, *args: str, **kwargs: Any) -> object:
- context = await super().get_context(*args, **kwargs)
- context["ws"] = self._handler._ws
- context["get_tasks"] = self.get_tasks
- context["connectionInitTimeoutTask"] = getattr(
- self._handler, "connection_init_timeout_task", None
- )
- for key, val in get_context({}).items():
- context[key] = val
- return context
-
-
class DebuggableGraphQLHTTPConsumer(GraphQLHTTPConsumer):
result_override: ResultOverrideFunction = None
@@ -130,6 +113,16 @@ def process_result(
return super().process_result(request, result)
+class DebuggableGraphQLWSConsumer(GraphQLWSConsumer):
+ graphql_transport_ws_handler_class = DebuggableGraphQLTransportWSHandler
+ graphql_ws_handler_class = DebuggableGraphQLWSHandler
+
+ async def get_context(self, request, response):
+ context = await super().get_context(request, response)
+
+ return get_context(context)
+
+
class ChannelsHttpClient(HttpClient):
"""A client to test websockets over channels."""
@@ -139,8 +132,9 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
result_override: ResultOverrideFunction = None,
+ multipart_uploads_enabled: bool = False,
):
- self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi(
+ self.ws_app = DebuggableGraphQLWSConsumer.as_asgi(
schema=schema,
keep_alive=False,
)
@@ -151,12 +145,11 @@ def __init__(
graphql_ide=graphql_ide,
allow_queries_via_get=allow_queries_via_get,
result_override=result_override,
+ multipart_uploads_enabled=multipart_uploads_enabled,
)
def create_app(self, **kwargs: Any) -> None:
- self.ws_app = DebuggableGraphQLTransportWSConsumer.as_asgi(
- schema=schema, **kwargs
- )
+ self.ws_app = DebuggableGraphQLWSConsumer.as_asgi(schema=schema, **kwargs)
async def _graphql_request(
self,
@@ -245,10 +238,13 @@ async def ws_connect(
) -> AsyncGenerator[WebSocketClient, None]:
client = WebsocketCommunicator(self.ws_app, url, subprotocols=protocols)
- res = await client.connect()
- assert res == (True, protocols[0])
+ connected, subprotocol_or_close_code = await client.connect()
+ assert connected
+
try:
- yield ChannelsWebSocketClient(client)
+ yield ChannelsWebSocketClient(
+ client, accepted_subprotocol=subprotocol_or_close_code
+ )
finally:
await client.disconnect()
@@ -260,6 +256,7 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
result_override: ResultOverrideFunction = None,
+ multipart_uploads_enabled: bool = False,
):
self.http_app = DebuggableSyncGraphQLHTTPConsumer.as_asgi(
schema=schema,
@@ -267,19 +264,26 @@ def __init__(
graphql_ide=graphql_ide,
allow_queries_via_get=allow_queries_via_get,
result_override=result_override,
+ multipart_uploads_enabled=multipart_uploads_enabled,
)
class ChannelsWebSocketClient(WebSocketClient):
- def __init__(self, client: WebsocketCommunicator):
+ def __init__(
+ self, client: WebsocketCommunicator, accepted_subprotocol: Optional[str]
+ ):
self.ws = client
self._closed: bool = False
self._close_code: Optional[int] = None
self._close_reason: Optional[str] = None
+ self._accepted_subprotocol = accepted_subprotocol
def name(self) -> str:
return "channels"
+ async def send_text(self, payload: str) -> None:
+ await self.ws.send_to(text_data=payload)
+
async def send_json(self, payload: Dict[str, Any]) -> None:
await self.ws.send_json_to(payload)
@@ -307,6 +311,10 @@ async def close(self) -> None:
await self.ws.disconnect()
self._closed = True
+ @property
+ def accepted_subprotocol(self) -> Optional[str]:
+ return self._accepted_subprotocol
+
@property
def closed(self) -> bool:
return self._closed
@@ -316,5 +324,6 @@ def close_code(self) -> int:
assert self._close_code is not None
return self._close_code
- def assert_reason(self, reason: str) -> None:
- assert self._close_reason == reason
+ @property
+ def close_reason(self) -> Optional[str]:
+ return self._close_reason
diff --git a/tests/http/clients/django.py b/tests/http/clients/django.py
index 75efa2825c..a871d7b63b 100644
--- a/tests/http/clients/django.py
+++ b/tests/http/clients/django.py
@@ -48,11 +48,13 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
result_override: ResultOverrideFunction = None,
+ multipart_uploads_enabled: bool = False,
):
self.graphiql = graphiql
self.graphql_ide = graphql_ide
self.allow_queries_via_get = allow_queries_via_get
self.result_override = result_override
+ self.multipart_uploads_enabled = multipart_uploads_enabled
def _get_header_name(self, key: str) -> str:
return f"HTTP_{key.upper().replace('-', '_')}"
@@ -75,6 +77,7 @@ async def _do_request(self, request: RequestFactory) -> Response:
graphql_ide=self.graphql_ide,
allow_queries_via_get=self.allow_queries_via_get,
result_override=self.result_override,
+ multipart_uploads_enabled=self.multipart_uploads_enabled,
)
try:
diff --git a/tests/http/clients/fastapi.py b/tests/http/clients/fastapi.py
index 1a8148c136..b1b80625fa 100644
--- a/tests/http/clients/fastapi.py
+++ b/tests/http/clients/fastapi.py
@@ -11,7 +11,6 @@
from fastapi import BackgroundTasks, Depends, FastAPI, Request, WebSocket
from fastapi.testclient import TestClient
from strawberry.fastapi import GraphQLRouter as BaseGraphQLRouter
-from strawberry.fastapi.handlers import GraphQLTransportWSHandler, GraphQLWSHandler
from strawberry.http import GraphQLHTTPResponse
from strawberry.http.ides import GraphQL_IDE
from strawberry.types import ExecutionResult
@@ -21,8 +20,8 @@
from .asgi import AsgiWebSocketClient
from .base import (
JSON,
- DebuggableGraphQLTransportWSMixin,
- DebuggableGraphQLWSMixin,
+ DebuggableGraphQLTransportWSHandler,
+ DebuggableGraphQLWSHandler,
HttpClient,
Response,
ResultOverrideFunction,
@@ -30,16 +29,6 @@
)
-class DebuggableGraphQLTransportWSHandler(
- DebuggableGraphQLTransportWSMixin, GraphQLTransportWSHandler
-):
- pass
-
-
-class DebuggableGraphQLWSHandler(DebuggableGraphQLWSMixin, GraphQLWSHandler):
- pass
-
-
def custom_context_dependency() -> str:
return "Hi!"
@@ -86,6 +75,7 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
result_override: ResultOverrideFunction = None,
+ multipart_uploads_enabled: bool = False,
):
self.app = FastAPI()
@@ -97,6 +87,7 @@ def __init__(
root_value_getter=get_root_value,
allow_queries_via_get=allow_queries_via_get,
keep_alive=False,
+ multipart_uploads_enabled=multipart_uploads_enabled,
)
graphql_app.result_override = result_override
self.app.include_router(graphql_app, prefix="/graphql")
diff --git a/tests/http/clients/flask.py b/tests/http/clients/flask.py
index abc0e3cec4..7da42bbaff 100644
--- a/tests/http/clients/flask.py
+++ b/tests/http/clients/flask.py
@@ -61,6 +61,7 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
result_override: ResultOverrideFunction = None,
+ multipart_uploads_enabled: bool = False,
):
self.app = Flask(__name__)
self.app.debug = True
@@ -72,6 +73,7 @@ def __init__(
graphql_ide=graphql_ide,
allow_queries_via_get=allow_queries_via_get,
result_override=result_override,
+ multipart_uploads_enabled=multipart_uploads_enabled,
)
self.app.add_url_rule(
diff --git a/tests/http/clients/litestar.py b/tests/http/clients/litestar.py
index ccf9999f7f..2548dc563c 100644
--- a/tests/http/clients/litestar.py
+++ b/tests/http/clients/litestar.py
@@ -13,15 +13,14 @@
from strawberry.http import GraphQLHTTPResponse
from strawberry.http.ides import GraphQL_IDE
from strawberry.litestar import make_graphql_controller
-from strawberry.litestar.controller import GraphQLTransportWSHandler, GraphQLWSHandler
from strawberry.types import ExecutionResult
from tests.views.schema import Query, schema
from ..context import get_context
from .base import (
JSON,
- DebuggableGraphQLTransportWSMixin,
- DebuggableGraphQLWSMixin,
+ DebuggableGraphQLTransportWSHandler,
+ DebuggableGraphQLWSHandler,
HttpClient,
Message,
Response,
@@ -42,16 +41,6 @@ async def get_root_value(request: Request = None):
return Query()
-class DebuggableGraphQLTransportWSHandler(
- DebuggableGraphQLTransportWSMixin, GraphQLTransportWSHandler
-):
- pass
-
-
-class DebuggableGraphQLWSHandler(DebuggableGraphQLWSMixin, GraphQLWSHandler):
- pass
-
-
class LitestarHttpClient(HttpClient):
def __init__(
self,
@@ -59,12 +48,14 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
result_override: ResultOverrideFunction = None,
+ multipart_uploads_enabled: bool = False,
):
self.create_app(
graphiql=graphiql,
graphql_ide=graphql_ide,
allow_queries_via_get=allow_queries_via_get,
result_override=result_override,
+ multipart_uploads_enabled=multipart_uploads_enabled,
)
def create_app(self, result_override: ResultOverrideFunction = None, **kwargs: Any):
@@ -188,6 +179,9 @@ def handle_disconnect(self, exc: WebSocketDisconnect) -> None:
self._closed = True
self._close_code = exc.code
+ async def send_text(self, payload: str) -> None:
+ self.ws.send_text(payload)
+
async def send_json(self, payload: Dict[str, Any]) -> None:
self.ws.send_json(payload)
@@ -227,6 +221,10 @@ async def close(self) -> None:
self.ws.close()
self._closed = True
+ @property
+ def accepted_subprotocol(self) -> Optional[str]:
+ return self.ws.accepted_subprotocol
+
@property
def closed(self) -> bool:
return self._closed
@@ -236,5 +234,6 @@ def close_code(self) -> int:
assert self._close_code is not None
return self._close_code
- def assert_reason(self, reason: str) -> None:
- assert self._close_reason == reason
+ @property
+ def close_reason(self) -> Optional[str]:
+ return self._close_reason
diff --git a/tests/http/clients/quart.py b/tests/http/clients/quart.py
index 60bc14b8c2..d9a184dfd4 100644
--- a/tests/http/clients/quart.py
+++ b/tests/http/clients/quart.py
@@ -54,6 +54,7 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
result_override: ResultOverrideFunction = None,
+ multipart_uploads_enabled: bool = False,
):
self.app = Quart(__name__)
self.app.debug = True
@@ -65,6 +66,7 @@ def __init__(
graphql_ide=graphql_ide,
allow_queries_via_get=allow_queries_via_get,
result_override=result_override,
+ multipart_uploads_enabled=multipart_uploads_enabled,
)
self.app.add_url_rule(
diff --git a/tests/http/clients/sanic.py b/tests/http/clients/sanic.py
index 449aa316e8..8edc351db9 100644
--- a/tests/http/clients/sanic.py
+++ b/tests/http/clients/sanic.py
@@ -53,6 +53,7 @@ def __init__(
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
result_override: ResultOverrideFunction = None,
+ multipart_uploads_enabled: bool = False,
):
self.app = Sanic(
f"test_{int(randint(0, 1000))}", # noqa: S311
@@ -63,6 +64,7 @@ def __init__(
graphql_ide=graphql_ide,
allow_queries_via_get=allow_queries_via_get,
result_override=result_override,
+ multipart_uploads_enabled=multipart_uploads_enabled,
)
self.app.add_route(
view,
diff --git a/tests/http/test_upload.py b/tests/http/test_upload.py
index 86871c9f2a..7a991db846 100644
--- a/tests/http/test_upload.py
+++ b/tests/http/test_upload.py
@@ -20,7 +20,18 @@ def http_client(http_client_class: Type[HttpClient]) -> HttpClient:
return http_client_class()
-async def test_upload(http_client: HttpClient):
+@pytest.fixture()
+def enabled_http_client(http_client_class: Type[HttpClient]) -> HttpClient:
+ with contextlib.suppress(ImportError):
+ from .clients.chalice import ChaliceHttpClient
+
+ if http_client_class is ChaliceHttpClient:
+ pytest.xfail(reason="Chalice does not support uploads")
+
+ return http_client_class(multipart_uploads_enabled=True)
+
+
+async def test_multipart_uploads_are_disabled_by_default(http_client: HttpClient):
f = BytesIO(b"strawberry")
query = """
@@ -35,16 +46,35 @@ async def test_upload(http_client: HttpClient):
files={"textFile": f},
)
+ assert response.status_code == 400
+ assert response.data == b"Unsupported content type"
+
+
+async def test_upload(enabled_http_client: HttpClient):
+ f = BytesIO(b"strawberry")
+
+ query = """
+ mutation($textFile: Upload!) {
+ readText(textFile: $textFile)
+ }
+ """
+
+ response = await enabled_http_client.query(
+ query,
+ variables={"textFile": None},
+ files={"textFile": f},
+ )
+
assert response.json.get("errors") is None
assert response.json["data"] == {"readText": "strawberry"}
-async def test_file_list_upload(http_client: HttpClient):
+async def test_file_list_upload(enabled_http_client: HttpClient):
query = "mutation($files: [Upload!]!) { readFiles(files: $files) }"
file1 = BytesIO(b"strawberry1")
file2 = BytesIO(b"strawberry2")
- response = await http_client.query(
+ response = await enabled_http_client.query(
query=query,
variables={"files": [None, None]},
files={"file1": file1, "file2": file2},
@@ -57,12 +87,12 @@ async def test_file_list_upload(http_client: HttpClient):
assert data["readFiles"][1] == "strawberry2"
-async def test_nested_file_list(http_client: HttpClient):
+async def test_nested_file_list(enabled_http_client: HttpClient):
query = "mutation($folder: FolderInput!) { readFolder(folder: $folder) }"
file1 = BytesIO(b"strawberry1")
file2 = BytesIO(b"strawberry2")
- response = await http_client.query(
+ response = await enabled_http_client.query(
query=query,
variables={"folder": {"files": [None, None]}},
files={"file1": file1, "file2": file2},
@@ -74,7 +104,7 @@ async def test_nested_file_list(http_client: HttpClient):
assert data["readFolder"][1] == "strawberry2"
-async def test_upload_single_and_list_file_together(http_client: HttpClient):
+async def test_upload_single_and_list_file_together(enabled_http_client: HttpClient):
query = """
mutation($files: [Upload!]!, $textFile: Upload!) {
readFiles(files: $files)
@@ -85,7 +115,7 @@ async def test_upload_single_and_list_file_together(http_client: HttpClient):
file2 = BytesIO(b"strawberry2")
file3 = BytesIO(b"strawberry3")
- response = await http_client.query(
+ response = await enabled_http_client.query(
query=query,
variables={"files": [None, None], "textFile": None},
files={"file1": file1, "file2": file2, "textFile": file3},
@@ -98,7 +128,7 @@ async def test_upload_single_and_list_file_together(http_client: HttpClient):
assert data["readText"] == "strawberry3"
-async def test_upload_invalid_query(http_client: HttpClient):
+async def test_upload_invalid_query(enabled_http_client: HttpClient):
f = BytesIO(b"strawberry")
query = """
@@ -106,7 +136,7 @@ async def test_upload_invalid_query(http_client: HttpClient):
readT
"""
- response = await http_client.query(
+ response = await enabled_http_client.query(
query,
variables={"textFile": None},
files={"textFile": f},
@@ -122,7 +152,7 @@ async def test_upload_invalid_query(http_client: HttpClient):
]
-async def test_upload_missing_file(http_client: HttpClient):
+async def test_upload_missing_file(enabled_http_client: HttpClient):
f = BytesIO(b"strawberry")
query = """
@@ -131,7 +161,7 @@ async def test_upload_missing_file(http_client: HttpClient):
}
"""
- response = await http_client.query(
+ response = await enabled_http_client.query(
query,
variables={"textFile": None},
# using the wrong name to simulate a missing file
@@ -155,7 +185,7 @@ def value(self) -> bytes:
return self.buffer.getvalue()
-async def test_extra_form_data_fields_are_ignored(http_client: HttpClient):
+async def test_extra_form_data_fields_are_ignored(enabled_http_client: HttpClient):
query = """mutation($textFile: Upload!) {
readText(textFile: $textFile)
}"""
@@ -175,7 +205,7 @@ async def test_extra_form_data_fields_are_ignored(http_client: HttpClient):
data, header = encode_multipart_formdata(fields)
- response = await http_client.post(
+ response = await enabled_http_client.post(
url="/graphql",
data=data,
headers={
@@ -188,9 +218,9 @@ async def test_extra_form_data_fields_are_ignored(http_client: HttpClient):
assert response.json["data"] == {"readText": "strawberry"}
-async def test_sending_invalid_form_data(http_client: HttpClient):
+async def test_sending_invalid_form_data(enabled_http_client: HttpClient):
headers = {"content-type": "multipart/form-data; boundary=----fake"}
- response = await http_client.post("/graphql", headers=headers)
+ response = await enabled_http_client.post("/graphql", headers=headers)
assert response.status_code == 400
# TODO: consolidate this, it seems only AIOHTTP returns the second error
@@ -202,7 +232,7 @@ async def test_sending_invalid_form_data(http_client: HttpClient):
@pytest.mark.aiohttp
-async def test_sending_invalid_json_body(http_client: HttpClient):
+async def test_sending_invalid_json_body(enabled_http_client: HttpClient):
f = BytesIO(b"strawberry")
operations = "}"
file_map = json.dumps({"textFile": ["variables.textFile"]})
@@ -215,7 +245,7 @@ async def test_sending_invalid_json_body(http_client: HttpClient):
data, header = encode_multipart_formdata(fields)
- response = await http_client.post(
+ response = await enabled_http_client.post(
"/graphql",
data=data,
headers={
diff --git a/tests/litestar/test_websockets.py b/tests/litestar/test_websockets.py
deleted file mode 100644
index b5554cb264..0000000000
--- a/tests/litestar/test_websockets.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import pytest
-
-from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
-
-
-def test_turning_off_graphql_ws():
- from litestar.exceptions import WebSocketDisconnect
- from litestar.testing import TestClient
- from tests.litestar.app import create_app
-
- app = create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL])
- test_client = TestClient(app)
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/graphql", [GRAPHQL_WS_PROTOCOL]):
- pass
-
- assert exc.value.code == 4406
-
-
-def test_turning_off_graphql_transport_ws():
- from litestar.exceptions import WebSocketDisconnect
- from litestar.testing import TestClient
- from tests.litestar.app import create_app
-
- app = create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL])
- test_client = TestClient(app)
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL]):
- pass
-
- assert exc.value.code == 4406
-
-
-def test_turning_off_all_ws_protocols():
- from litestar.exceptions import WebSocketDisconnect
- from litestar.testing import TestClient
- from tests.litestar.app import create_app
-
- app = create_app(subscription_protocols=[])
- test_client = TestClient(app)
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL]):
- pass
-
- assert exc.value.code == 4406
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/graphql", [GRAPHQL_WS_PROTOCOL]):
- pass
-
- assert exc.value.code == 4406
-
-
-def test_unsupported_ws_protocol():
- from litestar.exceptions import WebSocketDisconnect
- from litestar.testing import TestClient
- from tests.litestar.app import create_app
-
- app = create_app(subscription_protocols=[])
- test_client = TestClient(app)
-
- with pytest.raises(WebSocketDisconnect) as exc:
- with test_client.websocket_connect("/graphql", ["imaginary-protocol"]):
- pass
-
- assert exc.value.code == 4406
-
-
-def test_clients_can_prefer_protocols():
- from litestar.testing import TestClient
- from tests.litestar.app import create_app
-
- app = create_app(
- subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL]
- )
- test_client = TestClient(app)
-
- with test_client.websocket_connect(
- "/graphql", [GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL]
- ) as ws:
- assert ws.accepted_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL
-
- with test_client.websocket_connect(
- "/graphql", [GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL]
- ) as ws:
- assert ws.accepted_subprotocol == GRAPHQL_WS_PROTOCOL
diff --git a/strawberry/litestar/handlers/__init__.py b/tests/test/__init__.py
similarity index 100%
rename from strawberry/litestar/handlers/__init__.py
rename to tests/test/__init__.py
diff --git a/tests/test/conftest.py b/tests/test/conftest.py
new file mode 100644
index 0000000000..30ecd326bd
--- /dev/null
+++ b/tests/test/conftest.py
@@ -0,0 +1,66 @@
+from __future__ import annotations
+
+from contextlib import asynccontextmanager
+from typing import TYPE_CHECKING, AsyncGenerator
+
+import pytest
+
+from tests.views.schema import schema
+
+if TYPE_CHECKING:
+ from strawberry.test import BaseGraphQLTestClient
+
+
+@asynccontextmanager
+async def aiohttp_graphql_client() -> AsyncGenerator[BaseGraphQLTestClient]:
+ try:
+ from aiohttp import web
+ from aiohttp.test_utils import TestClient, TestServer
+ from strawberry.aiohttp.test import GraphQLTestClient
+ from strawberry.aiohttp.views import GraphQLView
+ except ImportError:
+ pytest.skip("Aiohttp not installed")
+
+ view = GraphQLView(schema=schema)
+ app = web.Application()
+ app.router.add_route("*", "/graphql/", view)
+
+ async with TestClient(TestServer(app)) as client:
+ yield GraphQLTestClient(client)
+
+
+@asynccontextmanager
+async def asgi_graphql_client() -> AsyncGenerator[BaseGraphQLTestClient]:
+ try:
+ from starlette.testclient import TestClient
+
+ from strawberry.asgi import GraphQL
+ from strawberry.asgi.test import GraphQLTestClient
+ except ImportError:
+ pytest.skip("Starlette not installed")
+
+ yield GraphQLTestClient(TestClient(GraphQL(schema)))
+
+
+@asynccontextmanager
+async def django_graphql_client() -> AsyncGenerator[BaseGraphQLTestClient]:
+ try:
+ from django.test.client import Client
+
+ from strawberry.django.test import GraphQLTestClient
+ except ImportError:
+ pytest.skip("Django not installed")
+
+ yield GraphQLTestClient(Client())
+
+
+@pytest.fixture(
+ params=[
+ pytest.param(aiohttp_graphql_client, marks=[pytest.mark.aiohttp]),
+ pytest.param(asgi_graphql_client, marks=[pytest.mark.asgi]),
+ pytest.param(django_graphql_client, marks=[pytest.mark.django]),
+ ]
+)
+async def graphql_client(request) -> AsyncGenerator[BaseGraphQLTestClient]:
+ async with request.param() as graphql_client:
+ yield graphql_client
diff --git a/tests/test/test_client.py b/tests/test/test_client.py
new file mode 100644
index 0000000000..0e03d12f77
--- /dev/null
+++ b/tests/test/test_client.py
@@ -0,0 +1,34 @@
+from contextlib import nullcontext
+
+import pytest
+
+from strawberry.utils.await_maybe import await_maybe
+
+
+@pytest.mark.parametrize("asserts_errors", [True, False])
+async def test_query_asserts_errors_option_is_deprecated(
+ graphql_client, asserts_errors
+):
+ with pytest.warns(
+ DeprecationWarning,
+ match="The `asserts_errors` argument has been renamed to `assert_no_errors`",
+ ):
+ await await_maybe(
+ graphql_client.query("{ hello }", asserts_errors=asserts_errors)
+ )
+
+
+@pytest.mark.parametrize("option_name", ["asserts_errors", "assert_no_errors"])
+@pytest.mark.parametrize(
+ ("assert_no_errors", "expectation"),
+ [(True, pytest.raises(AssertionError)), (False, nullcontext())],
+)
+async def test_query_with_assert_no_errors_option(
+ graphql_client, option_name, assert_no_errors, expectation
+):
+ query = "{ ThisIsNotAValidQuery }"
+
+ with expectation:
+ await await_maybe(
+ graphql_client.query(query, **{option_name: assert_no_errors})
+ )
diff --git a/tests/websockets/test_graphql_transport_ws.py b/tests/websockets/test_graphql_transport_ws.py
index 02f8366852..4dbea524f4 100644
--- a/tests/websockets/test_graphql_transport_ws.py
+++ b/tests/websockets/test_graphql_transport_ws.py
@@ -24,7 +24,7 @@
SubscribeMessage,
SubscribeMessagePayload,
)
-from tests.http.clients.base import DebuggableGraphQLTransportWSMixin
+from tests.http.clients.base import DebuggableGraphQLTransportWSHandler
from tests.views.schema import MyExtension, Schema
if TYPE_CHECKING:
@@ -76,7 +76,7 @@ async def test_unknown_message_type(ws_raw: WebSocketClient):
data = await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
- ws.assert_reason("Unknown message type: NOT_A_MESSAGE_TYPE")
+ assert ws.close_reason == "Unknown message type: NOT_A_MESSAGE_TYPE"
async def test_missing_message_type(ws_raw: WebSocketClient):
@@ -87,7 +87,7 @@ async def test_missing_message_type(ws_raw: WebSocketClient):
data = await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
- ws.assert_reason("Failed to parse message")
+ assert ws.close_reason == "Failed to parse message"
async def test_parsing_an_invalid_message(ws_raw: WebSocketClient):
@@ -98,7 +98,7 @@ async def test_parsing_an_invalid_message(ws_raw: WebSocketClient):
data = await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
- ws.assert_reason("Failed to parse message")
+ assert ws.close_reason == "Failed to parse message"
async def test_parsing_an_invalid_payload(ws_raw: WebSocketClient):
@@ -109,7 +109,7 @@ async def test_parsing_an_invalid_payload(ws_raw: WebSocketClient):
data = await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
- ws.assert_reason("Failed to parse message")
+ assert ws.close_reason == "Failed to parse message"
async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):
@@ -120,7 +120,18 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):
await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
- ws.assert_reason("WebSocket message type must be text")
+ assert ws.close_reason == "WebSocket message type must be text"
+
+
+async def test_ws_messages_must_be_json(ws_raw: WebSocketClient):
+ ws = ws_raw
+
+ await ws.send_text("not valid json")
+
+ await ws.receive(timeout=2)
+ assert ws.closed
+ assert ws.close_code == 4400
+ assert ws.close_reason == "WebSocket message type must be text"
async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient):
@@ -145,7 +156,7 @@ async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient):
await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
- ws.assert_reason("WebSocket message type must be text")
+ assert ws.close_reason == "WebSocket message type must be text"
async def test_connection_init_timeout(
@@ -169,7 +180,7 @@ async def test_connection_init_timeout(
data = await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4408
- ws.assert_reason("Connection initialisation timeout")
+ assert ws.close_reason == "Connection initialisation timeout"
@pytest.mark.flaky
@@ -229,7 +240,7 @@ async def test_close_twice(
await ws.receive(timeout=0.5)
assert ws.closed
assert ws.close_code == 4400
- ws.assert_reason("Invalid connection init payload")
+ assert ws.close_reason == "Invalid connection init payload"
transport_close.assert_not_called()
@@ -238,7 +249,7 @@ async def test_too_many_initialisation_requests(ws: WebSocketClient):
data = await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4429
- ws.assert_reason("Too many initialisation requests")
+ assert ws.close_reason == "Too many initialisation requests"
async def test_ping_pong(ws: WebSocketClient):
@@ -309,7 +320,7 @@ async def test_unauthorized_subscriptions(ws_raw: WebSocketClient):
data = await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4401
- ws.assert_reason("Unauthorized")
+ assert ws.close_reason == "Unauthorized"
async def test_duplicated_operation_ids(ws: WebSocketClient):
@@ -334,7 +345,7 @@ async def test_duplicated_operation_ids(ws: WebSocketClient):
data = await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4409
- ws.assert_reason("Subscriber for sub1 already exists")
+ assert ws.close_reason == "Subscriber for sub1 already exists"
async def test_reused_operation_ids(ws: WebSocketClient):
@@ -398,7 +409,7 @@ async def test_subscription_syntax_error(ws: WebSocketClient):
data = await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
- ws.assert_reason("Syntax Error: Expected Name, found .")
+ assert ws.close_reason == "Syntax Error: Expected Name, found ."
async def test_subscription_field_errors(ws: WebSocketClient):
@@ -652,7 +663,7 @@ async def test_single_result_invalid_operation_selection(ws: WebSocketClient):
data = await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
- ws.assert_reason("Can't get GraphQL operation type")
+ assert ws.close_reason == "Can't get GraphQL operation type"
async def test_single_result_execution_error(ws: WebSocketClient):
@@ -732,7 +743,7 @@ async def test_single_result_duplicate_ids_sub(ws: WebSocketClient):
data = await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4409
- ws.assert_reason("Subscriber for sub1 already exists")
+ assert ws.close_reason == "Subscriber for sub1 already exists"
async def test_single_result_duplicate_ids_query(ws: WebSocketClient):
@@ -763,7 +774,7 @@ async def test_single_result_duplicate_ids_query(ws: WebSocketClient):
data = await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4409
- ws.assert_reason("Subscriber for sub1 already exists")
+ assert ws.close_reason == "Subscriber for sub1 already exists"
async def test_injects_connection_params(ws_raw: WebSocketClient):
@@ -793,7 +804,7 @@ async def test_rejects_connection_params_not_dict(ws_raw: WebSocketClient):
data = await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
- ws.assert_reason("Invalid connection init payload")
+ assert ws.close_reason == "Invalid connection init payload"
@pytest.mark.parametrize(
@@ -809,7 +820,7 @@ async def test_rejects_connection_params_with_wrong_type(
data = await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 4400
- ws.assert_reason("Invalid connection init payload")
+ assert ws.close_reason == "Invalid connection init payload"
# timings can sometimes fail currently. Until this test is rewritten when
@@ -879,7 +890,7 @@ def on_init(_handler):
# cause an attribute error in the timeout task
handler.connection_init_wait_timeout = None
- with patch.object(DebuggableGraphQLTransportWSMixin, "on_init", on_init):
+ with patch.object(DebuggableGraphQLTransportWSHandler, "on_init", on_init):
async with http_client.ws_connect(
"/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
) as ws:
diff --git a/tests/websockets/test_graphql_ws.py b/tests/websockets/test_graphql_ws.py
index 6752eaa7f8..daf56f90fb 100644
--- a/tests/websockets/test_graphql_ws.py
+++ b/tests/websockets/test_graphql_ws.py
@@ -289,7 +289,18 @@ async def test_ws_messages_must_be_text(ws_raw: WebSocketClient):
await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 1002
- ws.assert_reason("WebSocket message type must be text")
+ assert ws.close_reason == "WebSocket message type must be text"
+
+
+async def test_ws_messages_must_be_json(ws_raw: WebSocketClient):
+ ws = ws_raw
+
+ await ws.send_text("not valid json")
+
+ await ws.receive(timeout=2)
+ assert ws.closed
+ assert ws.close_code == 1002
+ assert ws.close_reason == "WebSocket message type must be text"
async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient):
@@ -315,7 +326,7 @@ async def test_ws_message_frame_types_cannot_be_mixed(ws_raw: WebSocketClient):
await ws.receive(timeout=2)
assert ws.closed
assert ws.close_code == 1002
- ws.assert_reason("WebSocket message type must be text")
+ assert ws.close_reason == "WebSocket message type must be text"
async def test_unknown_protocol_messages_are_ignored(ws_raw: WebSocketClient):
diff --git a/tests/websockets/test_websockets.py b/tests/websockets/test_websockets.py
new file mode 100644
index 0000000000..2018316be0
--- /dev/null
+++ b/tests/websockets/test_websockets.py
@@ -0,0 +1,82 @@
+from typing import Type
+
+from strawberry.subscriptions import GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL
+from tests.http.clients.base import HttpClient
+
+
+async def test_turning_off_graphql_ws(http_client_class: Type[HttpClient]):
+ http_client = http_client_class()
+ http_client.create_app(subscription_protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL])
+
+ async with http_client.ws_connect(
+ "/graphql", protocols=[GRAPHQL_WS_PROTOCOL]
+ ) as ws:
+ await ws.receive(timeout=2)
+ assert ws.closed
+ assert ws.close_code == 4406
+ assert ws.close_reason == "Subprotocol not acceptable"
+
+
+async def test_turning_off_graphql_transport_ws(http_client_class: Type[HttpClient]):
+ http_client = http_client_class()
+ http_client.create_app(subscription_protocols=[GRAPHQL_WS_PROTOCOL])
+
+ async with http_client.ws_connect(
+ "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
+ ) as ws:
+ await ws.receive(timeout=2)
+ assert ws.closed
+ assert ws.close_code == 4406
+ assert ws.close_reason == "Subprotocol not acceptable"
+
+
+async def test_turning_off_all_subprotocols(http_client_class: Type[HttpClient]):
+ http_client = http_client_class()
+ http_client.create_app(subscription_protocols=[])
+
+ async with http_client.ws_connect(
+ "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL]
+ ) as ws:
+ await ws.receive(timeout=2)
+ assert ws.closed
+ assert ws.close_code == 4406
+ assert ws.close_reason == "Subprotocol not acceptable"
+
+ async with http_client.ws_connect(
+ "/graphql", protocols=[GRAPHQL_WS_PROTOCOL]
+ ) as ws:
+ await ws.receive(timeout=2)
+ assert ws.closed
+ assert ws.close_code == 4406
+ assert ws.close_reason == "Subprotocol not acceptable"
+
+
+async def test_generally_unsupported_subprotocols_are_rejected(http_client: HttpClient):
+ async with http_client.ws_connect(
+ "/graphql", protocols=["imaginary-protocol"]
+ ) as ws:
+ await ws.receive(timeout=2)
+ assert ws.closed
+ assert ws.close_code == 4406
+ assert ws.close_reason == "Subprotocol not acceptable"
+
+
+async def test_clients_can_prefer_subprotocols(http_client_class: Type[HttpClient]):
+ http_client = http_client_class()
+ http_client.create_app(
+ subscription_protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL]
+ )
+
+ async with http_client.ws_connect(
+ "/graphql", protocols=[GRAPHQL_TRANSPORT_WS_PROTOCOL, GRAPHQL_WS_PROTOCOL]
+ ) as ws:
+ assert ws.accepted_subprotocol == GRAPHQL_TRANSPORT_WS_PROTOCOL
+ await ws.close()
+ assert ws.closed
+
+ async with http_client.ws_connect(
+ "/graphql", protocols=[GRAPHQL_WS_PROTOCOL, GRAPHQL_TRANSPORT_WS_PROTOCOL]
+ ) as ws:
+ assert ws.accepted_subprotocol == GRAPHQL_WS_PROTOCOL
+ await ws.close()
+ assert ws.closed