Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure parents are created when creating a node #2262

Merged
merged 12 commits into from
Sep 27, 2024
19 changes: 19 additions & 0 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,16 @@ async def set(self, key: str, value: Buffer) -> None:
"""
...

@abstractmethod
async def setdefault(self, key: str, default: Buffer) -> None:
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
"""
Store a key with a value of ``default`` if the key is not already present.

Unlike MutableMapping.default, this method does not provide any way to
know whether ``default`` was actually set.
"""
...

async def _set_many(self, values: Iterable[tuple[str, Buffer]]) -> None:
"""
Insert multiple (key, value) pairs into storage.
Expand Down Expand Up @@ -297,9 +307,18 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -

async def delete(self) -> None: ...

async def setdefault(self, default: Buffer) -> None: ...


async def set_or_delete(byte_setter: ByteSetter, value: Buffer | None) -> None:
if value is None:
await byte_setter.delete()
else:
await byte_setter.set(value)


async def setdefault(byte_setter: ByteSetter, value: Buffer | None) -> None:
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
if value is None:
await byte_setter.delete()
else:
await byte_setter.set(value)
3 changes: 3 additions & 0 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -
async def delete(self) -> None:
del self.shard_dict[self.chunk_coords]

async def setdefault(self, default: Buffer) -> None:
self.shard_dict.setdefault(self.chunk_coords, default)
Copy link
Contributor Author

@TomAugspurger TomAugspurger Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is actually tested anywhere. I'm not 100% sure, but I think all the Group / Array Metadata creation method will be using StorePath as their ByteSetter.



class _ShardIndex(NamedTuple):
# dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2)
Expand Down
47 changes: 43 additions & 4 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from zarr.codecs import BytesCodec
from zarr.codecs._v2 import V2Compressor, V2Filters
from zarr.core.attributes import Attributes
from zarr.core.buffer import BufferPrototype, NDArrayLike, NDBuffer, default_buffer_prototype
from zarr.core.buffer import (
BufferPrototype,
NDArrayLike,
NDBuffer,
default_buffer_prototype,
)
from zarr.core.chunk_grids import RegularChunkGrid, _guess_chunks
from zarr.core.chunk_key_encodings import (
ChunkKeyEncoding,
Expand Down Expand Up @@ -71,6 +76,7 @@
from collections.abc import Iterable, Iterator, Sequence

from zarr.abc.codec import Codec, CodecPipeline
from zarr.core.group import AsyncGroup
from zarr.core.metadata.common import ArrayMetadata

# Array and AsyncArray are defined in the base ``zarr`` namespace
Expand Down Expand Up @@ -276,7 +282,7 @@ async def _create_v3(
)

array = cls(metadata=metadata, store_path=store_path)
await array._save_metadata(metadata)
await array._save_metadata(metadata, ensure_parents=True)
return array

@classmethod
Expand Down Expand Up @@ -315,7 +321,7 @@ async def _create_v2(
attributes=attributes,
)
array = cls(metadata=metadata, store_path=store_path)
await array._save_metadata(metadata)
await array._save_metadata(metadata, ensure_parents=True)
return array

@classmethod
Expand Down Expand Up @@ -603,9 +609,24 @@ async def getitem(
)
return await self._get_selection(indexer, prototype=prototype)

async def _save_metadata(self, metadata: ArrayMetadata) -> None:
async def _save_metadata(self, metadata: ArrayMetadata, ensure_parents: bool = False) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new keyword is to ensure that updates to an existing nodes don't require all the setdefault operations to ensure that parents exist. Anytime we create a brand new node we should call _save_metdata with ensure_parents=True.

to_save = metadata.to_buffer_dict(default_buffer_prototype())
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]

if ensure_parents:
# To enable zarr.create(store, path="a/b/c"), we need to create all the intermediates.
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
parents = _build_parents(self)

for parent in parents:
awaitables.extend(
[
(parent.store_path / key).setdefault(value)
for key, value in parent.metadata.to_buffer_dict(
default_buffer_prototype()
).items()
]
)

await gather(*awaitables)

