Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for generic unions #3515

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch
enoua5 marked this conversation as resolved.
Show resolved Hide resolved

Attempt to merge union types during schema conversion.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the last change this is probably a minor now, since it's a brand new feature (plus a fix) 😊

21 changes: 15 additions & 6 deletions strawberry/schema/name_converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List, Optional, Union, cast
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
from typing_extensions import Protocol

from strawberry.custom_scalar import ScalarDefinition
Expand Down Expand Up @@ -107,8 +107,14 @@ def from_union(self, union: StrawberryUnion) -> str:
return union.graphql_name

name = ""
types: Tuple[StrawberryType, ...] = union.types

for type_ in union.types:
if union.concrete_of and union.concrete_of.graphql_name:
concrete_of_types = set(union.concrete_of.types)

types = tuple(type_ for type_ in types if type_ not in concrete_of_types)

for type_ in types:
if isinstance(type_, LazyType):
type_ = cast("StrawberryType", type_.resolve_type()) # noqa: PLW2901

Expand All @@ -121,6 +127,9 @@ def from_union(self, union: StrawberryUnion) -> str:

name += type_name

if union.concrete_of and union.concrete_of.graphql_name:
name += union.concrete_of.graphql_name

return name

def from_generic(
Expand All @@ -133,12 +142,12 @@ def from_generic(
names: List[str] = []

for type_ in types:
name = self.get_from_type(type_)
name = self.get_name_from_type(type_)
names.append(name)

return "".join(names) + generic_type_name

def get_from_type(self, type_: Union[StrawberryType, type]) -> str:
def get_name_from_type(self, type_: Union[StrawberryType, type]) -> str:
type_ = eval_type(type_)

if isinstance(type_, LazyType):
Expand All @@ -148,9 +157,9 @@ def get_from_type(self, type_: Union[StrawberryType, type]) -> str:
elif isinstance(type_, StrawberryUnion):
name = type_.graphql_name if type_.graphql_name else self.from_union(type_)
elif isinstance(type_, StrawberryList):
name = self.get_from_type(type_.of_type) + "List"
name = self.get_name_from_type(type_.of_type) + "List"
elif isinstance(type_, StrawberryOptional):
name = self.get_from_type(type_.of_type) + "Optional"
name = self.get_name_from_type(type_.of_type) + "Optional"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the name of this function 😊

elif hasattr(type_, "_scalar_definition"):
strawberry_type = type_._scalar_definition

Expand Down
10 changes: 8 additions & 2 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,14 +851,20 @@ def from_union(self, union: StrawberryUnion) -> GraphQLUnionType:
return graphql_union

graphql_types: List[GraphQLObjectType] = []

for type_ in union.types:
graphql_type = self.from_type(type_)

if isinstance(graphql_type, GraphQLInputObjectType):
raise InvalidTypeInputForUnion(graphql_type)
assert isinstance(graphql_type, GraphQLObjectType)
assert isinstance(graphql_type, (GraphQLObjectType, GraphQLUnionType))

graphql_types.append(graphql_type)
# If the graphql_type is a GraphQLUnionType, merge its child types
if isinstance(graphql_type, GraphQLUnionType):
enoua5 marked this conversation as resolved.
Show resolved Hide resolved
# Add the child types of the GraphQLUnionType to the list of graphql_types
graphql_types.extend(graphql_type.types)
else:
graphql_types.append(graphql_type)

graphql_union = GraphQLUnionType(
name=union_name,
Expand Down
7 changes: 6 additions & 1 deletion strawberry/union.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self.directives = directives
self._source_file = None
self._source_line = None
self.concrete_of: Optional[StrawberryUnion] = None

def __eq__(self, other: object) -> bool:
if isinstance(other, StrawberryType):
Expand Down Expand Up @@ -139,6 +140,7 @@ def copy_with(
return self

new_types = []

for type_ in self.types:
new_type: Union[StrawberryType, type]

Expand All @@ -154,10 +156,13 @@ def copy_with(

new_types.append(new_type)

return StrawberryUnion(
new_union = StrawberryUnion(
type_annotations=tuple(map(StrawberryAnnotation, new_types)),
description=self.description,
)
new_union.concrete_of = self

return new_union

def __call__(self, *args: str, **kwargs: Any) -> NoReturn:
"""Do not use.
Expand Down
175 changes: 175 additions & 0 deletions tests/schema/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,3 +856,178 @@

assert not result.errors
assert result.data["something"] == {"__typename": "A", "a": 5}


def test_generic_union_with_annotated():
@strawberry.type
class SomeType:
id: strawberry.ID
name: str

@strawberry.type
class NotFoundError:
id: strawberry.ID
message: str

T = TypeVar("T")

@strawberry.type
class ObjectQueries(Generic[T]):
@strawberry.field
def by_id(
self, id: strawberry.ID
) -> Annotated[Union[T, NotFoundError], strawberry.union("ByIdResult")]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to clarify, this is quite different from the use case below, as this creates a generic union 😊


@strawberry.type
class Query:
@strawberry.field
def some_type_queries(self, id: strawberry.ID) -> ObjectQueries[SomeType]:
raise NotImplementedError()

schema = strawberry.Schema(Query)

assert (
str(schema)
== textwrap.dedent(
"""
type NotFoundError {
id: ID!
message: String!
}

type Query {
someTypeQueries(id: ID!): SomeTypeObjectQueries!
}

type SomeType {
id: ID!
name: String!
}

union SomeTypeByIdResult = SomeType | NotFoundError

type SomeTypeObjectQueries {
byId(id: ID!): SomeTypeByIdResult!
}
"""
).strip()
)


def test_generic_union_with_annotated_inside():
@strawberry.type
class SomeType:
id: strawberry.ID
name: str

@strawberry.type
class NotFoundError:
id: strawberry.ID
message: str

T = TypeVar("T")

@strawberry.type
class ObjectQueries(Generic[T]):
@strawberry.field
def by_id(
self, id: strawberry.ID
) -> Union[T, Annotated[NotFoundError, strawberry.union("ByIdResult")]]: ...

@strawberry.type
class Query:
@strawberry.field
def some_type_queries(self, id: strawberry.ID) -> ObjectQueries[SomeType]:
return ObjectQueries(SomeType)

Check warning on line 941 in tests/schema/test_union.py

View check run for this annotation

Codecov / codecov/patch

tests/schema/test_union.py#L941

Added line #L941 was not covered by tests

schema = strawberry.Schema(Query)

assert (
str(schema)
== textwrap.dedent(
"""
type NotFoundError {
id: ID!
message: String!
}

type Query {
someTypeQueries(id: ID!): SomeTypeObjectQueries!
}

type SomeType {
id: ID!
name: String!
}

union SomeTypeByIdResult = SomeType | NotFoundError

type SomeTypeObjectQueries {
byId(id: ID!): SomeTypeByIdResult!
}
"""
).strip()
)


def test_annoted_union_with_two_generics():
@strawberry.type
class SomeType:
a: str

@strawberry.type
class OtherType:
b: str

@strawberry.type
class NotFoundError:
message: str

T = TypeVar("T")
U = TypeVar("U")

@strawberry.type
class UnionObjectQueries(Generic[T, U]):
@strawberry.field
def by_id(
self, id: strawberry.ID
) -> T | Annotated[U | NotFoundError, strawberry.union("ByIdResult")]: ...
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test captures the combined case of one generic inside and one generic outside the annotation. Both in and both out also work, but I figured the tests would be getting a bit redundant if I included every combination of in and out.


@strawberry.type
class Query:
@strawberry.field
def some_type_queries(
self, id: strawberry.ID
) -> UnionObjectQueries[SomeType, OtherType]:
return UnionObjectQueries()

Check warning on line 1002 in tests/schema/test_union.py

View check run for this annotation

Codecov / codecov/patch

tests/schema/test_union.py#L1002

Added line #L1002 was not covered by tests

schema = strawberry.Schema(Query)

assert (
str(schema)
== textwrap.dedent(
"""
type NotFoundError {
message: String!
}

type OtherType {
b: String!
}

type Query {
someTypeQueries(id: ID!): SomeTypeOtherTypeUnionObjectQueries!
}

type SomeType {
a: String!
}

union SomeTypeOtherTypeByIdResult = SomeType | OtherType | NotFoundError

type SomeTypeOtherTypeUnionObjectQueries {
byId(id: ID!): SomeTypeOtherTypeByIdResult!
}
"""
).strip()
)
Loading