diff --git a/strawberry/extensions/context.py b/strawberry/extensions/context.py index 225f2bbbf5..87fade02b3 100644 --- a/strawberry/extensions/context.py +++ b/strawberry/extensions/context.py @@ -20,13 +20,13 @@ ) from strawberry.extensions import SchemaExtension -from strawberry.types.execution import ExecutionContext -from strawberry.utils.await_maybe import AwaitableOrValue if TYPE_CHECKING: from types import TracebackType from strawberry.extensions.base_extension import Hook + from strawberry.types.execution import ExecutionContext + from strawberry.utils.await_maybe import AwaitableOrValue class WrappedHook(NamedTuple): @@ -94,7 +94,7 @@ def get_hook(cls, extension: SchemaExtension) -> Optional[WrappedHook]: return cls.from_callable(extension, hook_fn) raise ValueError( - f"Hook {self.HOOK_NAME} on {extension} " + f"Hook {cls.HOOK_NAME} on {extension} " f"must be callable, received {hook_fn!r}" ) diff --git a/tests/schema/extensions/schema_extensions/conftest.py b/tests/schema/extensions/schema_extensions/conftest.py index 2f22310878..afd47a28cc 100644 --- a/tests/schema/extensions/schema_extensions/conftest.py +++ b/tests/schema/extensions/schema_extensions/conftest.py @@ -97,6 +97,10 @@ def hook_wrap(list_: List[str], hook_name: str): @pytest.fixture() def async_extension() -> Type[ExampleExtension]: class MyExtension(ExampleExtension): + async def on_operation(self, execution_context: ExecutionContext): + with hook_wrap(self.called_hooks, SchemaExtension.on_operation.__name__): + yield + async def on_validate(self, execution_context: ExecutionContext): with hook_wrap(self.called_hooks, SchemaExtension.on_validate.__name__): yield diff --git a/tests/schema/extensions/schema_extensions/test_extensions.py b/tests/schema/extensions/schema_extensions/test_extensions.py index 264c97f2cc..e4cae542a8 100644 --- a/tests/schema/extensions/schema_extensions/test_extensions.py +++ b/tests/schema/extensions/schema_extensions/test_extensions.py @@ -390,21 +390,6 @@ async def on_parse(self, execution_context: ExecutionContext): SyncExt.assert_expected() -def test_raise_if_defined_both_legacy_and_new_style(default_query_types_and_query): - class WrongUsageExtension(SchemaExtension): - def on_execute(self, execution_context: ExecutionContext): - yield - - def on_executing_start(self): ... - - schema = strawberry.Schema( - query=default_query_types_and_query.query_type, extensions=[WrongUsageExtension] - ) - result = schema.execute_sync(default_query_types_and_query.query) - assert len(result.errors) == 1 - assert isinstance(result.errors[0].original_error, ValueError) - - def test_warning_about_async_get_results_hooks_in_sync_context(): class MyExtension(SchemaExtension): async def get_results(self, execution_context: ExecutionContext): @@ -632,7 +617,6 @@ class MyExtension(SchemaExtension): def on_parse(self, execution_context: ExecutionContext): nonlocal execution_errors yield - execution_context = execution_context execution_errors = execution_context.errors @strawberry.type @@ -779,7 +763,6 @@ def test_execution_cache_example(mock_original_execute): class ExecutionCache(SchemaExtension): def on_execute(self, execution_context: ExecutionContext): # Check if we've come across this query before - execution_context = execution_context self.cache_key = ( f"{execution_context.query}:{json.dumps(execution_context.variables)}" ) @@ -849,7 +832,6 @@ def test_execution_reject_example(mock_original_execute): class RejectSomeQueries(SchemaExtension): def on_execute(self, execution_context: ExecutionContext): # Reject all operations called "RejectMe" - execution_context = execution_context if execution_context.operation_name == "RejectMe": execution_context.result = GraphQLExecutionResult( data=None, @@ -986,11 +968,9 @@ def test_raise_if_hook_is_not_callable(default_query_types_and_query: SchemaHelp class MyExtension(SchemaExtension): on_operation = "ABC" # type: ignore - schema = strawberry.Schema( - query=default_query_types_and_query.query_type, extensions=[MyExtension] - ) - result = schema.execute_sync(default_query_types_and_query.query) - assert len(result.errors) == 1 - assert isinstance(result.errors[0].original_error, ValueError) - assert result.errors[0].message.startswith("Hook on_operation on <") - assert result.errors[0].message.endswith("> must be callable, received 'ABC'") + with pytest.raises(ValueError) as exc_info: + schema = strawberry.Schema( + query=default_query_types_and_query.query_type, extensions=[MyExtension] + ) + _ = schema.execute_sync(default_query_types_and_query.query) + assert exc_info.match("Hook on_operation on <.*MyExtension.*> must be callable") diff --git a/tests/schema/extensions/schema_extensions/test_subscription.py b/tests/schema/extensions/schema_extensions/test_subscription.py index 1c86e29a39..1ebe133ebc 100644 --- a/tests/schema/extensions/schema_extensions/test_subscription.py +++ b/tests/schema/extensions/schema_extensions/test_subscription.py @@ -4,7 +4,7 @@ import strawberry from strawberry.extensions import SchemaExtension -from strawberry.types.execution import ExecutionResult +from strawberry.types.execution import ExecutionContext, ExecutionResult from tests.conftest import skip_if_gql_32 from .conftest import ExampleExtension, SchemaHelper diff --git a/tests/types/test_execution.py b/tests/types/test_execution.py index 917541ee53..34e43f5777 100644 --- a/tests/types/test_execution.py +++ b/tests/types/test_execution.py @@ -1,5 +1,6 @@ import strawberry from strawberry.extensions import SchemaExtension +from strawberry.types.execution import ExecutionContext @strawberry.type