Skip to content

Add transaction to event slots #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pycrdt/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ class ArrayEvent(BaseEvent):
path (list[int | str]): A list with the indices pointing to the array that was changed.
"""

__slots__ = "target", "delta", "path"
__slots__ = "target", "delta", "path", "transaction"


class ArrayIterator:
Expand Down
58 changes: 55 additions & 3 deletions python/pycrdt/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import threading
from abc import ABC, abstractmethod
from collections import Counter
from functools import lru_cache, partial
from inspect import signature
from typing import TYPE_CHECKING, Any, Callable, Type, cast, get_type_hints
Expand All @@ -11,7 +12,7 @@
from ._pycrdt import Doc as _Doc
from ._pycrdt import Subscription
from ._pycrdt import Transaction as _Transaction
from ._transaction import ReadTransaction, Transaction
from ._transaction import BaseTransaction, ReadTransaction, Transaction

if TYPE_CHECKING:
from ._doc import Doc
Expand All @@ -26,6 +27,55 @@ def forbid_read_transaction(txn: Transaction):
raise RuntimeError("Read-only transaction cannot be used to modify document structure")


def hash_origin(origin: Any) -> int:
try:
return hash(origin)
except Exception:
raise TypeError("Origin must be hashable")


class Origins:
"""Mapping of origins to their ID."""

def __init__(self) -> None:
self._map: dict[int, Any] = {}
self._counter = Counter[int]()

def __len__(self) -> int:
"""Return the number of origins."""
return len(self._map)

def add(self, value: Any) -> int:
"""Add a new origin.

Args:
value: Origin
Returns:
The origin ID.
"""
key = hash_origin(value)
if key not in self._map:
self._map[key] = value
self._counter.update([key])

return key

def get(self, key: int) -> Any | None:
"""Get the origin by its ID.

Returns:
The origin or None if not found.
"""
return self._map.get(key)

def remove(self, key: int) -> None:
"""Remove the origin by its ID."""
if key in self._map:
self._counter.subtract([key])
if self._counter[key] == 0:
del self._map[key]


class BaseDoc:
_doc: _Doc
_twin_doc: BaseDoc | None
Expand All @@ -35,7 +85,7 @@ class BaseDoc:
_allow_multithreading: bool
_Model: Any
_subscriptions: list[Subscription]
_origins: dict[int, Any]
_origins: Origins

def __init__(
self,
Expand All @@ -55,7 +105,7 @@ def __init__(
self._txn_async_lock = anyio.Lock()
self._Model = Model
self._subscriptions = []
self._origins = {}
self._origins = Origins()
self._allow_multithreading = allow_multithreading


Expand Down Expand Up @@ -237,6 +287,8 @@ def process_event(value: Any, doc: Doc) -> Any:
elif isinstance(value, dict):
for key, val in value.items():
value[key] = process_event(val, doc)
elif isinstance(value, _Transaction):
value = BaseTransaction(doc, doc._origins.get(value.origin()))
else:
val_type = type(value)
if val_type in base_types:
Expand Down
2 changes: 1 addition & 1 deletion python/pycrdt/_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class MapEvent(BaseEvent):
path (list[int | str]): A list with the indices pointing to the map that was changed.
"""

__slots__ = "target", "keys", "path"
__slots__ = "target", "keys", "path", "transaction"


base_types[_Map] = Map
Expand Down
2 changes: 1 addition & 1 deletion python/pycrdt/_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ class TextEvent(BaseEvent):
path (list[int | str]): A list with the indices pointing to the text that was changed.
"""

__slots__ = "target", "delta", "path"
__slots__ = "target", "delta", "path", "transaction"


base_types[_Text] = Text
Expand Down
56 changes: 35 additions & 21 deletions python/pycrdt/_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from functools import partial
from types import TracebackType
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from anyio import to_thread

Expand All @@ -12,7 +12,38 @@
from ._doc import Doc


class Transaction:
class BaseTransaction:
"""
Base class for read-write and read-only transactions.

