2020
2121if TYPE_CHECKING :
2222 from xarray .namedarray ._typing import (
23+ T_ChunkDim ,
2324 T_Chunks ,
2425 _Chunks ,
2526 _DType ,
@@ -780,12 +781,12 @@ def get_auto_chunk_size(
780781
781782 @staticmethod
782783 def preserve_chunks (
783- chunks : T_Chunks ,
784+ chunks : tuple [ T_ChunkDim , ...] ,
784785 shape : tuple [int , ...],
785786 target : int ,
786787 typesize : int ,
787- previous_chunks : tuple [int ] ,
788- ) -> tuple [int ]:
788+ previous_chunks : tuple [int , ...] | _NormalizedChunks ,
789+ ) -> tuple [T_ChunkDim , ... ]:
789790 """Determine meta chunks
790791
791792 This takes in a chunks value that contains ``"preserve"`` values in certain
@@ -810,21 +811,28 @@ def preserve_chunks(
810811 Size of chunks being preserved. Expressed as a tuple of ints which matches
811812 the way chunks are encoded in Zarr.
812813 """
813- shape = np .array (shape )
814- previous_chunks = np .array (previous_chunks )
814+ # pop the first item off in case it's a tuple of tuples
815+ preferred_chunks = np .array (
816+ [c if isinstance (c , int ) else c [0 ] for c in previous_chunks ]
817+ )
815818
816819 # "preserve" stays as "preserve"
817- # empty tuple means match previous chunks
820+ # None or empty tuple means match previous chunks
818821 # -1 means whole dim is in one chunk
819822 desired_chunks = np .array (
820823 [
821- c or previous_chunks [i ] if c != - 1 else shape [i ]
824+ c or preferred_chunks [i ] if c != - 1 else shape [i ]
822825 for i , c in enumerate (chunks )
823826 ]
824827 )
825-
826828 preserve_chunks = desired_chunks == "preserve"
827- chunks = np .where (preserve_chunks , previous_chunks , desired_chunks ).astype (int )
829+
830+ if not preserve_chunks .any ():
831+ return chunks
832+
833+ new_chunks = np .where (preserve_chunks , preferred_chunks , desired_chunks ).astype (
834+ int
835+ )
828836
829837 while True :
830838 # Repeatedly loop over the ``previous_chunks``, multiplying them by 2.
@@ -833,16 +841,16 @@ def preserve_chunks(
833841 # 1b. we are within 50% of the target chunk size OR
834842 # 2. the chunk covers the entire array
835843
836- num_chunks = shape / chunks * preserve_chunks
844+ num_chunks = np . array ( shape ) / new_chunks * preserve_chunks
837845 idx = np .argmax (num_chunks )
838- chunk_bytes = np .prod (chunks ) * typesize
846+ chunk_bytes = np .prod (new_chunks ) * typesize
839847
840848 if chunk_bytes > target or abs (chunk_bytes - target ) / target < 0.5 :
841849 break
842850
843851 if (num_chunks <= 1 ).all ():
844852 break
845853
846- chunks [idx ] = min (chunks [idx ] * 2 , shape [idx ])
854+ new_chunks [idx ] = min (new_chunks [idx ] * 2 , shape [idx ])
847855
848- return tuple (int (x ) for x in chunks )
856+ return tuple (int (x ) for x in new_chunks )
0 commit comments