Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions docs/integrations/pydantic.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,60 @@ class UserInput:
friends: strawberry.auto
```

## Controlling omission semantics with `use_pydantic_default`

`use_pydantic_default` on `@strawberry.experimental.pydantic.input` determines
how omitted GraphQL fields are represented.

- `True` (default): omitted fields use the Pydantic model's default or
`default_factory`.
- `False`: omitted fields become `strawberry.UNSET`, allowing omission to be
distinguished from `null` and explicit values.

| GraphQL input | True (default) | False |
| -------------- | ------------------------ | --------- |
| omitted | pydantic default applied | `UNSET` |
| provided value | unchanged | unchanged |
| null | `None` | `None` |

When `False`, `UNSET` values remain on the Strawberry input and are not passed
to the Pydantic constructor, enabling patch-style updates.

```python
from __future__ import annotations

import pydantic
import strawberry
from strawberry import UNSET


class UserModel(pydantic.BaseModel):
name: str
interests: list[str] | None = pydantic.Field(default_factory=list)


@strawberry.experimental.pydantic.input(model=UserModel, use_pydantic_default=False)
class UpdateUserInput:
name: strawberry.auto
interests: strawberry.auto


@strawberry.type
class Mutation:
@strawberry.mutation
async def update_user(self, user_data: UpdateUserInput) -> str:
changes: dict[str, object] = {}
if user_data.name is not UNSET:
changes["name"] = user_data.name
if user_data.interests is not UNSET:
changes["interests"] = user_data.interests

current = UserModel(name="Alice", interests=["games"])
updated = current.model_copy(update=changes)

