Skip to content

Commit

Permalink
Serialize dataframes manually
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Apr 5, 2024
1 parent 37119a9 commit b9725c5
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 109 deletions.
11 changes: 11 additions & 0 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,17 @@ def _deserialize_memoryview(header, frames):
return out


@dask_serialize.register(PickleBuffer)
def _serialize_picklebuffer(obj):
return _serialize_memoryview(obj.raw())


@dask_deserialize.register(PickleBuffer)
def _deserialize_picklebuffer(header, frames):
out = _deserialize_memoryview(header, frames)
return PickleBuffer(out)


#########################
# Descend into __dict__ #
#########################
Expand Down
6 changes: 3 additions & 3 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from distributed.shuffle._exceptions import ShuffleClosedError
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._memory import MemoryShardsBuffer
from distributed.sizeof import safe_sizeof as sizeof
from distributed.utils import run_in_executor_with_context, sync
from distributed.utils_comm import retry

Expand Down Expand Up @@ -215,7 +214,7 @@ async def send(
# and unpickle it on the other side.
# Performance tests informing the size threshold:
# https://github.com/dask/distributed/pull/8318
shards_or_bytes: list | bytes = pickle.dumps(shards)
shards_or_bytes: list | bytes = pickle.dumps(shards, protocol=5)
else:
shards_or_bytes = shards

Expand Down Expand Up @@ -334,6 +333,7 @@ def add_partition(
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to {self}")
# Log metrics both in the "execute" and in the "p2p" contexts
context_meter.digest_metric("p2p-partitions", 1, "count")
with self._capture_metrics("foreground"):
with (
context_meter.meter("p2p-shard-partition-noncpu"),
Expand Down Expand Up @@ -509,7 +509,7 @@ def _mean_shard_size(shards: Iterable) -> int:
if not isinstance(shard, int):
# This also asserts that shard is a Buffer and that we didn't forget
# a container or metadata type above
size += sizeof(shard)
size += memoryview(shard).nbytes
count += 1
if count == 10:
break
Expand Down
19 changes: 14 additions & 5 deletions distributed/shuffle/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __init__(
self._directory_lock = ReadWriteLock()

@log_errors
async def _process(self, id: str, shards: list[object]) -> None:
async def _process(self, id: str, shards: list[Any]) -> None:
"""Write one buffer to file
This function was built to offload the disk IO, but since then we've
Expand All @@ -154,12 +154,21 @@ async def _process(self, id: str, shards: list[object]) -> None:
"""
nbytes_acc = 0

def pickle_and_tally() -> Iterator[bytes | memoryview]:
def pickle_and_tally() -> Iterator[bytes | bytearray | memoryview]:
nonlocal nbytes_acc
for shard in shards:
for frame in pickle_bytelist(shard):
nbytes_acc += nbytes(frame)
yield frame
if isinstance(shard, list) and isinstance(
shard[0], (bytes, bytearray, memoryview)
):
# list[bytes | bytearray | memoryview] for dataframe shuffle
# Shard was pre-serialized before being sent over the network.
nbytes_acc += sum(map(nbytes, shard))
yield from shard
else:
# tuple[NDIndex, ndarray] for array rechunk
frames = [s.raw() for s in pickle_bytelist(shard)]
nbytes_acc += sum(frame.nbytes for frame in frames)
yield from frames

with (
self._directory_lock.read(),
Expand Down
3 changes: 1 addition & 2 deletions distributed/shuffle/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ def __init__(self) -> None:

@log_errors
async def _process(self, id: str, shards: list[Any]) -> None:
# TODO: This can be greatly simplified, there's no need for
# background threads at all.
self._shards[id].extend(shards)

def read(self, id: str) -> Any:
Expand All @@ -39,6 +37,7 @@ def read(self, id: str) -> Any:
data = []
while shards:
shard = shards.pop()
# TODO unpickle dataframes
data.append(shard)

return data
11 changes: 5 additions & 6 deletions distributed/shuffle/_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from distributed.protocol.utils import pack_frames_prelude, unpack_frames


def pickle_bytelist(obj: object) -> list[bytes | memoryview]:
def pickle_bytelist(obj: object, prelude: bool = True) -> list[pickle.PickleBuffer]:
"""Variant of :func:`serialize_bytelist`, that doesn't support compression, locally
defined classes, or any of its other fancy features but runs 10x faster for numpy
arrays
Expand All @@ -18,11 +18,10 @@ def pickle_bytelist(obj: object) -> list[bytes | memoryview]:
unpickle_bytestream
"""
frames: list = []
pik = pickle.dumps(
obj, protocol=5, buffer_callback=lambda pb: frames.append(pb.raw())
)
frames.insert(0, pik)
frames.insert(0, pack_frames_prelude(frames))
pik = pickle.dumps(obj, protocol=5, buffer_callback=frames.append)
frames.insert(0, pickle.PickleBuffer(pik))
if prelude:
frames.insert(0, pickle.PickleBuffer(pack_frames_prelude(frames)))
return frames


Expand Down
182 changes: 112 additions & 70 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from pickle import PickleBuffer
from typing import TYPE_CHECKING, Any

from toolz import concat, first, second
from toolz import first, second
from tornado.ioloop import IOLoop

import dask
Expand All @@ -28,6 +29,7 @@
from distributed.core import PooledRPCCall
from distributed.exceptions import Reschedule
from distributed.metrics import context_meter
from distributed.protocol.utils import pack_frames_prelude
from distributed.shuffle._core import (
NDIndex,
ShuffleId,
Expand All @@ -40,8 +42,9 @@
)
from distributed.shuffle._exceptions import DataUnavailable
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._pickle import pickle_bytelist
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.sizeof import sizeof
from distributed.utils import nbytes

logger = logging.getLogger("distributed.shuffle")
if TYPE_CHECKING:
Expand Down Expand Up @@ -297,36 +300,51 @@ def _construct_graph(self) -> _T_LowLevelGraph:
def split_by_worker(
df: pd.DataFrame,
column: str,
worker_for: pd.Series,
) -> dict[str, pd.DataFrame]:
"""Split data into many horizontal slices, partitioned by destination worker"""
nrows = len(df)

# (cudf support) Avoid pd.Series
constructor = df._constructor_sliced
assert isinstance(constructor, type)
if type(worker_for) is not constructor:
worker_for = constructor(worker_for)

df = df.merge(
right=worker_for,
left_on=column,
right_index=True,
how="inner",
)
out = dict(split_by_partition(df, "_workers", drop_column=True))
assert sum(map(len, out.values())) == nrows
return out

drop_column: bool,
worker_for: dict[int, str],
input_part_id: int,
) -> dict[str, tuple[int, list[tuple[int, list[PickleBuffer]]]]]:
"""Split data into many horizontal slices, partitioned by destination worker,
and serialize them once.
Returns
-------
{worker addr: (input_part_id, [(output_part_id, buffers), ...]), ...}
where buffers is a list of
[
PickleBuffer(pickle bytes) # includes input_part_id
buffer,
buffer,
...
]
**Notes**
- The pickle header, which is a bytes object, is wrapped in PickleBuffer so
that it's not unnecessarily deep-copied when it's deserialized by the network
stack.
- We are not delegating serialization to the network stack because (1) it's quicker
with plain pickle and (2) we want to avoid deserializing everything on receive()
only to re-serialize it again immediately afterwards when writing it to disk.
So we serialize it once now and deserialize it once after reading back from disk.
See Also
--------
distributed.protocol.serialize._deserialize_bytes
distributed.protocol.serialize._deserialize_picklebuffer
"""
out: defaultdict[str, list[tuple[int, list[PickleBuffer]]]] = defaultdict(list)

def split_by_partition(
df: pd.DataFrame, column: str, drop_column: bool
) -> Iterator[tuple[Any, pd.DataFrame]]:
"""Split data into many horizontal slices, partitioned by final partition"""
for k, group in df.groupby(column, observed=True):
for output_part_id, part in df.groupby(column, observed=False):
assert isinstance(output_part_id, int)
if drop_column:
del group[column]
yield k, group
del part[column]
frames = pickle_bytelist((input_part_id, part), prelude=False)
out[worker_for[output_part_id]].append((output_part_id, frames))

return {k: (input_part_id, v) for k, v in out.items()}


class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]):
Expand Down Expand Up @@ -376,7 +394,7 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]):
column: str
meta: pd.DataFrame
partitions_of: dict[str, list[int]]
worker_for: pd.Series
worker_for: dict[int, str]
drop_column: bool

