Skip to content

Commit

Permalink
Initial working version
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick91 committed Nov 20, 2024
1 parent c61f422 commit 1b82b34
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 59 deletions.
2 changes: 1 addition & 1 deletion strawberry/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def create_optional(self, evaled_type: Any) -> StrawberryOptional:
)

# Note that passing a single type to `Union` is equivalent to not using `Union`
# at all. This allows us to not di any checks for how many types have been
# at all. This allows us to not do any checks for how many types have been
# passed as we can safely use `Union` for both optional types
# (e.g. `Optional[str]`) and optional unions (e.g.
# `Optional[Union[TypeA, TypeB]]`)
Expand Down
45 changes: 24 additions & 21 deletions strawberry/relay/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@
Tuple,
Type,
Union,
UnionType,
cast,
overload,
)
from typing_extensions import Annotated, get_origin
from typing_extensions import Annotated, get_args, get_origin

from strawberry.annotation import StrawberryAnnotation
from strawberry.extensions.field_extension import (
Expand All @@ -45,9 +44,9 @@
from strawberry.types.fields.resolver import StrawberryResolver
from strawberry.types.lazy_type import LazyType
from strawberry.utils.aio import asyncgen_to_list
from strawberry.utils.typing import eval_type, is_generic_alias
from strawberry.utils.typing import eval_type, is_generic_alias, is_union

from .types import Connection, GlobalID, Node, NodeIterableType, NodeType
from .types import Connection, GlobalID, Node, NodeIterableType

if TYPE_CHECKING:
from typing_extensions import Literal
Expand Down Expand Up @@ -234,20 +233,10 @@ def apply(self, field: StrawberryField) -> None:
f_type = f_type.resolve_type()
field.type = f_type

# Handle Optional[Connection[T]] and Union[Connection[T], None] cases
type_origin = get_origin(f_type) if is_generic_alias(f_type) else f_type
if isinstance(f_type, StrawberryOptional):
f_type = f_type.of_type

# If it's Optional or Union, extract the inner type
if type_origin in (Union, UnionType):
types = getattr(f_type, "__args__", ())
# Find the non-None type in the Union
inner_type = next((t for t in types if t is not type(None)), None)
if inner_type is not None:
type_origin = (
get_origin(inner_type)
if is_generic_alias(inner_type)
else inner_type
)
type_origin = get_origin(f_type) if is_generic_alias(f_type) else f_type

if not isinstance(type_origin, type) or not issubclass(type_origin, Connection):
raise RelayWrongAnnotationError(field.name, cast(type, field.origin))
Expand All @@ -268,13 +257,18 @@ def apply(self, field: StrawberryField) -> None:
None,
)

if is_union(resolver_type):
# TODO: actually check if is optional and get correct type
resolver_type = get_args(resolver_type)[0]

origin = get_origin(resolver_type)

if origin is None or not issubclass(
origin, (Iterator, Iterable, AsyncIterator, AsyncIterable)
):
raise RelayWrongResolverAnnotationError(field.name, field.base_resolver)

self.connection_type = cast(Type[Connection[Node]], field.type)
self.connection_type = cast(Type[Connection[Node]], f_type)

