Skip to content

Commit

Permalink
Merge branch 'fix-3393' of https://github.com/enoua5/strawberry-3393
Browse files Browse the repository at this point in the history
  • Loading branch information
enoua5 committed Sep 29, 2024
2 parents 36bd99e + b34106e commit 5be8943
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 9 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

Attempt to merge union types during schema conversion.
21 changes: 15 additions & 6 deletions strawberry/schema/name_converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List, Optional, Union, cast
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
from typing_extensions import Protocol

from strawberry.directive import StrawberryDirective
Expand Down Expand Up @@ -107,8 +107,14 @@ def from_union(self, union: StrawberryUnion) -> str:
return union.graphql_name

name = ""
types: Tuple[StrawberryType, ...] = union.types

for type_ in union.types:
if union.concrete_of and union.concrete_of.graphql_name:
concrete_of_types = set(union.concrete_of.types)

types = tuple(type_ for type_ in types if type_ not in concrete_of_types)

for type_ in types:
if isinstance(type_, LazyType):
type_ = cast("StrawberryType", type_.resolve_type()) # noqa: PLW2901

Expand All @@ -121,6 +127,9 @@ def from_union(self, union: StrawberryUnion) -> str:

name += type_name

if union.concrete_of and union.concrete_of.graphql_name:
name += union.concrete_of.graphql_name

return name

