From d8f9b019da8a586f645bf8ae04792650c88849ea Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 26 Sep 2024 10:41:59 -0500 Subject: [PATCH 1/8] Ensure parents are created when creating a node This updates our Array and Group creation methods to ensure that parents implicitly defined through a nested path are also created. To accomplish this semi-safely and efficiently, we require a new setdefulat method on the Store class. --- src/zarr/abc/store.py | 19 +++++++++++++++ src/zarr/codecs/sharding.py | 3 +++ src/zarr/core/array.py | 47 +++++++++++++++++++++++++++++++++---- src/zarr/core/group.py | 19 ++++++++++++--- src/zarr/store/common.py | 3 +++ src/zarr/store/local.py | 20 ++++++++++++++-- src/zarr/store/memory.py | 8 +++++-- src/zarr/store/remote.py | 6 +++++ src/zarr/store/zip.py | 7 ++++++ src/zarr/testing/store.py | 17 ++++++++++++++ tests/v3/test_array.py | 31 ++++++++++++++++++++++++ tests/v3/test_group.py | 25 ++++++++++++++++++++ 12 files changed, 194 insertions(+), 11 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 42eb18ce0..b41b5c47e 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -173,6 +173,16 @@ async def set(self, key: str, value: Buffer) -> None: """ ... + @abstractmethod + async def setdefault(self, key: str, default: Buffer) -> None: + """ + Store a key with a value of ``default`` if the key is not already present. + + Unlike MutableMapping.default, this method does not provide any way to + know whether ``default`` was actually set. + """ + ... + async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: """ Insert multiple (key, value) pairs into storage. @@ -298,9 +308,18 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) - async def delete(self) -> None: ... + async def setdefault(self, default: Buffer) -> None: ... + async def set_or_delete(byte_setter: ByteSetter, value: Buffer | None) -> None: if value is None: await byte_setter.delete() else: await byte_setter.set(value) + + +async def setdefault(byte_setter: ByteSetter, value: Buffer | None) -> None: + if value is None: + await byte_setter.delete() + else: + await byte_setter.set(value) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 2f8946e46..34c56df2e 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -98,6 +98,9 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) - async def delete(self) -> None: del self.shard_dict[self.chunk_coords] + async def setdefault(self, default: Buffer) -> None: + self.shard_dict.setdefault(self.chunk_coords, default) + class _ShardIndex(NamedTuple): # dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index fac0facd7..c8d657772 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -13,7 +13,12 @@ from zarr.codecs import BytesCodec from zarr.codecs._v2 import V2Compressor, V2Filters from zarr.core.attributes import Attributes -from zarr.core.buffer import BufferPrototype, NDArrayLike, NDBuffer, default_buffer_prototype +from zarr.core.buffer import ( + BufferPrototype, + NDArrayLike, + NDBuffer, + default_buffer_prototype, +) from zarr.core.chunk_grids import RegularChunkGrid, _guess_chunks from zarr.core.chunk_key_encodings import ( ChunkKeyEncoding, @@ -71,6 +76,7 @@ from collections.abc import Iterable, Iterator, Sequence from zarr.abc.codec import Codec, CodecPipeline + from zarr.core.group import AsyncGroup from zarr.core.metadata.common import ArrayMetadata # Array and AsyncArray are defined in the base ``zarr`` namespace @@ -276,7 +282,7 @@ async def _create_v3( ) array = cls(metadata=metadata, store_path=store_path) - await array._save_metadata(metadata) + await array._save_metadata(metadata, ensure_parents=True) return array @classmethod @@ -315,7 +321,7 @@ async def _create_v2( attributes=attributes, ) array = cls(metadata=metadata, store_path=store_path) - await array._save_metadata(metadata) + await array._save_metadata(metadata, ensure_parents=True) return array @classmethod @@ -603,9 +609,24 @@ async def getitem( ) return await self._get_selection(indexer, prototype=prototype) - async def _save_metadata(self, metadata: ArrayMetadata) -> None: + async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None: to_save = metadata.to_buffer_dict(default_buffer_prototype()) awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()] + + if ensure_parents: + # To enable zarr.create(store, path="a/b/c"), we need to create all the intermediates. + parents = _build_parents(self) + + for parent in parents: + awaitables.extend( + [ + (parent.store_path / key).setdefault(value) + for key, value in parent.metadata.to_buffer_dict( + default_buffer_prototype() + ).items() + ] + ) + await gather(*awaitables) async def _set_selection( @@ -2336,3 +2357,21 @@ def chunks_initialized(array: Array | AsyncArray) -> tuple[str, ...]: out.append(chunk_key) return tuple(out) + + +def _build_parents(node: AsyncArray | AsyncGroup) -> list[AsyncGroup]: + from zarr.core.group import AsyncGroup, GroupMetadata + + required_parts = node.store_path.path.split("/")[:-1] + parents = [] + + for i, part in enumerate(required_parts): + path = "/".join(required_parts[:i] + [part]) + parents.append( + AsyncGroup( + metadata=GroupMetadata(zarr_format=node.metadata.zarr_format), + store_path=StorePath(store=node.store_path.store, path=path), + ) + ) + + return parents diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index d7ad960b1..6db4bf704 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -13,7 +13,7 @@ import zarr.api.asynchronous as async_api from zarr.abc.metadata import Metadata from zarr.abc.store import Store, set_or_delete -from zarr.core.array import Array, AsyncArray +from zarr.core.array import Array, AsyncArray, _build_parents from zarr.core.attributes import Attributes from zarr.core.buffer import default_buffer_prototype from zarr.core.common import ( @@ -144,7 +144,7 @@ async def from_store( metadata=GroupMetadata(attributes=attributes, zarr_format=zarr_format), store_path=store_path, ) - await group._save_metadata() + await group._save_metadata(ensure_parents=True) return group @classmethod @@ -279,9 +279,22 @@ async def delitem(self, key: str) -> None: else: raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}") - async def _save_metadata(self) -> None: + async def _save_metadata(self, ensure_parents: bool = False) -> None: to_save = self.metadata.to_buffer_dict(default_buffer_prototype()) awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()] + + if ensure_parents: + parents = _build_parents(self) + for parent in parents: + awaitables.extend( + [ + (parent.store_path / key).setdefault(value) + for key, value in parent.metadata.to_buffer_dict( + default_buffer_prototype() + ).items() + ] + ) + await asyncio.gather(*awaitables) @property diff --git a/src/zarr/store/common.py b/src/zarr/store/common.py index f39edb19a..952dea601 100644 --- a/src/zarr/store/common.py +++ b/src/zarr/store/common.py @@ -51,6 +51,9 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) - async def delete(self) -> None: await self.store.delete(self.path) + async def setdefault(self, default: Buffer) -> None: + await self.store.setdefault(self.path, default) + async def exists(self) -> bool: return await self.store.exists(self.path) diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index f1bce769d..e38c74e33 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -60,6 +60,7 @@ def _put( path: Path, value: Buffer, start: int | None = None, + exclusive: bool = False, ) -> int | None: path.parent.mkdir(parents=True, exist_ok=True) if start is not None: @@ -68,7 +69,13 @@ def _put( f.write(value.as_numpy_array().tobytes()) return None else: - return path.write_bytes(value.as_numpy_array().tobytes()) + view = memoryview(value.as_numpy_array().tobytes()) + if exclusive: + mode = "xb" + else: + mode = "wb" + with path.open(mode=mode) as f: + return f.write(view) class LocalStore(Store): @@ -152,6 +159,15 @@ async def get_partial_values( return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit async def set(self, key: str, value: Buffer) -> None: + return await self._set(key, value) + + async def setdefault(self, key: str, value: Buffer) -> None: + try: + return await self._set(key, value, exclusive=True) + except FileExistsError: + pass + + async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: if not self._is_open: await self._open() self._check_writable() @@ -159,7 +175,7 @@ async def set(self, key: str, value: Buffer) -> None: if not isinstance(value, Buffer): raise TypeError("LocalStore.set(): `value` must a Buffer instance") path = self.root / key - await to_thread(_put, path, value) + await to_thread(_put, path, value, start=None, exclusive=exclusive) async def set_partial_values( self, key_start_values: Iterable[tuple[str, int, bytes | bytearray | memoryview]] diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index ee4107b0a..2048202d5 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -85,9 +85,8 @@ async def exists(self, key: str) -> bool: return key in self._store_dict async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: - if not self._is_open: - await self._open() self._check_writable() + await self._ensure_open() assert isinstance(key, str) if not isinstance(value, Buffer): raise TypeError(f"Expected Buffer. Got {type(value)}.") @@ -99,6 +98,11 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None else: self._store_dict[key] = value + async def setdefault(self, key: str, default: Buffer) -> None: + self._check_writable() + await self._ensure_open() + self._store_dict.setdefault(key, default) + async def delete(self, key: str) -> None: self._check_writable() try: diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 284cd8d77..c72b9d144 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -5,6 +5,7 @@ import fsspec from zarr.abc.store import ByteRangeRequest, Store +from zarr.core.buffer import Buffer from zarr.store.common import _dereference_path if TYPE_CHECKING: @@ -208,6 +209,11 @@ async def set_partial_values( ) -> None: raise NotImplementedError + async def setdefault(self, key: str, default: Buffer) -> None: + # this isn't safe for concurrent writers, but that's probably unavoidable. + if not await self.fs._exists(_dereference_path(self.path, key)): + await self.set(key, default) + async def list(self) -> AsyncGenerator[str, None]: allfiles = await self.fs._find(self.path, detail=False, withdirs=False) for onefile in (a.replace(self.path + "/", "") for a in allfiles): diff --git a/src/zarr/store/zip.py b/src/zarr/store/zip.py index 949660913..3f5c2a87d 100644 --- a/src/zarr/store/zip.py +++ b/src/zarr/store/zip.py @@ -191,6 +191,13 @@ async def set(self, key: str, value: Buffer) -> None: async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, bytes]]) -> None: raise NotImplementedError + async def setdefault(self, key: str, default: Buffer) -> None: + self._check_writable() + with self._lock: + members = self._zf.namelist() + if key not in members: + self._set(key, default) + async def delete(self, key: str) -> None: raise NotImplementedError diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 7b78b8ed0..9c4257868 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -273,3 +273,20 @@ async def test_list_dir(self, store: S) -> None: keys_observed = await _collect_aiterator(store.list_dir(root + "/")) assert sorted(keys_expected) == sorted(keys_observed) + + async def test_setdefault(self, store: S) -> None: + key = "k" + data_buf = self.buffer_cls.from_bytes(b"0000") + self.set(store, key, data_buf) + + new = self.buffer_cls.from_bytes(b"1111") + await store.setdefault("k", new) # no error + + result = await store.get(key, default_buffer_prototype()) + assert result == data_buf + + await store.setdefault("k2", new) # no error + await store.get("k2", default_buffer_prototype()) + + result = await store.get("k2", default_buffer_prototype()) + assert result == new diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index 95bbde174..7bf968c0e 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -5,6 +5,7 @@ import numpy as np import pytest +import zarr.api.asynchronous from zarr import Array, AsyncArray, Group from zarr.core.array import chunks_initialized from zarr.core.buffer.cpu import NDBuffer @@ -65,6 +66,36 @@ def test_array_creation_existing_node( ) +@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) +@pytest.mark.parametrize("zarr_format", [2, 3]) +async def test_create_creates_parents( + store: LocalStore | MemoryStore, zarr_format: ZarrFormat +) -> None: + await zarr.api.asynchronous.create( + shape=(2, 2), store=store, path="a/b/c/d", zarr_format=zarr_format + ) + parts = ["a", "a/b", "a/b/c"] + + if zarr_format == 2: + files = [".zattrs", ".zgroup"] + else: + files = ["zarr.json"] + + expected = [f"{part}/{file}" for file in files for part in parts] + + if zarr_format == 2: + expected.append("a/b/c/d/.zarray") + expected.append("a/b/c/d/.zattrs") + else: + expected.append("a/b/c/d/zarr.json") + + expected = sorted(expected) + + result = sorted([x async for x in store.list_prefix("")]) + + assert result == expected + + @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) @pytest.mark.parametrize("zarr_format", [2, 3]) def test_array_name_properties_no_group( diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index 8c6464d3b..a0e66cb2a 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -7,6 +7,7 @@ import pytest import zarr +import zarr.api.asynchronous from zarr import Array, AsyncArray, AsyncGroup, Group from zarr.abc.store import Store from zarr.core.buffer import default_buffer_prototype @@ -56,6 +57,30 @@ def test_group_init(store: Store, zarr_format: ZarrFormat) -> None: assert group._async_group == agroup +async def test_create_creates_parents(store: Store, zarr_format: ZarrFormat) -> None: + await zarr.api.asynchronous.open_group(store=store, path="a/b/c/d", zarr_format=zarr_format) + parts = ["a", "a/b", "a/b/c"] + + if zarr_format == 2: + files = [".zattrs", ".zgroup"] + else: + files = ["zarr.json"] + + expected = [f"{part}/{file}" for file in files for part in parts] + + if zarr_format == 2: + expected.append("a/b/c/d/.zgroup") + expected.append("a/b/c/d/.zattrs") + else: + expected.append("a/b/c/d/zarr.json") + + expected = sorted(expected) + + result = sorted([x async for x in store.list_prefix("")]) + + assert result == expected + + def test_group_name_properties(store: Store, zarr_format: ZarrFormat) -> None: """ Test basic properties of groups From 84cfe18279388f067dd74484b04d7ac5e169c49d Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 26 Sep 2024 14:44:48 -0500 Subject: [PATCH 2/8] use the API --- tests/v3/test_array.py | 6 ++++++ tests/v3/test_group.py | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index 7bf968c0e..5b6a5b24d 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -10,6 +10,7 @@ from zarr.core.array import chunks_initialized from zarr.core.buffer.cpu import NDBuffer from zarr.core.common import ZarrFormat +from zarr.core.group import AsyncGroup from zarr.core.indexing import ceildiv from zarr.core.sync import sync from zarr.errors import ContainsArrayError, ContainsGroupError @@ -95,6 +96,11 @@ async def test_create_creates_parents( assert result == expected + paths = ["a", "a/b", "a/b/c"] + for path in paths: + g = await zarr.api.asynchronous.open_group(store=store, path=path) + assert isinstance(g, AsyncGroup) + @pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"]) @pytest.mark.parametrize("zarr_format", [2, 3]) diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index a0e66cb2a..dbda82271 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -80,6 +80,11 @@ async def test_create_creates_parents(store: Store, zarr_format: ZarrFormat) -> assert result == expected + paths = ["a", "a/b", "a/b/c"] + for path in paths: + g = await zarr.api.asynchronous.open_group(store=store, path=path) + assert isinstance(g, AsyncGroup) + def test_group_name_properties(store: Store, zarr_format: ZarrFormat) -> None: """ From 7ba26483d6305a9fcdb883ec328bbd6fa8b680d4 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 26 Sep 2024 20:40:40 -0500 Subject: [PATCH 3/8] fixed logging store --- src/zarr/store/logging.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/zarr/store/logging.py b/src/zarr/store/logging.py index 0c05b4265..6f324ec32 100644 --- a/src/zarr/store/logging.py +++ b/src/zarr/store/logging.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING from zarr.abc.store import AccessMode, ByteRangeRequest, Store +from zarr.core.buffer import Buffer if TYPE_CHECKING: from collections.abc import AsyncGenerator, Generator, Iterable @@ -138,6 +139,10 @@ async def set(self, key: str, value: Buffer) -> None: with self.log(): return await self._store.set(key=key, value=value) + async def setdefault(self, key: str, default: Buffer) -> None: + with self.log(): + return await self._store.set(key=key, value=default) + async def delete(self, key: str) -> None: with self.log(): return await self._store.delete(key=key) From d65047ec10ba237f8e78b4df87cec47a4d8642e3 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 26 Sep 2024 20:57:23 -0500 Subject: [PATCH 4/8] Update src/zarr/testing/store.py --- src/zarr/testing/store.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index 91b354e98..f1e812bcc 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -286,7 +286,6 @@ async def test_setdefault(self, store: S) -> None: assert result == data_buf await store.setdefault("k2", new) # no error - await store.get("k2", default_buffer_prototype()) result = await store.get("k2", default_buffer_prototype()) assert result == new From d44f9552e29dc22a12d40f6396bafda9c6b2ceb2 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Fri, 27 Sep 2024 07:28:20 -0500 Subject: [PATCH 5/8] fixes --- src/zarr/abc/store.py | 7 ------- src/zarr/core/array.py | 2 +- src/zarr/store/logging.py | 2 +- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index f1af8e8c9..7e3133ff5 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -315,10 +315,3 @@ async def set_or_delete(byte_setter: ByteSetter, value: Buffer | None) -> None: await byte_setter.delete() else: await byte_setter.set(value) - - -async def setdefault(byte_setter: ByteSetter, value: Buffer | None) -> None: - if value is None: - await byte_setter.delete() - else: - await byte_setter.set(value) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index c8d657772..aaebf417e 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -614,7 +614,7 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()] if ensure_parents: - # To enable zarr.create(store, path="a/b/c"), we need to create all the intermediates. + # To enable zarr.create(store, path="a/b/c"), we need to create all the intermediate groups. parents = _build_parents(self) for parent in parents: diff --git a/src/zarr/store/logging.py b/src/zarr/store/logging.py index 6f324ec32..aa5091a3a 100644 --- a/src/zarr/store/logging.py +++ b/src/zarr/store/logging.py @@ -141,7 +141,7 @@ async def set(self, key: str, value: Buffer) -> None: async def setdefault(self, key: str, default: Buffer) -> None: with self.log(): - return await self._store.set(key=key, value=default) + return await self._store.setdefault(key=key, default=default) async def delete(self, key: str) -> None: with self.log(): From 264327dce765719c1389b1bde0a4f242e439ec81 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Fri, 27 Sep 2024 09:53:39 -0500 Subject: [PATCH 6/8] fixup --- tests/v3/test_array.py | 6 ++++++ tests/v3/test_group.py | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index 4dbfce849..5de4a4d12 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -73,6 +73,12 @@ def test_array_creation_existing_node( async def test_create_creates_parents( store: LocalStore | MemoryStore, zarr_format: ZarrFormat ) -> None: + # prepare a root node, with some data set + await zarr.api.asynchronous.open_group( + store=store, path="a", zarr_format=zarr_format, attributes={"key": "value"} + ) + + # create a child node with a couple intermediates await zarr.api.asynchronous.create( shape=(2, 2), store=store, path="a/b/c/d", zarr_format=zarr_format ) diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index dbda82271..cb1681daa 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -58,6 +58,11 @@ def test_group_init(store: Store, zarr_format: ZarrFormat) -> None: async def test_create_creates_parents(store: Store, zarr_format: ZarrFormat) -> None: + # prepare a root node, with some data set + await zarr.api.asynchronous.open_group( + store=store, path="a", zarr_format=zarr_format, attributes={"key": "value"} + ) + # create a child node with a couple intermediates await zarr.api.asynchronous.open_group(store=store, path="a/b/c/d", zarr_format=zarr_format) parts = ["a", "a/b", "a/b/c"] @@ -85,6 +90,12 @@ async def test_create_creates_parents(store: Store, zarr_format: ZarrFormat) -> g = await zarr.api.asynchronous.open_group(store=store, path=path) assert isinstance(g, AsyncGroup) + if path == "a": + # ensure we didn't overwrite the root attributes + assert g.attrs == {"key": "value"} + else: + assert g.attrs == {} + def test_group_name_properties(store: Store, zarr_format: ZarrFormat) -> None: """ From 1039c165583585869450147ea073e039a32d1db3 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Fri, 27 Sep 2024 12:39:47 -0500 Subject: [PATCH 7/8] fixes --- src/zarr/abc/store.py | 20 +++++++++++++------- src/zarr/codecs/sharding.py | 2 +- src/zarr/core/array.py | 2 +- src/zarr/core/group.py | 2 +- src/zarr/store/common.py | 4 ++-- src/zarr/store/local.py | 2 +- src/zarr/store/logging.py | 4 ++-- src/zarr/store/memory.py | 2 +- src/zarr/store/remote.py | 5 ----- src/zarr/store/zip.py | 2 +- src/zarr/testing/store.py | 6 +++--- 11 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 7e3133ff5..39d4ac03c 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -172,15 +172,21 @@ async def set(self, key: str, value: Buffer) -> None: """ ... - @abstractmethod - async def setdefault(self, key: str, default: Buffer) -> None: + async def set_if_not_exists(self, key: str, value: Buffer) -> None: """ - Store a key with a value of ``default`` if the key is not already present. + Store a key to ``value`` if the key is not already present. - Unlike MutableMapping.default, this method does not provide any way to - know whether ``default`` was actually set. + Parameters + ----------- + key : str + value : Buffer """ - ... + # Note for implementers: the default implementation provided here + # is not safe for concurrent writers. There's a race condition between + # the `exists` check and the `set` where another writer could set some + # value at `key` or delete `key`. + if not await self.exists(key): + await self.set(key, value) async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None: """ @@ -307,7 +313,7 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) - async def delete(self) -> None: ... - async def setdefault(self, default: Buffer) -> None: ... + async def set_if_not_exists(self, default: Buffer) -> None: ... async def set_or_delete(byte_setter: ByteSetter, value: Buffer | None) -> None: diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index f7ffe57e0..2181e9eb7 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -97,7 +97,7 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) - async def delete(self) -> None: del self.shard_dict[self.chunk_coords] - async def setdefault(self, default: Buffer) -> None: + async def set_if_not_exists(self, default: Buffer) -> None: self.shard_dict.setdefault(self.chunk_coords, default) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 45b9a11e0..99a043c23 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -627,7 +627,7 @@ async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = F for parent in parents: awaitables.extend( [ - (parent.store_path / key).setdefault(value) + (parent.store_path / key).set_if_not_exists(value) for key, value in parent.metadata.to_buffer_dict( default_buffer_prototype() ).items() diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 6db4bf704..79d03d3fc 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -288,7 +288,7 @@ async def _save_metadata(self, ensure_parents: bool = False) -> None: for parent in parents: awaitables.extend( [ - (parent.store_path / key).setdefault(value) + (parent.store_path / key).set_if_not_exists(value) for key, value in parent.metadata.to_buffer_dict( default_buffer_prototype() ).items() diff --git a/src/zarr/store/common.py b/src/zarr/store/common.py index 952dea601..8a87880f3 100644 --- a/src/zarr/store/common.py +++ b/src/zarr/store/common.py @@ -51,8 +51,8 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) - async def delete(self) -> None: await self.store.delete(self.path) - async def setdefault(self, default: Buffer) -> None: - await self.store.setdefault(self.path, default) + async def set_if_not_exists(self, default: Buffer) -> None: + await self.store.set_if_not_exists(self.path, default) async def exists(self) -> bool: return await self.store.exists(self.path) diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index e38c74e33..23a87ea49 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -161,7 +161,7 @@ async def get_partial_values( async def set(self, key: str, value: Buffer) -> None: return await self._set(key, value) - async def setdefault(self, key: str, value: Buffer) -> None: + async def set_if_not_exists(self, key: str, value: Buffer) -> None: try: return await self._set(key, value, exclusive=True) except FileExistsError: diff --git a/src/zarr/store/logging.py b/src/zarr/store/logging.py index aa5091a3a..7ba2f07b9 100644 --- a/src/zarr/store/logging.py +++ b/src/zarr/store/logging.py @@ -139,9 +139,9 @@ async def set(self, key: str, value: Buffer) -> None: with self.log(): return await self._store.set(key=key, value=value) - async def setdefault(self, key: str, default: Buffer) -> None: + async def set_if_not_exists(self, key: str, default: Buffer) -> None: with self.log(): - return await self._store.setdefault(key=key, default=default) + return await self._store.set_if_not_exists(key=key, value=default) async def delete(self, key: str) -> None: with self.log(): diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index 2048202d5..d5294c9d2 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -98,7 +98,7 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None else: self._store_dict[key] = value - async def setdefault(self, key: str, default: Buffer) -> None: + async def set_if_not_exists(self, key: str, default: Buffer) -> None: self._check_writable() await self._ensure_open() self._store_dict.setdefault(key, default) diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index c72b9d144..6cc631d3b 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -209,11 +209,6 @@ async def set_partial_values( ) -> None: raise NotImplementedError - async def setdefault(self, key: str, default: Buffer) -> None: - # this isn't safe for concurrent writers, but that's probably unavoidable. - if not await self.fs._exists(_dereference_path(self.path, key)): - await self.set(key, default) - async def list(self) -> AsyncGenerator[str, None]: allfiles = await self.fs._find(self.path, detail=False, withdirs=False) for onefile in (a.replace(self.path + "/", "") for a in allfiles): diff --git a/src/zarr/store/zip.py b/src/zarr/store/zip.py index 8fcca076d..82ce7d024 100644 --- a/src/zarr/store/zip.py +++ b/src/zarr/store/zip.py @@ -188,7 +188,7 @@ async def set(self, key: str, value: Buffer) -> None: async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, bytes]]) -> None: raise NotImplementedError - async def setdefault(self, key: str, default: Buffer) -> None: + async def set_if_not_exists(self, key: str, default: Buffer) -> None: self._check_writable() with self._lock: members = self._zf.namelist() diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index f1e812bcc..ed49936d1 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -274,18 +274,18 @@ async def test_list_dir(self, store: S) -> None: keys_observed = await _collect_aiterator(store.list_dir(root + "/")) assert sorted(keys_expected) == sorted(keys_observed) - async def test_setdefault(self, store: S) -> None: + async def test_set_if_not_exists(self, store: S) -> None: key = "k" data_buf = self.buffer_cls.from_bytes(b"0000") self.set(store, key, data_buf) new = self.buffer_cls.from_bytes(b"1111") - await store.setdefault("k", new) # no error + await store.set_if_not_exists("k", new) # no error result = await store.get(key, default_buffer_prototype()) assert result == data_buf - await store.setdefault("k2", new) # no error + await store.set_if_not_exists("k2", new) # no error result = await store.get("k2", default_buffer_prototype()) assert result == new From 4e07e015dc4f52ac5cedcba07d743c9f5a54edf2 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Fri, 27 Sep 2024 15:22:52 -0500 Subject: [PATCH 8/8] pre-commit --- tests/v3/test_store/test_stateful_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v3/test_store/test_stateful_store.py b/tests/v3/test_store/test_stateful_store.py index 1a43b55fe..efa1953a5 100644 --- a/tests/v3/test_store/test_stateful_store.py +++ b/tests/v3/test_store/test_stateful_store.py @@ -239,7 +239,7 @@ def check_zarr_keys(self) -> None: def test_zarr_hierarchy(sync_store: Store) -> None: - def mk_test_instance_sync(): + def mk_test_instance_sync() -> None: return ZarrStoreStateMachine(sync_store) if isinstance(sync_store, ZipStore):