diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 80afb1f438f..36e5c6e9c25 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -24,7 +24,7 @@ from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule -from distributed.protocol import to_serialize +from distributed.protocol import dask_deserialize, dask_serialize, to_serialize from distributed.shuffle._comms import CommShardsBuffer from distributed.shuffle._disk import DiskShardsBuffer from distributed.shuffle._exceptions import ShuffleClosedError @@ -298,14 +298,17 @@ def _shard_partition( ) -> dict[str, tuple[_T_partition_id, Any]]: """Shard an input partition by the assigned output workers""" - def get_output_partition( - self, partition_id: _T_partition_id, key: Key, **kwargs: Any - ) -> _T_partition_type: + def _sync_output_partition(self, partition_id: _T_partition_id, key: Key) -> None: self.raise_if_closed() sync(self._loop, self._ensure_output_worker, partition_id, key) if not self.transferred: raise RuntimeError("`get_output_partition` called before barrier task") sync(self._loop, self.flush_receive) + + def get_output_partition( + self, partition_id: _T_partition_id, key: Key, **kwargs: Any + ) -> _T_partition_type: + self._sync_output_partition(partition_id, key) return self._get_output_partition(partition_id, key, **kwargs) @abc.abstractmethod @@ -314,6 +317,12 @@ def _get_output_partition( ) -> _T_partition_type: """Get an output partition to the shuffle run""" + def get_unloaded_output_partition( + self, partition_id: _T_partition_id, key: Key, **kwargs: Any + ) -> UnloadedPartition: + self._sync_output_partition(partition_id, key) + return UnloadedPartition(self, partition_id, key, **kwargs) + @abc.abstractmethod def read(self, path: Path) -> tuple[Any, int]: """Read shards from disk""" @@ -457,3 +466,71 @@ def _mean_shard_size(shards: Iterable) -> int: if count == 10: break return size // count if count else 0 + + +class UnloadedPartition: + """Wrap unloaded shuffle output + + The purpose of this class is to keep a shuffled partition + on disk until it is needed by one of its dependent tasks. + Otherwise, the in-memory partition may need to be spilled + back to disk before the dependent task is executed anyway. + + If ``shuffle_unpack`` returns an ``UnloadedPartition`` object, + ``P2PShuffleLayer`` must be followed by an extra ``Blockwise`` + call to ``_get_partition_data`` (to load and covert the data). + We want an extra ``Blockwise`` layer here so that the loading + and conversion can be fused into down-stream tasks. We do NOT + want the original ``shuffle_unpack`` tasks to be fused into + dependent tasks, because this would prevent effective load + balancing after the shuffle (long-running post-shuffle tasks + may be pinned to specific workers, while others sit idle). + + Note that serialization automatically converts to + ``LoadedPartition``, because the object may be moved to a + worker that doesn't have access to the same local storage. + """ + + def __init__( + self, + shuffle_run: ShuffleRun, + partition_id: _T_partition_id, + key: Key, + **kwargs: Any, + ): + self.shuffle_run = shuffle_run + self.partition_id = partition_id + self.key = key + self.kwargs = kwargs + + def load(self) -> Any: + with handle_unpack_errors(self.shuffle_run.id): + return self.shuffle_run._get_output_partition( + self.partition_id, self.key, **self.kwargs + ) + + +@dask_serialize.register(UnloadedPartition) +def _serialize_unloaded(obj: UnloadedPartition) -> tuple[None, list[bytes]]: + # Convert to LoadedPartition before serializing. Note that + # we don't convert all the way to DataFrame, because this + # adds unnecessary overhead and memory pressure for the + # cudf backend (and minor overhead for pandas) + return None, [pickle.dumps(obj.load())] + + +@dask_deserialize.register(UnloadedPartition) +def _deserialize_unloaded(header: None, frames: list[bytes]) -> Any: + return pickle.loads(frames[0]) + + +def load_output_partition( + data: UnloadedPartition | _T_partition_type, barrier_key: int +) -> _T_partition_type: + # Used by rearrange_by_column_p2p to "unwrap" + # UnloadedPartition/LoadedPartition data after + # a P2PShuffleLayer + assert barrier_key + if isinstance(data, UnloadedPartition): + data = data.load() + return data diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index c42f6d94426..aac566b658b 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -2,7 +2,6 @@ import logging import os -import pickle from collections import defaultdict from collections.abc import Callable, Collection, Iterable, Iterator, Sequence from concurrent.futures import ThreadPoolExecutor @@ -22,7 +21,6 @@ from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule -from distributed.protocol import dask_deserialize, dask_serialize from distributed.shuffle._arrow import ( check_dtype_support, check_minimal_arrow_version, @@ -41,6 +39,7 @@ get_worker_plugin, handle_transfer_errors, handle_unpack_errors, + load_output_partition, ) from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin @@ -58,108 +57,6 @@ from dask.dataframe import DataFrame -class UnloadedPartition: - """Wrap unloaded shuffle output - - The purpose of this class is to keep a shuffled partition - on disk until it is needed by one of its dependent tasks. - Otherwise, the in-memory partition may need to be spilled - back to disk before the dependent task is executed anyway. - - If ``shuffle_unpack`` returns an ``UnloadedPartition`` object, - ``P2PShuffleLayer`` must be followed by an extra ``Blockwise`` - call to ``_get_partition_data`` (to load and covert the data). - We want an extra ``Blockwise`` layer here so that the loading - and conversion can be fused into down-stream tasks. We do NOT - want the original ``shuffle_unpack`` tasks to be fused into - dependent tasks, because this would prevent effective load - balancing after the shuffle (long-running post-shuffle tasks - may be pinned to specific workers, while others sit idle). - - Note that serialization automatically converts to - ``LoadedPartition``, because the object may be moved to a - worker that doesn't have access to the same local storage. - """ - - def __init__( - self, - shuffle_run: DataFrameShuffleRun, - partition_id: int, - ): - self.shuffle_run = shuffle_run - self.partition_id = partition_id - - def load(self) -> LoadedPartition: - with handle_unpack_errors(self.shuffle_run.id): - try: - data = self.shuffle_run._read_from_disk((self.partition_id,)) - except KeyError: - data = None - return LoadedPartition(data, self.shuffle_run.meta, self.shuffle_run.id) - - -class LoadedPartition: - def __init__( - self, - data: list[pa.Table] | None, - meta: pd.DataFrame, - shuffle_id: ShuffleId, - ): - self.data = data - self.meta = meta - self.shuffle_id = shuffle_id - - def convert(self) -> pd.DataFrame: - with handle_unpack_errors(self.shuffle_id): - if self.data is None: - data = self.meta.copy() - else: - data = convert_shards(self.data, self.meta) - return data - - -@dask_serialize.register(UnloadedPartition) -def _serialize_unloaded(obj: UnloadedPartition) -> tuple[tuple[ShuffleId], list[bytes]]: - # Convert to LoadedPartition before serializing. Note that - # we don't convert all the way to DataFrame, because this - # adds unnecessary overhead and memory pressure for the - # cudf backend (and minor overhead for pandas) - loaded = obj.load() - return (loaded.shuffle_id,), [ - pickle.dumps(loaded.meta), - pickle.dumps(loaded.data), - ] - - -@dask_serialize.register(LoadedPartition) -def _serialize_loaded(obj: LoadedPartition) -> tuple[tuple[ShuffleId], list[bytes]]: - return (obj.shuffle_id,), [pickle.dumps(obj.meta), pickle.dumps(obj.data)] - - -@dask_deserialize.register((UnloadedPartition, LoadedPartition)) -def _deserialize_loaded( - header: tuple[ShuffleId], frames: list[bytes] -) -> LoadedPartition: - shuffle_id = header[0] - meta = pickle.loads(frames[0]) - data = pickle.loads(frames[1]) - return LoadedPartition(data, meta, shuffle_id) - - -def _get_partition_data( - part: UnloadedPartition | LoadedPartition | pd.DataFrame, barrier_key: int -) -> pd.DataFrame: - # Used by rearrange_by_column_p2p to "unwrap" - # UnloadedPartition/LoadedPartition data after - # a P2PShuffleLayer - assert barrier_key - if isinstance(part, UnloadedPartition): - part = part.load() - if isinstance(part, LoadedPartition): - part = part.convert() - return part - - def shuffle_transfer( input: pd.DataFrame, id: ShuffleId, @@ -194,6 +91,15 @@ def shuffle_unpack( ) +def shuffle_unpack_unloaded( + id: ShuffleId, output_partition: int, barrier_run_id: int +) -> pd.DataFrame: + with handle_unpack_errors(id): + return get_worker_plugin().get_output_partition( + id, barrier_run_id, output_partition, load=False + ) + + def shuffle_barrier(id: ShuffleId, run_ids: list[int]) -> int: try: return get_worker_plugin().barrier(id, run_ids) @@ -239,15 +145,14 @@ def rearrange_by_column_p2p( meta_input=meta, disk=disk, ) - _barrier_key = layer._tokens[1] return new_dd_object( HighLevelGraph.from_collections(name, layer, [df]), name, meta, [None] * (npartitions + 1), ).map_partitions( - _get_partition_data, - _barrier_key, + load_output_partition, + layer._tokens[1], meta=meta, enforce_metadata=False, ) @@ -400,7 +305,7 @@ def _construct_graph(self) -> _T_LowLevelGraph: name = self.name for part_out in self.parts_out: dsk[(name, part_out)] = ( - shuffle_unpack, + shuffle_unpack_unloaded, token, part_out, _barrier_key, @@ -616,7 +521,11 @@ def _get_output_partition( key: Key, **kwargs: Any, ) -> pd.DataFrame: - return UnloadedPartition(self, partition_id) + try: + data = self._read_from_disk((partition_id,)) + return convert_shards(data, self.meta) + except KeyError: + return self.meta.copy() def _get_assigned_worker(self, id: int) -> str: return self.worker_for[id] diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 39b1c126d2a..3cbb6f594b3 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -417,6 +417,7 @@ def get_output_partition( run_id: int, partition_id: int | NDIndex, meta: pd.DataFrame | None = None, + load: bool = True, ) -> Any: """ Task: Retrieve a shuffled output partition from the ShuffleWorkerPlugin. @@ -425,7 +426,13 @@ def get_output_partition( """ shuffle_run = self.get_shuffle_run(shuffle_id, run_id) key = thread_state.key - return shuffle_run.get_output_partition( + if load: + return shuffle_run.get_output_partition( + partition_id=partition_id, + key=key, + meta=meta, + ) + return shuffle_run.get_unloaded_output_partition( partition_id=partition_id, key=key, meta=meta,