Skip to content

Commit

Permalink
model: handle copying embedded model collections
Browse files Browse the repository at this point in the history
Find all instances of _BaseODMModel subclasses, including those stored
in lists or tuples, and set the required attribute on them during copy.

Fixes art049#321.
  • Loading branch information
bartoszflis-silvair committed Jan 17, 2023
1 parent 739a683 commit 9fa6897
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
32 changes: 27 additions & 5 deletions odmantic/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import decimal
import enum
import functools
import pathlib
import uuid
import warnings
Expand All @@ -16,6 +17,7 @@
FrozenSet,
Iterable,
List,
NamedTuple,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -499,6 +501,29 @@ def __new__(


BaseT = TypeVar("BaseT", bound="_BaseODMModel")
TraversalStateT = NamedTuple(
"TraversalStateT", [("output", List[Any]), ("staging", List[Any])]
)


def flat_tree(o: BaseT) -> List[BaseT]:
state = TraversalStateT(output=[], staging=[o])

def obj_fields(obj):
return [getattr(obj, name) for name in set(obj.__odm_fields__)]

def unpack(acc: TraversalStateT, obj: Any) -> TraversalStateT:
output, (_, *staging_tail) = acc
if isinstance(obj, _BaseODMModel):
return TraversalStateT(output + [obj], staging_tail + obj_fields(obj))
elif isinstance(obj, Iterable) and not isinstance(obj, (str, dict)):
return TraversalStateT(output, staging_tail + [*obj])
else:
return TraversalStateT(output, staging_tail)

while state.staging:
state = functools.reduce(unpack, state.staging, state)
return state.output


class _BaseODMModel(pydantic.BaseModel, metaclass=ABCMeta):
Expand Down Expand Up @@ -583,11 +608,8 @@ def _post_copy_update(self: BaseT) -> None:
"""Recursively update internal fields of the copied model after it has been
copied.
"""
object.__setattr__(self, "__fields_modified__", set(self.__fields__))
for field_name, field in self.__odm_fields__.items():
if isinstance(field, ODMEmbedded):
value = getattr(self, field_name)
value._post_copy_update()
for model in flat_tree(self):
object.__setattr__(model, "__fields_modified__", set(model.__fields__))

def update(
self,
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/test_model_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,27 @@ class M(Model):
assert instance.e.f.g != copied.e.f.g


@pytest.mark.parametrize(
"hint, ctor",
[
pytest.param(List, list),
pytest.param(Tuple, tuple),
],
)
def test_model_copy_deep_embedded_model_collection(hint, ctor):
class E(EmbeddedModel):
f: int

class M(Model):
e: hint[E]

instance = M(e=ctor([E(f=1)]))
copied = instance.copy(deep=True)
copied.e[0].f = 2

assert copied.e[0].f != instance.e[0].f


def test_model_copy_not_deep_embedded():
class E(EmbeddedModel):
f: int
Expand Down

0 comments on commit 9fa6897

Please sign in to comment.