diff --git a/strawberry/relay/fields.py b/strawberry/relay/fields.py index 5af00700f8..a65a2ed279 100644 --- a/strawberry/relay/fields.py +++ b/strawberry/relay/fields.py @@ -23,6 +23,7 @@ Tuple, Type, Union, + UnionType, cast, overload, ) @@ -233,7 +234,21 @@ def apply(self, field: StrawberryField) -> None: f_type = f_type.resolve_type() field.type = f_type + # Handle Optional[Connection[T]] and Union[Connection[T], None] cases type_origin = get_origin(f_type) if is_generic_alias(f_type) else f_type + + # If it's Optional or Union, extract the inner type + if type_origin in (Union, UnionType): + types = getattr(f_type, "__args__", ()) + # Find the non-None type in the Union + inner_type = next((t for t in types if t is not type(None)), None) + if inner_type is not None: + type_origin = ( + get_origin(inner_type) + if is_generic_alias(inner_type) + else inner_type + ) + if not isinstance(type_origin, type) or not issubclass(type_origin, Connection): raise RelayWrongAnnotationError(field.name, cast(type, field.origin)) diff --git a/tests/relay/test_connection.py b/tests/relay/test_connection.py new file mode 100644 index 0000000000..40a8203eda --- /dev/null +++ b/tests/relay/test_connection.py @@ -0,0 +1,122 @@ +from typing import Optional, Union + +import strawberry +from strawberry.permission import BasePermission +from strawberry.relay import Connection, Node, connection + + +@strawberry.type +class User(Node): + name: str = "John" + + @classmethod + def resolve_nodes(cls, *, info, node_ids, required): + return [cls() for _ in node_ids] + + +class TestPermission(BasePermission): + message = "Not allowed" + + def has_permission(self, source, info, **kwargs): + return False + + +def test_nullable_connection_with_optional(): + @strawberry.type + class Query: + @connection + def users(self) -> Optional[Connection[User]]: + return None + + schema = strawberry.Schema(query=Query) + query = """ + query { + users { + edges { + node { + name + } + } + } + } + """ + + result = schema.execute_sync(query) + assert result.data == {"users": None} + assert not result.errors + + +def test_nullable_connection_with_union(): + @strawberry.type + class Query: + @connection + def users(self) -> Union[Connection[User], None]: + return None + + schema = strawberry.Schema(query=Query) + query = """ + query { + users { + edges { + node { + name + } + } + } + } + """ + + result = schema.execute_sync(query) + assert result.data == {"users": None} + assert not result.errors + + +def test_nullable_connection_with_permission(): + @strawberry.type + class Query: + @strawberry.permission_classes([TestPermission]) + @connection + def users(self) -> Optional[Connection[User]]: + return Connection[User](edges=[], page_info=None) + + schema = strawberry.Schema(query=Query) + query = """ + query { + users { + edges { + node { + name + } + } + } + } + """ + + result = schema.execute_sync(query) + assert result.data == {"users": None} + assert not result.errors + + +def test_non_nullable_connection(): + @strawberry.type + class Query: + @connection + def users(self) -> Connection[User]: + return Connection[User](edges=[], page_info=None) + + schema = strawberry.Schema(query=Query) + query = """ + query { + users { + edges { + node { + name + } + } + } + } + """ + + result = schema.execute_sync(query) + assert result.data == {"users": {"edges": []}} + assert not result.errors