From c15a2befc6662ae5f35beca20f5330ea3bbd0aea Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Sat, 26 Jul 2025 22:27:58 +0800 Subject: [PATCH 1/3] fix not none field with default accept None value for pydantic model --- tests/contrib/test_pydantic.py | 31 ++++++++++++++++++++++++++++ tests/testmodels.py | 26 ++++++++++++++--------- tortoise/contrib/pydantic/creator.py | 3 ++- 3 files changed, 49 insertions(+), 11 deletions(-) diff --git a/tests/contrib/test_pydantic.py b/tests/contrib/test_pydantic.py index 56ebeaf37..f688aca0b 100644 --- a/tests/contrib/test_pydantic.py +++ b/tests/contrib/test_pydantic.py @@ -5,6 +5,7 @@ from tests.testmodels import ( Address, + Author, CamelCaseAliasPerson, Employee, EnumFields, @@ -13,10 +14,12 @@ JSONFields, ModelTestPydanticMetaBackwardRelations1, ModelTestPydanticMetaBackwardRelations2, + Node, Reporter, Team, Tournament, User, + generate_unique_string, json_pydantic_default, ) from tortoise.contrib import test @@ -48,6 +51,12 @@ class PydanticMetaOverride: self.Event_Pydantic_non_backward_from_override = pydantic_model_creator( Event, meta_override=PydanticMetaOverride, name="Event_non_backward" ) + self.Author_Pydantic = pydantic_model_creator(Author, meta_override=PydanticMetaOverride) + # TODO: no exclude + # self.Node_Pydantic = pydantic_model_creator(Node, meta_override=PydanticMetaOverride) + self.Node_Pydantic = pydantic_model_creator( + Node, meta_override=PydanticMetaOverride, exclude=("o2opkmodelwithm2ms",) + ) self.tournament = await Tournament.create(name="New Tournament") self.reporter = await Reporter.create(name="The Reporter") @@ -62,6 +71,28 @@ class PydanticMetaOverride: await self.event2.participants.add(self.team1, self.team2) self.maxDiff = None + async def test_with_default_but_not_null(self): + author_data = self.Author_Pydantic.model_validate({"id": 1}).model_dump() + assert author_data == {"id": 1, "name": ""} + # TODO: no name + # node_data = self.Node_Pydantic.model_validate({"id": 1}).model_dump() + node_data = self.Node_Pydantic.model_validate( + {"id": 1, "name": generate_unique_string()} + ).model_dump() + assert node_data["id"] == 1 and node_data["name"] and isinstance(node_data["name"], str) + info = { + "input": None, + "type": "string_type", + "loc": ("name",), + "msg": "Input should be a valid string", + } + with self.assertRaises(ValidationError) as cm: + self.Author_Pydantic.model_validate({"id": 1, "name": None}) + self.assertEqual([info], cm.exception.errors(include_url=False)) + with self.assertRaises(ValidationError) as cm: + self.Node_Pydantic.model_validate({"id": 1, "name": None}) + self.assertEqual([info], cm.exception.errors(include_url=False)) + async def test_backward_relations_with_meta_override(self): event_schema = copy.deepcopy(dict(self.Event_Pydantic.model_json_schema())) event_non_backward_schema_by_override = copy.deepcopy( diff --git a/tests/testmodels.py b/tests/testmodels.py index c0b95f352..fdbe785ca 100644 --- a/tests/testmodels.py +++ b/tests/testmodels.py @@ -33,10 +33,14 @@ ) -def generate_token(): +def generate_token() -> str: return binascii.hexlify(os.urandom(16)).decode("ascii") +def generate_unique_string() -> str: + return uuid.uuid4().hex[:10] + + class TestSchemaForJSONField(BaseModel): foo: int bar: str @@ -47,7 +51,7 @@ class TestSchemaForJSONField(BaseModel): class Author(Model): - name = fields.CharField(max_length=255) + name = fields.CharField(max_length=255, default="", null=False) class Book(Model): @@ -151,7 +155,7 @@ class ModelTestPydanticMetaBackwardRelations3(Model): class Node(Model): - name = fields.CharField(max_length=10) + name = fields.CharField(max_length=10, default=generate_unique_string, null=False) class Tree(Model): @@ -333,7 +337,7 @@ class FloatFields(Model): floatnum_null = fields.FloatField(null=True) -def raise_if_not_dict_or_list(value: dict | list): +def raise_if_not_dict_or_list(value: dict | list) -> None: if not isinstance(value, (dict, list)): raise ValidationError("Value must be a dict or list.") @@ -570,7 +574,7 @@ class Employee(Model): def __str__(self): return self.name - async def full_hierarchy__async_for(self, level=0): + async def full_hierarchy__async_for(self, level=0) -> str: """ Demonstrates ``async for` to fetch relations @@ -588,7 +592,7 @@ async def full_hierarchy__async_for(self, level=0): text.append(await member.full_hierarchy__async_for(level + 1)) return "\n".join(text) - async def full_hierarchy__fetch_related(self, level=0): + async def full_hierarchy__fetch_related(self, level=0) -> str: """ Demonstrates ``await .fetch_related`` to fetch relations @@ -883,16 +887,18 @@ class NumberSourceField(Model): class StatusQuerySet(QuerySet): - def active(self): + def active(self) -> QuerySet: return self.filter(status=1) class StatusManager(Manager): - def __init__(self, model=None, queryset_cls=None) -> None: + def __init__( + self, model: type[Model] | None = None, queryset_cls: type[QuerySet] | None = None + ) -> None: super().__init__(model=model) self.queryset_cls = queryset_cls or QuerySet - def get_queryset(self): + def get_queryset(self) -> QuerySet: return self.queryset_cls(self._model) @@ -961,7 +967,7 @@ class OldStyleModel(Model): external_id = fields.IntField(index=True) -def camelize_var(var_name: str): +def camelize_var(var_name: str) -> str: var_parts: list[str] = var_name.split("_") return var_parts[0] + "".join([part.title() for part in var_parts[1:]]) diff --git a/tortoise/contrib/pydantic/creator.py b/tortoise/contrib/pydantic/creator.py index d67f49d38..047231aa5 100644 --- a/tortoise/contrib/pydantic/creator.py +++ b/tortoise/contrib/pydantic/creator.py @@ -525,7 +525,8 @@ def _process_data_field( if field.null: json_schema_extra["nullable"] = True if not field.pk and ( - field_name in self._optional or field.default is not None or field.null + # field_name in self._optional or field.default is not None or field.null + field_name in self._optional or field.null ): ptype = Optional[ptype] if not (self._exclude_read_only and json_schema_extra.get("readOnly") is True): From 9346f321ec8de11918926930f286390a1a6c8c89 Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Sat, 26 Jul 2025 22:34:09 +0800 Subject: [PATCH 2/3] Use default_factory if field.default is callable --- tests/contrib/test_pydantic.py | 7 +------ tortoise/contrib/pydantic/creator.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/tests/contrib/test_pydantic.py b/tests/contrib/test_pydantic.py index f688aca0b..9c30c23db 100644 --- a/tests/contrib/test_pydantic.py +++ b/tests/contrib/test_pydantic.py @@ -19,7 +19,6 @@ Team, Tournament, User, - generate_unique_string, json_pydantic_default, ) from tortoise.contrib import test @@ -74,11 +73,7 @@ class PydanticMetaOverride: async def test_with_default_but_not_null(self): author_data = self.Author_Pydantic.model_validate({"id": 1}).model_dump() assert author_data == {"id": 1, "name": ""} - # TODO: no name - # node_data = self.Node_Pydantic.model_validate({"id": 1}).model_dump() - node_data = self.Node_Pydantic.model_validate( - {"id": 1, "name": generate_unique_string()} - ).model_dump() + node_data = self.Node_Pydantic.model_validate({"id": 1}).model_dump() assert node_data["id"] == 1 and node_data["name"] and isinstance(node_data["name"], str) info = { "input": None, diff --git a/tortoise/contrib/pydantic/creator.py b/tortoise/contrib/pydantic/creator.py index 047231aa5..6d924b513 100644 --- a/tortoise/contrib/pydantic/creator.py +++ b/tortoise/contrib/pydantic/creator.py @@ -432,20 +432,20 @@ def _process_field( description = _br_it(field.docstring or field.description or "") if description: fconfig["description"] = description - if field_name in self._optional or ( - field.default is not None and not callable(field.default) - ): - self._properties[field_name] = ( - field_property, - PydanticField(default=field.default, **fconfig), - ) + if field_name in self._optional: + fconfig["default"] = field.default + elif field.default is not None: + if callable(field.default): + fconfig["default_factory"] = field.default + else: + fconfig["default"] = field.default else: if (json_schema_extra.get("nullable") and not is_to_one_relation) or ( self._exclude_read_only and json_schema_extra.get("readOnly") ): # see: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields fconfig["default"] = None - self._properties[field_name] = (field_property, PydanticField(**fconfig)) + self._properties[field_name] = (field_property, PydanticField(**fconfig)) elif isinstance(field, ComputedFieldDescription): field_property, is_to_one_relation = self._process_computed_field(field), False if field_property: From 415d8bdbf687d748d34397ba52a0f9024e2ec9ac Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Mon, 28 Jul 2025 23:08:57 +0800 Subject: [PATCH 3/3] fix: default value is async function --- tests/contrib/test_pydantic.py | 2 -- tests/testmodels.py | 4 ++- tortoise/contrib/pydantic/creator.py | 41 +++++++++++++++++++++++----- 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/tests/contrib/test_pydantic.py b/tests/contrib/test_pydantic.py index 9c30c23db..f75ce2a44 100644 --- a/tests/contrib/test_pydantic.py +++ b/tests/contrib/test_pydantic.py @@ -51,8 +51,6 @@ class PydanticMetaOverride: Event, meta_override=PydanticMetaOverride, name="Event_non_backward" ) self.Author_Pydantic = pydantic_model_creator(Author, meta_override=PydanticMetaOverride) - # TODO: no exclude - # self.Node_Pydantic = pydantic_model_creator(Node, meta_override=PydanticMetaOverride) self.Node_Pydantic = pydantic_model_creator( Node, meta_override=PydanticMetaOverride, exclude=("o2opkmodelwithm2ms",) ) diff --git a/tests/testmodels.py b/tests/testmodels.py index fdbe785ca..d02a66e86 100644 --- a/tests/testmodels.py +++ b/tests/testmodels.py @@ -14,6 +14,7 @@ from typing import Union import pytz +from anyio.lowlevel import checkpoint from pydantic import BaseModel, ConfigDict from tortoise import fields @@ -37,7 +38,8 @@ def generate_token() -> str: return binascii.hexlify(os.urandom(16)).decode("ascii") -def generate_unique_string() -> str: +async def generate_unique_string() -> str: + await checkpoint() return uuid.uuid4().hex[:10] diff --git a/tortoise/contrib/pydantic/creator.py b/tortoise/contrib/pydantic/creator.py index 6d924b513..9d6859e07 100644 --- a/tortoise/contrib/pydantic/creator.py +++ b/tortoise/contrib/pydantic/creator.py @@ -1,13 +1,15 @@ from __future__ import annotations +import functools import inspect from base64 import b32encode -from collections.abc import MutableMapping +from collections.abc import Awaitable, MutableMapping from copy import copy from enum import Enum, IntEnum from hashlib import sha3_224 -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union +from anyio import from_thread from pydantic import ConfigDict, computed_field, create_model from pydantic import Field as PydanticField from pydantic.fields import ComputedFieldInfo @@ -100,6 +102,27 @@ def _pydantic_recursion_protector( return pmc.create_pydantic_model() +T_Retval = TypeVar("T_Retval") + + +async def async_to_sync(func: Callable[[], Awaitable[T_Retval]]) -> Callable[[], T_Retval]: + """Wrap the async function to be sync(will be run in worker thread)""" + + @functools.wraps(func) + def wrapped() -> T_Retval: + result: list[T_Retval] = [] + + async def runner() -> None: + res = await func() + result.append(res) + + with from_thread.start_blocking_portal() as portal: + portal.call(runner) + return result[0] + + return wrapped + + class FieldMap(MutableMapping[str, Union[Field, ComputedFieldDescription]]): def __init__(self, meta: PydanticMetaData, pk_field: Field | None = None): self._field_map: dict[str, Field | ComputedFieldDescription] = {} @@ -432,13 +455,17 @@ def _process_field( description = _br_it(field.docstring or field.description or "") if description: fconfig["description"] = description + field_default = field.default if field_name in self._optional: - fconfig["default"] = field.default - elif field.default is not None: - if callable(field.default): - fconfig["default_factory"] = field.default + fconfig["default"] = field_default + elif field_default is not None: + if callable(field_default): + if inspect.iscoroutinefunction(field_default): + fconfig["default_factory"] = async_to_sync(field_default) + else: + fconfig["default_factory"] = field_default else: - fconfig["default"] = field.default + fconfig["default"] = field_default else: if (json_schema_extra.get("nullable") and not is_to_one_relation) or ( self._exclude_read_only and json_schema_extra.get("readOnly")