Skip to content

Commit

Permalink
Pydantic v2 support (strawberry-graphql#2972)
Browse files Browse the repository at this point in the history
* work so far

* add modelfield

* it works kinda

* fix missing type

* fix default tesT

* skip tests for constrained types

* add tests pass?

* ruff

* Revert "ruff"

This reverts commit dc0b9dd.

* make most of mypy happy

* make linters and mypy happy

* fix pydantic v1 import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove extra ruff

* fix tests for pydantic v1

* remove weird stuff ruff added

* add release file

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bump pydantic in pyproject.toml

* upgrade pydantic lock file

* remove unused

* fix mypy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix poetry

* remove implicit default test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* force pydantic 1.10

* Update noxfile.py

Co-authored-by: Patrick Arminio <[email protected]>

* Update tests/experimental/pydantic/schema/test_defaults.py

Co-authored-by: Patrick Arminio <[email protected]>

* rename v2_compat -> _compat

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* ignore mypy windows tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Patrick Arminio <[email protected]>
  • Loading branch information
3 people authored and etripier committed Oct 25, 2023
1 parent c4865d8 commit 667f1da
Show file tree
Hide file tree
Showing 18 changed files with 566 additions and 195 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,15 @@ jobs:

- run: poetry install --with integrations
if: steps.setup-python.outputs.cache-hit != 'true'
# Since we are running all the integrations at once, we can't use
# pydantic v2. It is not compatible with starlette yet
- run: poetry run pip install pydantic==1.10

# we use poetry directly instead of nox since we want to
# test all integrations at once on windows
# but we want to exclude tests/mypy since we are using an old version of pydantic
- run: |
poetry run pytest --cov=. --cov-append --cov-report=xml -n auto --showlocals -vv
poetry run pytest --cov=. --cov-append --cov-report=xml -n auto --showlocals --ignore tests/mypy -vv
- name: coverage xml
run: coverage xml -i
if: ${{ always() }}
Expand Down
Empty file added 2.0.0
Empty file.
7 changes: 7 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Release type: minor

Adds initial support for pydantic V2.

This is extremely experimental for wider initial testing.

We do not encourage using this in production systems yet.
3 changes: 1 addition & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def tests_integrations(session: Session, integration: str) -> None:


@session(python=["3.11"], name="Pydantic tests", tags=["tests"])
# TODO: add pydantic 2.0 here :)
@nox.parametrize("pydantic", ["1.10"])
@nox.parametrize("pydantic", ["1.10", "2.0.3"])
def test_pydantic(session: Session, pydantic: str) -> None:
session.run_always("poetry", "install", external=True)

Expand Down
288 changes: 191 additions & 97 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ flask = {version = ">=1.1", optional = true}
opentelemetry-api = {version = "<2", optional = true}
opentelemetry-sdk = {version = "<2", optional = true}
chalice = {version = "^1.22", optional = true}
pydantic = {version = "<2", optional = true}
pydantic = {version = ">1.6.1", optional = true}
python-multipart = {version = ">=0.0.5,<0.0.7", optional = true}
sanic = {version = ">=20.12.2", optional = true}
aiohttp = {version = "^3.7.4.post0", optional = true}
Expand Down Expand Up @@ -112,7 +112,7 @@ channels = "^3.0.5"
Django = ">=3.2"
fastapi = {version = ">=0.65.0", optional = false}
flask = ">=1.1"
pydantic = {version = "<2", optional = false}
pydantic = {version = ">1.6.1", optional = false}
pytest-aiohttp = "^1.0.3"
pytest-django = {version = "^4.5"}
pytest-flask = {version = "^1.2.0"}
Expand Down
104 changes: 104 additions & 0 deletions strawberry/experimental/pydantic/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import dataclasses
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type

import pydantic
from pydantic import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION

if TYPE_CHECKING:
from pydantic.fields import FieldInfo

IS_PYDANTIC_V2: bool = PYDANTIC_VERSION.startswith("2.")
IS_PYDANTIC_V1: bool = not IS_PYDANTIC_V2


@dataclass
class CompatModelField:
name: str
type_: Any
outer_type_: Any
default: Any
default_factory: Optional[Callable[[], Any]]
required: bool
alias: Optional[str]
allow_none: bool
has_alias: bool
description: Optional[str]


if pydantic.VERSION[0] == "2":
from typing_extensions import get_args, get_origin

from pydantic._internal._typing_extra import is_new_type
from pydantic._internal._utils import lenient_issubclass, smart_deepcopy
from pydantic_core import PydanticUndefined

PYDANTIC_MISSING_TYPE = PydanticUndefined

def new_type_supertype(type_: Any) -> Any:
return type_.__supertype__

def get_model_fields(model: Type[BaseModel]) -> Dict[str, CompatModelField]:
field_info: dict[str, FieldInfo] = model.model_fields
new_fields = {}
# Convert it into CompatModelField
for name, field in field_info.items():
new_fields[name] = CompatModelField(
name=name,
type_=field.annotation,
outer_type_=field.annotation,
default=field.default,
default_factory=field.default_factory,
required=field.is_required(),
alias=field.alias,
# v2 doesn't have allow_none
allow_none=False,
has_alias=field is not None,
description=field.description,
)
return new_fields

else:
from pydantic.typing import ( # type: ignore[no-redef]
get_args,
get_origin,
is_new_type,
new_type_supertype,
)
from pydantic.utils import ( # type: ignore[no-redef]
lenient_issubclass,
smart_deepcopy,
)

PYDANTIC_MISSING_TYPE = dataclasses.MISSING # type: ignore[assignment]

def get_model_fields(model: Type[BaseModel]) -> Dict[str, CompatModelField]:
new_fields = {}
# Convert it into CompatModelField
for name, field in model.__fields__.items(): # type: ignore[attr-defined]
new_fields[name] = CompatModelField(
name=name,
type_=field.type_,
outer_type_=field.outer_type_,
default=field.default,
default_factory=field.default_factory,
required=field.required,
alias=field.alias,
allow_none=field.allow_none,
has_alias=field.has_alias,
description=field.field_info.description,
)
return new_fields


__all__ = [
"smart_deepcopy",
"lenient_issubclass",
"get_args",
"get_origin",
"is_new_type",
"new_type_supertype",
"get_model_fields",
"PYDANTIC_MISSING_TYPE",
]
14 changes: 7 additions & 7 deletions strawberry/experimental/pydantic/error_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import dataclasses
import warnings
from typing import (
TYPE_CHECKING,
Any,
Callable,
List,
Expand All @@ -16,9 +15,13 @@
)

from pydantic import BaseModel
from pydantic.utils import lenient_issubclass

from strawberry.auto import StrawberryAuto
from strawberry.experimental.pydantic._compat import (
CompatModelField,
get_model_fields,
lenient_issubclass,
)
from strawberry.experimental.pydantic.utils import (
get_private_fields,
get_strawberry_type_from_model,
Expand All @@ -30,11 +33,8 @@

from .exceptions import MissingFieldsListError

if TYPE_CHECKING:
from pydantic.fields import ModelField


def get_type_for_field(field: ModelField) -> Union[Any, Type[None], Type[List]]:
def get_type_for_field(field: CompatModelField) -> Union[Any, Type[None], Type[List]]:
type_ = field.outer_type_
type_ = normalize_type(type_)
return field_type_to_type(type_)
Expand Down Expand Up @@ -72,7 +72,7 @@ def error_type(
all_fields: bool = False,
) -> Callable[..., Type]:
def wrap(cls: Type) -> Type:
model_fields = model.__fields__
model_fields = get_model_fields(model)
fields_set = set(fields) if fields else set()

if fields:
Expand Down
56 changes: 38 additions & 18 deletions strawberry/experimental/pydantic/fields.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
import builtins
from decimal import Decimal
from typing import Any, List, Optional, Type
from typing import Any, List, Optional, Type, Union
from uuid import UUID

import pydantic
from pydantic import BaseModel
from pydantic.typing import get_args, get_origin, is_new_type, new_type_supertype
from pydantic.utils import lenient_issubclass

from strawberry.experimental.pydantic._compat import (
IS_PYDANTIC_V1,
get_args,
get_origin,
is_new_type,
lenient_issubclass,
new_type_supertype,
)
from strawberry.experimental.pydantic.exceptions import (
UnregisteredTypeException,
UnsupportedTypeError,
)
from strawberry.types.types import StrawberryObjectDefinition

try:
from types import UnionType as TypingUnionType
from typing import GenericAlias as TypingGenericAlias # type: ignore
except ImportError:
import sys
Expand All @@ -25,6 +32,10 @@
TypingGenericAlias = ()
else:
raise
if sys.version_info < (3, 10):
TypingUnionType = ()
else:
raise

ATTR_TO_TYPE_MAP = {
"NoneStr": Optional[str],
Expand Down Expand Up @@ -70,23 +81,31 @@
"RedisDsn": str,
}


FIELDS_MAP = {
getattr(pydantic, field_name): type
for field_name, type in ATTR_TO_TYPE_MAP.items()
if hasattr(pydantic, field_name)
}
"""TODO:
Most of these fields are not supported by pydantic V2
"""
FIELDS_MAP = (
{
getattr(pydantic, field_name): type
for field_name, type in ATTR_TO_TYPE_MAP.items()
if hasattr(pydantic, field_name)
}
if IS_PYDANTIC_V1
else {}
)


def get_basic_type(type_: Any) -> Type[Any]:
if lenient_issubclass(type_, pydantic.ConstrainedInt):
return int
if lenient_issubclass(type_, pydantic.ConstrainedFloat):
return float
if lenient_issubclass(type_, pydantic.ConstrainedStr):
return str
if lenient_issubclass(type_, pydantic.ConstrainedList):
return List[get_basic_type(type_.item_type)] # type: ignore
if IS_PYDANTIC_V1:
# only pydantic v1 has these
if lenient_issubclass(type_, pydantic.ConstrainedInt):
return int
if lenient_issubclass(type_, pydantic.ConstrainedFloat):
return float
if lenient_issubclass(type_, pydantic.ConstrainedStr):
return str
if lenient_issubclass(type_, pydantic.ConstrainedList):
return List[get_basic_type(type_.item_type)] # type: ignore

if type_ in FIELDS_MAP:
type_ = FIELDS_MAP.get(type_)
Expand Down Expand Up @@ -125,7 +144,8 @@ def replace_types_recursively(type_: Any, is_input: bool) -> Any:

if isinstance(replaced_type, TypingGenericAlias):
return TypingGenericAlias(origin, converted)

if isinstance(replaced_type, TypingUnionType):
return Union[converted]
replaced_type = replaced_type.copy_with(converted)

if isinstance(replaced_type, StrawberryObjectDefinition):
Expand Down
26 changes: 16 additions & 10 deletions strawberry/experimental/pydantic/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

from strawberry.annotation import StrawberryAnnotation
from strawberry.auto import StrawberryAuto
from strawberry.experimental.pydantic._compat import (
IS_PYDANTIC_V1,
CompatModelField,
get_model_fields,
)
from strawberry.experimental.pydantic.conversion import (
convert_pydantic_model_to_strawberry_class,
convert_strawberry_class_to_pydantic_model,
Expand All @@ -37,21 +42,22 @@

if TYPE_CHECKING:
from graphql import GraphQLResolveInfo
from pydantic.fields import ModelField


def get_type_for_field(field: ModelField, is_input: bool): # noqa: ANN201
def get_type_for_field(field: CompatModelField, is_input: bool): # noqa: ANN201
outer_type = field.outer_type_
replaced_type = replace_types_recursively(outer_type, is_input)
should_add_optional: bool = field.allow_none
if should_add_optional:
return Optional[replaced_type]
else:
return replaced_type
if IS_PYDANTIC_V1:
# only pydantic v1 has this Optional logic
should_add_optional: bool = field.allow_none
if should_add_optional:
return Optional[replaced_type]

return replaced_type


def _build_dataclass_creation_fields(
field: ModelField,
field: CompatModelField,
is_input: bool,
existing_fields: Dict[str, StrawberryField],
auto_fields_set: Set[str],
Expand Down Expand Up @@ -85,7 +91,7 @@ def _build_dataclass_creation_fields(
default=dataclasses.MISSING,
default_factory=get_default_factory_for_field(field),
type_annotation=StrawberryAnnotation.from_annotation(field_type),
description=field.field_info.description,
description=field.description,
deprecation_reason=(
existing_field.deprecation_reason if existing_field else None
),
Expand Down Expand Up @@ -123,7 +129,7 @@ def type(
use_pydantic_alias: bool = True,
) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]:
def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]:
model_fields = model.__fields__
model_fields = get_model_fields(model)
original_fields_set = set(fields) if fields else set()

if fields:
Expand Down
Loading

0 comments on commit 667f1da

Please sign in to comment.