diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..853d3b19b6 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,28 @@ +Release type: patch + +Enhancements: +- Improved pydantic conversion compatibility with specialized list classes. + - Modified `StrawberryAnnotation._is_list` to check if the `annotation` extends from the `list` type, enabling it to be considered a list. + - in `StrawberryAnnotation` Moved the `_is_list` check before the `_is_generic` check in `resolve` to avoid `unsupported` error in `_is_generic` before it checked `_is_list`. + +This enhancement enables the usage of constrained lists as class types and allows the creation of specialized lists. The following example demonstrates this feature: + +```python +import strawberry +from pydantic import BaseModel, ConstrainedList + + +class FriendList(ConstrainedList): + min_items = 1 + + +class UserModel(BaseModel): + age: int + friend_names: FriendList[str] + + +@strawberry.experimental.pydantic.type(UserModel) +class User: + age: strawberry.auto + friend_names: strawberry.auto +``` diff --git a/strawberry/annotation.py b/strawberry/annotation.py index 22d8fafe4d..f46a111140 100644 --- a/strawberry/annotation.py +++ b/strawberry/annotation.py @@ -141,6 +141,8 @@ def resolve(self) -> Union[StrawberryType, type]: if self._is_lazy_type(evaled_type): return evaled_type + if self._is_list(evaled_type): + return self.create_list(evaled_type) if self._is_generic(evaled_type): if any(is_type_var(type_) for type_ in get_args(evaled_type)): @@ -151,8 +153,6 @@ def resolve(self) -> Union[StrawberryType, type]: # a StrawberryType if self._is_enum(evaled_type): return self.create_enum(evaled_type) - if self._is_list(evaled_type): - return self.create_list(evaled_type) elif self._is_optional(evaled_type, args): return self.create_optional(evaled_type) elif self._is_union(evaled_type, args): @@ -298,8 +298,14 @@ def _is_list(cls, annotation: Any) -> bool: """Returns True if annotation is a List""" annotation_origin = get_origin(annotation) + annotation_mro = getattr(annotation, "__mro__", []) + is_list = any(x is list for x in annotation_mro) - return (annotation_origin in (list, tuple)) or annotation_origin is abc.Sequence + return ( + (annotation_origin in (list, tuple)) + or annotation_origin is abc.Sequence + or is_list + ) @classmethod def _is_strawberry_type(cls, evaled_type: Any) -> bool: diff --git a/tests/experimental/pydantic/schema/test_basic.py b/tests/experimental/pydantic/schema/test_basic.py index 06a95a8fbd..e636a9bdce 100644 --- a/tests/experimental/pydantic/schema/test_basic.py +++ b/tests/experimental/pydantic/schema/test_basic.py @@ -526,3 +526,32 @@ def user(self) -> User: assert not result.errors assert result.data["user"]["age"] == 1 assert result.data["user"]["password"] is None + + +def test_basic_type_with_constrained_list(): + class FriendList(pydantic.ConstrainedList): + min_items = 1 + + class UserModel(pydantic.BaseModel): + age: int + friend_names: FriendList[str] + + @strawberry.experimental.pydantic.type(UserModel) + class User: + age: strawberry.auto + friend_names: strawberry.auto + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1, friend_names=["A", "B"]) + + schema = strawberry.Schema(query=Query) + + query = "{ user { friendNames } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["friendNames"] == ["A", "B"] diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index 78aa1d778a..935e162020 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -3,10 +3,10 @@ import re import sys from enum import Enum -from typing import Any, Dict, List, NewType, Optional, Union +from typing import Any, Dict, List, NewType, Optional, TypeVar, Union import pytest -from pydantic import BaseModel, Field, ValidationError +from pydantic import BaseModel, ConstrainedList, Field, ValidationError import strawberry from strawberry.experimental.pydantic._compat import ( @@ -1225,3 +1225,64 @@ class Test: assert test.optional_list == [1, 2, 3] assert test.optional_str is None + + +def test_can_convert_pydantic_type_to_strawberry_with_constrained_list(): + class WorkModel(BaseModel): + name: str + + class workList(ConstrainedList): + min_items = 1 + + class UserModel(BaseModel): + work: workList[WorkModel] + + @strawberry.experimental.pydantic.type(WorkModel) + class Work: + name: strawberry.auto + + @strawberry.experimental.pydantic.type(UserModel) + class User: + work: strawberry.auto + + origin_user = UserModel( + work=[WorkModel(name="developer"), WorkModel(name="tester")] + ) + + user = User.from_pydantic(origin_user) + + assert user == User(work=[Work(name="developer"), Work(name="tester")]) + + +SI = TypeVar("SI", covariant=True) # pragma: no mutate + + +class SpecialList(List[SI]): + pass + + +def test_can_convert_pydantic_type_to_strawberry_with_specialized_list(): + class WorkModel(BaseModel): + name: str + + class workList(SpecialList[SI]): + min_items = 1 + + class UserModel(BaseModel): + work: workList[WorkModel] + + @strawberry.experimental.pydantic.type(WorkModel) + class Work: + name: strawberry.auto + + @strawberry.experimental.pydantic.type(UserModel) + class User: + work: strawberry.auto + + origin_user = UserModel( + work=[WorkModel(name="developer"), WorkModel(name="tester")] + ) + + user = User.from_pydantic(origin_user) + + assert user == User(work=[Work(name="developer"), Work(name="tester")])