diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index c4d0f9dfc..bd6befce7 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -172,6 +172,22 @@ async def set(self, key: str, value: Buffer) -> None: """ ... + async def set_if_not_exists(self, key: str, value: Buffer) -> None: + """ + Store a key to ``value`` if the key is not already present. + + Parameters + ----------- + key : str + value : Buffer + """ + # Note for implementers: the default implementation provided here + # is not safe for concurrent writers. There's a race condition between + # the `exists` check and the `set` where another writer could set some + # value at `key` or delete `key`. + if not await self.exists(key): + await self.set(key, value) + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: """ Insert multiple (key, value) pairs into storage. @@ -297,6 +313,8 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) - async def delete(self) -> None: ... + async def set_if_not_exists(self, default: Buffer) -> None: ... + async def set_or_delete(byte_setter: ByteSetter, value: Buffer | None) -> None: if value is None: diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index c818e3c66..2181e9eb7 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -97,6 +97,9 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) - async def delete(self) -> None: del self.shard_dict[self.chunk_coords] + async def set_if_not_exists(self, default: Buffer) -> None: + self.shard_dict.setdefault(self.chunk_coords, default) + class _ShardIndex(NamedTuple): # dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 25af05508..e1de15c74 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -13,7 +13,12 @@ from zarr.codecs import BytesCodec from zarr.codecs._v2 import V2Compressor, V2Filters from zarr.core.attributes import Attributes -from zarr.core.buffer import BufferPrototype, NDArrayLike, NDBuffer, default_buffer_prototype +from zarr.core.buffer import ( + BufferPrototype, + NDArrayLike, + NDBuffer, + default_buffer_prototype, +) from zarr.core.chunk_grids import RegularChunkGrid, _guess_chunks from zarr.core.chunk_key_encodings import ( ChunkKeyEncoding, @@ -71,6 +76,7 @@ from collections.abc import Iterable, Iterator, Sequence from zarr.abc.codec import Codec, CodecPipeline + from zarr.core.group import AsyncGroup from zarr.core.metadata.common import ArrayMetadata # Array and AsyncArray are defined in the base ``zarr`` namespace @@ -337,7 +343,7 @@ async def _create_v3( ) array = cls(metadata=metadata, store_path=store_path) - await array._save_metadata(metadata) + await array._save_metadata(metadata, ensure_parents=True) return array @classmethod @@ -376,7 +382,7 @@ async def _create_v2( attributes=attributes, ) array = cls(metadata=metadata, store_path=store_path) - await array._save_metadata(metadata) + await array._save_metadata(metadata, ensure_parents=True) return array @classmethod @@ -621,9 +627,24 @@ async def getitem( ) return await self._get_selection(indexer, prototype=prototype) - async def _save_metadata(self, metadata: ArrayMetadata) -> None: + async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None: to_save = metadata.to_buffer_dict(default_buffer_prototype()) awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()] + + if ensure_parents: + # To enable zarr.create(store, path="a/b/c"), we need to create all the intermediate groups. + parents = _build_parents(self) + + for parent in parents: + awaitables.extend( + [ + (parent.store_path / key).set_if_not_exists(value) + for key, value in parent.metadata.to_buffer_dict( + default_buffer_prototype() + ).items() + ] + ) + await gather(*awaitables) async def _set_selection( @@ -2354,3 +2375,21 @@ def chunks_initialized(array: Array | AsyncArray) -> tuple[str, ...]: out.append(chunk_key) return tuple(out) + + +def _build_parents(node: AsyncArray | AsyncGroup) -> list[AsyncGroup]: + from zarr.core.group import AsyncGroup, GroupMetadata + + required_parts = node.store_path.path.split("/")[:-1] + parents = [] + + for i, part in enumerate(required_parts): + path = "/".join(required_parts[:i] + [part]) + parents.append( + AsyncGroup( + metadata=GroupMetadata(zarr_format=node.metadata.zarr_format), + store_path=StorePath(store=node.store_path.store, path=path), + ) + ) + + return parents diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index d7ad960b1..79d03d3fc 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -13,7 +13,7 @@ import zarr.api.asynchronous as async_api from zarr.abc.metadata import Metadata from zarr.abc.store import Store, set_or_delete -from zarr.core.array import Array, AsyncArray +from zarr.core.array import Array, AsyncArray, _build_parents from zarr.core.attributes import Attributes from zarr.core.buffer import default_buffer_prototype from zarr.core.common import ( @@ -144,7 +144,7 @@ async def from_store( metadata=GroupMetadata(attributes=attributes, zarr_format=zarr_format), store_path=store_path, ) - await group._save_metadata() + await group._save_metadata(ensure_parents=True) return group @classmethod @@ -279,9 +279,22 @@ async def delitem(self, key: str) -> None: else: raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}") - async def _save_metadata(self) -> None: + async def _save_metadata(self, ensure_parents: bool = False) -> None: to_save = self.metadata.to_buffer_dict(default_buffer_prototype()) awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()] + + if ensure_parents: + parents = _build_parents(self) + for parent in parents: + awaitables.extend( + [ + (parent.store_path / key).set_if_not_exists(value) + for key, value in parent.metadata.to_buffer_dict( + default_buffer_prototype() + ).items() + ] + ) + await asyncio.gather(*awaitables) @property diff --git a/src/zarr/store/common.py b/src/zarr/store/common.py index c4d837e74..ea0edbe5e 100644 --- a/src/zarr/store/common.py +++ b/src/zarr/store/common.py @@ -51,6 +51,9 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) - async def delete(self) -> None: await self.store.delete(self.path) + async def set_if_not_exists(self, default: Buffer) -> None: + await self.store.set_if_not_exists(self.path, default) + async def exists(self) -> bool: return await self.store.exists(self.path) diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index f1bce769d..23a87ea49 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -60,6 +60,7 @@ def _put( path: Path, value: Buffer, start: int | None = None, + exclusive: bool = False, ) -> int | None: path.parent.mkdir(parents=True, exist_ok=True) if start is not None: @@ -68,7 +69,13 @@ def _put( f.write(value.as_numpy_array().tobytes()) return None else: - return path.write_bytes(value.as_numpy_array().tobytes()) + view = memoryview(value.as_numpy_array().tobytes()) + if exclusive: + mode = "xb" + else: + mode = "wb" + with path.open(mode=mode) as f: + return f.write(view) class LocalStore(Store): @@ -152,6 +159,15 @@ async def get_partial_values( return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit async def set(self, key: str, value: Buffer) -> None: + return await self._set(key, value) + + async def set_if_not_exists(self, key: str, value: Buffer) -> None: + try: + return await self._set(key, value, exclusive=True) + except FileExistsError: + pass + + async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: if not self._is_open: await self._open() self._check_writable() @@ -159,7 +175,7 @@ async def set(self, key: str, value: Buffer) -> None: if not isinstance(value, Buffer): raise TypeError("LocalStore.set(): `value` must a Buffer instance") path = self.root / key - await to_thread(_put, path, value) + await to_thread(_put, path, value, start=None, exclusive=exclusive) async def set_partial_values( self, key_start_values: Iterable[tuple[str, int, bytes | bytearray | memoryview]] diff --git a/src/zarr/store/logging.py b/src/zarr/store/logging.py index ee0614172..3a4ae26c3 100644 --- a/src/zarr/store/logging.py +++ b/src/zarr/store/logging.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING from zarr.abc.store import AccessMode, ByteRangeRequest, Store +from zarr.core.buffer import Buffer if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator, Iterable @@ -138,6 +139,10 @@ async def set(self, key: str, value: Buffer) -> None: with self.log(): return await self._store.set(key=key, value=value) + async def set_if_not_exists(self, key: str, default: Buffer) -> None: + with self.log(): + return await self._store.set_if_not_exists(key=key, value=default) + async def delete(self, key: str) -> None: with self.log(): return await self._store.delete(key=key) diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index ee4107b0a..d5294c9d2 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -85,9 +85,8 @@ async def exists(self, key: str) -> bool: return key in self._store_dict async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: - if not self._is_open: - await self._open() self._check_writable() + await self._ensure_open() assert isinstance(key, str) if not isinstance(value, Buffer): raise TypeError(f"Expected Buffer. Got {type(value)}.") @@ -99,6 +98,11 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None else: self._store_dict[key] = value + async def set_if_not_exists(self, key: str, default: Buffer) -> None: + self._check_writable() + await self._ensure_open() + self._store_dict.setdefault(key, default) + async def delete(self, key: str) -> None: self._check_writable() try: diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 284cd8d77..6cc631d3b 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -5,6 +5,7 @@ import fsspec from zarr.abc.store import ByteRangeRequest, Store +from zarr.core.buffer import Buffer from zarr.store.common import _dereference_path if TYPE_CHECKING: diff --git a/src/zarr/store/zip.py b/src/zarr/store/zip.py index f9c458709..82ce7d024 100644 --- a/src/zarr/store/zip.py +++ b/src/zarr/store/zip.py @@ -188,6 +188,13 @@ async def set(self, key: str, value: Buffer) -> None: async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, bytes]]) -> None: raise NotImplementedError + async def set_if_not_exists(self, key: str, default: Buffer) -> None: + self._check_writable() + with self._lock: + members = self._zf.namelist() + if key not in members: + self._set(key, default) + async def delete(self, key: str) -> None: raise NotImplementedError diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 5c7500734..ed49936d1 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -273,3 +273,19 @@ async def test_list_dir(self, store: S) -> None: keys_observed = await _collect_aiterator(store.list_dir(root + "/")) assert sorted(keys_expected) == sorted(keys_observed) + + async def test_set_if_not_exists(self, store: S) -> None: + key = "k" + data_buf = self.buffer_cls.from_bytes(b"0000") + self.set(store, key, data_buf) + + new = self.buffer_cls.from_bytes(b"1111") + await store.set_if_not_exists("k", new) # no error + + result = await store.get(key, default_buffer_prototype()) + assert result == data_buf + + await store.set_if_not_exists("k2", new) # no error + + result = await store.get("k2", default_buffer_prototype()) + assert result == new diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index 6224bc39e..5de4a4d12 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -5,11 +5,13 @@ import numpy as np import pytest +import zarr.api.asynchronous from zarr import Array, AsyncArray, Group from zarr.codecs.bytes import BytesCodec from zarr.core.array import chunks_initialized from zarr.core.buffer.cpu import NDBuffer from zarr.core.common import JSON, ZarrFormat +from zarr.core.group import AsyncGroup from zarr.core.indexing import ceildiv from zarr.core.sync import sync from zarr.errors import ContainsArrayError, ContainsGroupError @@ -66,6 +68,47 @@ def test_array_creation_existing_node( ) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +@pytest.mark.parametrize("zarr_format", [2, 3]) +async def test_create_creates_parents( + store: LocalStore | MemoryStore, zarr_format: ZarrFormat +) -> None: + # prepare a root node, with some data set + await zarr.api.asynchronous.open_group( + store=store, path="a", zarr_format=zarr_format, attributes={"key": "value"} + ) + + # create a child node with a couple intermediates + await zarr.api.asynchronous.create( + shape=(2, 2), store=store, path="a/b/c/d", zarr_format=zarr_format + ) + parts = ["a", "a/b", "a/b/c"] + + if zarr_format == 2: + files = [".zattrs", ".zgroup"] + else: + files = ["zarr.json"] + + expected = [f"{part}/{file}" for file in files for part in parts] + + if zarr_format == 2: + expected.append("a/b/c/d/.zarray") + expected.append("a/b/c/d/.zattrs") + else: + expected.append("a/b/c/d/zarr.json") + + expected = sorted(expected) + + result = sorted([x async for x in store.list_prefix("")]) + + assert result == expected + + paths = ["a", "a/b", "a/b/c"] + for path in paths: + g = await zarr.api.asynchronous.open_group(store=store, path=path) + assert isinstance(g, AsyncGroup) + + @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) @pytest.mark.parametrize("zarr_format", [2, 3]) def test_array_name_properties_no_group( diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index 8c6464d3b..cb1681daa 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -7,6 +7,7 @@ import pytest import zarr +import zarr.api.asynchronous from zarr import Array, AsyncArray, AsyncGroup, Group from zarr.abc.store import Store from zarr.core.buffer import default_buffer_prototype @@ -56,6 +57,46 @@ def test_group_init(store: Store, zarr_format: ZarrFormat) -> None: assert group._async_group == agroup +async def test_create_creates_parents(store: Store, zarr_format: ZarrFormat) -> None: + # prepare a root node, with some data set + await zarr.api.asynchronous.open_group( + store=store, path="a", zarr_format=zarr_format, attributes={"key": "value"} + ) + # create a child node with a couple intermediates + await zarr.api.asynchronous.open_group(store=store, path="a/b/c/d", zarr_format=zarr_format) + parts = ["a", "a/b", "a/b/c"] + + if zarr_format == 2: + files = [".zattrs", ".zgroup"] + else: + files = ["zarr.json"] + + expected = [f"{part}/{file}" for file in files for part in parts] + + if zarr_format == 2: + expected.append("a/b/c/d/.zgroup") + expected.append("a/b/c/d/.zattrs") + else: + expected.append("a/b/c/d/zarr.json") + + expected = sorted(expected) + + result = sorted([x async for x in store.list_prefix("")]) + + assert result == expected + + paths = ["a", "a/b", "a/b/c"] + for path in paths: + g = await zarr.api.asynchronous.open_group(store=store, path=path) + assert isinstance(g, AsyncGroup) + + if path == "a": + # ensure we didn't overwrite the root attributes + assert g.attrs == {"key": "value"} + else: + assert g.attrs == {} + + def test_group_name_properties(store: Store, zarr_format: ZarrFormat) -> None: """ Test basic properties of groups diff --git a/tests/v3/test_store/test_stateful_store.py b/tests/v3/test_store/test_stateful_store.py index 1a43b55fe..efa1953a5 100644 --- a/tests/v3/test_store/test_stateful_store.py +++ b/tests/v3/test_store/test_stateful_store.py @@ -239,7 +239,7 @@ def check_zarr_keys(self) -> None: def test_zarr_hierarchy(sync_store: Store) -> None: - def mk_test_instance_sync(): + def mk_test_instance_sync() -> None: return ZarrStoreStateMachine(sync_store) if isinstance(sync_store, ZipStore):