Skip to content

Commit

Permalink
Allow passing model for document update validation (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart authored Dec 6, 2023
1 parent eba6787 commit 886523b
Show file tree
Hide file tree
Showing 11 changed files with 206 additions and 42 deletions.
4 changes: 2 additions & 2 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,13 @@ Unregistering the callback is done with the same `unobserve` method.
### Document events

Observing changes made to a document is mostly meant to send the changes to another document, usually over the wire to a remote machine.
Changes can be serialized to binary by calling `get_update()` on the event:
Changes can be serialized to binary by getting the event's `update`:

```py
from pycrdt import TransactionEvent
def handle_doc_changes(event: TransactionEvent):
update: bytes = event.get_update()
update: bytes = event.update
# send binary update on the wire
doc.observe(handle_doc_changes)
Expand Down
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]

[project.optional-dependencies]
test = [
"pytest >=7.4.2,<8",
"y-py >=0.7.0a1,<0.8",
"mypy",
"pytest >=7.4.2,<8",
"y-py >=0.7.0a1,<0.8",
"pydantic >=2.5.2,<3",
"mypy",
]
docs = [ "mkdocs", "mkdocs-material" ]

Expand Down
7 changes: 7 additions & 0 deletions python/pycrdt/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ def __str__(self) -> str:
with self.doc.transaction() as txn:
return self.integrated.to_json(txn._txn)

def to_py(self) -> list | None:
if self._integrated is None:
return self._prelim
return list(self)

def observe(self, callback: Callable[[Any], None]) -> str:
_callback = partial(observe_callback, callback, self.doc)
return f"o_{self.integrated.observe(_callback)}"
Expand All @@ -193,7 +198,9 @@ def unobserve(self, subscription_id: str) -> None:

def observe_callback(callback: Callable[[Any], None], doc: Doc, event: Any):
_event = event_types[type(event)](event, doc)
doc._txn = ReadTransaction(doc=doc, _txn=event.transaction)
callback(_event)
doc._txn = None


def observe_deep_callback(callback: Callable[[Any], None], doc: Doc, events: list[Any]):
Expand Down
23 changes: 17 additions & 6 deletions python/pycrdt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ._pycrdt import Doc as _Doc
from ._pycrdt import Transaction as _Transaction
from .transaction import ReadTransaction, Transaction
from .transaction import Transaction

if TYPE_CHECKING:
from .doc import Doc
Expand All @@ -17,18 +17,26 @@

class BaseDoc:
_doc: _Doc
_twin_doc: BaseDoc | None
_txn: Transaction | None
_Model: Any
_dict: dict[str, BaseType]

def __init__(
self,
*,
client_id: int | None = None,
doc: _Doc | None = None,
Model=None,
**data,
) -> None:
super().__init__(**data)
if doc is None:
doc = _Doc(client_id)
self._doc = doc
self._txn = None
self._Model = Model
self._dict = {}


class BaseType(ABC):
Expand Down Expand Up @@ -56,6 +64,10 @@ def __init__(
self._prelim = init
self._integrated = None

@abstractmethod
def to_py(self) -> Any:
...

@abstractmethod
def _get_or_insert(self, name: str, doc: Doc) -> Any:
...
Expand Down Expand Up @@ -136,7 +148,7 @@ class BaseEvent:
def __init__(self, event: Any, doc: Doc):
slot: str
for slot in self.__slots__:
processed = process_event(getattr(event, slot), doc, event.transaction)
processed = process_event(getattr(event, slot), doc)
setattr(self, slot, processed)

def __str__(self):
Expand All @@ -148,13 +160,13 @@ def __str__(self):
return "{" + ret + "}"


def process_event(value: Any, doc: Doc, txn) -> Any:
def process_event(value: Any, doc: Doc) -> Any:
if isinstance(value, list):
for idx, val in enumerate(value):
value[idx] = process_event(val, doc, txn)
value[idx] = process_event(val, doc)
elif isinstance(value, dict):
for key, val in value.items():
value[key] = process_event(val, doc, txn)
value[key] = process_event(val, doc)
else:
val_type = type(value)
if val_type in base_types:
Expand All @@ -164,5 +176,4 @@ def process_event(value: Any, doc: Doc, txn) -> Any:
else:
base_type = cast(Type[BaseType], base_types[val_type])
value = base_type(_integrated=value, _doc=doc)
doc._txn = ReadTransaction(doc=doc, _txn=txn)
return value
21 changes: 19 additions & 2 deletions python/pycrdt/doc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable
from typing import Callable, cast

from ._pycrdt import Doc as _Doc
from ._pycrdt import SubdocsEvent, TransactionEvent
Expand All @@ -9,16 +9,20 @@


class Doc(BaseDoc):

def __init__(
self,
init: dict[str, BaseType] = {},
*,
client_id: int | None = None,
doc: _Doc | None = None,
Model=None,
) -> None:
super().__init__(client_id=client_id, doc=doc)
super().__init__(client_id=client_id, doc=doc, Model=Model)
for k, v in init.items():
self[k] = v
if Model is not None:
self._twin_doc = Doc(init)

@property
def guid(self) -> int:
Expand All @@ -42,6 +46,15 @@ def get_update(self, state: bytes | None = None) -> bytes:
return self._doc.get_update(state)

def apply_update(self, update: bytes) -> None:
if self._Model is not None:
twin_doc = cast(Doc, self._twin_doc)
twin_doc.apply_update(update)
d = {k: twin_doc[k].to_py() for k in self._Model.model_fields}
try:
self._Model(**d)
except Exception as e:
self._twin_doc = Doc(self._dict)
raise e
self._doc.apply_update(update)

def __setitem__(self, key: str, value: BaseType) -> None:
Expand All @@ -50,6 +63,10 @@ def __setitem__(self, key: str, value: BaseType) -> None:
integrated = value._get_or_insert(key, self)
prelim = value._integrate(self, integrated)
value._init(prelim)
self._dict[key] = value

def __getitem__(self, key: str) -> BaseType:
return self._dict[key]

def observe(self, callback: Callable[[TransactionEvent], None]) -> int:
return self._doc.observe(callback)
Expand Down
7 changes: 7 additions & 0 deletions python/pycrdt/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def __str__(self) -> str:
with self.doc.transaction() as txn:
return self.integrated.to_json(txn._txn)

def to_py(self) -> dict | None:
if self._integrated is None:
return self._prelim
return dict(self)

def __delitem__(self, key: str) -> None:
if not isinstance(key, str):
raise RuntimeError("Key must be of type string")
Expand Down Expand Up @@ -147,7 +152,9 @@ def unobserve(self, subscription_id: str) -> None:

def observe_callback(callback: Callable[[Any], None], doc: Doc, event: Any):
_event = event_types[type(event)](event, doc)
doc._txn = ReadTransaction(doc=doc, _txn=event.transaction)
callback(_event)
doc._txn = None


def observe_deep_callback(callback: Callable[[Any], None], doc: Doc, events: list[Any]):
Expand Down
15 changes: 14 additions & 1 deletion python/pycrdt/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def __str__(self) -> str:
with self.doc.transaction() as txn:
return self.integrated.get_string(txn._txn)

def to_py(self) -> str | None:
if self._integrated is None:
return self._prelim
return str(self)

def __iadd__(self, value: str) -> Text:
with self.doc.transaction() as txn:
if isinstance(txn, ReadTransaction):
Expand Down Expand Up @@ -89,7 +94,13 @@ def __setitem__(self, key: int | slice, value: str) -> None:
"Read-only transaction cannot be used to modify document structure"
)
if isinstance(key, int):
raise RuntimeError("Single item assignment not supported")
value_len = len(value)
if value_len != 1:
raise RuntimeError(
f"Single item assigned value must have a length of 1, not {value_len}"
)
del self[key]
self.integrated.insert(txn._txn, key, value)
elif isinstance(key, slice):
if key.step is not None:
raise RuntimeError("Step not supported")
Expand Down Expand Up @@ -118,7 +129,9 @@ def unobserve(self, subscription_id: str) -> None:

def observe_callback(callback: Callable[[Any], None], doc: Doc, event: Any):
_event = event_types[type(event)](event, doc)
doc._txn = ReadTransaction(doc=doc, _txn=event.transaction)
callback(_event)
doc._txn = None


class TextEvent(BaseEvent):
Expand Down
8 changes: 7 additions & 1 deletion python/pycrdt/transaction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from types import TracebackType
from typing import TYPE_CHECKING

from ._pycrdt import Transaction as _Transaction
Expand All @@ -25,7 +26,12 @@ def __enter__(self) -> Transaction:
self._doc._txn = self
return self

def __exit__(self, exc_type, exc_value, exc_tb) -> None:
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self._nb -= 1
# only drop the transaction when exiting root context manager
# since nested transactions reuse the root transaction
Expand Down
Loading

0 comments on commit 886523b

Please sign in to comment.