Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 1, 2024
1 parent 5f2e4e1 commit a444dd2
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 90 deletions.
6 changes: 4 additions & 2 deletions strawberry/extensions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,13 @@ def from_callable(
extension: SchemaExtension,
func: Callable[[SchemaExtension, ExecutionContext], AwaitableOrValue[Any]],
) -> WrappedHook:
self_ = extension
self_ = extension
if iscoroutinefunction(func):

@contextlib.asynccontextmanager
async def iterator(execution_context: ExecutionContext) -> AsyncIterator[None]:
async def iterator(
execution_context: ExecutionContext,
) -> AsyncIterator[None]:
await func(self_, execution_context)
yield

Expand Down
157 changes: 78 additions & 79 deletions tests/schema/extensions/schema_extensions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,115 +7,114 @@

import strawberry
from strawberry.extensions import SchemaExtension

from strawberry.types.execution import ExecutionContext


@dataclasses.dataclass
class SchemaHelper:
query_type: type
subscription_type: type
query: str
subscription: str
query_type: type
subscription_type: type
query: str
subscription: str


class ExampleExtension(SchemaExtension):
def __init_subclass__(cls, **kwargs: Any):
super().__init_subclass__(**kwargs)
cls.called_hooks = []

expected = [
"on_operation Entered",
"on_parse Entered",
"on_parse Exited",
"on_validate Entered",
"on_validate Exited",
"on_execute Entered",
"resolve",
"resolve",
"on_execute Exited",
"on_operation Exited",
"get_results",
]
called_hooks: List[str]

@classmethod
def assert_expected(cls) -> None:
assert cls.called_hooks == cls.expected
def __init_subclass__(cls, **kwargs: Any):
super().__init_subclass__(**kwargs)
cls.called_hooks = []

expected = [
"on_operation Entered",
"on_parse Entered",
"on_parse Exited",
"on_validate Entered",
"on_validate Exited",
"on_execute Entered",
"resolve",
"resolve",
"on_execute Exited",
"on_operation Exited",
"get_results",
]
called_hooks: List[str]

@classmethod
def assert_expected(cls) -> None:
assert cls.called_hooks == cls.expected


@pytest.fixture()
def default_query_types_and_query() -> SchemaHelper:
@strawberry.type
class Person:
name: str = "Jess"

@strawberry.type
class Query:
@strawberry.field
def person(self) -> Person:
return Person()

@strawberry.type
class Subscription:
@strawberry.subscription
async def count(self) -> AsyncGenerator[int, None]:
for i in range(5):
yield i

subscription = "subscription TestSubscribe { count }"
query = "query TestQuery { person { name } }"
return SchemaHelper(
query_type=Query,
query=query,
subscription_type=Subscription,
subscription=subscription,
)
@strawberry.type
class Person:
name: str = "Jess"

@strawberry.type
class Query:
@strawberry.field
def person(self) -> Person:
return Person()

@strawberry.type
class Subscription:
@strawberry.subscription
async def count(self) -> AsyncGenerator[int, None]:
for i in range(5):
yield i

subscription = "subscription TestSubscribe { count }"
query = "query TestQuery { person { name } }"
return SchemaHelper(
query_type=Query,
query=query,
subscription_type=Subscription,
subscription=subscription,
)


class ExecType(enum.Enum):
SYNC = enum.auto()
ASYNC = enum.auto()
SYNC = enum.auto()
ASYNC = enum.auto()

def is_async(self) -> bool:
return self == ExecType.ASYNC
def is_async(self) -> bool:
return self == ExecType.ASYNC


@pytest.fixture(params=[ExecType.ASYNC, ExecType.SYNC])
def exec_type(request: pytest.FixtureRequest) -> ExecType:
return request.param
return request.param


@contextlib.contextmanager
def hook_wrap(list_: List[str], hook_name: str):
list_.append(f"{hook_name} Entered")
try:
yield
finally:
list_.append(f"{hook_name} Exited")
list_.append(f"{hook_name} Entered")
try:
yield
finally:
list_.append(f"{hook_name} Exited")


