Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Schema Extensions in subscriptions #2784

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b466ecf
schema.subscribe is now an async generator. Apply extension contexts…
kristjanvalur Jul 11, 2023
9d5d16f
Update graphql_ws handler
kristjanvalur May 27, 2023
46c9aa1
Remove unused code
kristjanvalur May 27, 2023
e0cd1fc
fix tests now that field verification is done by Strawberry
kristjanvalur May 27, 2023
22855f4
Return "extensions" as part of ExecutionResult if present.
kristjanvalur May 27, 2023
7213377
Add tests for blocking operation validation
kristjanvalur Mar 8, 2023
c8cec42
Add some error/validation test cases
kristjanvalur Apr 9, 2023
6dedf22
Add release.md
kristjanvalur Apr 23, 2023
4998305
Fix schema tests for new `subscribe()` signature
kristjanvalur May 27, 2023
41d053e
Add extensive tests for extension hook execution while running subscr…
kristjanvalur May 31, 2023
8a039af
make extension test more robust, wait for old connections to drain.
kristjanvalur May 31, 2023
b622dac
update docs for `resolve` extension hook
kristjanvalur May 31, 2023
9a8a633
fix channels tests to work with sync channels
kristjanvalur May 31, 2023
7aa2009
schema.subscribe() now raises a SubscribeSingleResult when not return…
kristjanvalur Jun 5, 2023
c5eec36
Move closure function out into a separate function
kristjanvalur Jun 5, 2023
edd6d2e
ruff
kristjanvalur Jun 5, 2023
64c33ca
tests for process_errors() (graphql_ws)
kristjanvalur Jul 9, 2023
a033624
update benchmark test
kristjanvalur Aug 25, 2023
699a910
Update newly added tests
kristjanvalur Aug 25, 2023
4fdbac4
Fix typing
kristjanvalur May 3, 2024
5130fba
fix new tests
kristjanvalur May 3, 2024
ef6688d
formatting
kristjanvalur May 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: minor

Subscriptions now support Schema Extensions.
17 changes: 15 additions & 2 deletions docs/guides/custom-extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ resolvers.
If you need to wrap only certain field resolvers with additional logic, please
check out [field extensions](field-extensions.md).

Note that `resolve` can also be implemented asynchronously.

```python
from strawberry.extensions import SchemaExtension

Expand All @@ -47,6 +45,21 @@ class MyExtension(SchemaExtension):
return _next(root, info, *args, **kwargs)
```

Note that `resolve` can also be implemented asynchronously, in which case the
result from `_next` must be optionally awaited:

```python
from inspect import isawaitable
from strawberry.types import Info
from strawberry.extensions import SchemaExtension


class MyExtension(SchemaExtension):
async def resolve(self, _next, root, info: Info, *args, **kwargs):
result = _next(root, info, *args, **kwargs)
return await result if isawaitable(result) else result
```

### Get results

`get_results` allows to return a dictionary of data or alternatively an
Expand Down
4 changes: 1 addition & 3 deletions docs/operations/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,8 @@ async def test_subscription():
}
"""

sub = await schema.subscribe(query)

index = 0
async for result in sub:
async for result in schema.subscribe(query):
assert not result.errors
assert result.data == {"count": index}

Expand Down
4 changes: 2 additions & 2 deletions strawberry/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import BaseSchema
from .base import BaseSchema, SubscribeSingleResult
from .schema import Schema

__all__ = ["BaseSchema", "Schema"]
__all__ = ["BaseSchema", "Schema", "SubscribeSingleResult"]
25 changes: 22 additions & 3 deletions strawberry/schema/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@

from abc import abstractmethod
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Type, Union
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
Iterable,
List,
Optional,
Type,
Union,
)
from typing_extensions import Protocol

from strawberry.utils.logging import StrawberryLogger
Expand All @@ -22,6 +32,15 @@
from .config import StrawberryConfig


class SubscribeSingleResult(RuntimeError):
"""Raised when Schema.subscribe() returns a single execution result, instead of a
subscription generator, typically as a result of validation errors.
"""

def __init__(self, value: ExecutionResult) -> None:
self.value = value


class BaseSchema(Protocol):
config: StrawberryConfig
schema_converter: GraphQLCoreConverter
Expand Down Expand Up @@ -55,14 +74,14 @@ def execute_sync(
raise NotImplementedError

@abstractmethod
async def subscribe(
def subscribe(
self,
query: str,
variable_values: Optional[Dict[str, Any]] = None,
context_value: Optional[Any] = None,
root_value: Optional[Any] = None,
operation_name: Optional[str] = None,
) -> Any:
) -> AsyncGenerator[ExecutionResult, None]:
raise NotImplementedError

@abstractmethod
Expand Down
152 changes: 151 additions & 1 deletion strawberry/schema/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
AsyncGenerator,
AsyncIterable,
Awaitable,
Callable,
Iterable,
Expand All @@ -17,21 +19,23 @@
cast,
)

from graphql import ExecutionResult as GraphQLExecutionResult
from graphql import GraphQLError, parse
from graphql import execute as original_execute
from graphql import subscribe as original_subscribe
from graphql.validation import validate

from strawberry.exceptions import MissingQueryError
from strawberry.extensions.runner import SchemaExtensionsRunner
from strawberry.types import ExecutionResult

from .base import SubscribeSingleResult
from .exceptions import InvalidOperationTypeError

if TYPE_CHECKING:
from typing_extensions import NotRequired, Unpack

from graphql import ExecutionContext as GraphQLExecutionContext
from graphql import ExecutionResult as GraphQLExecutionResult
from graphql import GraphQLSchema
from graphql.language import DocumentNode
from graphql.validation import ASTValidationRule
Expand Down Expand Up @@ -272,3 +276,149 @@ def execute_sync(
errors=execution_context.result.errors,
extensions=extensions_runner.get_extensions_results_sync(),
)


async def subscribe(
schema: GraphQLSchema,
*,
extensions: Sequence[Union[Type[SchemaExtension], SchemaExtension]],
execution_context: ExecutionContext,
process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None],
) -> AsyncGenerator[ExecutionResult, None]:
"""
The graphql-core subscribe function returns either an ExecutionResult or an
AsyncGenerator[ExecutionResult, None]. The former is returned in case of an error
during parsing or validation.
Because we need to maintain execution context, we cannot return an
async generator, we must _be_ an async generator. So we yield a
(bool, ExecutionResult) tuple, where the bool indicates whether the result is an
Copy link
Member