def resolve(
self,
Expand Down Expand Up @@ -342,9 +336,18 @@ def node(*args: Any, **kwargs: Any) -> StrawberryField:
return field(*args, **kwargs)


# we used to have `Type[Connection[NodeType]]` here, but that when we added
# support for making the Connection type optional, we had to change it to
# `Any` because otherwise it wouldn't be type check since `Optional[Connection[Something]]`
# is not a `Type`, but a special form, see https://discuss.python.org/t/is-annotated-compatible-with-type-t/43898/46
# for more information, and also https://peps.python.org/pep-0747/, which is currently
# in draft status (and no type checker supports it yet)
ConnectionGraphQLType = Any


@overload
def connection(
graphql_type: Optional[Type[Connection[NodeType]]] = None,
graphql_type: Optional[ConnectionGraphQLType] = None,
*,
resolver: Optional[_RESOLVER_TYPE[NodeIterableType[Any]]] = None,
name: Optional[str] = None,
Expand All @@ -363,7 +366,7 @@ def connection(

@overload
def connection(
graphql_type: Optional[Type[Connection[NodeType]]] = None,
graphql_type: Optional[ConnectionGraphQLType] = None,
*,
name: Optional[str] = None,
is_subscription: bool = False,
Expand All @@ -379,7 +382,7 @@ def connection(


def connection(
graphql_type: Optional[Type[Connection[NodeType]]] = None,
graphql_type: Optional[ConnectionGraphQLType] = None,
*,
resolver: Optional[_RESOLVER_TYPE[Any]] = None,
name: Optional[str] = None,
Expand Down
77 changes: 40 additions & 37 deletions tests/relay/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,40 @@
from typing import Optional, Union
import sys
from typing import Any, Iterable, List, Optional
from typing_extensions import Self

import pytest

import strawberry
from strawberry.permission import BasePermission
from strawberry.relay import Connection, Node, connection
from strawberry.relay import Connection, Node


@strawberry.type
class User(Node):
id: strawberry.relay.NodeID
name: str = "John"

@classmethod
def resolve_nodes(cls, *, info, node_ids, required):
return [cls() for _ in node_ids]

Check warning on line 19 in tests/relay/test_connection.py

View check run for this annotation

Codecov / codecov/patch

tests/relay/test_connection.py#L19

Added line #L19 was not covered by tests


@strawberry.type
class UserConnection(Connection[User]):
@classmethod
def resolve_connection(
cls,
nodes: Iterable[User],
*,
info: Any,
after: Optional[str] = None,
before: Optional[str] = None,
first: Optional[int] = None,
last: Optional[int] = None,
) -> Optional[Self]:
return None


class TestPermission(BasePermission):
message = "Not allowed"

Expand All @@ -24,8 +45,8 @@ def has_permission(self, source, info, **kwargs):
def test_nullable_connection_with_optional():
@strawberry.type
class Query:
@connection
def users(self) -> Optional[Connection[User]]:
@strawberry.relay.connection(Optional[UserConnection])
def users(self) -> Optional[List[User]]:
return None

schema = strawberry.Schema(query=Query)
Expand All @@ -46,11 +67,17 @@ def users(self) -> Optional[Connection[User]]:
assert not result.errors


def test_nullable_connection_with_union():
pytest.mark.skipif(
sys.version_info < (3, 10),
reason="pipe syntax for union is only available on python 3.10+",
)


def test_nullable_connection_with_pipe():
@strawberry.type
class Query:
@connection
def users(self) -> Union[Connection[User], None]:
@strawberry.relay.connection(UserConnection | None)
def users(self) -> List[User] | None:
return None

schema = strawberry.Schema(query=Query)
Expand All @@ -74,10 +101,11 @@ def users(self) -> Union[Connection[User], None]:
def test_nullable_connection_with_permission():
@strawberry.type
class Query:
@strawberry.permission_classes([TestPermission])
@connection
def users(self) -> Optional[Connection[User]]:
return Connection[User](edges=[], page_info=None)
@strawberry.relay.connection(
Optional[UserConnection], permission_classes=[TestPermission]
)
def users(self) -> Optional[List[User]]:
return None

Check warning on line 108 in tests/relay/test_connection.py

View check run for this annotation

Codecov / codecov/patch

tests/relay/test_connection.py#L108

Added line #L108 was not covered by tests

schema = strawberry.Schema(query=Query)
query = """
Expand All @@ -94,29 +122,4 @@ def users(self) -> Optional[Connection[User]]:

result = schema.execute_sync(query)
assert result.data == {"users": None}
assert not result.errors


def test_non_nullable_connection():
@strawberry.type
class Query:
@connection
def users(self) -> Connection[User]:
return Connection[User](edges=[], page_info=None)

schema = strawberry.Schema(query=Query)
query = """
query {
users {
edges {
node {
name
}
}
}
}
"""

result = schema.execute_sync(query)
assert result.data == {"users": {"edges": []}}
assert not result.errors
assert result.errors[0].message == "Not allowed"

0 comments on commit 1b82b34

Please sign in to comment.