Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: automatically include implementers of interfaces in schema #3686

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,30 @@ class Query:
},
)

# due to 7dc715c9a65e6e0b6ea0ea3e903cf20284e3316b (#1684, #1586),
# fields are evaluated lazily. This means, that we only know about all
# interfaces after the schema is created.
# We need to find a way to add the extra implementations to the schema after creating it.
# This is not officially supported by GraphQL core and would be somewhat hacky.

# TODO: prevent duplicates - no error, but duplicate processing is inefficient
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (performance): Consider using a set or bulk extend operation to prevent duplicate processing

Instead of appending types one by one, consider collecting them in a set first or using list.extend() to add them all at once. This would prevent duplicate processing and improve performance.

            types = set()  # Using a set to automatically prevent duplicates

for (
extra_interface_type
) in self.schema_converter.extra_interface_child_map.values():
graphql_type = self.schema_converter.from_object(extra_interface_type)
graphql_types.append(graphql_type)

self._schema = GraphQLSchema(
erikwrede marked this conversation as resolved.
Show resolved Hide resolved
query=query_type,
mutation=mutation_type,
subscription=subscription_type if subscription else None,
directives=specified_directives + tuple(graphql_directives),
types=graphql_types,
extensions={
GraphQLCoreConverter.DEFINITION_BACKREF: self,
},
)

except TypeError as error:
# GraphQL core throws a TypeError if there's any exception raised
# during the schema creation, so we check if the cause was a
Expand Down
16 changes: 15 additions & 1 deletion strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
from strawberry.types.info import Info
from strawberry.types.scalar import ScalarDefinition


FieldType = TypeVar(
"FieldType", bound=Union[GraphQLField, GraphQLInputField], covariant=True
)
Expand Down Expand Up @@ -245,6 +244,7 @@ def __init__(
scalar_registry: Dict[object, Union[ScalarWrapper, ScalarDefinition]],
get_fields: Callable[[StrawberryObjectDefinition], List[StrawberryField]],
) -> None:
self.extra_interface_child_map: Dict[str, StrawberryObjectDefinition] = {}
self.type_map: Dict[str, ConcreteType] = {}
self.config = config
self.scalar_registry = scalar_registry
Expand Down Expand Up @@ -583,6 +583,20 @@ def resolve_type(
definition=interface, implementation=graphql_interface
)

# get all subclasses of the interface
subclasses = interface.origin.__subclasses__()
for subclass in subclasses:
# check if subclass is strawberry type

subclass_object_definition = get_object_definition(subclass, strict=False)
object_type_name = self.config.name_converter.from_type(
subclass_object_definition
)

if object_type_name not in self.type_map:
self.extra_interface_child_map[object_type_name] = (
subclass_object_definition
)
return graphql_interface

def from_list(self, type_: StrawberryList) -> GraphQLList:
Expand Down
95 changes: 95 additions & 0 deletions tests/schema/test_interface.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import textwrap
from dataclasses import dataclass
from typing import Any, List

Expand All @@ -8,6 +9,100 @@
from strawberry.types.base import StrawberryObjectDefinition


def test_query_interface_without_extra_types_duplicate_reference():
@strawberry.interface
class Cheese:
name: str

@strawberry.type
class Swiss(Cheese):
canton: str

@strawberry.type
class Italian(Cheese):
province: str

@strawberry.type
class Query:
@strawberry.field
def assortment(self) -> List[Cheese]:
return [

Check warning on line 29 in tests/schema/test_interface.py

View check run for this annotation

Codecov / codecov/patch

tests/schema/test_interface.py#L29

Added line #L29 was not covered by tests
Italian(name="Asiago", province="Friuli"),
Swiss(name="Tomme", canton="Vaud"),
]

@strawberry.field
def italians(self) -> List[Italian]:
return [Italian(name="Asiago", province="Friuli")]

Check warning on line 36 in tests/schema/test_interface.py

View check run for this annotation

Codecov / codecov/patch

tests/schema/test_interface.py#L36

Added line #L36 was not covered by tests

schema = strawberry.Schema(query=Query)

expected_schema = """
interface Cheese {
name: String!
}

type Italian implements Cheese {
name: String!
province: String!
}

type Query {
assortment: [Cheese!]!
italians: [Italian!]!
}

type Swiss implements Cheese {
name: String!
canton: String!
}
"""

assert str(schema) == textwrap.dedent(expected_schema).strip()


def test_query_interface_without_extra_types():
@strawberry.interface
class Cheese:
name: str

@strawberry.type
class Swiss(Cheese):
canton: str

@strawberry.type
class Italian(Cheese):
province: str

@strawberry.type
class Root:
@strawberry.field
def assortment(self) -> List[Cheese]:
return [
Italian(name="Asiago", province="Friuli"),
Swiss(name="Tomme", canton="Vaud"),
]

schema = strawberry.Schema(query=Root)

query = """{
assortment {
name
... on Italian { province }
... on Swiss { canton }
}
}"""

result = schema.execute_sync(query)

assert not result.errors
assert result.data is not None
assert result.data["assortment"] == [
{"name": "Asiago", "province": "Friuli"},
{"canton": "Vaud", "name": "Tomme"},
]


def test_query_interface():
@strawberry.interface
class Cheese:
Expand Down
Loading