Skip to content

Commit

Permalink
refactor and temporarily combine loading and conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Dec 7, 2023
1 parent 69a89a2 commit 18b49ca
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 114 deletions.
85 changes: 81 additions & 4 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from distributed.core import PooledRPCCall
from distributed.exceptions import Reschedule
from distributed.protocol import to_serialize
from distributed.protocol import dask_deserialize, dask_serialize, to_serialize
from distributed.shuffle._comms import CommShardsBuffer
from distributed.shuffle._disk import DiskShardsBuffer
from distributed.shuffle._exceptions import ShuffleClosedError
Expand Down Expand Up @@ -298,14 +298,17 @@ def _shard_partition(
) -> dict[str, tuple[_T_partition_id, Any]]:
"""Shard an input partition by the assigned output workers"""

def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
def _sync_output_partition(self, partition_id: _T_partition_id, key: Key) -> None:
self.raise_if_closed()
sync(self._loop, self._ensure_output_worker, partition_id, key)
if not self.transferred:
raise RuntimeError("`get_output_partition` called before barrier task")
sync(self._loop, self.flush_receive)

def get_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> _T_partition_type:
self._sync_output_partition(partition_id, key)
return self._get_output_partition(partition_id, key, **kwargs)

@abc.abstractmethod
Expand All @@ -314,6 +317,12 @@ def _get_output_partition(
) -> _T_partition_type:
"""Get an output partition to the shuffle run"""

def get_unloaded_output_partition(
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
) -> UnloadedPartition:
self._sync_output_partition(partition_id, key)
return UnloadedPartition(self, partition_id, key, **kwargs)

@abc.abstractmethod
def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""
Expand Down Expand Up @@ -457,3 +466,71 @@ def _mean_shard_size(shards: Iterable) -> int:
if count == 10:
break
return size // count if count else 0


class UnloadedPartition:
"""Wrap unloaded shuffle output
The purpose of this class is to keep a shuffled partition
on disk until it is needed by one of its dependent tasks.
Otherwise, the in-memory partition may need to be spilled
back to disk before the dependent task is executed anyway.
If ``shuffle_unpack`` returns an ``UnloadedPartition`` object,
``P2PShuffleLayer`` must be followed by an extra ``Blockwise``
call to ``_get_partition_data`` (to load and covert the data).
We want an extra ``Blockwise`` layer here so that the loading
and conversion can be fused into down-stream tasks. We do NOT
want the original ``shuffle_unpack`` tasks to be fused into
dependent tasks, because this would prevent effective load
balancing after the shuffle (long-running post-shuffle tasks
may be pinned to specific workers, while others sit idle).
Note that serialization automatically converts to
``LoadedPartition``, because the object may be moved to a
worker that doesn't have access to the same local storage.
"""

def __init__(
self,
shuffle_run: ShuffleRun,
partition_id: _T_partition_id,
key: Key,
**kwargs: Any,
):
self.shuffle_run = shuffle_run
self.partition_id = partition_id
self.key = key
self.kwargs = kwargs

def load(self) -> Any:
with handle_unpack_errors(self.shuffle_run.id):
return self.shuffle_run._get_output_partition(
self.partition_id, self.key, **self.kwargs
)


@dask_serialize.register(UnloadedPartition)
def _serialize_unloaded(obj: UnloadedPartition) -> tuple[None, list[bytes]]:
# Convert to LoadedPartition before serializing. Note that
# we don't convert all the way to DataFrame, because this
# adds unnecessary overhead and memory pressure for the
# cudf backend (and minor overhead for pandas)
return None, [pickle.dumps(obj.load())]


@dask_deserialize.register(UnloadedPartition)
def _deserialize_unloaded(header: None, frames: list[bytes]) -> Any:
return pickle.loads(frames[0])


def load_output_partition(
data: UnloadedPartition | _T_partition_type, barrier_key: int
) -> _T_partition_type:
# Used by rearrange_by_column_p2p to "unwrap"
# UnloadedPartition/LoadedPartition data after
# a P2PShuffleLayer
assert barrier_key
if isinstance(data, UnloadedPartition):
data = data.load()
return data
127 changes: 18 additions & 109 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import os
import pickle
from collections import defaultdict
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -22,7 +21,6 @@

