Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nullable Connection types in relay field decorator #3706

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions strawberry/relay/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Tuple,
Type,
Union,
UnionType,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UnionType should be imported from types, but it is only available on python 3.10+

see is_union in strawberry/utils/typing.py

cast,
overload,
)
Expand Down Expand Up @@ -233,7 +234,21 @@ def apply(self, field: StrawberryField) -> None:
f_type = f_type.resolve_type()
field.type = f_type

# Handle Optional[Connection[T]] and Union[Connection[T], None] cases
type_origin = get_origin(f_type) if is_generic_alias(f_type) else f_type

# If it's Optional or Union, extract the inner type
if type_origin in (Union, UnionType):
types = getattr(f_type, "__args__", ())
# Find the non-None type in the Union
inner_type = next((t for t in types if t is not type(None)), None)
if inner_type is not None:
type_origin = (
get_origin(inner_type)
if is_generic_alias(inner_type)
else inner_type
)

if not isinstance(type_origin, type) or not issubclass(type_origin, Connection):
raise RelayWrongAnnotationError(field.name, cast(type, field.origin))

Expand Down
122 changes: 122 additions & 0 deletions tests/relay/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from typing import Optional, Union

import strawberry
from strawberry.permission import BasePermission
from strawberry.relay import Connection, Node, connection


@strawberry.type
class User(Node):
name: str = "John"

@classmethod
def resolve_nodes(cls, *, info, node_ids, required):
return [cls() for _ in node_ids]


class TestPermission(BasePermission):
message = "Not allowed"

def has_permission(self, source, info, **kwargs):
return False


def test_nullable_connection_with_optional():
@strawberry.type
class Query:
@connection
def users(self) -> Optional[Connection[User]]:
return None

schema = strawberry.Schema(query=Query)
query = """
query {
users {
edges {
node {
name
}
}
}
}
"""

result = schema.execute_sync(query)
assert result.data == {"users": None}
assert not result.errors


def test_nullable_connection_with_union():
@strawberry.type
class Query:
@connection
def users(self) -> Union[Connection[User], None]:
return None

schema = strawberry.Schema(query=Query)
query = """
query {
users {
edges {
node {
name
}
}
}
}
"""

result = schema.execute_sync(query)
assert result.data == {"users": None}
assert not result.errors


def test_nullable_connection_with_permission():
@strawberry.type
class Query:
@strawberry.permission_classes([TestPermission])
@connection
def users(self) -> Optional[Connection[User]]:
return Connection[User](edges=[], page_info=None)

schema = strawberry.Schema(query=Query)
query = """
query {
users {
edges {
node {
name
}
}
}
}
"""

result = schema.execute_sync(query)
assert result.data == {"users": None}
assert not result.errors


def test_non_nullable_connection():
@strawberry.type
class Query:
@connection
def users(self) -> Connection[User]:
return Connection[User](edges=[], page_info=None)

schema = strawberry.Schema(query=Query)
query = """
query {
users {
edges {
node {
name
}
}
}
}
"""

result = schema.execute_sync(query)
assert result.data == {"users": {"edges": []}}
assert not result.errors
Loading