Skip to content

Commit a468c4b

Browse files
committed
Fix up typing
1 parent 478af0e commit a468c4b

File tree

2 files changed

+26
-17
lines changed

2 files changed

+26
-17
lines changed

xarray/namedarray/daskmanager.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
if TYPE_CHECKING:
1313
from xarray.namedarray._typing import (
14+
T_ChunkDim,
1415
T_Chunks,
1516
_DType_co,
1617
_NormalizedChunks,
@@ -45,24 +46,24 @@ def chunks(self, data: Any) -> _NormalizedChunks:
4546

4647
def normalize_chunks(
4748
self,
48-
chunks: T_Chunks | _NormalizedChunks,
49+
chunks: tuple[T_ChunkDim, ...] | _NormalizedChunks,
4950
shape: tuple[int, ...] | None = None,
5051
limit: int | None = None,
5152
dtype: _DType_co | None = None,
52-
previous_chunks: _NormalizedChunks | None = None,
53+
previous_chunks: tuple[int, ...] | _NormalizedChunks | None = None,
5354
) -> Any:
5455
"""Called by open_dataset"""
5556
from dask.array.core import normalize_chunks
5657

5758
if any(c == "preserve" for c in chunks) and any(c == "auto" for c in chunks):
5859
raise ValueError('chunks cannot use a combination of "auto" and "preserve"')
5960

60-
if previous_chunks and any(c == "preserve" for c in chunks):
61+
if shape and previous_chunks and any(c == "preserve" for c in chunks):
6162
chunks = self.preserve_chunks(
6263
chunks,
6364
shape=shape,
6465
target=self.get_auto_chunk_size(),
65-
typesize=dtype.itemsize,
66+
typesize=getattr(dtype, "itemsize", 8),
6667
previous_chunks=previous_chunks,
6768
)
6869

xarray/namedarray/parallelcompat.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
if TYPE_CHECKING:
2222
from xarray.namedarray._typing import (
23+
T_ChunkDim,
2324
T_Chunks,
2425
_Chunks,
2526
_DType,
@@ -780,12 +781,12 @@ def get_auto_chunk_size(
780781

781782
@staticmethod
782783
def preserve_chunks(
783-
chunks: T_Chunks,
784+
chunks: tuple[T_ChunkDim, ...],
784785
shape: tuple[int, ...],
785786
target: int,
786787
typesize: int,
787-
previous_chunks: tuple[int],
788-
) -> tuple[int]:
788+
previous_chunks: tuple[int, ...] | _NormalizedChunks,
789+
) -> tuple[T_ChunkDim, ...]:
789790
"""Determine meta chunks
790791
791792
This takes in a chunks value that contains ``"preserve"`` values in certain
@@ -810,21 +811,28 @@ def preserve_chunks(
810811
Size of chunks being preserved. Expressed as a tuple of ints which matches
811812
the way chunks are encoded in Zarr.
812813
"""
813-
shape = np.array(shape)
814-
previous_chunks = np.array(previous_chunks)
814+
# pop the first item off in case it's a tuple of tuples
815+
preferred_chunks = np.array(
816+
[c if isinstance(c, int) else c[0] for c in previous_chunks]
817+
)
815818

816819
# "preserve" stays as "preserve"
817-
# empty tuple means match previous chunks
820+
# None or empty tuple means match previous chunks
818821
# -1 means whole dim is in one chunk
819822
desired_chunks = np.array(
820823
[
821-
c or previous_chunks[i] if c != -1 else shape[i]
824+
c or preferred_chunks[i] if c != -1 else shape[i]
822825
for i, c in enumerate(chunks)
823826
]
824827
)
825-
826828
preserve_chunks = desired_chunks == "preserve"
827-
chunks = np.where(preserve_chunks, previous_chunks, desired_chunks).astype(int)
829+
830+
if not preserve_chunks.any():
831+
return chunks
832+
833+
new_chunks = np.where(preserve_chunks, preferred_chunks, desired_chunks).astype(
834+
int
835+
)
828836

829837
while True:
830838
# Repeatedly loop over the ``previous_chunks``, multiplying them by 2.
@@ -833,16 +841,16 @@ def preserve_chunks(
833841
# 1b. we are within 50% of the target chunk size OR
834842
# 2. the chunk covers the entire array
835843

836-
num_chunks = shape / chunks * preserve_chunks
844+
num_chunks = np.array(shape) / new_chunks * preserve_chunks
837845
idx = np.argmax(num_chunks)
838-
chunk_bytes = np.prod(chunks) * typesize
846+
chunk_bytes = np.prod(new_chunks) * typesize
839847

840848
if chunk_bytes > target or abs(chunk_bytes - target) / target < 0.5:
841849
break
842850

843851
if (num_chunks <= 1).all():
844852
break
845853

846-
chunks[idx] = min(chunks[idx] * 2, shape[idx])
854+
new_chunks[idx] = min(new_chunks[idx] * 2, shape[idx])
847855

848-
return tuple(int(x) for x in chunks)
856+
return tuple(int(x) for x in new_chunks)

0 commit comments

Comments
 (0)