def from_generic(
Expand All @@ -133,12 +142,12 @@ def from_generic(
names: List[str] = []

for type_ in types:
name = self.get_from_type(type_)
name = self.get_name_from_type(type_)
names.append(name)

return "".join(names) + generic_type_name

def get_from_type(self, type_: Union[StrawberryType, type]) -> str:
def get_name_from_type(self, type_: Union[StrawberryType, type]) -> str:
type_ = eval_type(type_)

if isinstance(type_, LazyType):
Expand All @@ -148,9 +157,9 @@ def get_from_type(self, type_: Union[StrawberryType, type]) -> str:
elif isinstance(type_, StrawberryUnion):
name = type_.graphql_name if type_.graphql_name else self.from_union(type_)
elif isinstance(type_, StrawberryList):
name = self.get_from_type(type_.of_type) + "List"
name = self.get_name_from_type(type_.of_type) + "List"
elif isinstance(type_, StrawberryOptional):
name = self.get_from_type(type_.of_type) + "Optional"
name = self.get_name_from_type(type_.of_type) + "Optional"
elif hasattr(type_, "_scalar_definition"):
strawberry_type = type_._scalar_definition

Expand Down
10 changes: 8 additions & 2 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,14 +865,20 @@ def from_union(self, union: StrawberryUnion) -> GraphQLUnionType:
return graphql_union

graphql_types: List[GraphQLObjectType] = []

for type_ in union.types:
graphql_type = self.from_type(type_)

if isinstance(graphql_type, GraphQLInputObjectType):
raise InvalidTypeInputForUnion(graphql_type)
assert isinstance(graphql_type, GraphQLObjectType)
assert isinstance(graphql_type, (GraphQLObjectType, GraphQLUnionType))

graphql_types.append(graphql_type)
# If the graphql_type is a GraphQLUnionType, merge its child types
if isinstance(graphql_type, GraphQLUnionType):
# Add the child types of the GraphQLUnionType to the list of graphql_types
graphql_types.extend(graphql_type.types)
else:
graphql_types.append(graphql_type)

graphql_union = GraphQLUnionType(
name=union_name,
Expand Down
7 changes: 6 additions & 1 deletion strawberry/types/union.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self.directives = directives
self._source_file = None
self._source_line = None
self.concrete_of: Optional[StrawberryUnion] = None

def __eq__(self, other: object) -> bool:
if isinstance(other, StrawberryType):
Expand Down Expand Up @@ -139,6 +140,7 @@ def copy_with(
return self

new_types = []

for type_ in self.types:
new_type: Union[StrawberryType, type]

Expand All @@ -154,10 +156,13 @@ def copy_with(

new_types.append(new_type)

return StrawberryUnion(
new_union = StrawberryUnion(
type_annotations=tuple(map(StrawberryAnnotation, new_types)),
description=self.description,
)
new_union.concrete_of = self

return new_union

def __call__(self, *args: str, **kwargs: Any) -> NoReturn:
"""Do not use.
Expand Down
175 changes: 175 additions & 0 deletions tests/schema/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,3 +850,178 @@ class Query:

assert not result.errors
assert result.data["something"] == {"__typename": "A", "a": 5}


def test_generic_union_with_annotated():
@strawberry.type
class SomeType:
id: strawberry.ID
name: str

@strawberry.type
class NotFoundError:
id: strawberry.ID
message: str

T = TypeVar("T")

@strawberry.type
class ObjectQueries(Generic[T]):
@strawberry.field
def by_id(
self, id: strawberry.ID
) -> Annotated[Union[T, NotFoundError], strawberry.union("ByIdResult")]: ...

@strawberry.type
class Query:
@strawberry.field
def some_type_queries(self, id: strawberry.ID) -> ObjectQueries[SomeType]:
raise NotImplementedError()

schema = strawberry.Schema(Query)

assert (
str(schema)
== textwrap.dedent(
"""
type NotFoundError {
id: ID!
message: String!
}
type Query {
someTypeQueries(id: ID!): SomeTypeObjectQueries!
}
type SomeType {
id: ID!
name: String!
}
union SomeTypeByIdResult = SomeType | NotFoundError
type SomeTypeObjectQueries {
byId(id: ID!): SomeTypeByIdResult!
}
"""
).strip()
)


def test_generic_union_with_annotated_inside():
@strawberry.type
class SomeType:
id: strawberry.ID
name: str

@strawberry.type
class NotFoundError:
id: strawberry.ID
message: str

T = TypeVar("T")

@strawberry.type
class ObjectQueries(Generic[T]):
@strawberry.field
def by_id(
self, id: strawberry.ID
) -> Union[T, Annotated[NotFoundError, strawberry.union("ByIdResult")]]: ...

@strawberry.type
class Query:
@strawberry.field
def some_type_queries(self, id: strawberry.ID) -> ObjectQueries[SomeType]: ...

schema = strawberry.Schema(Query)

assert (
str(schema)
== textwrap.dedent(
"""
type NotFoundError {
id: ID!
message: String!
}
type Query {
someTypeQueries(id: ID!): SomeTypeObjectQueries!
}
type SomeType {
id: ID!
name: String!
}
union SomeTypeByIdResult = SomeType | NotFoundError
type SomeTypeObjectQueries {
byId(id: ID!): SomeTypeByIdResult!
}
"""
).strip()
)


def test_annoted_union_with_two_generics():
@strawberry.type
class SomeType:
a: str

@strawberry.type
class OtherType:
b: str

@strawberry.type
class NotFoundError:
message: str

T = TypeVar("T")
U = TypeVar("U")

@strawberry.type
class UnionObjectQueries(Generic[T, U]):
@strawberry.field
def by_id(
self, id: strawberry.ID
) -> Union[
T, Annotated[Union[U, NotFoundError], strawberry.union("ByIdResult")]
]: ...

@strawberry.type
class Query:
@strawberry.field
def some_type_queries(
self, id: strawberry.ID
) -> UnionObjectQueries[SomeType, OtherType]: ...

schema = strawberry.Schema(Query)

assert (
str(schema)
== textwrap.dedent(
"""
type NotFoundError {
message: String!
}
type OtherType {
b: String!
}
type Query {
someTypeQueries(id: ID!): SomeTypeOtherTypeUnionObjectQueries!
}
type SomeType {
a: String!
}
union SomeTypeOtherTypeByIdResult = SomeType | OtherType | NotFoundError
type SomeTypeOtherTypeUnionObjectQueries {
byId(id: ID!): SomeTypeOtherTypeByIdResult!
}
"""
).strip()
)

0 comments on commit 5be8943

Please sign in to comment.