diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index e3f5c13569..19524834ed 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -296,7 +296,7 @@ def fail(self, exception: Exception) -> None: if not self.closed: self._exception = exception - def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing + def _read_from_disk(self, id: NDIndex) -> Any: self.raise_if_closed() return self._disk_buffer.read("_".join(str(i) for i in id)) diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 153127d5d9..0cb9d3ff68 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -6,7 +6,7 @@ import shutil import threading from collections.abc import Generator, Iterator -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from pathlib import Path from typing import Any @@ -123,6 +123,11 @@ class DiskShardsBuffer(ShardsBuffer): implementation of this scheme. """ + directory: pathlib.Path + _closed: bool + _use_raw_buffers: bool | None + _directory_lock: ReadWriteLock + def __init__( self, directory: str | pathlib.Path, @@ -136,6 +141,7 @@ def __init__( self.directory = pathlib.Path(directory) self.directory.mkdir(exist_ok=True) self._closed = False + self._use_raw_buffers = None self._directory_lock = ReadWriteLock() @log_errors @@ -152,14 +158,23 @@ async def _process(self, id: str, shards: list[Any]) -> None: future then we should consider simplifying this considerably and dropping the write into communicate above. """ + assert shards + if self._use_raw_buffers is None: + self._use_raw_buffers = isinstance(shards[0], list) and isinstance( + shards[0][0], (bytes, bytearray, memoryview) + ) + serialize_ctx = ( + nullcontext() + if self._use_raw_buffers + else context_meter.meter("serialize", func=thread_time) + ) + nbytes_acc = 0 def pickle_and_tally() -> Iterator[bytes | bytearray | memoryview]: nonlocal nbytes_acc for shard in shards: - if isinstance(shard, list) and isinstance( - shard[0], (bytes, bytearray, memoryview) - ): + if self._use_raw_buffers: # list[bytes | bytearray | memoryview] for dataframe shuffle # Shard was pre-serialized before being sent over the network. nbytes_acc += sum(map(nbytes, shard)) @@ -173,7 +188,7 @@ def pickle_and_tally() -> Iterator[bytes | bytearray | memoryview]: with ( self._directory_lock.read(), context_meter.meter("disk-write"), - context_meter.meter("serialize", func=thread_time), + serialize_ctx, ): if self._closed: raise RuntimeError("Already closed") @@ -184,7 +199,7 @@ def pickle_and_tally() -> Iterator[bytes | bytearray | memoryview]: context_meter.digest_metric("disk-write", 1, "count") context_meter.digest_metric("disk-write", nbytes_acc, "bytes") - def read(self, id: str) -> list[Any]: + def read(self, id: str) -> Any: """Read a complete file back into memory""" self.raise_on_exception() if not self._inputs_done: @@ -211,8 +226,7 @@ def read(self, id: str) -> list[Any]: else: raise DataUnavailable(id) - @staticmethod - def _read(path: Path) -> tuple[list[Any], int]: + def _read(self, path: Path) -> tuple[Any, int]: """Open a memory-mapped file descriptor to disk, read all metadata, and unpickle all arrays. This is a fast sequence of short reads interleaved with seeks. Do not read in memory the actual data; the arrays' buffers will point to the @@ -224,10 +238,14 @@ def _read(path: Path) -> tuple[list[Any], int]: """ with path.open(mode="r+b") as fh: buffer = memoryview(mmap.mmap(fh.fileno(), 0)) - # The file descriptor has *not* been closed! - shards = list(unpickle_bytestream(buffer)) - return shards, buffer.nbytes + + assert self._use_raw_buffers is not None + if self._use_raw_buffers: + return buffer, buffer.nbytes + else: + shards = list(unpickle_bytestream(buffer)) + return shards, buffer.nbytes async def close(self) -> None: await super().close() diff --git a/distributed/shuffle/_pickle.py b/distributed/shuffle/_pickle.py index 4db706565b..49250a0c0f 100644 --- a/distributed/shuffle/_pickle.py +++ b/distributed/shuffle/_pickle.py @@ -2,10 +2,15 @@ import pickle from collections.abc import Iterator -from typing import Any +from typing import TYPE_CHECKING, Any + +from toolz import first from distributed.protocol.utils import pack_frames_prelude, unpack_frames +if TYPE_CHECKING: + import pandas as pd + def pickle_bytelist(obj: object, prelude: bool = True) -> list[pickle.PickleBuffer]: """Variant of :func:`serialize_bytelist`, that doesn't support compression, locally @@ -39,3 +44,68 @@ def unpickle_bytestream(b: bytes | bytearray | memoryview) -> Iterator[Any]: if remainder.nbytes == 0: break b = remainder + + +def pickle_dataframe_shard( + input_part_id: int, + shard: pd.DataFrame, +) -> list[pickle.PickleBuffer]: + """Optimized pickler for pandas Dataframes. DIscard all unnecessary metadata + (like the columns header). + + Parameters: + obj: pandas + """ + return pickle_bytelist( + (input_part_id, shard.index, *shard._mgr.blocks), prelude=False + ) + + +def unpickle_and_concat_dataframe_shards( + b: bytes | bytearray | memoryview, meta: pd.DataFrame +) -> pd.DataFrame: + """Optimized unpickler for pandas Dataframes. + + Parameters + ---------- + b: + raw buffer, containing the concatenation of the outputs of + :func:`pickle_dataframe_shard`, in arbitrary order + meta: + DataFrame header + + Returns + ------- + Reconstructed output shard, sorted by input partition ID + + **Roundtrip example** + + >>> import random + >>> import pandas as pd + >>> from toolz import concat + + >>> df = pd.DataFrame(...) # Input partition + >>> meta = df.iloc[:0].copy() + >>> shards = df.iloc[0:10], df.iloc[10:20], ... + >>> frames = [pickle_dataframe_shard(i, shard) for i, shard in enumerate(shards)] + >>> random.shuffle(frames) # Simulate the frames arriving in arbitrary order + >>> blob = bytearray(b"".join(concat(frames))) # Simulate disk roundtrip + >>> df2 = unpickle_and_concat_dataframe_shards(blob, meta) + """ + import pandas as pd + from pandas.core.internals import BlockManager + + parts = list(unpickle_bytestream(b)) + # [(input_part_id, index, *blocks), ...] + parts.sort(key=first) + shards = [] + for _, idx, *blocks in parts: + axes = [meta.columns, idx] + df = pd.DataFrame._from_mgr( # type: ignore[attr-defined] + BlockManager(blocks, axes, verify_integrity=False), axes + ) + shards.append(df) + + # Actually load memory-mapped buffers into memory and close the file + # descriptors + return pd.concat(shards, copy=True) diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 218f660023..c9db53dbef 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -17,7 +17,6 @@ from pickle import PickleBuffer from typing import TYPE_CHECKING, Any -from toolz import first, second from tornado.ioloop import IOLoop import dask @@ -42,7 +41,10 @@ ) from distributed.shuffle._exceptions import DataUnavailable from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._pickle import pickle_bytelist +from distributed.shuffle._pickle import ( + pickle_dataframe_shard, + unpickle_and_concat_dataframe_shards, +) from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.utils import nbytes @@ -311,14 +313,8 @@ def split_by_worker( ------- {worker addr: (input_part_id, [(output_part_id, buffers), ...]), ...} - where buffers is a list of - - [ - PickleBuffer(pickle bytes) # includes input_part_id - buffer, - buffer, - ... - ] + where buffers is the serialized output (pickle bytes, buffer, buffer, ...) of + (input_part_id, index, *blocks) **Notes** @@ -341,7 +337,7 @@ def split_by_worker( assert isinstance(output_part_id, int) if drop_column: del part[column] - frames = pickle_bytelist((input_part_id, part), prelude=False) + frames = pickle_dataframe_shard(input_part_id, part) out[worker_for[output_part_id]].append((output_part_id, frames)) return {k: (input_part_id, v) for k, v in out.items()} @@ -520,21 +516,16 @@ def _get_output_partition( key: Key, **kwargs: Any, ) -> pd.DataFrame: - import pandas as pd + meta = self.meta.copy() + if self.drop_column: + meta = self.meta.drop(columns=self.column) try: - parts = self._read_from_disk((partition_id,)) + buffer = self._read_from_disk((partition_id,)) except DataUnavailable: - result = self.meta.copy() - if self.drop_column: - result = self.meta.drop(columns=self.column) - return result - - # [(input_partition_id, part), ...]] -> [part, ...] - shards = list(map(second, sorted(parts, key=first))) - # Actually load memory-mapped buffers into memory and close the file - # descriptors - return pd.concat(shards, copy=True) + return meta + + return unpickle_and_concat_dataframe_shards(buffer, meta) def _get_assigned_worker(self, id: int) -> str: return self.worker_for[id] diff --git a/distributed/shuffle/tests/test_core.py b/distributed/shuffle/tests/test_core.py index deb9d2a0bb..3d310d492b 100644 --- a/distributed/shuffle/tests/test_core.py +++ b/distributed/shuffle/tests/test_core.py @@ -1,5 +1,7 @@ from __future__ import annotations +from pickle import PickleBuffer + import pytest from distributed.shuffle._core import _mean_shard_size @@ -12,7 +14,17 @@ def test_mean_shard_size(): # Don't fully iterate over large collections assert _mean_shard_size([b"12" * n for n in range(1000)]) == 9 # Support any Buffer object - assert _mean_shard_size([b"12", bytearray(b"1234"), memoryview(b"123456")]) == 4 + assert ( + _mean_shard_size( + [ + b"12", + bytearray(b"1234"), + memoryview(b"123456"), + PickleBuffer(b"12345678"), + ] + ) + == 5 + ) # Recursion into lists or tuples; ignore int assert _mean_shard_size([(1, 2, [3, b"123456"])]) == 6 # Don't blindly call sizeof() on unexpected objects