Skip to content

Commit

Permalink
Raise an assertion if a flush() occurs on changes without a proper cl…
Browse files Browse the repository at this point in the history
…ock tick (#24)

* Fix for issue #21: Raise an assertion if a flush() occurs on temporalized changes without a proper clock tick
* Refactor invalid flush() assertion to only be active if a flag (strict_mode) is set
* Bump version number to 0.3.2
  • Loading branch information
bijanvakili authored Jun 1, 2017
1 parent a3a3ed6 commit 3b44746
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 13 deletions.
1 change: 1 addition & 0 deletions AUTHORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Here is an entirely incomplete list of amazing contributors.

Alec Clowes <[email protected]>
Ben Kudria <[email protected]>
Bijan Vakili <[email protected]>
Dave Flerlage <[email protected]>
Diego Argueta <[email protected]>
George Leslie-Waksman <[email protected]>
Expand Down
10 changes: 10 additions & 0 deletions temporal_sqlalchemy/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import psycopg2.extras as psql_extras

from temporal_sqlalchemy import nine
from temporal_sqlalchemy.metadata import get_session_metadata


_ClockSet = collections.namedtuple('_ClockSet', ('effective', 'vclock'))

Expand Down Expand Up @@ -86,12 +88,16 @@ def record_history(self,
timestamp: dt.datetime):
"""record all history for a given clocked object"""
state = attributes.instance_state(clocked)
vclock_history = attributes.get_history(clocked, 'vclock')
try:
new_tick = state.dict['vclock']
except KeyError:
# TODO understand why this is necessary
new_tick = getattr(clocked, 'vclock')

is_strict_mode = get_session_metadata(session).get('strict_mode', False)
is_vclock_unchanged = vclock_history.unchanged and new_tick == vclock_history.unchanged[0]

new_clock = self.make_clock(timestamp, new_tick)
attr = {'entity': clocked}

Expand All @@ -109,6 +115,10 @@ def record_history(self,
changes = attributes.get_history(clocked, prop.key)

if changes.added:
if is_strict_mode:
assert not is_vclock_unchanged, \
'flush() has triggered for a changed temporalized property outside of a clock tick'

# Cap previous history row if exists
if sa.inspect(clocked).identity is not None:
# but only if it already exists!!
Expand Down
24 changes: 24 additions & 0 deletions temporal_sqlalchemy/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import sqlalchemy.orm as orm

TEMPORAL_METADATA_KEY = '__temporal'

__all__ = [
'get_session_metadata',
'set_session_metadata',
]


def set_session_metadata(session: orm.Session, metadata: dict):
if isinstance(session, orm.Session):
session.info[TEMPORAL_METADATA_KEY] = metadata
elif isinstance(session, orm.sessionmaker):
session.configure(info={TEMPORAL_METADATA_KEY: metadata})
else:
raise ValueError('Invalid session')


def get_session_metadata(session: orm.Session) -> dict:
"""
:return: metadata dictionary, or None if it was never installed
"""
return session.info.get(TEMPORAL_METADATA_KEY)
35 changes: 24 additions & 11 deletions temporal_sqlalchemy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import sqlalchemy.orm as orm

from temporal_sqlalchemy.bases import ClockedOption, Clocked

TEMPORAL_FLAG = '__temporal'
from temporal_sqlalchemy.metadata import (
get_session_metadata,
set_session_metadata
)


def _temporal_models(session: orm.Session) -> typing.Iterable[Clocked]:
Expand All @@ -25,18 +27,29 @@ def persist_history(session: orm.Session, flush_context, instances):
obj.temporal_options.record_history(obj, session, correlate_timestamp)


def temporal_session(session: typing.Union[orm.Session, orm.sessionmaker]) -> orm.Session:
if not is_temporal_session(session):
def temporal_session(session: typing.Union[orm.Session, orm.sessionmaker], strict_mode=False) -> orm.Session:
"""
Setup the session to track changes via temporal
:param session: SQLAlchemy ORM session to temporalize
:param strict_mode: if True, will raise exceptions when improper flush() calls are made (default is False)
:return: temporalized SQLALchemy ORM session
"""
temporal_metadata = {
'strict_mode': strict_mode
}

# defer listening to the flush hook until after we update the metadata
install_flush_hook = not is_temporal_session(session)

# update to the latest metadata
set_session_metadata(session, temporal_metadata)

if install_flush_hook:
event.listen(session, 'before_flush', persist_history)
if isinstance(session, orm.Session):
session.info[TEMPORAL_FLAG] = True
elif isinstance(session, orm.sessionmaker):
session.configure(info={TEMPORAL_FLAG: True})
else:
raise ValueError('Invalid session')

return session


def is_temporal_session(session: orm.Session) -> bool:
return isinstance(session, orm.Session) and session.info.get(TEMPORAL_FLAG)
return isinstance(session, orm.Session) and get_session_metadata(session) is not None
2 changes: 1 addition & 1 deletion temporal_sqlalchemy/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Version information."""
__version__ = '0.3.1'
__version__ = '0.3.2'
102 changes: 101 additions & 1 deletion tests/test_temporal_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import datetime
import re

import pytest
import psycopg2.extras as psql_extras
import sqlalchemy as sa

Expand Down Expand Up @@ -242,7 +244,7 @@ def test_edit_on_double_wrapped(self, session):
with t.clock_tick():
t.prop_a = 2
t.prop_b = 'bar'
double_wrapped_session.commit()
double_wrapped_session.commit()

history_tables = {
'prop_a': temporal.get_history_model(
Expand Down Expand Up @@ -294,3 +296,101 @@ def test_doesnt_duplicate_unnecessary_history(self, session):
recorded_history = clock_query.first()
assert 1 in recorded_history.vclock
assert getattr(t, attr) == getattr(recorded_history, attr)

@pytest.mark.parametrize('session_func_name', (
'flush',
'commit'
))
def test_disallow_flushes_within_clock_ticks_when_strict(self, session, session_func_name):
session = temporal.temporal_session(session, strict_mode=True)

t = models.SimpleTableTemporal(
prop_a=1,
prop_b='foo',
prop_c=datetime.datetime(2016, 5, 11,
tzinfo=datetime.timezone.utc))
session.add(t)
session.commit()

with t.clock_tick():
t.prop_a = 2

with pytest.raises(AssertionError) as excinfo:
eval('session.{func_name}()'.format(func_name=session_func_name))

assert re.match(
r'.*flush\(\) has triggered for a changed temporalized property outside of a clock tick.*',
str(excinfo)
)


@pytest.mark.parametrize('session_func_name', (
'flush',
'commit'
))
def test_allow_flushes_within_clock_ticks_when_strict_but_no_change(self, session, session_func_name):
session = temporal.temporal_session(session, strict_mode=True)

t = models.SimpleTableTemporal(
prop_a=1,
prop_b='foo',
prop_c=datetime.datetime(2016, 5, 11,
tzinfo=datetime.timezone.utc))
session.add(t)
session.commit()

with t.clock_tick():
t.prop_a = 1

eval('session.{func_name}()'.format(func_name=session_func_name))


@pytest.mark.parametrize('session_func_name', (
'flush',
'commit'
))
def test_disallow_flushes_on_changes_without_clock_ticks_when_strict(self, session, session_func_name):
session = temporal.temporal_session(session, strict_mode=True)

t = models.SimpleTableTemporal(
prop_a=1,
prop_b='foo',
prop_c=datetime.datetime(2016, 5, 11,
tzinfo=datetime.timezone.utc))
session.add(t)
session.commit()

# this change should have been done within a clock tick
t.prop_a = 2

with pytest.raises(AssertionError) as excinfo:
eval('session.{func_name}()'.format(func_name=session_func_name))

assert re.match(
r'.*flush\(\) has triggered for a changed temporalized property outside of a clock tick.*',
str(excinfo)
)

# TODO this test should be removed once strict flush() checking becomes the default behavior
@pytest.mark.parametrize('session_func_name', (
'flush',
'commit'
))
def test_allow_loose_flushes_when_not_strict(self, session, session_func_name):
t = models.SimpleTableTemporal(
prop_a=1,
prop_b='foo',
prop_c=datetime.datetime(2016, 5, 11,
tzinfo=datetime.timezone.utc))
session.add(t)
session.commit()

with t.clock_tick():
t.prop_a = 2

# this should succeed in non-strict mode
eval('session.{func_name}()'.format(func_name=session_func_name))

# this should also succeed in non-strict mode
t.prop_a = 3
eval('session.{func_name}()'.format(func_name=session_func_name))

0 comments on commit 3b44746

Please sign in to comment.