async def _set_selection(
Expand Down Expand Up @@ -2336,3 +2357,21 @@ def chunks_initialized(array: Array | AsyncArray) -> tuple[str, ...]:
out.append(chunk_key)

return tuple(out)


def _build_parents(node: AsyncArray | AsyncGroup) -> list[AsyncGroup]:
from zarr.core.group import AsyncGroup, GroupMetadata

required_parts = node.store_path.path.split("/")[:-1]
parents = []

for i, part in enumerate(required_parts):
path = "/".join(required_parts[:i] + [part])
parents.append(
AsyncGroup(
metadata=GroupMetadata(zarr_format=node.metadata.zarr_format),
store_path=StorePath(store=node.store_path.store, path=path),
)
)

return parents
19 changes: 16 additions & 3 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import zarr.api.asynchronous as async_api
from zarr.abc.metadata import Metadata
from zarr.abc.store import Store, set_or_delete
from zarr.core.array import Array, AsyncArray
from zarr.core.array import Array, AsyncArray, _build_parents
from zarr.core.attributes import Attributes
from zarr.core.buffer import default_buffer_prototype
from zarr.core.common import (
Expand Down Expand Up @@ -144,7 +144,7 @@ async def from_store(
metadata=GroupMetadata(attributes=attributes, zarr_format=zarr_format),
store_path=store_path,
)
await group._save_metadata()
await group._save_metadata(ensure_parents=True)
return group

@classmethod
Expand Down Expand Up @@ -279,9 +279,22 @@ async def delitem(self, key: str) -> None:
else:
raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}")

async def _save_metadata(self) -> None:
async def _save_metadata(self, ensure_parents: bool = False) -> None:
to_save = self.metadata.to_buffer_dict(default_buffer_prototype())
awaitables = [set_or_delete(self.store_path / key, value) for key, value in to_save.items()]

if ensure_parents:
parents = _build_parents(self)
for parent in parents:
awaitables.extend(
[
(parent.store_path / key).setdefault(value)
for key, value in parent.metadata.to_buffer_dict(
default_buffer_prototype()
).items()
]
)

await asyncio.gather(*awaitables)

@property
Expand Down
3 changes: 3 additions & 0 deletions src/zarr/store/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ async def set(self, value: Buffer, byte_range: ByteRangeRequest | None = None) -
async def delete(self) -> None:
await self.store.delete(self.path)

async def setdefault(self, default: Buffer) -> None:
await self.store.setdefault(self.path, default)

async def exists(self) -> bool:
return await self.store.exists(self.path)

Expand Down
20 changes: 18 additions & 2 deletions src/zarr/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _put(
path: Path,
value: Buffer,
start: int | None = None,
exclusive: bool = False,
) -> int | None:
path.parent.mkdir(parents=True, exist_ok=True)
if start is not None:
Expand All @@ -68,7 +69,13 @@ def _put(
f.write(value.as_numpy_array().tobytes())
return None
else:
return path.write_bytes(value.as_numpy_array().tobytes())
view = memoryview(value.as_numpy_array().tobytes())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pathlib.Path.write_bytes doesn't provide control over the mode. So this is just that method inlined, with mode="xb" if we want exclusive create.

if exclusive:
mode = "xb"
else:
mode = "wb"
with path.open(mode=mode) as f:
return f.write(view)


class LocalStore(Store):
Expand Down Expand Up @@ -152,14 +159,23 @@ async def get_partial_values(
return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit

async def set(self, key: str, value: Buffer) -> None:
return await self._set(key, value)

async def setdefault(self, key: str, value: Buffer) -> None:
try:
return await self._set(key, value, exclusive=True)
except FileExistsError:
pass

async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None:
if not self._is_open:
await self._open()
self._check_writable()
assert isinstance(key, str)
if not isinstance(value, Buffer):
raise TypeError("LocalStore.set(): `value` must a Buffer instance")
path = self.root / key
await to_thread(_put, path, value)
await to_thread(_put, path, value, start=None, exclusive=exclusive)

async def set_partial_values(
self, key_start_values: Iterable[tuple[str, int, bytes | bytearray | memoryview]]
Expand Down
5 changes: 5 additions & 0 deletions src/zarr/store/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING

from zarr.abc.store import AccessMode, ByteRangeRequest, Store
from zarr.core.buffer import Buffer

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Generator, Iterable
Expand Down Expand Up @@ -138,6 +139,10 @@ async def set(self, key: str, value: Buffer) -> None:
with self.log():
return await self._store.set(key=key, value=value)

async def setdefault(self, key: str, default: Buffer) -> None:
with self.log():
return await self._store.set(key=key, value=default)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return await self._store.set(key=key, value=default)
return await self._store.setdefault(key=key, value=default)


async def delete(self, key: str) -> None:
with self.log():
return await self._store.delete(key=key)
Expand Down
8 changes: 6 additions & 2 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,8 @@ async def exists(self, key: str) -> bool:
return key in self._store_dict

async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
if not self._is_open:
await self._open()
self._check_writable()
await self._ensure_open()
assert isinstance(key, str)
if not isinstance(value, Buffer):
raise TypeError(f"Expected Buffer. Got {type(value)}.")
Expand All @@ -99,6 +98,11 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None
else:
self._store_dict[key] = value

async def setdefault(self, key: str, default: Buffer) -> None:
self._check_writable()
await self._ensure_open()
self._store_dict.setdefault(key, default)

async def delete(self, key: str) -> None:
self._check_writable()
try:
Expand Down
6 changes: 6 additions & 0 deletions src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import fsspec

from zarr.abc.store import ByteRangeRequest, Store
from zarr.core.buffer import Buffer
from zarr.store.common import _dereference_path

if TYPE_CHECKING:
Expand Down Expand Up @@ -208,6 +209,11 @@ async def set_partial_values(
) -> None:
raise NotImplementedError

async def setdefault(self, key: str, default: Buffer) -> None:
# this isn't safe for concurrent writers, but that's probably unavoidable.
if not await self.fs._exists(_dereference_path(self.path, key)):
await self.set(key, default)

async def list(self) -> AsyncGenerator[str, None]:
allfiles = await self.fs._find(self.path, detail=False, withdirs=False)
for onefile in (a.replace(self.path + "/", "") for a in allfiles):
Expand Down
7 changes: 7 additions & 0 deletions src/zarr/store/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ async def set(self, key: str, value: Buffer) -> None:
async def set_partial_values(self, key_start_values: Iterable[tuple[str, int, bytes]]) -> None:
raise NotImplementedError

async def setdefault(self, key: str, default: Buffer) -> None:
self._check_writable()
with self._lock:
members = self._zf.namelist()
if key not in members:
self._set(key, default)

async def delete(self, key: str) -> None:
raise NotImplementedError

Expand Down
17 changes: 17 additions & 0 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,20 @@ async def test_list_dir(self, store: S) -> None:

keys_observed = await _collect_aiterator(store.list_dir(root + "/"))
assert sorted(keys_expected) == sorted(keys_observed)

async def test_setdefault(self, store: S) -> None:
key = "k"
data_buf = self.buffer_cls.from_bytes(b"0000")
self.set(store, key, data_buf)

new = self.buffer_cls.from_bytes(b"1111")
await store.setdefault("k", new) # no error

result = await store.get(key, default_buffer_prototype())
assert result == data_buf

await store.setdefault("k2", new) # no error
await store.get("k2", default_buffer_prototype())
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved

result = await store.get("k2", default_buffer_prototype())
assert result == new
37 changes: 37 additions & 0 deletions tests/v3/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import numpy as np
import pytest

import zarr.api.asynchronous
from zarr import Array, AsyncArray, Group
from zarr.core.array import chunks_initialized
from zarr.core.buffer.cpu import NDBuffer
from zarr.core.common import ZarrFormat
from zarr.core.group import AsyncGroup
from zarr.core.indexing import ceildiv
from zarr.core.sync import sync
from zarr.errors import ContainsArrayError, ContainsGroupError
Expand Down Expand Up @@ -65,6 +67,41 @@ def test_array_creation_existing_node(
)


@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
@pytest.mark.parametrize("zarr_format", [2, 3])
async def test_create_creates_parents(
store: LocalStore | MemoryStore, zarr_format: ZarrFormat
) -> None:
await zarr.api.asynchronous.create(
shape=(2, 2), store=store, path="a/b/c/d", zarr_format=zarr_format
)
parts = ["a", "a/b", "a/b/c"]

if zarr_format == 2:
files = [".zattrs", ".zgroup"]
else:
files = ["zarr.json"]

expected = [f"{part}/{file}" for file in files for part in parts]

if zarr_format == 2:
expected.append("a/b/c/d/.zarray")
expected.append("a/b/c/d/.zattrs")
else:
expected.append("a/b/c/d/zarr.json")

expected = sorted(expected)

result = sorted([x async for x in store.list_prefix("")])

assert result == expected

paths = ["a", "a/b", "a/b/c"]
for path in paths:
g = await zarr.api.asynchronous.open_group(store=store, path=path)
assert isinstance(g, AsyncGroup)


@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
@pytest.mark.parametrize("zarr_format", [2, 3])
def test_array_name_properties_no_group(
Expand Down
Loading
Loading