return f"changes={changes} before={current.model_dump()} after={updated.model_dump()}"
```

## Interface types

Interface types are similar to normal types; we can create one by using the
Expand Down
36 changes: 30 additions & 6 deletions strawberry/experimental/pydantic/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from strawberry.types.field import StrawberryField
from strawberry.types.object_type import _process_type, _wrap_dataclass
from strawberry.types.type_resolver import _get_fields
from strawberry.types.unset import UNSET

if TYPE_CHECKING:
import builtins
Expand Down Expand Up @@ -61,6 +62,8 @@ def _build_dataclass_creation_fields(
auto_fields_set: set[str],
use_pydantic_alias: bool,
compat: PydanticCompat,
*,
use_pydantic_default: bool,
) -> DataclassCreationFields:
field_type = (
get_type_for_field(field, is_input, compat=compat)
Expand All @@ -83,12 +86,18 @@ def _build_dataclass_creation_fields(
elif field.has_alias and use_pydantic_alias:
graphql_name = field.alias

# for inputs with use_pydantic_default, default_factory should be used
if is_input and not use_pydantic_default:
default_factory = UNSET
else:
default_factory = get_default_factory_for_field(field, compat=compat)

strawberry_field = StrawberryField(
python_name=field.name,
graphql_name=graphql_name,
# always unset because we use default_factory instead
default=dataclasses.MISSING,
default_factory=get_default_factory_for_field(field, compat=compat),
default_factory=default_factory,
type_annotation=StrawberryAnnotation.from_annotation(field_type),
description=field.description,
deprecation_reason=(
Expand Down Expand Up @@ -127,6 +136,7 @@ def type(
all_fields: bool = False,
include_computed: bool = False,
use_pydantic_alias: bool = True,
use_pydantic_default: bool = True,
) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]:
def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]:
compat = PydanticCompat.from_model(model)
Expand Down Expand Up @@ -192,6 +202,7 @@ def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]:
auto_fields_set,
use_pydantic_alias,
compat=compat,
use_pydantic_default=use_pydantic_default,
)
for field_name, field in model_fields.items()
if field_name in fields_set
Expand Down Expand Up @@ -280,12 +291,15 @@ def from_pydantic_default(
return ret

def to_pydantic_default(self: Any, **kwargs: Any) -> PydanticModel:
instance_kwargs = {
f.name: convert_strawberry_class_to_pydantic_model(
getattr(self, f.name)
# when preserving omission on inputs, drop UNSET fields
instance_kwargs = {}
for f in dataclasses.fields(self):
value = getattr(self, f.name)
if is_input and value is UNSET and not use_pydantic_default:
continue
instance_kwargs[f.name] = convert_strawberry_class_to_pydantic_model(
value
)
for f in dataclasses.fields(self)
}
instance_kwargs.update(kwargs)
return model(**instance_kwargs)

Expand All @@ -309,12 +323,21 @@ def input(
directives: Sequence[object] | None = (),
all_fields: bool = False,
use_pydantic_alias: bool = True,
use_pydantic_default: bool = True,
) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]:
"""Convenience decorator for creating an input type from a Pydantic model.

Equal to `partial(type, is_input=True)`

See https://github.com/strawberry-graphql/strawberry/issues/1830.

Parameters
----------
use_pydantic_default:
When False, fields omitted by the GraphQL client are represented as
:data:`strawberry.UNSET` on the generated input class, instead of
materialising the Pydantic default or default_factory. This enables
true tri-state semantics (omitted vs. null vs. value) for inputs.
"""
return type(
model=model,
Expand All @@ -326,6 +349,7 @@ def input(
directives=directives,
all_fields=all_fields,
use_pydantic_alias=use_pydantic_alias,
use_pydantic_default=use_pydantic_default,
)


Expand Down
3 changes: 3 additions & 0 deletions strawberry/types/unset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def __new__(cls: type["UnsetType"]) -> "UnsetType":
return ret
return cls.__instance

def __call__(self) -> "UnsetType":
return self

def __str__(self) -> str:
return ""

Expand Down
82 changes: 82 additions & 0 deletions tests/experimental/pydantic/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,3 +1331,85 @@ def user(self) -> User:

assert not result.errors
assert result.data["user"] == {"age": 20, "location": "earth"}


@pytest.mark.parametrize(
"use_default, provided_interests, expected_raw, expected_pydantic",
[
# use_pydantic_default=False: omitted results in UNSET, no pydantic default
(
False,
strawberry.UNSET,
{"name": "John", "interests": strawberry.UNSET},
{"name": "John", "interests": []},
),
# use_pydantic_default=False: provided list passed through
(
False,
["games"],
{"name": "John", "interests": ["games"]},
{"name": "John", "interests": ["games"]},
),
# use_pydantic_default=False: provided None passed as None
(
False,
None,
{"name": "John", "interests": None},
{"name": "John", "interests": None},
),
# use_pydantic_default=True: omitted, default_factory=list is applied (not UNSET)
(
True,
strawberry.UNSET,
{"name": "John", "interests": []},
{"name": "John", "interests": []},
),
# use_pydantic_default=True: provided list unchanged
(
True,
["games"],
{"name": "John", "interests": ["games"]},
{"name": "John", "interests": ["games"]},
),
# use_pydantic_default=True: provided None unchanged
(
True,
None,
{"name": "John", "interests": None},
{"name": "John", "interests": None},
),
],
)
def test_input_use_pydantic_default_parameterized(
use_default,
provided_interests,
expected_raw,
expected_pydantic,
):
class UserModel(BaseModel):
name: str
interests: list[str] | None = Field(default_factory=list)

@strawberry.experimental.pydantic.input(
UserModel,
use_pydantic_default=use_default,
)
class UpdateUserInput:
name: strawberry.auto
interests: strawberry.auto

if provided_interests is strawberry.UNSET:
data = UpdateUserInput(name="John")
else:
data = UpdateUserInput(name="John", interests=provided_interests)

raw = strawberry.asdict(data)
assert raw["name"] == expected_raw["name"]

if expected_raw["interests"] is strawberry.UNSET:
assert raw["interests"] is strawberry.UNSET
else:
assert raw["interests"] == expected_raw["interests"]

p = data.to_pydantic().model_dump()
assert p == expected_pydantic
16 changes: 16 additions & 0 deletions tests/experimental/pydantic/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,19 @@ class Type:

assert field.python_name == "field"
assert field.type == Literal["field"]


def test_input_use_pydantic_default_false_field_types():
class UserModel(pydantic.BaseModel):
name: str
interests: list[str] | None = pydantic.Field(default_factory=list)

@strawberry.experimental.pydantic.input(UserModel, use_pydantic_default=False)
class UpdateUserInput:
name: strawberry.auto
interests: strawberry.auto

definition = UpdateUserInput.__strawberry_definition__
fields = {f.python_name: f for f in definition.fields}

assert isinstance(fields["interests"].type, StrawberryOptional)
91 changes: 91 additions & 0 deletions tests/schema/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,94 @@ class Query:

assert not result.errors
assert result.data["user"] == {"__typename": "User", "age_": 5}


@pytest.mark.parametrize(
"use_pydantic_default, expected_raw_interests, expected_pydantic",
[
(False, "UNSET", {"name": "John", "interests": []}),
(True, [], {"name": "John", "interests": []}),
],
)
def test_graphql_input_use_pydantic_default_integration(
use_pydantic_default,
expected_raw_interests,
expected_pydantic,
):
from pydantic import BaseModel, Field

class UserModel(BaseModel):
name: str
interests: list[str] | None = Field(default_factory=list)

@strawberry.experimental.pydantic.input(
UserModel,
use_pydantic_default=use_pydantic_default,
)
class UpdateUserInput:
name: strawberry.auto
interests: strawberry.auto

@strawberry.type
class RawResult:
name: str
interests: strawberry.scalars.JSON

@strawberry.type
class PydanticResult:
name: str
interests: strawberry.scalars.JSON | None

@strawberry.type
class UpdateResult:
raw: RawResult
pydantic: PydanticResult

@strawberry.type
class Mutation:
@strawberry.mutation
def update_user(self, user_data: UpdateUserInput) -> UpdateResult:
raw_dict = strawberry.asdict(user_data)
p_dict = user_data.to_pydantic().model_dump()

# JSON-friendly representation
raw_interests = raw_dict["interests"]
if raw_interests is strawberry.UNSET:
raw_interests = "UNSET"

return UpdateResult(
raw=RawResult(
name=raw_dict["name"],
interests=raw_interests,
),
pydantic=PydanticResult(
name=p_dict["name"],
interests=p_dict.get("interests"),
),
)

@strawberry.type
class Query:
ok: bool

schema = strawberry.Schema(query=Query, mutation=Mutation)

query = """
mutation {
updateUser(userData: { name: "John" }) {
raw { name interests }
pydantic { name interests }
}
}
"""

result = schema.execute_sync(query)
assert not result.errors

raw = result.data["updateUser"]["raw"]
pyd = result.data["updateUser"]["pydantic"]

assert raw["name"] == "John"
assert raw["interests"] == expected_raw_interests

assert pyd == expected_pydantic
Loading