Skip to content

Commit

Permalink
all core extension tests pass & lints
Browse files Browse the repository at this point in the history
  • Loading branch information
nrbnlulu committed Oct 1, 2024
1 parent 2c59c71 commit e60f1ef
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 30 deletions.
6 changes: 3 additions & 3 deletions strawberry/extensions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
)

from strawberry.extensions import SchemaExtension
from strawberry.types.execution import ExecutionContext
from strawberry.utils.await_maybe import AwaitableOrValue

if TYPE_CHECKING:
from types import TracebackType

from strawberry.extensions.base_extension import Hook
from strawberry.types.execution import ExecutionContext
from strawberry.utils.await_maybe import AwaitableOrValue


class WrappedHook(NamedTuple):
Expand Down Expand Up @@ -94,7 +94,7 @@ def get_hook(cls, extension: SchemaExtension) -> Optional[WrappedHook]:
return cls.from_callable(extension, hook_fn)

raise ValueError(
f"Hook {self.HOOK_NAME} on {extension} "
f"Hook {cls.HOOK_NAME} on {extension} "
f"must be callable, received {hook_fn!r}"
)

Expand Down
4 changes: 4 additions & 0 deletions tests/schema/extensions/schema_extensions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def hook_wrap(list_: List[str], hook_name: str):
@pytest.fixture()
def async_extension() -> Type[ExampleExtension]:
class MyExtension(ExampleExtension):
async def on_operation(self, execution_context: ExecutionContext):
with hook_wrap(self.called_hooks, SchemaExtension.on_operation.__name__):
yield

async def on_validate(self, execution_context: ExecutionContext):
with hook_wrap(self.called_hooks, SchemaExtension.on_validate.__name__):
yield
Expand Down
32 changes: 6 additions & 26 deletions tests/schema/extensions/schema_extensions/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,21 +390,6 @@ async def on_parse(self, execution_context: ExecutionContext):
SyncExt.assert_expected()


def test_raise_if_defined_both_legacy_and_new_style(default_query_types_and_query):
class WrongUsageExtension(SchemaExtension):
def on_execute(self, execution_context: ExecutionContext):
yield

def on_executing_start(self): ...

schema = strawberry.Schema(
query=default_query_types_and_query.query_type, extensions=[WrongUsageExtension]
)
result = schema.execute_sync(default_query_types_and_query.query)
assert len(result.errors) == 1
assert isinstance(result.errors[0].original_error, ValueError)


def test_warning_about_async_get_results_hooks_in_sync_context():
class MyExtension(SchemaExtension):
async def get_results(self, execution_context: ExecutionContext):
Expand Down Expand Up @@ -632,7 +617,6 @@ class MyExtension(SchemaExtension):
def on_parse(self, execution_context: ExecutionContext):
nonlocal execution_errors
yield
execution_context = execution_context
execution_errors = execution_context.errors

@strawberry.type
Expand Down Expand Up @@ -779,7 +763,6 @@ def test_execution_cache_example(mock_original_execute):
class ExecutionCache(SchemaExtension):
def on_execute(self, execution_context: ExecutionContext):
# Check if we've come across this query before
execution_context = execution_context
self.cache_key = (
f"{execution_context.query}:{json.dumps(execution_context.variables)}"
)
Expand Down Expand Up @@ -849,7 +832,6 @@ def test_execution_reject_example(mock_original_execute):
class RejectSomeQueries(SchemaExtension):
def on_execute(self, execution_context: ExecutionContext):
# Reject all operations called "RejectMe"
execution_context = execution_context
if execution_context.operation_name == "RejectMe":
execution_context.result = GraphQLExecutionResult(
data=None,
Expand Down Expand Up @@ -986,11 +968,9 @@ def test_raise_if_hook_is_not_callable(default_query_types_and_query: SchemaHelp
class MyExtension(SchemaExtension):
on_operation = "ABC" # type: ignore

schema = strawberry.Schema(
query=default_query_types_and_query.query_type, extensions=[MyExtension]
)
result = schema.execute_sync(default_query_types_and_query.query)
assert len(result.errors) == 1
assert isinstance(result.errors[0].original_error, ValueError)
assert result.errors[0].message.startswith("Hook on_operation on <")
assert result.errors[0].message.endswith("> must be callable, received 'ABC'")
with pytest.raises(ValueError) as exc_info:
schema = strawberry.Schema(
query=default_query_types_and_query.query_type, extensions=[MyExtension]
)
_ = schema.execute_sync(default_query_types_and_query.query)
assert exc_info.match("Hook on_operation on <.*MyExtension.*> must be callable")
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import strawberry
from strawberry.extensions import SchemaExtension
from strawberry.types.execution import ExecutionResult
from strawberry.types.execution import ExecutionContext, ExecutionResult
from tests.conftest import skip_if_gql_32

from .conftest import ExampleExtension, SchemaHelper
Expand Down
1 change: 1 addition & 0 deletions tests/types/test_execution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import strawberry
from strawberry.extensions import SchemaExtension
from strawberry.types.execution import ExecutionContext


@strawberry.type
Expand Down

0 comments on commit e60f1ef

Please sign in to comment.