-
-
Notifications
You must be signed in to change notification settings - Fork 535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Schema Extensions in subscriptions #2784
Changes from all commits
b466ecf
9d5d16f
46c9aa1
e0cd1fc
22855f4
7213377
c8cec42
6dedf22
4998305
41d053e
8a039af
b622dac
9a8a633
7aa2009
c5eec36
edd6d2e
64c33ca
a033624
699a910
4fdbac4
5130fba
ef6688d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Release type: minor | ||
|
||
Subscriptions now support Schema Extensions. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .base import BaseSchema | ||
from .base import BaseSchema, SubscribeSingleResult | ||
from .schema import Schema | ||
|
||
__all__ = ["BaseSchema", "Schema"] | ||
__all__ = ["BaseSchema", "Schema", "SubscribeSingleResult"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,8 @@ | |
from inspect import isawaitable | ||
from typing import ( | ||
TYPE_CHECKING, | ||
AsyncGenerator, | ||
AsyncIterable, | ||
Awaitable, | ||
Callable, | ||
Iterable, | ||
|
@@ -17,21 +19,23 @@ | |
cast, | ||
) | ||
|
||
from graphql import ExecutionResult as GraphQLExecutionResult | ||
from graphql import GraphQLError, parse | ||
from graphql import execute as original_execute | ||
from graphql import subscribe as original_subscribe | ||
from graphql.validation import validate | ||
|
||
from strawberry.exceptions import MissingQueryError | ||
from strawberry.extensions.runner import SchemaExtensionsRunner | ||
from strawberry.types import ExecutionResult | ||
|
||
from .base import SubscribeSingleResult | ||
from .exceptions import InvalidOperationTypeError | ||
|
||
if TYPE_CHECKING: | ||
from typing_extensions import NotRequired, Unpack | ||
|
||
from graphql import ExecutionContext as GraphQLExecutionContext | ||
from graphql import ExecutionResult as GraphQLExecutionResult | ||
from graphql import GraphQLSchema | ||
from graphql.language import DocumentNode | ||
from graphql.validation import ASTValidationRule | ||
|
@@ -272,3 +276,149 @@ def execute_sync( | |
errors=execution_context.result.errors, | ||
extensions=extensions_runner.get_extensions_results_sync(), | ||
) | ||
|
||
|
||
async def subscribe( | ||
schema: GraphQLSchema, | ||
*, | ||
extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]], | ||
execution_context: ExecutionContext, | ||
process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], | ||
) -> AsyncGenerator[ExecutionResult, None]: | ||
""" | ||
The graphql-core subscribe function returns either an ExecutionResult or an | ||
AsyncGenerator[ExecutionResult, None]. The former is returned in case of an error | ||
during parsing or validation. | ||
Because we need to maintain execution context, we cannot return an | ||
async generator, we must _be_ an async generator. So we yield a | ||
(bool, ExecutionResult) tuple, where the bool indicates whether the result is an | ||
potentially multiple execution result or a single result. | ||
A False value indicates an single result, most likely an intial | ||
failure (and no more values will be yielded) whereas a True value indicates a | ||
successful subscription, and more values may be yielded. | ||
""" | ||
|
||
extensions_runner = SchemaExtensionsRunner( | ||
execution_context=execution_context, | ||
extensions=list(extensions), | ||
) | ||
|
||
# unlike execute(), the entire operation, including the results hooks, | ||
# is run within the operation() hook. | ||
async with extensions_runner.operation(): | ||
# Note: In graphql-core the schema would be validated here but in | ||
# Strawberry we are validating it at initialisation time instead | ||
assert execution_context.query is not None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should raise here MissingQueryError There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should it? I removed that code because it didn't get hit by coverage testing. It is my understanding that that can only happen if a query parameter is missing from a "query string", and we don't have these for subscriptions. Under what conditions could that possibly happen? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm I'm not sure I think @jthorniley added this. |
||
|
||
async with extensions_runner.parsing(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is duplicated from async execution, probably can be reused... see https://github.com/nrbnlulu/strawberry/blob/support_extensions_on_subscriptions/strawberry/schema/execute.py#L192 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll see if I can do that after I change the tuple semantics. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did refactor the common validation tests |
||
try: | ||
if not execution_context.graphql_document: | ||
execution_context.graphql_document = parse_document( | ||
execution_context.query, **execution_context.parse_options | ||
) | ||
|
||
except GraphQLError as error: | ||
execution_context.errors = [error] | ||
process_errors([error], execution_context) | ||
raise SubscribeSingleResult( | ||
ExecutionResult( | ||
data=None, | ||
errors=[error], | ||
extensions=await extensions_runner.get_extensions_results(), | ||
) | ||
) | ||
|
||
except Exception as error: # pragma: no cover | ||
error = GraphQLError(str(error), original_error=error) | ||
execution_context.errors = [error] | ||
process_errors([error], execution_context) | ||
|
||
raise SubscribeSingleResult( | ||
ExecutionResult( | ||
data=None, | ||
errors=[error], | ||
extensions=await extensions_runner.get_extensions_results(), | ||
) | ||
) | ||
|
||
async with extensions_runner.validation(): | ||
_run_validation(execution_context) | ||
if execution_context.errors: | ||
process_errors(execution_context.errors, execution_context) | ||
raise SubscribeSingleResult( | ||
ExecutionResult(data=None, errors=execution_context.errors) | ||
) | ||
|
||
async with extensions_runner.executing(): | ||
# currently original_subscribe is an async function. A future release | ||
# of graphql-core will make it optionally awaitable | ||
result: Union[AsyncIterable[GraphQLExecutionResult], GraphQLExecutionResult] | ||
result_or_awaitable = original_subscribe( | ||
schema, | ||
execution_context.graphql_document, | ||
root_value=execution_context.root_value, | ||
context_value=execution_context.context, | ||
variable_values=execution_context.variables, | ||
operation_name=execution_context.operation_name, | ||
) | ||
if isawaitable(result_or_awaitable): | ||
result = await cast( | ||
Awaitable[ | ||
Union[ | ||
AsyncIterable["GraphQLExecutionResult"], | ||
"GraphQLExecutionResult", | ||
] | ||
], | ||
result_or_awaitable, | ||
) | ||
else: # pragma: no cover | ||
result = cast( | ||
Union[ | ||
AsyncIterable["GraphQLExecutionResult"], | ||
"GraphQLExecutionResult", | ||
], | ||
result_or_awaitable, | ||
) | ||
|
||
if isinstance(result, GraphQLExecutionResult): | ||
raise SubscribeSingleResult( | ||
await process_subscribe_result( | ||
execution_context, process_errors, extensions_runner, result | ||
) | ||
) | ||
|
||
aiterator = result.__aiter__() | ||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suppress what exactly? AttributeError? Because that exception might come from anywhere. This is surgically testing for the existence of the aclose method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand.
What's wrong with with contextlib.supress(BaseException):
async for result in aiterator:
yield True, await process_result(result)
if hasattr(aiterator, "aclose"):
await aiterator.aclose() AFAIK this is the same as what you are doing... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, you don't want to suppress No, try-finally exists precisely to do this kind of thing. In fact, that is how with graphql-core 3.3, there will be an |
||
async for result in aiterator: | ||
yield await process_subscribe_result( | ||
execution_context, process_errors, extensions_runner, result | ||
) | ||
finally: | ||
# grapql-core's iterator may or may not have an aclose() method | ||
if hasattr(aiterator, "aclose"): | ||
await aiterator.aclose() | ||
|
||
|
||
async def process_subscribe_result( | ||
execution_context: ExecutionContext, | ||
process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None], | ||
extensions_runner: SchemaExtensionsRunner, | ||
result: GraphQLExecutionResult, | ||
) -> ExecutionResult: | ||
execution_context.result = result | ||
# Also set errors on the execution_context so that it's easier | ||
# to access in extensions | ||
if result.errors: | ||
execution_context.errors = result.errors | ||
|
||
# Run the `Schema.process_errors` function here before | ||
# extensions have a chance to modify them (see the MaskErrors | ||
# extension). That way we can log the original errors but | ||
# only return a sanitised version to the client. | ||
process_errors(result.errors, execution_context) | ||
|
||
return ExecutionResult( | ||
data=execution_context.result.data, | ||
errors=execution_context.result.errors, | ||
extensions=await extensions_runner.get_extensions_results(), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tuple hack is not very pythonic IMO other than the fact that it is a breaking change... You might wanna check what I did at https://github.com/strawberry-graphql/strawberry/pull/2810/files#diff-88aa6fd17e4c6feac6e7152ebd3f2b8f972544c444a071b550e4d23061b97a3fR215 where if there is an error I return
ExecutionResultError
which is basically the same as normalExecutionResult
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a good idea. Yes, tuples are problematic and thought this might be a sticking point. I'll create a special exception class instead, much nicer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, either an exception, or a special result class.. I think the exception might be cleaner, since one expects a subscription and the failure to get one is an exception of sorts. I'll see which one is nicer.