Skip to content

Add SchemaExtension.should_await #3724

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions strawberry/extensions/base_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@
"""Called before and after the execution step."""
yield None

def should_await(self, _next: Callable) -> bool:
"""Whether the extension should await the result from next.

This relies on the _is_async attribute set by the `SchemaConverter`,
normally you'd use `inspect.isawaitable` instead, but that has
some performance hits, especially because we know if the resolver
is async or not at schema creation time.
"""
return _next._is_async # type: ignore

Check warning on line 61 in strawberry/extensions/base_extension.py

View check run for this annotation

Codecov / codecov/patch

strawberry/extensions/base_extension.py#L61

Added line #L61 was not covered by tests

def resolve(
self,
_next: Callable,
Expand Down
2 changes: 2 additions & 0 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,11 @@ async def _async_resolver(

if field.is_async:
_async_resolver._is_default = not field.base_resolver # type: ignore
_async_resolver._is_async = True # type: ignore
return _async_resolver
else:
_resolver._is_default = not field.base_resolver # type: ignore
_resolver._is_async = False # type: ignore
return _resolver

def from_scalar(self, scalar: Type) -> GraphQLScalarType:
Expand Down
2 changes: 1 addition & 1 deletion tests/benchmarks/test_execute_with_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_results(self) -> AwaitableOrValue[Dict[str, Any]]:
class ResolveExtension(SchemaExtension):
async def resolve(self, _next, root, info, *args: Any, **kwargs: Any) -> Any:
result = _next(root, info, *args, **kwargs)
if isawaitable(result):
if (hasattr(_next, "_is_async") and _next._is_async) or isawaitable(result):
result = await result
return result

Expand Down
Loading