@nrbnlulu nrbnlulu Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tuple hack is not very pythonic IMO other than the fact that it is a breaking change... You might wanna check what I did at https://github.com/strawberry-graphql/strawberry/pull/2810/files#diff-88aa6fd17e4c6feac6e7152ebd3f2b8f972544c444a071b550e4d23061b97a3fR215 where if there is an error I return ExecutionResultError which is basically the same as normal ExecutionResult

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good idea. Yes, tuples are problematic and thought this might be a sticking point. I'll create a special exception class instead, much nicer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, either an exception, or a special result class.. I think the exception might be cleaner, since one expects a subscription and the failure to get one is an exception of sorts. I'll see which one is nicer.

potentially multiple execution result or a single result.
A False value indicates an single result, most likely an intial
failure (and no more values will be yielded) whereas a True value indicates a
successful subscription, and more values may be yielded.
"""

extensions_runner = SchemaExtensionsRunner(
execution_context=execution_context,
extensions=list(extensions),
)

# unlike execute(), the entire operation, including the results hooks,
# is run within the operation() hook.
async with extensions_runner.operation():
# Note: In graphql-core the schema would be validated here but in
# Strawberry we are validating it at initialisation time instead
assert execution_context.query is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should raise here MissingQueryError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it? I removed that code because it didn't get hit by coverage testing. It is my understanding that that can only happen if a query parameter is missing from a "query string", and we don't have these for subscriptions. Under what conditions could that possibly happen?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I'm not sure I think @jthorniley added this.


async with extensions_runner.parsing():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll see if I can do that after I change the tuple semantics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did refactor the common validation tests

try:
if not execution_context.graphql_document:
execution_context.graphql_document = parse_document(
execution_context.query, **execution_context.parse_options
)

except GraphQLError as error:
execution_context.errors = [error]
process_errors([error], execution_context)
raise SubscribeSingleResult(
ExecutionResult(
data=None,
errors=[error],
extensions=await extensions_runner.get_extensions_results(),
)
)

except Exception as error: # pragma: no cover
error = GraphQLError(str(error), original_error=error)
execution_context.errors = [error]
process_errors([error], execution_context)

raise SubscribeSingleResult(
ExecutionResult(
data=None,
errors=[error],
extensions=await extensions_runner.get_extensions_results(),
)
)

async with extensions_runner.validation():
_run_validation(execution_context)
if execution_context.errors:
process_errors(execution_context.errors, execution_context)
raise SubscribeSingleResult(
ExecutionResult(data=None, errors=execution_context.errors)
)

async with extensions_runner.executing():
# currently original_subscribe is an async function. A future release
# of graphql-core will make it optionally awaitable
result: Union[AsyncIterable[GraphQLExecutionResult], GraphQLExecutionResult]
result_or_awaitable = original_subscribe(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
context_value=execution_context.context,
variable_values=execution_context.variables,
operation_name=execution_context.operation_name,
)
if isawaitable(result_or_awaitable):
result = await cast(
Awaitable[
Union[
AsyncIterable["GraphQLExecutionResult"],
"GraphQLExecutionResult",
]
],
result_or_awaitable,
)
else: # pragma: no cover
result = cast(
Union[
AsyncIterable["GraphQLExecutionResult"],
"GraphQLExecutionResult",
],
result_or_awaitable,
)

if isinstance(result, GraphQLExecutionResult):
raise SubscribeSingleResult(
await process_subscribe_result(
execution_context, process_errors, extensions_runner, result
)
)

aiterator = result.__aiter__()
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not contextlib.suppress?

Copy link
Contributor Author

@kristjanvalur kristjanvalur Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suppress what exactly? AttributeError? Because that exception might come from anywhere. This is surgically testing for the existence of the aclose method.
The need for this will go away with release 3.3.0 in graphql-core, where the subscribe() will return an async-generator.

Copy link
Member

@nrbnlulu nrbnlulu Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand.

Because that exception might come from anywhere

What's wrong with

with contextlib.supress(BaseException):
    async for result in aiterator:
        yield True, await process_result(result)

if hasattr(aiterator, "aclose"):
    await aiterator.aclose()

AFAIK this is the same as what you are doing...

Copy link
Contributor Author

@kristjanvalur kristjanvalur Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, you don't want to suppress BaseException. You never want to do that. CancelledError is one reason. And we would be suppressing any kind of Exception happening during iteration, including errors we want to pass to the callers, internal errors, whatnot.

No, try-finally exists precisely to do this kind of thing. In fact, that is how contextlib.aclosing() is implemented and it would be appropriate here, except that a) it requires python 3.8 and b), it assumes that aclose() is present.

with graphql-core 3.3, there will be an aclose() method present, and so aclosing() can be used. We can implement it manually if using 3.7

async for result in aiterator:
yield await process_subscribe_result(
execution_context, process_errors, extensions_runner, result
)
finally:
# grapql-core's iterator may or may not have an aclose() method
if hasattr(aiterator, "aclose"):
await aiterator.aclose()


async def process_subscribe_result(
execution_context: ExecutionContext,
process_errors: Callable[[List[GraphQLError], Optional[ExecutionContext]], None],
extensions_runner: SchemaExtensionsRunner,
result: GraphQLExecutionResult,
) -> ExecutionResult:
execution_context.result = result
# Also set errors on the execution_context so that it's easier
# to access in extensions
if result.errors:
execution_context.errors = result.errors

# Run the `Schema.process_errors` function here before
# extensions have a chance to modify them (see the MaskErrors
# extension). That way we can log the original errors but
# only return a sanitised version to the client.
process_errors(result.errors, execution_context)

return ExecutionResult(
data=execution_context.result.data,
errors=execution_context.result.errors,
extensions=await extensions_runner.get_extensions_results(),
)
32 changes: 18 additions & 14 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
AsyncGenerator,
Dict,
Iterable,
List,
Expand All @@ -20,10 +20,8 @@
GraphQLNonNull,
GraphQLSchema,
get_introspection_query,
parse,
validate_schema,
)
from graphql.execution import subscribe
from graphql.type.directives import specified_directives

from strawberry import relay
Expand All @@ -43,11 +41,10 @@
from . import compat
from .base import BaseSchema
from .config import StrawberryConfig
from .execute import execute, execute_sync
from .execute import execute, execute_sync, subscribe

if TYPE_CHECKING:
from graphql import ExecutionContext as GraphQLExecutionContext
from graphql import ExecutionResult as GraphQLExecutionResult

from strawberry.custom_scalar import ScalarDefinition, ScalarWrapper
from strawberry.directive import StrawberryDirective
Expand Down Expand Up @@ -305,22 +302,29 @@ def execute_sync(

async def subscribe(
self,
# TODO: make this optional when we support extensions
query: str,
query: Optional[str],
variable_values: Optional[Dict[str, Any]] = None,
context_value: Optional[Any] = None,
root_value: Optional[Any] = None,
operation_name: Optional[str] = None,
) -> Union[AsyncIterator[GraphQLExecutionResult], GraphQLExecutionResult]:
return await subscribe(
self._schema,
parse(query),
) -> AsyncGenerator[ExecutionResult, None]:
execution_context = ExecutionContext(
query=query,
schema=self,
context=context_value,
root_value=root_value,
context_value=context_value,
variable_values=variable_values,
operation_name=operation_name,
variables=variable_values,
provided_operation_name=operation_name,
)

async for result in subscribe(
self._schema,
extensions=self.get_extensions(),
execution_context=execution_context,
process_errors=self.process_errors,
):
yield result

def _resolve_node_ids(self):
for concrete_type in self.schema_converter.type_map.values():
type_def = concrete_type.definition
Expand Down
Loading
Loading