diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index a8fce095b5..6a0cf6be39 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -187,6 +187,30 @@ class Query: }, ) + # due to 7dc715c9a65e6e0b6ea0ea3e903cf20284e3316b (#1684, #1586), + # fields are evaluated lazily. This means, that we only know about all + # interfaces after the schema is created. + # We need to find a way to add the extra implementations to the schema after creating it. + # This is not officially supported by GraphQL core and would be somewhat hacky. + + # TODO: prevent duplicates - no error, but duplicate processing is inefficient + for ( + extra_interface_type + ) in self.schema_converter.extra_interface_child_map.values(): + graphql_type = self.schema_converter.from_object(extra_interface_type) + graphql_types.append(graphql_type) + + self._schema = GraphQLSchema( + query=query_type, + mutation=mutation_type, + subscription=subscription_type if subscription else None, + directives=specified_directives + tuple(graphql_directives), + types=graphql_types, + extensions={ + GraphQLCoreConverter.DEFINITION_BACKREF: self, + }, + ) + except TypeError as error: # GraphQL core throws a TypeError if there's any exception raised # during the schema creation, so we check if the cause was a diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index 1083b46f9b..5eed057388 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -92,7 +92,6 @@ from strawberry.types.info import Info from strawberry.types.scalar import ScalarDefinition - FieldType = TypeVar( "FieldType", bound=Union[GraphQLField, GraphQLInputField], covariant=True ) @@ -245,6 +244,7 @@ def __init__( scalar_registry: Dict[object, Union[ScalarWrapper, ScalarDefinition]], get_fields: Callable[[StrawberryObjectDefinition], List[StrawberryField]], ) -> None: + self.extra_interface_child_map: Dict[str, StrawberryObjectDefinition] = {} self.type_map: Dict[str, ConcreteType] = {} self.config = config self.scalar_registry = scalar_registry @@ -583,6 +583,20 @@ def resolve_type( definition=interface, implementation=graphql_interface ) + # get all subclasses of the interface + subclasses = interface.origin.__subclasses__() + for subclass in subclasses: + # check if subclass is strawberry type + + subclass_object_definition = get_object_definition(subclass, strict=False) + object_type_name = self.config.name_converter.from_type( + subclass_object_definition + ) + + if object_type_name not in self.type_map: + self.extra_interface_child_map[object_type_name] = ( + subclass_object_definition + ) return graphql_interface def from_list(self, type_: StrawberryList) -> GraphQLList: diff --git a/tests/schema/test_interface.py b/tests/schema/test_interface.py index 31f106b09e..7a8d8296e3 100644 --- a/tests/schema/test_interface.py +++ b/tests/schema/test_interface.py @@ -1,3 +1,4 @@ +import textwrap from dataclasses import dataclass from typing import Any, List @@ -8,6 +9,100 @@ from strawberry.types.base import StrawberryObjectDefinition +def test_query_interface_without_extra_types_duplicate_reference(): + @strawberry.interface + class Cheese: + name: str + + @strawberry.type + class Swiss(Cheese): + canton: str + + @strawberry.type + class Italian(Cheese): + province: str + + @strawberry.type + class Query: + @strawberry.field + def assortment(self) -> List[Cheese]: + return [ + Italian(name="Asiago", province="Friuli"), + Swiss(name="Tomme", canton="Vaud"), + ] + + @strawberry.field + def italians(self) -> List[Italian]: + return [Italian(name="Asiago", province="Friuli")] + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + interface Cheese { + name: String! + } + + type Italian implements Cheese { + name: String! + province: String! + } + + type Query { + assortment: [Cheese!]! + italians: [Italian!]! + } + + type Swiss implements Cheese { + name: String! + canton: String! + } + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + +def test_query_interface_without_extra_types(): + @strawberry.interface + class Cheese: + name: str + + @strawberry.type + class Swiss(Cheese): + canton: str + + @strawberry.type + class Italian(Cheese): + province: str + + @strawberry.type + class Root: + @strawberry.field + def assortment(self) -> List[Cheese]: + return [ + Italian(name="Asiago", province="Friuli"), + Swiss(name="Tomme", canton="Vaud"), + ] + + schema = strawberry.Schema(query=Root) + + query = """{ + assortment { + name + ... on Italian { province } + ... on Swiss { canton } + } + }""" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data is not None + assert result.data["assortment"] == [ + {"name": "Asiago", "province": "Friuli"}, + {"canton": "Vaud", "name": "Tomme"}, + ] + + def test_query_interface(): @strawberry.interface class Cheese: