diff --git a/.gitignore b/.gitignore index 83e8a08..bece568 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,5 @@ parsetab_* .tox .eggs/ venv/ +.\#* +\#*\# \ No newline at end of file diff --git a/ming/odm/mapper.py b/ming/odm/mapper.py index 6e89a22..b79eff0 100644 --- a/ming/odm/mapper.py +++ b/ming/odm/mapper.py @@ -1,6 +1,6 @@ import six import warnings -from copy import copy +from copy import copy, deepcopy from ming.base import Object, NoDefault from ming.utils import wordwrap @@ -61,6 +61,16 @@ def __init__(self, mapped_class, collection, session, **kwargs): raise TypeError('Unknown kwd args: %r' % kwargs) self._instrument_class(properties, include_properties, exclude_properties) + @classmethod + def replace_session(cls, session): + for _mapper in cls.all_mappers(): + _mapper.session = session + _mapper.mapped_class.query.session = session + _mapper.mapped_class.__mongometa__.session = session + _mapper._compiled = False + _mapper.compile() + _mapper.session.ensure_indexes(_mapper.collection) + def __repr__(self): return '' % ( self.mapped_class.__name__, self.collection.m.collection_name) @@ -74,13 +84,12 @@ def insert(self, obj, state, session, **kwargs): @_with_hooks('update') def update(self, obj, state, session, **kwargs): - fields = state.options.get('fields', None) - if fields is None: - fields = () - doc = self.collection(state.document, skip_from_bson=True) - ret = session.impl.save(doc, *fields, validate=False) + ret = session.impl.save(doc, validate=False, state=state) state.status = state.clean + # Make sure that st.document is never the same as st.original_document + # otherwise mutating one mutates the other. + state.original_document = deepcopy(doc) return ret @_with_hooks('delete') @@ -177,9 +186,7 @@ def _from_doc(self, doc, options, validate=True): # Make sure that st.document is never the same as st.original_document # otherwise mutating one mutates the other. - # There is no need to deepcopy as nested mutable objects are already - # copied by InstrumentedList and InstrumentedObj to instrument them. - st.original_document = doc + st.original_document = deepcopy(doc) if validate is False: # .create calls this after it already created the document with the diff --git a/ming/session.py b/ming/session.py index 3dd33ea..7a829b7 100644 --- a/ming/session.py +++ b/ming/session.py @@ -9,11 +9,12 @@ import six from .base import Cursor, Object -from .utils import fixup_index, fix_write_concern +from .utils import fixup_index, fix_write_concern, doc_to_set from . import exc log = logging.getLogger(__name__) + def annotate_doc_failure(func): '''Decorator to wrap a session operation so that any pymongo errors raised will note the document that caused the failure @@ -30,7 +31,7 @@ def wrapper(self, doc, *args, **kwargs): return update_wrapper(wrapper, func) -class Session(object): +class Session: _registry = {} _datastores = {} @@ -139,7 +140,8 @@ def find_and_modify(self, cls, query=None, sort=None, new=False, **kw): def _prep_save(self, doc, validate): hook = doc.m.before_save - if hook: hook(doc) + if hook: + hook(doc) if validate: if doc.m.schema is None: data = dict(doc) @@ -151,8 +153,12 @@ def _prep_save(self, doc, validate): return data @annotate_doc_failure - def save(self, doc, *args, **kwargs): + def save(self, doc, *args, state=None, **kwargs): data = self._prep_save(doc, kwargs.pop('validate', True)) + if not args and state is not None and state.original_document: + args = tuple(set((k for k, v in + doc_to_set(state.original_document) + ^ doc_to_set(data)))) if args: values = dict((arg, data[arg]) for arg in args) result = self._impl(doc).update( diff --git a/ming/tests/odm/test_declarative.py b/ming/tests/odm/test_declarative.py index 9cbd9f3..1af5e89 100644 --- a/ming/tests/odm/test_declarative.py +++ b/ming/tests/odm/test_declarative.py @@ -1,6 +1,7 @@ import sys from collections import defaultdict from unittest import TestCase, SkipTest +from unittest.mock import MagicMock from ming import schema as S from ming import create_datastore @@ -811,3 +812,60 @@ def test_hook_base(self): [ {'_id': doc._id, 'a': doc.a} ]) + + +class TestReplacingSession(TestCase): + + def setUp(self): + Mapper._mapper_by_classname.clear() + self.datastore = create_datastore('mim:///test_db') + self.session = ODMSession(bind=self.datastore) + class Basic(MappedClass): + class __mongometa__: + name = 'hook' + session = self.session + _id = FieldProperty(S.ObjectId) + a = FieldProperty(int) + Mapper.compile_all() + self.Basic = Basic + self.session.remove(self.Basic) + + def test_hook_base(self): + assert id(self.Basic.query.session) == id(self.session) + session2 = MagicMock() + new_session = ODMSession(bind=session2) + Mapper.replace_session(new_session) + assert id(self.Basic.query.session) == id(new_session) + assert id(self.session) != id(new_session) + +class TestBeforeSave(TestCase): + + def setUp(self): + Mapper._mapper_by_classname.clear() + self.datastore = create_datastore('mim:///test_db') + self.session = ODMSession(bind=self.datastore) + class Basic(MappedClass): + class __mongometa__: + name = 'hook' + session = self.session + def before_save(instance): + instance.a = 9 + + _id = FieldProperty(S.ObjectId) + a = FieldProperty(int) + Mapper.compile_all() + self.Basic = Basic + self.session.remove(self.Basic) + + def test_hook_base(self): + doc = self.Basic() + doc.a = 5 + self.session.flush() # first insert + self.session.close() + doc = self.Basic.query.get(doc._id) + assert doc.a == 9, doc.a + doc.a = 6 + self.session.flush() # then save + self.session.close() + doc = self.Basic.query.get(doc._id) + assert doc.a == 9, doc.a diff --git a/ming/tests/odm/test_mapper.py b/ming/tests/odm/test_mapper.py index e8055a7..0c202da 100644 --- a/ming/tests/odm/test_mapper.py +++ b/ming/tests/odm/test_mapper.py @@ -277,6 +277,23 @@ def test_group(self, pymongo_group): self.Basic.query.group() assert pymongo_group.called + def test_multiple_update_flushes(self): + initial_doc = self.Basic() + initial_doc.a = 1 + self.session.flush() + self.session.close() + + doc_updating = self.Basic.query.get(_id=initial_doc._id) + doc_updating.a = 2 + self.session.flush() + doc_updating.a = 1 # back to "initial" value + doc_updating.e = 'foo' # change something else too + self.session.flush() + self.session.close() + + doc_after_updates = self.Basic.query.get(_id=doc_updating._id) + assert doc_after_updates.a == 1 + class TestRelation(TestCase): def setUp(self): diff --git a/ming/tests/test_session.py b/ming/tests/test_session.py index 4631e25..80fe412 100644 --- a/ming/tests/test_session.py +++ b/ming/tests/test_session.py @@ -81,7 +81,7 @@ def test_base_session(self): doc = self.TestDocNoSchema({'_id':5, 'a':5}) sess.save(doc) impl.save.assert_called_with(dict(_id=5, a=5)) - doc = self.TestDocNoSchema({'_id':5, 'a':5}) + doc = self.TestDocNoSchema({'_id':5, 'a':5, 'b': 6}) sess.save(doc, 'a') impl.update.assert_called_with(dict(_id=5), {'$set':dict(a=5)}) doc = self.TestDocNoSchema({'_id':5, 'a':5}) diff --git a/ming/utils.py b/ming/utils.py index cabf3c6..8c66f7d 100644 --- a/ming/utils.py +++ b/ming/utils.py @@ -145,3 +145,16 @@ def fix_write_concern(kwargs): warnings.warn('safe option is now deprecated', DeprecationWarning) kwargs['w'] = int(kwargs.pop('safe')) return kwargs + + +def to_hashable(v): + if isinstance(v, list): + return tuple((to_hashable(sv) for sv in v)) + elif isinstance(v, dict): + return tuple(((to_hashable(k), to_hashable(sv)) + for k, sv in sorted(v.items()))) + return v + + +def doc_to_set(doc): + return set((k, to_hashable(v)) for k, v in doc.copy().items()) diff --git a/setup.cfg b/setup.cfg index ebde9bc..1be30d6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,6 @@ +[pylint] +# disabling protected-access because of mongodb _id property +disable = protected-access [nosetests] detailed-errors=1