Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 22 additions & 150 deletions icechunk-python/python/icechunk/dask.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
import itertools
from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
Literal,
TypeAlias,
overload,
)
from collections.abc import Mapping
from typing import Any, TypeAlias

from packaging.version import Version

import dask
import dask.array
import dask.array as da
import zarr
from dask import config
from dask.array.core import Array
from dask.base import compute_as_if_collection, tokenize
from dask.core import flatten
from dask.delayed import Delayed
from dask.highlevelgraph import HighLevelGraph
from icechunk import Session
from icechunk.distributed import extract_session, merge_sessions

Expand Down Expand Up @@ -68,6 +57,7 @@ def store_dask(
Arbitrary keyword arguments passed to `dask.array.store`. Notably `compute`,
`return_stored`, `load_stored`, and `lock` are unsupported.
"""
_assert_correct_dask_version()
stored_arrays = dask.array.store(
sources=sources,
targets=targets, # type: ignore[arg-type]
Expand All @@ -79,141 +69,23 @@ def store_dask(
**store_kwargs,
)
# Now we tree-reduce all changesets
merged_session = stateful_store_reduce(
stored_arrays,
prefix="ice-changeset",
chunk=extract_session,
aggregate=merge_sessions,
split_every=split_every,
compute=True,
**store_kwargs,
)
session.merge(merged_session)


# tree-reduce all changesets, regardless of array
def _partial_reduce(
aggregate: Callable[..., Any],
keys: Iterable[tuple[Any, ...]],
*,
layer_name: str,
split_every: int,
) -> SimpleGraph:
"""
Creates a new dask graph layer, that aggregates `split_every` keys together.
"""
from toolz import partition_all

return {
(layer_name, i): (aggregate, *keys_batch)
for i, keys_batch in enumerate(partition_all(split_every, keys))
}


@overload
def stateful_store_reduce(
stored_arrays: Sequence[Array],
*,
chunk: Callable[..., Any],
aggregate: Callable[..., Any],
prefix: str | None = None,
split_every: int | None = None,
compute: Literal[False] = False,
**kwargs: Any,
) -> Delayed: ...


@overload
def stateful_store_reduce(
stored_arrays: Sequence[Array],
*,
chunk: Callable[..., Any],
aggregate: Callable[..., Any],
compute: Literal[True] = True,
prefix: str | None = None,
split_every: int | None = None,
**kwargs: Any,
) -> Session: ...


def stateful_store_reduce(
stored_arrays: Sequence[Array],
*,
chunk: Callable[..., Any],
aggregate: Callable[..., Any],
compute: bool = True,
prefix: str | None = None,
split_every: int | None = None,
**kwargs: Any,
) -> Session | Delayed:
_assert_correct_dask_version()

split_every = split_every or config.get("split_every", 8)

layers: MutableMapping[str, SimpleGraph] = {}
dependencies: MutableMapping[str, set[str]] = {}

array_names = tuple(a.name for a in stored_arrays)
all_array_keys = list(
# flatten is untyped
itertools.chain(*[flatten(array.__dask_keys__()) for array in stored_arrays]) # type: ignore[no-untyped-call]
)
token = tokenize(array_names, chunk, aggregate, split_every)

# Each write task returns one Zarr array,
# now extract the changeset (as bytes) from each of those Zarr arrays
map_layer_name = f"{prefix}-blockwise-{token}"
map_dsk: SimpleGraph = {
(map_layer_name, i): (chunk, key) for i, key in enumerate(all_array_keys)
}
layers[map_layer_name] = map_dsk
dependencies[map_layer_name] = set(array_names)
latest_layer = map_layer_name

if aggregate is not None:
# Now tree-reduce across *all* write tasks,
# regardless of which Array the task belongs to
aggprefix = f"{prefix}-merge"

depth = 0
keys = map_dsk.keys()
while len(keys) > split_every:
latest_layer = f"{aggprefix}-{depth}-{token}"

layers[latest_layer] = _partial_reduce(
aggregate, keys, layer_name=latest_layer, split_every=split_every
)
previous_layer, *_ = next(iter(keys))
dependencies[latest_layer] = {previous_layer}

keys = layers[latest_layer].keys()
depth += 1

# last one
latest_layer = f"{aggprefix}-final-{token}"
layers[latest_layer] = _partial_reduce(
aggregate, keys, layer_name=latest_layer, split_every=split_every
# reduce the individual arrays since concatenation isn't always trivial due
# to different shapes
merged_sessions = [
da.reduction( # type: ignore[no-untyped-call]
arr,
name="ice-changeset",
chunk=extract_session,
aggregate=merge_sessions,
split_every=split_every,
concatenate=False,
keepdims=False,
dtype=object,
**store_kwargs,
)
previous_layer, *_ = next(iter(keys))
dependencies[latest_layer] = {previous_layer}

store_dsk = HighLevelGraph.merge(
HighLevelGraph(layers, dependencies), # type: ignore[arg-type]
*[array.__dask_graph__() for array in stored_arrays],
for arr in stored_arrays
]
merged_session = merge_sessions(
*da.compute(*merged_sessions) # type: ignore[no-untyped-call]
)
if compute:
# copied from dask.array.store
merged_session, *_ = compute_as_if_collection( # type: ignore[no-untyped-call]
Array, store_dsk, list(layers[latest_layer].keys()), **kwargs
)
if TYPE_CHECKING:
assert isinstance(merged_session, Session)
return merged_session

else:
key = "stateful-store-" + tokenize(array_names)
store_dsk = HighLevelGraph.merge(
HighLevelGraph({key: {key: (latest_layer, 0)}}, {key: {latest_layer}}),
store_dsk,
)
return Delayed(key, store_dsk) # type: ignore[no-untyped-call]
session.merge(merged_session)
34 changes: 29 additions & 5 deletions icechunk-python/python/icechunk/distributed.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,41 @@
# distributed utility functions
from typing import cast
from collections.abc import Generator, Iterable
from typing import Any, cast

import zarr
from icechunk import IcechunkStore, Session

__all__ = [
"extract_session",
"merge_sessions",
]

def extract_session(zarray: zarr.Array) -> Session:

def _flatten(seq: Iterable[Any], container: type = list) -> Generator[Any, None, None]:
if isinstance(seq, str):
yield seq
else:
for item in seq:
if isinstance(item, container):
assert isinstance(item, Iterable)
yield from _flatten(item, container=container)
else:
yield item


def extract_session(
zarray: zarr.Array, axis: Any = None, keepdims: Any = None
) -> Session:
store = cast(IcechunkStore, zarray.store)
return store.session


def merge_sessions(*sessions: Session) -> Session:
session, *rest = sessions
def merge_sessions(
*sessions: Session | list[Session] | list[list[Session]],
axis: Any = None,
keepdims: Any = None,
) -> Session:
session, *rest = list(_flatten(sessions))
for other in rest:
session.merge(other)
return session
return cast(Session, session)
25 changes: 15 additions & 10 deletions icechunk-python/python/icechunk/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import numpy as np
from packaging.version import Version

import dask.array as da
import xarray as xr
import zarr
from icechunk import IcechunkStore, Session
from icechunk.dask import stateful_store_reduce
from icechunk.distributed import extract_session, merge_sessions
from icechunk.vendor.xarray import _choose_default_mode
from xarray import DataArray, Dataset
Expand Down Expand Up @@ -179,15 +179,20 @@ def write_lazy(
) # type: ignore[no-untyped-call]

# Now we tree-reduce all changesets
merged_session = stateful_store_reduce(
stored_arrays,
prefix="ice-changeset",
chunk=extract_session,
aggregate=merge_sessions,
split_every=split_every,
compute=True,
**chunkmanager_store_kwargs,
)
merged_sessions = [
da.reduction( # type: ignore[no-untyped-call]
arr,
name="ice-changeset",
chunk=extract_session,
aggregate=merge_sessions,
split_every=split_every,
concatenate=False,
dtype=object,
)
for arr in stored_arrays
]
computed = da.compute(*merged_sessions) # type: ignore[no-untyped-call]
merged_session = merge_sessions(*computed)
self.store.session.merge(merged_session)


Expand Down
2 changes: 1 addition & 1 deletion icechunk-python/tests/test_distributed_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ async def verify(branch_name: str) -> None:
warnings.simplefilter("ignore", category=UserWarning)
assert_eq(roundtripped, dask_array) # type: ignore [no-untyped-call]

with Client(n_workers=8): # type: ignore[no-untyped-call]
with Client(dashboard_address=":0"): # type: ignore[no-untyped-call]
do_writes("with-processes")
await verify("with-processes")

Expand Down
Loading