@pytest.fixture()
def async_extension() -> Type[ExampleExtension]:
class MyExtension(ExampleExtension):
async def on_validate(self, execution_context: ExecutionContext):
with hook_wrap(self.called_hooks, SchemaExtension.on_validate.__name__):
yield
class MyExtension(ExampleExtension):
async def on_validate(self, execution_context: ExecutionContext):
with hook_wrap(self.called_hooks, SchemaExtension.on_validate.__name__):
yield

async def on_parse(self, execution_context: ExecutionContext):
with hook_wrap(self.called_hooks, SchemaExtension.on_parse.__name__):
yield
async def on_parse(self, execution_context: ExecutionContext):
with hook_wrap(self.called_hooks, SchemaExtension.on_parse.__name__):
yield

async def on_execute(self, execution_context: ExecutionContext):
with hook_wrap(self.called_hooks, SchemaExtension.on_execute.__name__):
yield
async def on_execute(self, execution_context: ExecutionContext):
with hook_wrap(self.called_hooks, SchemaExtension.on_execute.__name__):
yield

async def get_results(self, execution_context: ExecutionContext):
self.called_hooks.append("get_results")
return {"example": "example"}
async def get_results(self, execution_context: ExecutionContext):
self.called_hooks.append("get_results")
return {"example": "example"}

async def resolve(self, _next, root, info, *args: str, **kwargs: Any):
self.called_hooks.append("resolve")
return _next(root, info, *args, **kwargs)
async def resolve(self, _next, root, info, *args: str, **kwargs: Any):
self.called_hooks.append("resolve")
return _next(root, info, *args, **kwargs)

return MyExtension
return MyExtension
29 changes: 20 additions & 9 deletions tests/schema/extensions/schema_extensions/test_extensions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import contextlib
import json
from unittest import mock
import warnings
from typing import Any, List, Optional, Type
from unittest import mock
from unittest.mock import patch

import pytest
Expand All @@ -13,10 +12,8 @@
import strawberry
from strawberry.exceptions import StrawberryGraphQLError
from strawberry.extensions import SchemaExtension

from strawberry.types.execution import ExecutionContext


from .conftest import ExampleExtension, ExecType, SchemaHelper, hook_wrap


Expand Down Expand Up @@ -49,9 +46,20 @@ def person(self) -> Person:


def test_called_only_if_overridden(monkeypatch: pytest.MonkeyPatch) -> None:
from strawberry.extensions.context import OperationContextManager, ParsingContextManager, ValidationContextManager, ExecutingContextManager
from strawberry.extensions.context import (
ExecutingContextManager,
OperationContextManager,
ParsingContextManager,
ValidationContextManager,
)

hooks_mock = mock.Mock()
for manager in [OperationContextManager, ParsingContextManager, ValidationContextManager, ExecutingContextManager]:
for manager in [
OperationContextManager,
ParsingContextManager,
ValidationContextManager,
ExecutingContextManager,
]:
monkeypatch.setattr(manager, "DEFAULT_HOOK", hooks_mock)

@strawberry.type
Expand All @@ -63,8 +71,10 @@ class Query:
@strawberry.field
def person(self) -> Person:
return Person()

class ExtNoHooks(SchemaExtension):
pass

schema = strawberry.Schema(query=Query, extensions=[ExtNoHooks])

query = """
Expand All @@ -80,8 +90,7 @@ class ExtNoHooks(SchemaExtension):
assert not result.errors

assert result.extensions == {}
hooks_mock.assert_not_called()

hooks_mock.assert_not_called()


def test_extension_access_to_parsed_document():
Expand Down Expand Up @@ -395,6 +404,7 @@ def on_executing_start(self): ...
assert len(result.errors) == 1
assert isinstance(result.errors[0].original_error, ValueError)


def test_warning_about_async_get_results_hooks_in_sync_context():
class MyExtension(SchemaExtension):
async def get_results(self, execution_context: ExecutionContext):
Expand All @@ -419,11 +429,12 @@ class ExceptionTestingExtension(SchemaExtension):
def __init__(self, failing_hook: str):
self.failing_hook = failing_hook
self.called_hooks = set()

def on_operation(self, execution_context: ExecutionContext):
if self.failing_hook == "on_operation_start":
raise Exception(self.failing_hook)
self.called_hooks.add(1)

with contextlib.suppress(Exception):
yield

Expand Down

0 comments on commit a444dd2

Please sign in to comment.