def __init__(
Expand All @@ -399,8 +417,6 @@ def __init__(
drop_column: bool,
loop: IOLoop,
):
import pandas as pd

super().__init__(
id=id,
run_id=run_id,
Expand All @@ -422,55 +438,81 @@ def __init__(
for part, addr in worker_for.items():
partitions_of[addr].append(part)
self.partitions_of = dict(partitions_of)
self.worker_for = pd.Series(worker_for, name="_workers").astype("category")
self.worker_for = worker_for
self.drop_column = drop_column

async def _receive(self, data: list[tuple[int, pd.DataFrame]]) -> None:
async def _receive(
# See split_by_worker to understand annotation of data.
# PickleBuffer objects may have been converted to bytearray by the
# pickle roundtrip that is done by _core.py when buffers are too small
self,
data: list[
tuple[int, list[tuple[int, list[PickleBuffer | bytes | bytearray]]]]
],
) -> None:
self.raise_if_closed()

filtered = []
for partition_id, part in data:
if partition_id not in self.received:
filtered.append((partition_id, part))
self.received.add(partition_id)
self.total_recvd += sizeof(part)
del data
if not filtered:
to_write: defaultdict[
NDIndex, list[bytes | bytearray | memoryview]
] = defaultdict(list)

for input_part_id, parts in data:
if input_part_id not in self.received:
self.received.add(input_part_id)
for output_part_id, frames in parts:
frames_raw = [
frame.raw() if isinstance(frame, PickleBuffer) else frame
for frame in frames
]
self.total_recvd += sum(map(nbytes, frames_raw))
to_write[output_part_id,] += [
pack_frames_prelude(frames_raw),
*frames_raw,
]

if not to_write:
return
try:
groups = await self.offload(self._repartition_buffers, filtered)
del filtered
await self._write_to_disk(groups)
await self._write_to_disk(to_write)
except Exception as e:
self._exception = e
raise

def _repartition_buffers(
self, data: list[tuple[int, pd.DataFrame]]
) -> dict[NDIndex, list[tuple[int, pd.DataFrame]]]:
out: dict[NDIndex, list[tuple[int, pd.DataFrame]]] = defaultdict(list)

for input_part_id, part in data:
groups = split_by_partition(part, self.column, self.drop_column)
for output_part_id, part in groups:
out[output_part_id,].append((input_part_id, part))

assert sum(len(part) for _, part in data) == sum(
len(part) for parts in out.values() for _, part in parts
)
return out

def _shard_partition(
self,
data: pd.DataFrame,
partition_id: int,
**kwargs: Any,
) -> dict[str, tuple[int, pd.DataFrame]]:
out = split_by_worker(data, self.column, self.worker_for)
nbytes = sum(map(sizeof, out.values()))
context_meter.digest_metric("p2p-shards", nbytes, "bytes")
context_meter.digest_metric("p2p-shards", len(out), "count")
return {k: (partition_id, s) for k, s in out.items()}
# See split_by_worker to understand annotation
) -> dict[str, tuple[int, list[tuple[int, list[PickleBuffer]]]]]:
out = split_by_worker(
df=data,
column=self.column,
drop_column=self.drop_column,
worker_for=self.worker_for,
input_part_id=partition_id,
)

# Log metrics
# Note: more metrics for this function call are logged by _core.add_partitiion()
overhead_nbytes = 0
buffers_nbytes = 0
shards_count = 0
buffers_count = 0
for _, shards in out.values():
shards_count += len(shards)
for _, frames in shards:
# frames = [pickle bytes, buffer, buffer, ...]
buffers_count += len(frames) - 2
overhead_nbytes += frames[0].raw().nbytes
buffers_nbytes += sum(frame.raw().nbytes for frame in frames[1:])

context_meter.digest_metric("p2p-shards-overhead", overhead_nbytes, "bytes")
context_meter.digest_metric("p2p-shards-buffers", buffers_nbytes, "bytes")
context_meter.digest_metric("p2p-shards-buffers", buffers_count, "count")
context_meter.digest_metric("p2p-shards", shards_count, "count")
# End log metrics

return out

def _get_output_partition(
self,
Expand All @@ -488,8 +530,8 @@ def _get_output_partition(
result = self.meta.drop(columns=self.column)
return result

# [[(input_partition_id, part), (...), ...], [...]] -> [part, ...]
shards = list(map(second, sorted(concat(parts), key=first)))
# [(input_partition_id, part), ...]] -> [part, ...]
shards = list(map(second, sorted(parts, key=first)))
# Actually load memory-mapped buffers into memory and close the file
# descriptors
return pd.concat(shards, copy=True)
Expand Down
Loading

0 comments on commit b9725c5

Please sign in to comment.