Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/contrib/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from tests.testmodels import (
Address,
Author,
CamelCaseAliasPerson,
Employee,
EnumFields,
Expand All @@ -13,6 +14,7 @@
JSONFields,
ModelTestPydanticMetaBackwardRelations1,
ModelTestPydanticMetaBackwardRelations2,
Node,
Reporter,
Team,
Tournament,
Expand Down Expand Up @@ -48,6 +50,10 @@ class PydanticMetaOverride:
self.Event_Pydantic_non_backward_from_override = pydantic_model_creator(
Event, meta_override=PydanticMetaOverride, name="Event_non_backward"
)
self.Author_Pydantic = pydantic_model_creator(Author, meta_override=PydanticMetaOverride)
self.Node_Pydantic = pydantic_model_creator(
Node, meta_override=PydanticMetaOverride, exclude=("o2opkmodelwithm2ms",)
)

self.tournament = await Tournament.create(name="New Tournament")
self.reporter = await Reporter.create(name="The Reporter")
Expand All @@ -62,6 +68,24 @@ class PydanticMetaOverride:
await self.event2.participants.add(self.team1, self.team2)
self.maxDiff = None

async def test_with_default_but_not_null(self):
author_data = self.Author_Pydantic.model_validate({"id": 1}).model_dump()
assert author_data == {"id": 1, "name": ""}
node_data = self.Node_Pydantic.model_validate({"id": 1}).model_dump()
assert node_data["id"] == 1 and node_data["name"] and isinstance(node_data["name"], str)
info = {
"input": None,
"type": "string_type",
"loc": ("name",),
"msg": "Input should be a valid string",
}
with self.assertRaises(ValidationError) as cm:
self.Author_Pydantic.model_validate({"id": 1, "name": None})
self.assertEqual([info], cm.exception.errors(include_url=False))
with self.assertRaises(ValidationError) as cm:
self.Node_Pydantic.model_validate({"id": 1, "name": None})
self.assertEqual([info], cm.exception.errors(include_url=False))

async def test_backward_relations_with_meta_override(self):
event_schema = copy.deepcopy(dict(self.Event_Pydantic.model_json_schema()))
event_non_backward_schema_by_override = copy.deepcopy(
Expand Down
28 changes: 18 additions & 10 deletions tests/testmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Union

import pytz
from anyio.lowlevel import checkpoint
from pydantic import BaseModel, ConfigDict

from tortoise import fields
Expand All @@ -33,10 +34,15 @@
)


def generate_token():
def generate_token() -> str:
return binascii.hexlify(os.urandom(16)).decode("ascii")


async def generate_unique_string() -> str:
await checkpoint()
return uuid.uuid4().hex[:10]


class TestSchemaForJSONField(BaseModel):
foo: int
bar: str
Expand All @@ -47,7 +53,7 @@ class TestSchemaForJSONField(BaseModel):


class Author(Model):
name = fields.CharField(max_length=255)
name = fields.CharField(max_length=255, default="", null=False)


class Book(Model):
Expand Down Expand Up @@ -151,7 +157,7 @@ class ModelTestPydanticMetaBackwardRelations3(Model):


class Node(Model):
name = fields.CharField(max_length=10)
name = fields.CharField(max_length=10, default=generate_unique_string, null=False)


class Tree(Model):
Expand Down Expand Up @@ -333,7 +339,7 @@ class FloatFields(Model):
floatnum_null = fields.FloatField(null=True)


def raise_if_not_dict_or_list(value: dict | list):
def raise_if_not_dict_or_list(value: dict | list) -> None:
if not isinstance(value, (dict, list)):
raise ValidationError("Value must be a dict or list.")

Expand Down Expand Up @@ -570,7 +576,7 @@ class Employee(Model):
def __str__(self):
return self.name

async def full_hierarchy__async_for(self, level=0):
async def full_hierarchy__async_for(self, level=0) -> str:
"""
Demonstrates ``async for` to fetch relations

Expand All @@ -588,7 +594,7 @@ async def full_hierarchy__async_for(self, level=0):
text.append(await member.full_hierarchy__async_for(level + 1))
return "\n".join(text)

async def full_hierarchy__fetch_related(self, level=0):
async def full_hierarchy__fetch_related(self, level=0) -> str:
"""
Demonstrates ``await .fetch_related`` to fetch relations

Expand Down Expand Up @@ -883,16 +889,18 @@ class NumberSourceField(Model):


class StatusQuerySet(QuerySet):
def active(self):
def active(self) -> QuerySet:
return self.filter(status=1)


class StatusManager(Manager):
def __init__(self, model=None, queryset_cls=None) -> None:
def __init__(
self, model: type[Model] | None = None, queryset_cls: type[QuerySet] | None = None
) -> None:
super().__init__(model=model)
self.queryset_cls = queryset_cls or QuerySet

def get_queryset(self):
def get_queryset(self) -> QuerySet:
return self.queryset_cls(self._model)


Expand Down Expand Up @@ -961,7 +969,7 @@ class OldStyleModel(Model):
external_id = fields.IntField(index=True)


def camelize_var(var_name: str):
def camelize_var(var_name: str) -> str:
var_parts: list[str] = var_name.split("_")
return var_parts[0] + "".join([part.title() for part in var_parts[1:]])

Expand Down
50 changes: 39 additions & 11 deletions tortoise/contrib/pydantic/creator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

import functools
import inspect
from base64 import b32encode
from collections.abc import MutableMapping
from collections.abc import Awaitable, MutableMapping
from copy import copy
from enum import Enum, IntEnum
from hashlib import sha3_224
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union

from anyio import from_thread
from pydantic import ConfigDict, computed_field, create_model
from pydantic import Field as PydanticField
from pydantic.fields import ComputedFieldInfo
Expand Down Expand Up @@ -100,6 +102,27 @@ def _pydantic_recursion_protector(
return pmc.create_pydantic_model()


T_Retval = TypeVar("T_Retval")


async def async_to_sync(func: Callable[[], Awaitable[T_Retval]]) -> Callable[[], T_Retval]:
"""Wrap the async function to be sync(will be run in worker thread)"""

@functools.wraps(func)
def wrapped() -> T_Retval:
result: list[T_Retval] = []

async def runner() -> None:
res = await func()
result.append(res)

with from_thread.start_blocking_portal() as portal:
portal.call(runner)
return result[0]

return wrapped


class FieldMap(MutableMapping[str, Union[Field, ComputedFieldDescription]]):
def __init__(self, meta: PydanticMetaData, pk_field: Field | None = None):
self._field_map: dict[str, Field | ComputedFieldDescription] = {}
Expand Down Expand Up @@ -432,20 +455,24 @@ def _process_field(
description = _br_it(field.docstring or field.description or "")
if description:
fconfig["description"] = description
if field_name in self._optional or (
field.default is not None and not callable(field.default)
):
self._properties[field_name] = (
field_property,
PydanticField(default=field.default, **fconfig),
)
field_default = field.default
if field_name in self._optional:
fconfig["default"] = field_default
elif field_default is not None:
if callable(field_default):
if inspect.iscoroutinefunction(field_default):
fconfig["default_factory"] = async_to_sync(field_default)
else:
fconfig["default_factory"] = field_default
else:
fconfig["default"] = field_default
else:
if (json_schema_extra.get("nullable") and not is_to_one_relation) or (
self._exclude_read_only and json_schema_extra.get("readOnly")
):
# see: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
fconfig["default"] = None
self._properties[field_name] = (field_property, PydanticField(**fconfig))
self._properties[field_name] = (field_property, PydanticField(**fconfig))
elif isinstance(field, ComputedFieldDescription):
field_property, is_to_one_relation = self._process_computed_field(field), False
if field_property:
Expand Down Expand Up @@ -525,7 +552,8 @@ def _process_data_field(
if field.null:
json_schema_extra["nullable"] = True
if not field.pk and (
field_name in self._optional or field.default is not None or field.null
# field_name in self._optional or field.default is not None or field.null
field_name in self._optional or field.null
):
ptype = Optional[ptype]
if not (self._exclude_read_only and json_schema_extra.get("readOnly") is True):
Expand Down
Loading