from distributed.core import PooledRPCCall
from distributed.exceptions import Reschedule
from distributed.protocol import dask_deserialize, dask_serialize
from distributed.shuffle._arrow import (
check_dtype_support,
check_minimal_arrow_version,
Expand All @@ -41,6 +39,7 @@
get_worker_plugin,
handle_transfer_errors,
handle_unpack_errors,
load_output_partition,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
Expand All @@ -58,108 +57,6 @@
from dask.dataframe import DataFrame


class UnloadedPartition:
"""Wrap unloaded shuffle output
The purpose of this class is to keep a shuffled partition
on disk until it is needed by one of its dependent tasks.
Otherwise, the in-memory partition may need to be spilled
back to disk before the dependent task is executed anyway.
If ``shuffle_unpack`` returns an ``UnloadedPartition`` object,
``P2PShuffleLayer`` must be followed by an extra ``Blockwise``
call to ``_get_partition_data`` (to load and covert the data).
We want an extra ``Blockwise`` layer here so that the loading
and conversion can be fused into down-stream tasks. We do NOT
want the original ``shuffle_unpack`` tasks to be fused into
dependent tasks, because this would prevent effective load
balancing after the shuffle (long-running post-shuffle tasks
may be pinned to specific workers, while others sit idle).
Note that serialization automatically converts to
``LoadedPartition``, because the object may be moved to a
worker that doesn't have access to the same local storage.
"""

def __init__(
self,
shuffle_run: DataFrameShuffleRun,
partition_id: int,
):
self.shuffle_run = shuffle_run
self.partition_id = partition_id

def load(self) -> LoadedPartition:
with handle_unpack_errors(self.shuffle_run.id):
try:
data = self.shuffle_run._read_from_disk((self.partition_id,))
except KeyError:
data = None
return LoadedPartition(data, self.shuffle_run.meta, self.shuffle_run.id)


class LoadedPartition:
def __init__(
self,
data: list[pa.Table] | None,
meta: pd.DataFrame,
shuffle_id: ShuffleId,
):
self.data = data
self.meta = meta
self.shuffle_id = shuffle_id

def convert(self) -> pd.DataFrame:
with handle_unpack_errors(self.shuffle_id):
if self.data is None:
data = self.meta.copy()
else:
data = convert_shards(self.data, self.meta)
return data


@dask_serialize.register(UnloadedPartition)
def _serialize_unloaded(obj: UnloadedPartition) -> tuple[tuple[ShuffleId], list[bytes]]:
# Convert to LoadedPartition before serializing. Note that
# we don't convert all the way to DataFrame, because this
# adds unnecessary overhead and memory pressure for the
# cudf backend (and minor overhead for pandas)
loaded = obj.load()
return (loaded.shuffle_id,), [
pickle.dumps(loaded.meta),
pickle.dumps(loaded.data),
]


@dask_serialize.register(LoadedPartition)
def _serialize_loaded(obj: LoadedPartition) -> tuple[tuple[ShuffleId], list[bytes]]:
return (obj.shuffle_id,), [pickle.dumps(obj.meta), pickle.dumps(obj.data)]


@dask_deserialize.register((UnloadedPartition, LoadedPartition))
def _deserialize_loaded(
header: tuple[ShuffleId], frames: list[bytes]
) -> LoadedPartition:
shuffle_id = header[0]
meta = pickle.loads(frames[0])
data = pickle.loads(frames[1])
return LoadedPartition(data, meta, shuffle_id)


def _get_partition_data(
part: UnloadedPartition | LoadedPartition | pd.DataFrame, barrier_key: int
) -> pd.DataFrame:
# Used by rearrange_by_column_p2p to "unwrap"
# UnloadedPartition/LoadedPartition data after
# a P2PShuffleLayer
assert barrier_key
if isinstance(part, UnloadedPartition):
part = part.load()
if isinstance(part, LoadedPartition):
part = part.convert()
return part


def shuffle_transfer(
input: pd.DataFrame,
id: ShuffleId,
Expand Down Expand Up @@ -194,6 +91,15 @@ def shuffle_unpack(
)


def shuffle_unpack_unloaded(
id: ShuffleId, output_partition: int, barrier_run_id: int
) -> pd.DataFrame:
with handle_unpack_errors(id):
return get_worker_plugin().get_output_partition(
id, barrier_run_id, output_partition, load=False
)


def shuffle_barrier(id: ShuffleId, run_ids: list[int]) -> int:
try:
return get_worker_plugin().barrier(id, run_ids)
Expand Down Expand Up @@ -239,15 +145,14 @@ def rearrange_by_column_p2p(
meta_input=meta,
disk=disk,
)
_barrier_key = layer._tokens[1]
return new_dd_object(
HighLevelGraph.from_collections(name, layer, [df]),
name,
meta,
[None] * (npartitions + 1),
).map_partitions(
_get_partition_data,
_barrier_key,
load_output_partition,
layer._tokens[1],
meta=meta,
enforce_metadata=False,
)
Expand Down Expand Up @@ -400,7 +305,7 @@ def _construct_graph(self) -> _T_LowLevelGraph:
name = self.name
for part_out in self.parts_out:
dsk[(name, part_out)] = (
shuffle_unpack,
shuffle_unpack_unloaded,
token,
part_out,
_barrier_key,
Expand Down Expand Up @@ -616,7 +521,11 @@ def _get_output_partition(
key: Key,
**kwargs: Any,
) -> pd.DataFrame:
return UnloadedPartition(self, partition_id)
try:
data = self._read_from_disk((partition_id,))
return convert_shards(data, self.meta)
except KeyError:
return self.meta.copy()

def _get_assigned_worker(self, id: int) -> str:
return self.worker_for[id]
Expand Down
9 changes: 8 additions & 1 deletion distributed/shuffle/_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def get_output_partition(
run_id: int,
partition_id: int | NDIndex,
meta: pd.DataFrame | None = None,
load: bool = True,
) -> Any:
"""
Task: Retrieve a shuffled output partition from the ShuffleWorkerPlugin.
Expand All @@ -425,7 +426,13 @@ def get_output_partition(
"""
shuffle_run = self.get_shuffle_run(shuffle_id, run_id)
key = thread_state.key
return shuffle_run.get_output_partition(
if load:
return shuffle_run.get_output_partition(
partition_id=partition_id,
key=key,
meta=meta,
)
return shuffle_run.get_unloaded_output_partition(
partition_id=partition_id,
key=key,
meta=meta,
Expand Down

0 comments on commit 18b49ca

Please sign in to comment.