Skip to content

Commit

Permalink
Add doc root type iterator (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart authored Dec 7, 2023
1 parent 37ca4df commit a595a9c
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 8 deletions.
2 changes: 2 additions & 0 deletions python/pycrdt/_pycrdt.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class Doc:
"""Get the update from the given state to the current state."""
def apply_update(self, update: bytes) -> None:
"""Apply the update to the document."""
def roots(self, txn: Transaction) -> dict[str, Text | Array | Map]:
"""Get top-level (root) shared types available in current document."""
def observe(self, callback: Callable[[TransactionEvent], None]) -> int:
"""Subscribes a callback to be called with the shared document change event.
Returns a subscription ID that can be used to unsubscribe."""
Expand Down
2 changes: 0 additions & 2 deletions python/pycrdt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class BaseDoc:
_twin_doc: BaseDoc | None
_txn: Transaction | None
_Model: Any
_dict: dict[str, BaseType]

def __init__(
self,
Expand All @@ -36,7 +35,6 @@ def __init__(
self._doc = doc
self._txn = None
self._Model = Model
self._dict = {}


class BaseType(ABC):
Expand Down
31 changes: 26 additions & 5 deletions python/pycrdt/doc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable, cast
from typing import Callable, Type, cast

from ._pycrdt import Doc as _Doc
from ._pycrdt import SubdocsEvent, TransactionEvent
Expand All @@ -9,7 +9,6 @@


class Doc(BaseDoc):

def __init__(
self,
init: dict[str, BaseType] = {},
Expand Down Expand Up @@ -53,7 +52,7 @@ def apply_update(self, update: bytes) -> None:
try:
self._Model(**d)
except Exception as e:
self._twin_doc = Doc(self._dict)
self._twin_doc = Doc(dict(self))
raise e
self._doc.apply_update(update)

Expand All @@ -63,10 +62,32 @@ def __setitem__(self, key: str, value: BaseType) -> None:
integrated = value._get_or_insert(key, self)
prelim = value._integrate(self, integrated)
value._init(prelim)
self._dict[key] = value

def __getitem__(self, key: str) -> BaseType:
return self._dict[key]
return self._roots[key]

def __iter__(self):
return self.keys()

def keys(self):
return self._roots.keys()

def values(self):
return self._roots.values()

def items(self):
return self._roots.items()

@property
def _roots(self) -> dict[str, BaseType]:
with self.transaction() as txn:
assert txn._txn is not None
return {
key: cast(Type[BaseType], base_types[type(val)])(
_integrated=val, _doc=self
)
for key, val in self._doc.roots(txn._txn).items()
}

def observe(self, callback: Callable[[TransactionEvent], None]) -> int:
return self._doc.observe(callback)
Expand Down
14 changes: 13 additions & 1 deletion src/doc.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyLong, PyList};
use pyo3::types::{PyBytes, PyDict, PyLong, PyList};
use yrs::{
Doc as _Doc,
ReadTxn,
Expand All @@ -16,6 +16,7 @@ use crate::text::Text;
use crate::array::Array;
use crate::map::Map;
use crate::transaction::Transaction;
use crate::type_conversions::ToPython;


#[pyclass(unsendable)]
Expand Down Expand Up @@ -100,6 +101,17 @@ impl Doc {
Ok(())
}

fn roots(&self, py: Python<'_>, txn: &mut Transaction) -> PyObject {
let mut t0 = txn.transaction();
let t1 = t0.as_mut().unwrap();
let t = t1.as_ref();
let result = PyDict::new(py);
for (k, v) in t.root_refs() {
result.set_item(k, v.into_py(py)).unwrap();
}
result.into()
}

pub fn observe(&mut self, f: PyObject) -> PyResult<u32> {
let id: u32 = self.doc
.observe_transaction_cleanup(move |txn, event| {
Expand Down
25 changes: 25 additions & 0 deletions tests/test_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,28 @@ def test_client_id():
b = encode_client_id(client_id_bytes)

assert update[2 : 2 + len(b)] == b


def test_roots():
remote_doc = Doc(
{
"a": Text("foo"),
"b": Array([5, 2, 8]),
"c": Map({"k1": 1, "k2": 2}),
}
)
roots = dict(remote_doc)
assert str(roots["a"]) == "foo"
assert list(roots["b"]) == [5, 2, 8]
assert dict(roots["c"]) == {"k1": 1, "k2": 2}

# need to update to yrs v0.17.2
# see https://github.com/y-crdt/y-crdt/issues/364#issuecomment-1839791409

# local_doc = Doc()
# update = remote_doc.get_update()
# local_doc.apply_update(update)
# roots = dict(local_doc)
# assert str(roots["a"]) == "foo"
# assert list(roots["b"]) == [5, 2, 8]
# assert dict(roots["c"]) == {"k1": 1, "k2": 2}

0 comments on commit a595a9c

Please sign in to comment.