From a595a9cd73dbb824bb654b5fb98f325f84fcd919 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 7 Dec 2023 11:18:48 +0100 Subject: [PATCH] Add doc root type iterator (#39) --- python/pycrdt/_pycrdt.pyi | 2 ++ python/pycrdt/base.py | 2 -- python/pycrdt/doc.py | 31 ++++++++++++++++++++++++++----- src/doc.rs | 14 +++++++++++++- tests/test_doc.py | 25 +++++++++++++++++++++++++ 5 files changed, 66 insertions(+), 8 deletions(-) diff --git a/python/pycrdt/_pycrdt.pyi b/python/pycrdt/_pycrdt.pyi index 31acb3a..e7caff2 100644 --- a/python/pycrdt/_pycrdt.pyi +++ b/python/pycrdt/_pycrdt.pyi @@ -24,6 +24,8 @@ class Doc: """Get the update from the given state to the current state.""" def apply_update(self, update: bytes) -> None: """Apply the update to the document.""" + def roots(self, txn: Transaction) -> dict[str, Text | Array | Map]: + """Get top-level (root) shared types available in current document.""" def observe(self, callback: Callable[[TransactionEvent], None]) -> int: """Subscribes a callback to be called with the shared document change event. Returns a subscription ID that can be used to unsubscribe.""" diff --git a/python/pycrdt/base.py b/python/pycrdt/base.py index 9d95314..c7b5657 100644 --- a/python/pycrdt/base.py +++ b/python/pycrdt/base.py @@ -20,7 +20,6 @@ class BaseDoc: _twin_doc: BaseDoc | None _txn: Transaction | None _Model: Any - _dict: dict[str, BaseType] def __init__( self, @@ -36,7 +35,6 @@ def __init__( self._doc = doc self._txn = None self._Model = Model - self._dict = {} class BaseType(ABC): diff --git a/python/pycrdt/doc.py b/python/pycrdt/doc.py index f5f2ad0..640527d 100644 --- a/python/pycrdt/doc.py +++ b/python/pycrdt/doc.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable, cast +from typing import Callable, Type, cast from ._pycrdt import Doc as _Doc from ._pycrdt import SubdocsEvent, TransactionEvent @@ -9,7 +9,6 @@ class Doc(BaseDoc): - def __init__( self, init: dict[str, BaseType] = {}, @@ -53,7 +52,7 @@ def apply_update(self, update: bytes) -> None: try: self._Model(**d) except Exception as e: - self._twin_doc = Doc(self._dict) + self._twin_doc = Doc(dict(self)) raise e self._doc.apply_update(update) @@ -63,10 +62,32 @@ def __setitem__(self, key: str, value: BaseType) -> None: integrated = value._get_or_insert(key, self) prelim = value._integrate(self, integrated) value._init(prelim) - self._dict[key] = value def __getitem__(self, key: str) -> BaseType: - return self._dict[key] + return self._roots[key] + + def __iter__(self): + return self.keys() + + def keys(self): + return self._roots.keys() + + def values(self): + return self._roots.values() + + def items(self): + return self._roots.items() + + @property + def _roots(self) -> dict[str, BaseType]: + with self.transaction() as txn: + assert txn._txn is not None + return { + key: cast(Type[BaseType], base_types[type(val)])( + _integrated=val, _doc=self + ) + for key, val in self._doc.roots(txn._txn).items() + } def observe(self, callback: Callable[[TransactionEvent], None]) -> int: return self._doc.observe(callback) diff --git a/src/doc.rs b/src/doc.rs index 210990c..0c560df 100644 --- a/src/doc.rs +++ b/src/doc.rs @@ -1,5 +1,5 @@ use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyLong, PyList}; +use pyo3::types::{PyBytes, PyDict, PyLong, PyList}; use yrs::{ Doc as _Doc, ReadTxn, @@ -16,6 +16,7 @@ use crate::text::Text; use crate::array::Array; use crate::map::Map; use crate::transaction::Transaction; +use crate::type_conversions::ToPython; #[pyclass(unsendable)] @@ -100,6 +101,17 @@ impl Doc { Ok(()) } + fn roots(&self, py: Python<'_>, txn: &mut Transaction) -> PyObject { + let mut t0 = txn.transaction(); + let t1 = t0.as_mut().unwrap(); + let t = t1.as_ref(); + let result = PyDict::new(py); + for (k, v) in t.root_refs() { + result.set_item(k, v.into_py(py)).unwrap(); + } + result.into() + } + pub fn observe(&mut self, f: PyObject) -> PyResult { let id: u32 = self.doc .observe_transaction_cleanup(move |txn, event| { diff --git a/tests/test_doc.py b/tests/test_doc.py index 91b3c6d..f03af73 100644 --- a/tests/test_doc.py +++ b/tests/test_doc.py @@ -109,3 +109,28 @@ def test_client_id(): b = encode_client_id(client_id_bytes) assert update[2 : 2 + len(b)] == b + + +def test_roots(): + remote_doc = Doc( + { + "a": Text("foo"), + "b": Array([5, 2, 8]), + "c": Map({"k1": 1, "k2": 2}), + } + ) + roots = dict(remote_doc) + assert str(roots["a"]) == "foo" + assert list(roots["b"]) == [5, 2, 8] + assert dict(roots["c"]) == {"k1": 1, "k2": 2} + + # need to update to yrs v0.17.2 + # see https://github.com/y-crdt/y-crdt/issues/364#issuecomment-1839791409 + + # local_doc = Doc() + # update = remote_doc.get_update() + # local_doc.apply_update(update) + # roots = dict(local_doc) + # assert str(roots["a"]) == "foo" + # assert list(roots["b"]) == [5, 2, 8] + # assert dict(roots["c"]) == {"k1": 1, "k2": 2}