diff --git a/RELEASED.md b/RELEASED.md new file mode 100644 index 0000000000..28da267b4a --- /dev/null +++ b/RELEASED.md @@ -0,0 +1,18 @@ +Release type: minor + +Add `PartialType` metaclass to make fields optional dynamically. + +```py +from strawberry.tools import PartialType + + +@strawberry.type +class UserCreate: + firstname: str + lastname: str + + +@strawberry.type +class UserUpdate(UserCreate, metaclass=PartialType): + pass +``` diff --git a/docs/guides/tools.md b/docs/guides/tools.md index 5b3d2c02ad..79cc16cf18 100644 --- a/docs/guides/tools.md +++ b/docs/guides/tools.md @@ -94,4 +94,33 @@ type ComboQuery { } ``` +### `PartialType` + +`PartialType` metaclass is used to extend your type but makes its all field +optional. Consider you have different types for each operation on the same model +such as `UserCreate`, `UserUpdate` and `UserQuery`. `UserQuery` should have id +field but the other does not. All fields of `UserQuery` and `UserUpdate` might +be optional. In this case instead of defining the same field for each type one +can define in a single type and extend it. + +```py +from strawberry.tools import PartialType + + +@strawberry.type +class UserCreate: + firstname: str + lastname: str + + +@strawberry.type +class UserUpdate(UserCreate, metaclass=PartialType): + pass + + +@strawberry.type +class UserQuery(UserCreate, metaclass=PartialType): + id: Optional[strawberry.ID] +``` + diff --git a/strawberry/tools/__init__.py b/strawberry/tools/__init__.py index b6be7a91dc..117de68c82 100644 --- a/strawberry/tools/__init__.py +++ b/strawberry/tools/__init__.py @@ -1,7 +1,9 @@ from .create_type import create_type from .merge_types import merge_types +from .partialtype import PartialType __all__ = [ + "PartialType", "create_type", "merge_types", ] diff --git a/strawberry/tools/partialtype.py b/strawberry/tools/partialtype.py new file mode 100644 index 0000000000..f46012cec3 --- /dev/null +++ b/strawberry/tools/partialtype.py @@ -0,0 +1,30 @@ +from abc import ABCMeta +from typing import Optional + +import strawberry + + +class PartialType(ABCMeta): + def __new__(cls, name: str, bases: tuple, namespaces: dict, **kwargs): + mro = super().__new__(cls, name, bases, namespaces, **kwargs).mro() + annotations = namespaces.get("__annotations__", {}) + fields: list[str] = [] + for base in mro[:-1]: # the object class has no __annotations__ attr + for k, v in base.__annotations__.items(): + # To prevent overriding the higher attr annotation + if k not in annotations: + annotations[k] = v + + fields.extend(field.name for field in base._type_definition.fields) + + for field in annotations: + if not field.startswith("_"): + annotations[field] = Optional[annotations[field]] + + namespaces["__annotations__"] = annotations + klass = super().__new__(cls, name, bases, namespaces, **kwargs) + for field in fields: + if not hasattr(klass, field): + setattr(klass, field, strawberry.UNSET) + + return klass diff --git a/tests/tools/test_partialtype.py b/tests/tools/test_partialtype.py new file mode 100644 index 0000000000..8395c4c57f --- /dev/null +++ b/tests/tools/test_partialtype.py @@ -0,0 +1,68 @@ +import dataclasses + +import strawberry +from strawberry.tools import PartialType +from strawberry.type import StrawberryOptional + + +def test_partialtype(): + @strawberry.type + class RoleRead: + name: str + description: str + + @strawberry.type + class UserRead: + firstname: str + lastname: str + role: RoleRead + + @strawberry.input + class RoleInput(RoleRead): + pass + + @strawberry.input + class UserQuery(UserRead, metaclass=PartialType): + role: RoleInput + + read_firstname, read_lastname, read_role = UserRead._type_definition.fields + + # user read type firstname field tests + assert read_firstname.python_name == "firstname" + assert read_firstname.graphql_name is None + assert read_firstname.default is dataclasses.MISSING + assert read_firstname.type is str + + # user read type lastname field tests + assert read_lastname.python_name == "lastname" + assert read_lastname.graphql_name is None + assert read_lastname.default is dataclasses.MISSING + assert read_lastname.type is str + + assert read_role.python_name == "role" + assert read_role.graphql_name is None + assert read_role.default is dataclasses.MISSING + assert read_role.type is RoleRead + + query_firstname, query_lastname, query_role = UserQuery._type_definition.fields + + # user query type firstname field tests + assert query_firstname.python_name == "firstname" + assert query_firstname.graphql_name is None + assert query_firstname.default is strawberry.UNSET + assert isinstance(query_firstname.type, StrawberryOptional) + assert query_firstname.type.of_type is str + + # user query type lastname field tests + assert query_lastname.python_name == "lastname" + assert query_lastname.graphql_name is None + assert query_lastname.default is strawberry.UNSET + assert isinstance(query_lastname.type, StrawberryOptional) + assert query_lastname.type.of_type is str + + # user query type role field tests + assert query_role.python_name == "role" + assert query_role.graphql_name is None + assert query_role.default is strawberry.UNSET + assert isinstance(query_role.type, StrawberryOptional) + assert query_role.type.of_type is RoleInput