From 0ae9291f94a198b4b13dd44ba8496461395f6bd9 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 26 Mar 2025 10:48:27 +0100 Subject: [PATCH 1/6] Use dask array native reduction --- icechunk-python/python/icechunk/dask.py | 171 ++---------------- .../python/icechunk/distributed.py | 15 +- icechunk-python/python/icechunk/xarray.py | 24 ++- .../tests/run_xarray_backends_tests.py | 3 +- .../tests/test_distributed_writers.py | 2 +- 5 files changed, 47 insertions(+), 168 deletions(-) diff --git a/icechunk-python/python/icechunk/dask.py b/icechunk-python/python/icechunk/dask.py index fcf4f5f9c..bbdb710d7 100644 --- a/icechunk-python/python/icechunk/dask.py +++ b/icechunk-python/python/icechunk/dask.py @@ -1,24 +1,13 @@ -import itertools -from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence -from typing import ( - TYPE_CHECKING, - Any, - Literal, - TypeAlias, - overload, -) +from collections.abc import Mapping +from typing import Any, TypeAlias from packaging.version import Version import dask import dask.array +import dask.array as da import zarr -from dask import config from dask.array.core import Array -from dask.base import compute_as_if_collection, tokenize -from dask.core import flatten -from dask.delayed import Delayed -from dask.highlevelgraph import HighLevelGraph from icechunk import Session from icechunk.distributed import extract_session, merge_sessions @@ -68,6 +57,7 @@ def store_dask( Arbitrary keyword arguments passed to `dask.array.store`. Notably `compute`, `return_stored`, `load_stored`, and `lock` are unsupported. """ + _assert_correct_dask_version() stored_arrays = dask.array.store( sources=sources, targets=targets, # type: ignore[arg-type] @@ -79,141 +69,20 @@ def store_dask( **store_kwargs, ) # Now we tree-reduce all changesets - merged_session = stateful_store_reduce( - stored_arrays, - prefix="ice-changeset", - chunk=extract_session, - aggregate=merge_sessions, - split_every=split_every, - compute=True, - **store_kwargs, - ) - session.merge(merged_session) - - -# tree-reduce all changesets, regardless of array -def _partial_reduce( - aggregate: Callable[..., Any], - keys: Iterable[tuple[Any, ...]], - *, - layer_name: str, - split_every: int, -) -> SimpleGraph: - """ - Creates a new dask graph layer, that aggregates `split_every` keys together. - """ - from toolz import partition_all - - return { - (layer_name, i): (aggregate, *keys_batch) - for i, keys_batch in enumerate(partition_all(split_every, keys)) - } - - -@overload -def stateful_store_reduce( - stored_arrays: Sequence[Array], - *, - chunk: Callable[..., Any], - aggregate: Callable[..., Any], - prefix: str | None = None, - split_every: int | None = None, - compute: Literal[False] = False, - **kwargs: Any, -) -> Delayed: ... - - -@overload -def stateful_store_reduce( - stored_arrays: Sequence[Array], - *, - chunk: Callable[..., Any], - aggregate: Callable[..., Any], - compute: Literal[True] = True, - prefix: str | None = None, - split_every: int | None = None, - **kwargs: Any, -) -> Session: ... - - -def stateful_store_reduce( - stored_arrays: Sequence[Array], - *, - chunk: Callable[..., Any], - aggregate: Callable[..., Any], - compute: bool = True, - prefix: str | None = None, - split_every: int | None = None, - **kwargs: Any, -) -> Session | Delayed: - _assert_correct_dask_version() - - split_every = split_every or config.get("split_every", 8) - - layers: MutableMapping[str, SimpleGraph] = {} - dependencies: MutableMapping[str, set[str]] = {} - - array_names = tuple(a.name for a in stored_arrays) - all_array_keys = list( - # flatten is untyped - itertools.chain(*[flatten(array.__dask_keys__()) for array in stored_arrays]) # type: ignore[no-untyped-call] - ) - token = tokenize(array_names, chunk, aggregate, split_every) - - # Each write task returns one Zarr array, - # now extract the changeset (as bytes) from each of those Zarr arrays - map_layer_name = f"{prefix}-blockwise-{token}" - map_dsk: SimpleGraph = { - (map_layer_name, i): (chunk, key) for i, key in enumerate(all_array_keys) - } - layers[map_layer_name] = map_dsk - dependencies[map_layer_name] = set(array_names) - latest_layer = map_layer_name - - if aggregate is not None: - # Now tree-reduce across *all* write tasks, - # regardless of which Array the task belongs to - aggprefix = f"{prefix}-merge" - - depth = 0 - keys = map_dsk.keys() - while len(keys) > split_every: - latest_layer = f"{aggprefix}-{depth}-{token}" - - layers[latest_layer] = _partial_reduce( - aggregate, keys, layer_name=latest_layer, split_every=split_every - ) - previous_layer, *_ = next(iter(keys)) - dependencies[latest_layer] = {previous_layer} - - keys = layers[latest_layer].keys() - depth += 1 - - # last one - latest_layer = f"{aggprefix}-final-{token}" - layers[latest_layer] = _partial_reduce( - aggregate, keys, layer_name=latest_layer, split_every=split_every + # reduce the individual arrays since concatenation isn't always trivial due + # to different shapes + merged_sessions = [ + da.reduction( + arr, + name="ice-changeset", + chunk=extract_session, + aggregate=merge_sessions, + split_every=split_every, + concatenate=False, + dtype=object, + **store_kwargs, ) - previous_layer, *_ = next(iter(keys)) - dependencies[latest_layer] = {previous_layer} - - store_dsk = HighLevelGraph.merge( - HighLevelGraph(layers, dependencies), # type: ignore[arg-type] - *[array.__dask_graph__() for array in stored_arrays], - ) - if compute: - # copied from dask.array.store - merged_session, *_ = compute_as_if_collection( # type: ignore[no-untyped-call] - Array, store_dsk, list(layers[latest_layer].keys()), **kwargs - ) - if TYPE_CHECKING: - assert isinstance(merged_session, Session) - return merged_session - - else: - key = "stateful-store-" + tokenize(array_names) - store_dsk = HighLevelGraph.merge( - HighLevelGraph({key: {key: (latest_layer, 0)}}, {key: {latest_layer}}), - store_dsk, - ) - return Delayed(key, store_dsk) # type: ignore[no-untyped-call] + for arr in stored_arrays + ] + merged_session = merge_sessions(*da.compute(*merged_sessions)) + session.merge(merged_session) diff --git a/icechunk-python/python/icechunk/distributed.py b/icechunk-python/python/icechunk/distributed.py index 065eb4f0d..222971f73 100644 --- a/icechunk-python/python/icechunk/distributed.py +++ b/icechunk-python/python/icechunk/distributed.py @@ -1,17 +1,24 @@ # distributed utility functions -from typing import cast +from typing import Any, cast import zarr +from dask.core import flatten from icechunk import IcechunkStore, Session -def extract_session(zarray: zarr.Array) -> Session: +def extract_session( + zarray: zarr.Array, axis: Any = None, keepdims: Any = None +) -> Session: store = cast(IcechunkStore, zarray.store) return store.session -def merge_sessions(*sessions: Session) -> Session: - session, *rest = sessions +def merge_sessions( + *sessions: Session | list[Session] | list[list[Session]], + axis: Any = None, + keepdims: Any = None, +) -> Session: + session, *rest = list(flatten(sessions)) for other in rest: session.merge(other) return session diff --git a/icechunk-python/python/icechunk/xarray.py b/icechunk-python/python/icechunk/xarray.py index dd3473e14..34f12a222 100644 --- a/icechunk-python/python/icechunk/xarray.py +++ b/icechunk-python/python/icechunk/xarray.py @@ -6,10 +6,10 @@ import numpy as np from packaging.version import Version +import dask.array as da import xarray as xr import zarr from icechunk import IcechunkStore, Session -from icechunk.dask import stateful_store_reduce from icechunk.distributed import extract_session, merge_sessions from icechunk.vendor.xarray import _choose_default_mode from xarray import DataArray, Dataset @@ -179,15 +179,19 @@ def write_lazy( ) # type: ignore[no-untyped-call] # Now we tree-reduce all changesets - merged_session = stateful_store_reduce( - stored_arrays, - prefix="ice-changeset", - chunk=extract_session, - aggregate=merge_sessions, - split_every=split_every, - compute=True, - **chunkmanager_store_kwargs, - ) + merged_sessions = [ + da.reduction( + arr, + name="ice-changeset", + chunk=extract_session, + aggregate=merge_sessions, + split_every=split_every, + concatenate=False, + dtype=object, + ) + for arr in stored_arrays + ] + merged_session = merge_sessions(*da.compute(*merged_sessions)) self.store.session.merge(merged_session) diff --git a/icechunk-python/tests/run_xarray_backends_tests.py b/icechunk-python/tests/run_xarray_backends_tests.py index 4a1090c40..9440dc4fe 100644 --- a/icechunk-python/tests/run_xarray_backends_tests.py +++ b/icechunk-python/tests/run_xarray_backends_tests.py @@ -20,8 +20,7 @@ TestZarrRegionAuto as ZarrRegionAutoTests, ) from xarray.tests.test_backends import ( - ZarrBase, - default_zarr_format, # noqa: F401; needed otherwise not discovered + ZarrBase, # noqa: F401; needed otherwise not discovered ) diff --git a/icechunk-python/tests/test_distributed_writers.py b/icechunk-python/tests/test_distributed_writers.py index 78cae9253..eec9e53d4 100644 --- a/icechunk-python/tests/test_distributed_writers.py +++ b/icechunk-python/tests/test_distributed_writers.py @@ -110,7 +110,7 @@ async def verify(branch_name: str) -> None: warnings.simplefilter("ignore", category=UserWarning) assert_eq(roundtripped, dask_array) # type: ignore [no-untyped-call] - with Client(n_workers=8): # type: ignore[no-untyped-call] + with Client(dashboard_address=":0"): # type: ignore[no-untyped-call] do_writes("with-processes") await verify("with-processes") From b5117118e9c7cb150bcef24591726399b2ccd356 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 28 Mar 2025 09:52:23 -0600 Subject: [PATCH 2/6] fix xarray test? --- icechunk-python/tests/run_xarray_backends_tests.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/icechunk-python/tests/run_xarray_backends_tests.py b/icechunk-python/tests/run_xarray_backends_tests.py index 9440dc4fe..4a1090c40 100644 --- a/icechunk-python/tests/run_xarray_backends_tests.py +++ b/icechunk-python/tests/run_xarray_backends_tests.py @@ -20,7 +20,8 @@ TestZarrRegionAuto as ZarrRegionAutoTests, ) from xarray.tests.test_backends import ( - ZarrBase, # noqa: F401; needed otherwise not discovered + ZarrBase, + default_zarr_format, # noqa: F401; needed otherwise not discovered ) From b42c76e8f0976e439d3e1686b255c84f15c51853 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 28 Mar 2025 10:07:02 -0600 Subject: [PATCH 3/6] Fix types --- .../python/icechunk/distributed.py | 21 ++++++++++++++++--- icechunk-python/python/icechunk/xarray.py | 5 +++-- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/icechunk-python/python/icechunk/distributed.py b/icechunk-python/python/icechunk/distributed.py index 222971f73..f609b5e77 100644 --- a/icechunk-python/python/icechunk/distributed.py +++ b/icechunk-python/python/icechunk/distributed.py @@ -2,9 +2,24 @@ from typing import Any, cast import zarr -from dask.core import flatten from icechunk import IcechunkStore, Session +__all__ = [ + "extract_session", + "merge_sessions", +] + + +def _flatten(seq, container=list): + if isinstance(seq, str): + yield seq + else: + for item in seq: + if isinstance(item, container): + yield from _flatten(item, container=container) + else: + yield item + def extract_session( zarray: zarr.Array, axis: Any = None, keepdims: Any = None @@ -18,7 +33,7 @@ def merge_sessions( axis: Any = None, keepdims: Any = None, ) -> Session: - session, *rest = list(flatten(sessions)) + session, *rest = list(_flatten(sessions)) # type: ignore[no-untyped-call] for other in rest: session.merge(other) - return session + return cast(Session, session) diff --git a/icechunk-python/python/icechunk/xarray.py b/icechunk-python/python/icechunk/xarray.py index 34f12a222..9ad3e21b3 100644 --- a/icechunk-python/python/icechunk/xarray.py +++ b/icechunk-python/python/icechunk/xarray.py @@ -180,7 +180,7 @@ def write_lazy( # Now we tree-reduce all changesets merged_sessions = [ - da.reduction( + da.reduction( # type: ignore[no-untyped-call] arr, name="ice-changeset", chunk=extract_session, @@ -191,7 +191,8 @@ def write_lazy( ) for arr in stored_arrays ] - merged_session = merge_sessions(*da.compute(*merged_sessions)) + computed = da.compute(*merged_sessions) # type: ignore[no-untyped-call] + merged_session = merge_sessions(*computed) self.store.session.merge(merged_session) From e16d25d7dec0a74e69a55703a76368a5a3bfdd66 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 28 Mar 2025 10:07:22 -0600 Subject: [PATCH 4/6] More explicit `da.reduction` --- icechunk-python/python/icechunk/dask.py | 1 + 1 file changed, 1 insertion(+) diff --git a/icechunk-python/python/icechunk/dask.py b/icechunk-python/python/icechunk/dask.py index bbdb710d7..049d41d15 100644 --- a/icechunk-python/python/icechunk/dask.py +++ b/icechunk-python/python/icechunk/dask.py @@ -79,6 +79,7 @@ def store_dask( aggregate=merge_sessions, split_every=split_every, concatenate=False, + keepdims=False, dtype=object, **store_kwargs, ) From f3a03767649a42711131a68bc28188847eafbf5b Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 1 Apr 2025 10:33:41 +0200 Subject: [PATCH 5/6] fix typing --- icechunk-python/python/icechunk/dask.py | 6 ++++-- icechunk-python/python/icechunk/distributed.py | 8 +++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/icechunk-python/python/icechunk/dask.py b/icechunk-python/python/icechunk/dask.py index 049d41d15..ec71abbc9 100644 --- a/icechunk-python/python/icechunk/dask.py +++ b/icechunk-python/python/icechunk/dask.py @@ -72,7 +72,7 @@ def store_dask( # reduce the individual arrays since concatenation isn't always trivial due # to different shapes merged_sessions = [ - da.reduction( + da.reduction( # type: ignore[no-untyped-call] arr, name="ice-changeset", chunk=extract_session, @@ -85,5 +85,7 @@ def store_dask( ) for arr in stored_arrays ] - merged_session = merge_sessions(*da.compute(*merged_sessions)) + merged_session = merge_sessions( + *da.compute(*merged_sessions) # type: ignore[no-untyped-call] + ) session.merge(merged_session) diff --git a/icechunk-python/python/icechunk/distributed.py b/icechunk-python/python/icechunk/distributed.py index f609b5e77..2b8dec8bb 100644 --- a/icechunk-python/python/icechunk/distributed.py +++ b/icechunk-python/python/icechunk/distributed.py @@ -1,7 +1,8 @@ # distributed utility functions -from typing import Any, cast +from typing import Any, Generator, Iterable, cast import zarr + from icechunk import IcechunkStore, Session __all__ = [ @@ -10,12 +11,13 @@ ] -def _flatten(seq, container=list): +def _flatten(seq: Iterable[Any], container: type = list) -> Generator[Any, None, None]: if isinstance(seq, str): yield seq else: for item in seq: if isinstance(item, container): + assert isinstance(item, Iterable) yield from _flatten(item, container=container) else: yield item @@ -33,7 +35,7 @@ def merge_sessions( axis: Any = None, keepdims: Any = None, ) -> Session: - session, *rest = list(_flatten(sessions)) # type: ignore[no-untyped-call] + session, *rest = list(_flatten(sessions)) for other in rest: session.merge(other) return cast(Session, session) From 1da310ad51dfee79c3956181ac4f836197d9a54c Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 1 Apr 2025 18:13:44 +0200 Subject: [PATCH 6/6] ruff --- icechunk-python/python/icechunk/distributed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/icechunk-python/python/icechunk/distributed.py b/icechunk-python/python/icechunk/distributed.py index 2b8dec8bb..1a4757cca 100644 --- a/icechunk-python/python/icechunk/distributed.py +++ b/icechunk-python/python/icechunk/distributed.py @@ -1,8 +1,8 @@ # distributed utility functions -from typing import Any, Generator, Iterable, cast +from collections.abc import Generator, Iterable +from typing import Any, cast import zarr - from icechunk import IcechunkStore, Session __all__ = [