Skip to content

Commit cc8c71d

Browse files
Add logic to handle globalns and localns in pydantic_model_creator
1 parent 1d4d60b commit cc8c71d

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

tortoise/contrib/pydantic/creator.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def _pydantic_recursion_protector(
6161
name=None,
6262
allow_cycles: bool = False,
6363
sort_alphabetically: Optional[bool] = None,
64+
globalns: dict = None,
65+
localns: dict = None,
6466
) -> Optional[type[PydanticModel]]:
6567
"""
6668
It is an inner function to protect pydantic model creator against cyclic recursion
@@ -93,6 +95,8 @@ def _pydantic_recursion_protector(
9395
_stack=stack,
9496
allow_cycles=allow_cycles,
9597
sort_alphabetically=sort_alphabetically,
98+
globalns=globalns,
99+
localns=localns,
96100
_as_submodel=True,
97101
)
98102
return pmc.create_pydantic_model()
@@ -237,6 +241,8 @@ def __init__(
237241
model_config: Optional[ConfigDict] = None,
238242
validators: Optional[dict[str, Any]] = None,
239243
module: str = __name__,
244+
globalns: dict = None,
245+
localns: dict = None,
240246
_stack: tuple = (),
241247
_as_submodel: bool = False,
242248
) -> None:
@@ -288,7 +294,7 @@ def __init__(
288294

289295
self._as_submodel = _as_submodel
290296

291-
self._annotations = get_annotations(cls)
297+
self._annotations = get_annotations(cls, globalns=globalns, localns=localns)
292298

293299
self._pconfig: ConfigDict
294300

@@ -305,6 +311,9 @@ def __init__(
305311
self._validators = validators
306312
self._module = module
307313

314+
self.globalns = globalns
315+
self.localns = localns
316+
308317
self._stack = _stack
309318

310319
@property
@@ -523,7 +532,7 @@ def _process_computed_field(
523532
field: ComputedFieldDescription,
524533
) -> Optional[Any]:
525534
func = field.function
526-
annotation = get_annotations(self._cls, func).get("return", None)
535+
annotation = get_annotations(self._cls, func, globalns=self.globalns, localns=self.localns).get("return", None)
527536
comment = _cleandoc(func)
528537
if annotation is not None:
529538
c_f = computed_field(return_type=annotation, description=comment)
@@ -555,6 +564,8 @@ def get_fields_to_carry_on(field_tuple: tuple[str, ...]) -> tuple[str, ...]:
555564
stack=new_stack,
556565
allow_cycles=self.meta.allow_cycles,
557566
sort_alphabetically=self.meta.sort_alphabetically,
567+
globalns=self.globalns,
568+
localns=self.localns,
558569
)
559570
else:
560571
pmodel = None
@@ -581,6 +592,8 @@ def pydantic_model_creator(
581592
model_config: Optional[ConfigDict] = None,
582593
validators: Optional[dict[str, Any]] = None,
583594
module: str = __name__,
595+
globalns: dict = None,
596+
localns: dict = None,
584597
) -> type[PydanticModel]:
585598
"""
586599
Function to build `Pydantic Model <https://docs.pydantic.dev/latest/concepts/models/>`__ off Tortoise Model.
@@ -607,6 +620,8 @@ def pydantic_model_creator(
607620
:param model_config: A custom config to use as pydantic config.
608621
:param validators: A dictionary of methods that validate fields.
609622
:param module: The name of the module that the model belongs to.
623+
:param globalns: If specified, use this dictionary as the globals map.
624+
:param localns: If specified, use this dictionary as the locals map.
610625
611626
Note: Created pydantic model uses config_class parameter and PydanticMeta's
612627
config_class as its Config class's bases(Only if provided!), but it
@@ -627,5 +642,7 @@ def pydantic_model_creator(
627642
model_config=model_config,
628643
validators=validators,
629644
module=module,
645+
globalns=globalns,
646+
localns=localns,
630647
)
631648
return pmc.create_pydantic_model()

tortoise/contrib/pydantic/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,21 @@
55
from tortoise.models import Model
66

77

8+
<<<<<<< HEAD
89
def get_annotations(cls: "type[Model]", method: Optional[Callable] = None) -> dict[str, Any]:
10+
=======
11+
def get_annotations(cls: "Type[Model]", method: Optional[Callable] = None, globalns: dict = None, localns: dict = None) -> dict[str, Any]:
12+
>>>>>>> Add logic to handle `globalns` and `localns` in `pydantic_model_creator`
913
"""
1014
Get all annotations including base classes
1115
:param cls: The model class we need annotations from
1216
:param method: If specified, we try to get the annotations for the callable
17+
:param globalns: If specified, use this dictionary as the globals map
18+
:param localns: If specified, use this dictionary as the locals map
1319
:return: The list of annotations
1420
"""
21+
<<<<<<< HEAD
1522
return get_type_hints(method or cls)
23+
=======
24+
return typing.get_type_hints(method or cls, globalns, localns)
25+
>>>>>>> Add logic to handle `globalns` and `localns` in `pydantic_model_creator`

0 commit comments

Comments
 (0)