It allows to persist the origin of the transaction outside of its context.
"""

_doc: Doc
_origin_hash: int | None

def __init__(
self,
doc: Doc,
origin: Any = None,
) -> None:
self._doc = doc
if origin is None:
self._origin_hash = None
else:
self._origin_hash = doc._origins.add(origin)

def __del__(self) -> None:
if getattr(self, "_origin_hash", None) is not None:
self._doc._origins.remove(cast(int, self._origin_hash))
Comment on lines +37 to +38
Copy link
Collaborator

@davidbrochart davidbrochart Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why using getattr, since _origin_hash is set in __init__?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why, but I got hit by missing _origin_hash in some cases. It may have to do with some inheritance pattern.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it's something to investigate, instead of using getattr because we don't know why it doesn't work without. Otherwise that's going to lead to code that we don't understand and don't know how to maintain.


@property
def origin(self) -> Any:
"""The origin of the transaction."""
return None if self._origin_hash is None else self._doc._origins.get(self._origin_hash)


class Transaction(BaseTransaction):
"""
A read-write transaction that can be used to mutate a document.
It must be used with a context manager (see [Doc.transaction()][pycrdt.Doc.transaction]):
Expand All @@ -22,10 +53,8 @@ class Transaction:
```
"""

_doc: Doc
_txn: _Transaction | None
_leases: int
_origin_hash: int | None
_timeout: float

def __init__(
Expand All @@ -36,14 +65,9 @@ def __init__(
origin: Any = None,
timeout: float | None = None,
) -> None:
self._doc = doc
super().__init__(doc, origin)
self._txn = _txn
self._leases = 0
if origin is None:
self._origin_hash = None
else:
self._origin_hash = hash_origin(origin)
doc._origins[self._origin_hash] = origin
self._timeout = -1 if timeout is None else timeout

def __enter__(self, _acquire_transaction: bool = True) -> Transaction:
Expand Down Expand Up @@ -75,9 +99,6 @@ def __exit__(
assert self._txn is not None
if not isinstance(self, ReadTransaction):
self._txn.commit()
origin_hash = self._txn.origin()
if origin_hash is not None:
del self._doc._origins[origin_hash]
Comment on lines -78 to -80
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a very deterministic behavior, why did you move that part to the transaction's __del__ method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you leave the logic there, then the tests won't work. Because when requesting the transaction origin from the event object the entry in the doc._origins will be gone.

The usage of __del__ is not less deterministic (at least in CPython) and it is more consistent as Transaction.origin will always be correct while the transaction object is reference somewhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you leave the logic there, then the tests won't work. Because when requesting the transaction origin from the event object the entry in the doc._origins will be gone.

I think it's just a matter of dropping the transaction after removing the origin from the doc.

The usage of __del__ is not less deterministic

Yes it is, because now tests behave differently if running on CPython or PyPy.

if self._doc._allow_multithreading:
self._doc._txn_lock.release()
self._txn.drop()
Expand All @@ -99,7 +120,7 @@ def origin(self) -> Any:
if origin_hash is None:
return None

return self._doc._origins[origin_hash]
return self._doc._origins.get(origin_hash)


class NewTransaction(Transaction):
Expand Down Expand Up @@ -141,10 +162,3 @@ class ReadTransaction(Transaction):
"""
A read-only transaction that cannot be used to mutate a document.
"""


def hash_origin(origin: Any) -> int:
try:
return hash(origin)
except Exception:
raise TypeError("Origin must be hashable")
3 changes: 1 addition & 2 deletions python/pycrdt/_undo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from time import time_ns
from typing import TYPE_CHECKING, Any, Callable

from ._base import BaseType
from ._base import BaseType, hash_origin
from ._pycrdt import (
StackItem,
)
from ._pycrdt import (
UndoManager as _UndoManager,
)
from ._transaction import hash_origin

if TYPE_CHECKING:
from ._doc import Doc
Expand Down
2 changes: 1 addition & 1 deletion python/pycrdt/_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def clear(self) -> None:


class XmlEvent(BaseEvent):
__slots__ = ["children_changed", "target", "path", "delta", "keys"]
__slots__ = ["children_changed", "target", "path", "delta", "keys", "transaction"]


class XmlAttributesView:
Expand Down
11 changes: 9 additions & 2 deletions tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,15 @@ def cb(events):
deep_events.append(events)

sid4 = array.observe_deep(cb)
array.append("bar")
assert str(deep_events[0][0]) == """{target: ["bar"], delta: [{'insert': ['bar']}], path: []}"""
origin = "test"
with doc.transaction(origin=origin):
array.append("bar")
assert deep_events[0]
event = deep_events[0][0]
assert str(event.target) == '["bar"]'
assert str(event.delta) == "[{'insert': ['bar']}]"
assert str(event.path) == "[]"
assert event.transaction.origin == origin
deep_events.clear()
array.unobserve(sid4)
array.append("baz")
Expand Down
22 changes: 13 additions & 9 deletions tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,14 @@ def test_observe():
events = []

sub = map0.observe(partial(callback, events))
map0["0"] = 0
assert (
str(events[0])
== """{target: {"0":0}, keys: {'0': {'action': 'add', 'newValue': 0.0}}, path: []}"""
)
origin = "test-map"
with doc.transaction(origin=origin):
map0["0"] = 0
event = events[0]
assert str(event.target) == '{"0":0}'
assert str(event.keys) == "{'0': {'action': 'add', 'newValue': 0.0}}"
assert event.path == []
assert event.transaction.origin == origin
events.clear()
map0.unobserve(sub)
map0["1"] = 1
Expand All @@ -155,10 +158,11 @@ def test_observe():
deep_events = []
sub = map1.observe_deep(partial(callback_deep, deep_events))
map1["1"] = 1
assert (
str(deep_events[0][0])
== """{target: {"1":1}, keys: {'1': {'action': 'add', 'newValue': 1.0}}, path: []}"""
)
event = deep_events[0][0]
assert str(event.target) == '{"1":1}'
assert str(event.keys) == "{'1': {'action': 'add', 'newValue': 1.0}}"
assert event.path == []
assert event.transaction.origin is None
deep_events.clear()
map1.unobserve(sub)
map1["0"] = 0
Expand Down
19 changes: 17 additions & 2 deletions tests/test_text.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import pytest
from pycrdt import Array, Doc, Map, Text

Expand Down Expand Up @@ -161,5 +163,18 @@ def callback(event):
events.append(event)

sub = text.observe(callback) # noqa: F841
text += hello
assert str(events[0]) == """{target: Hello, delta: [{'insert': 'Hello'}], path: []}"""
origin = "test-text"
with doc.transaction(origin=origin):
text += hello
event = events[0]
assert str(event.target) == hello
assert str(event.delta) == f"[{{'insert': '{hello}'}}]"
assert event.path == []
assert event.transaction.origin == origin
assert (
re.match(
r"{target: Hello, delta: \[{'insert': 'Hello'}\], path: \[\], transaction: <[\w\.]+ object at 0x[a-fA-F\d]+>}", # noqa E501
str(event),
)
is not None
), "Event string representation"
19 changes: 13 additions & 6 deletions tests/test_transaction.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import platform
import sys
import time
from functools import partial

import pytest
from anyio import create_task_group, fail_after, sleep, to_thread
from pycrdt import Array, Doc, Map, Text
from pycrdt._base import hash_origin

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup # pragma: no cover

pytestmark = pytest.mark.anyio


IS_CPYTHON = platform.python_implementation() == "CPython"


def test_callback_transaction():
text = Text()
array = Array()
Expand Down Expand Up @@ -91,13 +96,15 @@ def callback(event, txn):
assert txn0.origin == origin0
assert txn1.origin == origin0
assert len(doc0._origins) == 1
assert list(doc0._origins.values())[0] == origin0
assert doc0._origins == doc1._origins
hashed_origin = hash_origin(origin0)
assert doc0._origins.get(hashed_origin) == origin0
assert doc1._origins.get(hashed_origin) == origin0
assert len(doc0._origins) == 1
assert list(doc0._origins.values())[0] == origin0
assert len(doc1._origins) == 0
assert len(doc0._origins) == 0
assert len(doc1._origins) == 0
assert doc0._origins.get(hash_origin(origin0)) == origin0
del txn1
assert (doc1._origins.get(hashed_origin) is None) is (True if IS_CPYTHON else False)
del txn0
assert (doc1._origins.get(hashed_origin) is None) is (True if IS_CPYTHON else False)

with doc0.transaction(origin=123):
with doc0.transaction(origin=123):
Expand Down
Loading