diff --git a/strawberry/annotation.py b/strawberry/annotation.py index dff708a1a1..8934f5bad8 100644 --- a/strawberry/annotation.py +++ b/strawberry/annotation.py @@ -208,7 +208,7 @@ def create_optional(self, evaled_type: Any) -> StrawberryOptional: ) # Note that passing a single type to `Union` is equivalent to not using `Union` - # at all. This allows us to not di any checks for how many types have been + # at all. This allows us to not do any checks for how many types have been # passed as we can safely use `Union` for both optional types # (e.g. `Optional[str]`) and optional unions (e.g. # `Optional[Union[TypeA, TypeB]]`) diff --git a/strawberry/relay/fields.py b/strawberry/relay/fields.py index a65a2ed279..ac002fa978 100644 --- a/strawberry/relay/fields.py +++ b/strawberry/relay/fields.py @@ -23,11 +23,10 @@ Tuple, Type, Union, - UnionType, cast, overload, ) -from typing_extensions import Annotated, get_origin +from typing_extensions import Annotated, get_args, get_origin from strawberry.annotation import StrawberryAnnotation from strawberry.extensions.field_extension import ( @@ -45,9 +44,9 @@ from strawberry.types.fields.resolver import StrawberryResolver from strawberry.types.lazy_type import LazyType from strawberry.utils.aio import asyncgen_to_list -from strawberry.utils.typing import eval_type, is_generic_alias +from strawberry.utils.typing import eval_type, is_generic_alias, is_union -from .types import Connection, GlobalID, Node, NodeIterableType, NodeType +from .types import Connection, GlobalID, Node, NodeIterableType if TYPE_CHECKING: from typing_extensions import Literal @@ -234,20 +233,10 @@ def apply(self, field: StrawberryField) -> None: f_type = f_type.resolve_type() field.type = f_type - # Handle Optional[Connection[T]] and Union[Connection[T], None] cases - type_origin = get_origin(f_type) if is_generic_alias(f_type) else f_type + if isinstance(f_type, StrawberryOptional): + f_type = f_type.of_type - # If it's Optional or Union, extract the inner type - if type_origin in (Union, UnionType): - types = getattr(f_type, "__args__", ()) - # Find the non-None type in the Union - inner_type = next((t for t in types if t is not type(None)), None) - if inner_type is not None: - type_origin = ( - get_origin(inner_type) - if is_generic_alias(inner_type) - else inner_type - ) + type_origin = get_origin(f_type) if is_generic_alias(f_type) else f_type if not isinstance(type_origin, type) or not issubclass(type_origin, Connection): raise RelayWrongAnnotationError(field.name, cast(type, field.origin)) @@ -268,13 +257,18 @@ def apply(self, field: StrawberryField) -> None: None, ) + if is_union(resolver_type): + # TODO: actually check if is optional and get correct type + resolver_type = get_args(resolver_type)[0] + origin = get_origin(resolver_type) + if origin is None or not issubclass( origin, (Iterator, Iterable, AsyncIterator, AsyncIterable) ): raise RelayWrongResolverAnnotationError(field.name, field.base_resolver) - self.connection_type = cast(Type[Connection[Node]], field.type) + self.connection_type = cast(Type[Connection[Node]], f_type) def resolve( self, @@ -342,9 +336,18 @@ def node(*args: Any, **kwargs: Any) -> StrawberryField: return field(*args, **kwargs) +# we used to have `Type[Connection[NodeType]]` here, but that when we added +# support for making the Connection type optional, we had to change it to +# `Any` because otherwise it wouldn't be type check since `Optional[Connection[Something]]` +# is not a `Type`, but a special form, see https://discuss.python.org/t/is-annotated-compatible-with-type-t/43898/46 +# for more information, and also https://peps.python.org/pep-0747/, which is currently +# in draft status (and no type checker supports it yet) +ConnectionGraphQLType = Any + + @overload def connection( - graphql_type: Optional[Type[Connection[NodeType]]] = None, + graphql_type: Optional[ConnectionGraphQLType] = None, *, resolver: Optional[_RESOLVER_TYPE[NodeIterableType[Any]]] = None, name: Optional[str] = None, @@ -363,7 +366,7 @@ def connection( @overload def connection( - graphql_type: Optional[Type[Connection[NodeType]]] = None, + graphql_type: Optional[ConnectionGraphQLType] = None, *, name: Optional[str] = None, is_subscription: bool = False, @@ -379,7 +382,7 @@ def connection( def connection( - graphql_type: Optional[Type[Connection[NodeType]]] = None, + graphql_type: Optional[ConnectionGraphQLType] = None, *, resolver: Optional[_RESOLVER_TYPE[Any]] = None, name: Optional[str] = None, diff --git a/tests/relay/test_connection.py b/tests/relay/test_connection.py index 40a8203eda..3595a8c7a2 100644 --- a/tests/relay/test_connection.py +++ b/tests/relay/test_connection.py @@ -1,12 +1,17 @@ -from typing import Optional, Union +import sys +from typing import Any, Iterable, List, Optional +from typing_extensions import Self + +import pytest import strawberry from strawberry.permission import BasePermission -from strawberry.relay import Connection, Node, connection +from strawberry.relay import Connection, Node @strawberry.type class User(Node): + id: strawberry.relay.NodeID name: str = "John" @classmethod @@ -14,6 +19,22 @@ def resolve_nodes(cls, *, info, node_ids, required): return [cls() for _ in node_ids] +@strawberry.type +class UserConnection(Connection[User]): + @classmethod + def resolve_connection( + cls, + nodes: Iterable[User], + *, + info: Any, + after: Optional[str] = None, + before: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + ) -> Optional[Self]: + return None + + class TestPermission(BasePermission): message = "Not allowed" @@ -24,8 +45,8 @@ def has_permission(self, source, info, **kwargs): def test_nullable_connection_with_optional(): @strawberry.type class Query: - @connection - def users(self) -> Optional[Connection[User]]: + @strawberry.relay.connection(Optional[UserConnection]) + def users(self) -> Optional[List[User]]: return None schema = strawberry.Schema(query=Query) @@ -46,11 +67,17 @@ def users(self) -> Optional[Connection[User]]: assert not result.errors -def test_nullable_connection_with_union(): +pytest.mark.skipif( + sys.version_info < (3, 10), + reason="pipe syntax for union is only available on python 3.10+", +) + + +def test_nullable_connection_with_pipe(): @strawberry.type class Query: - @connection - def users(self) -> Union[Connection[User], None]: + @strawberry.relay.connection(UserConnection | None) + def users(self) -> List[User] | None: return None schema = strawberry.Schema(query=Query) @@ -74,10 +101,11 @@ def users(self) -> Union[Connection[User], None]: def test_nullable_connection_with_permission(): @strawberry.type class Query: - @strawberry.permission_classes([TestPermission]) - @connection - def users(self) -> Optional[Connection[User]]: - return Connection[User](edges=[], page_info=None) + @strawberry.relay.connection( + Optional[UserConnection], permission_classes=[TestPermission] + ) + def users(self) -> Optional[List[User]]: + return None schema = strawberry.Schema(query=Query) query = """ @@ -94,29 +122,4 @@ def users(self) -> Optional[Connection[User]]: result = schema.execute_sync(query) assert result.data == {"users": None} - assert not result.errors - - -def test_non_nullable_connection(): - @strawberry.type - class Query: - @connection - def users(self) -> Connection[User]: - return Connection[User](edges=[], page_info=None) - - schema = strawberry.Schema(query=Query) - query = """ - query { - users { - edges { - node { - name - } - } - } - } - """ - - result = schema.execute_sync(query) - assert result.data == {"users": {"edges": []}} - assert not result.errors + assert result.errors[0].message == "Not allowed"