Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,9 @@ def _dataset_from_backend_dataset(
create_default_indexes,
**extra_tokens,
):
if not isinstance(chunks, int | dict) and chunks not in {None, "auto"}:
if not isinstance(chunks, int | dict) and chunks not in {None, "auto", "preserve"}:
raise ValueError(
f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}."
f"chunks must be an int, dict, 'auto', 'preserve', or None. Instead found {chunks}."
)

_protect_dataset_variables_inplace(backend_ds, cache)
Expand Down
2 changes: 1 addition & 1 deletion xarray/namedarray/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def dtype(self) -> _DType_co: ...
_NormalizedChunks = tuple[tuple[int, ...], ...]
# FYI in some cases we don't allow `None`, which this doesn't take account of.
# # FYI the `str` is for a size string, e.g. "16MB", supported by dask.
T_ChunkDim: TypeAlias = str | int | Literal["auto"] | tuple[int, ...] | None # noqa: PYI051
T_ChunkDim: TypeAlias = str | int | Literal["auto", "preserve"] | tuple[int, ...] | None # noqa: PYI051
# We allow the tuple form of this (though arguably we could transition to named dims only)
T_Chunks: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDim]

Expand Down
81 changes: 81 additions & 0 deletions xarray/namedarray/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,75 @@ def is_chunked_array(self, data: duckarray[Any, Any]) -> bool:
def chunks(self, data: Any) -> _NormalizedChunks:
return data.chunks # type: ignore[no-any-return]

def preserve_chunks(
self,
chunks: T_Chunks,
shape: tuple[int, ...],
target: int,
typesize: int,
previous_chunks: tuple[int],
) -> tuple[int]:
"""Determine meta chunks

This takes in a chunks value that contains ``"preserve"`` values in certain
dimensions and replaces those values with concrete dimension sizes that try
to get chunks to be of a certain size in bytes, provided by the ``limit=``
keyword. Any dimensions marked as ``"preserve"`` will potentially be multiplied
to get close to the byte target, while never splitting ``previous_chunks``.

Parameters
----------
chunks: tuple[int | str | tuple, ...]
A tuple of either dimensions or tuples of explicit chunk dimensions
Some entries should be "preserve". Any explicit dimensions must match or
be multiple of ``previous_chunks``
shape: tuple[int]
The shape of the array
target: int
The target size of the chunk in bytes.
typesize: int
The size, in bytes, of each element of the chunk.
previous_chunks: tuple[int]
Size of chunks being preserved. Expressed as a tuple of ints which matches
the way chunks are encoded in Zarr.
"""
shape = np.array(shape)
previous_chunks = np.array(previous_chunks)

# "preserve" stays as "preserve"
# empty tuple means match previous chunks
# -1 means whole dim is in one chunk
desired_chunks = np.array(
[
c or previous_chunks[i] if c != -1 else shape[i]
for i, c in enumerate(chunks)
]
)

preserve_chunks = desired_chunks == "preserve"
chunks = np.where(preserve_chunks, previous_chunks, desired_chunks).astype(int)

while True:
# Repeatedly loop over the ``previous_chunks``, multiplying them by 2.
# Stop when:
# 1a. we are larger than the target chunk size OR
# 1b. we are within 50% of the target chunk size OR
# 2. the chunk covers the entire array

num_chunks = shape / chunks * preserve_chunks
idx = np.argmax(num_chunks)
chunk_bytes = np.prod(chunks) * typesize

if chunk_bytes > target or abs(chunk_bytes - target) / target < 0.5:
break

if (num_chunks <= 1).all():
break

chunks[idx] = min(chunks[idx] * 2, shape[idx])

return tuple(int(x) for x in chunks)

def normalize_chunks(
self,
chunks: T_Chunks | _NormalizedChunks,
Expand All @@ -54,6 +123,18 @@ def normalize_chunks(
"""Called by open_dataset"""
from dask.array.core import normalize_chunks

if any(c == "preserve" for c in chunks) and any(c == "auto" for c in chunks):
raise ValueError('chunks cannot use a combination of "auto" and "preserve"')

if previous_chunks and any(c == "preserve" for c in chunks):
chunks = self.preserve_chunks(
chunks,
shape=shape,
target=96 * 1024 * 1024,
typesize=dtype.itemsize,
previous_chunks=previous_chunks,
)

return normalize_chunks(
chunks,
shape=shape,
Expand Down
2 changes: 1 addition & 1 deletion xarray/namedarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _get_chunk( # type: ignore[no-untyped-def]
preferred_chunk_shape = tuple(
itertools.starmap(preferred_chunks.get, zip(dims, shape, strict=True))
)
if isinstance(chunks, Number) or (chunks == "auto"):
if isinstance(chunks, (Number, str)):
chunks = dict.fromkeys(dims, chunks)
chunk_shape = tuple(
chunks.get(dim, None) or preferred_chunk_sizes
Expand Down
Loading