Skip to content

Commit cfe4ad5

Browse files
committed
try delaying conversion only
1 parent 4c16a25 commit cfe4ad5

File tree

3 files changed

+20
-102
lines changed

3 files changed

+20
-102
lines changed

distributed/shuffle/_core.py

+5-91
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from distributed.core import PooledRPCCall
2626
from distributed.exceptions import Reschedule
27-
from distributed.protocol import dask_deserialize, dask_serialize, to_serialize
27+
from distributed.protocol import to_serialize
2828
from distributed.shuffle._comms import CommShardsBuffer
2929
from distributed.shuffle._disk import DiskShardsBuffer
3030
from distributed.shuffle._exceptions import ShuffleClosedError
@@ -308,8 +308,7 @@ def _sync_output_partition(self, partition_id: _T_partition_id, key: Key) -> Non
308308
def get_output_partition(
309309
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
310310
) -> _T_partition_type:
311-
if kwargs.pop("sync", True):
312-
self._sync_output_partition(partition_id, key)
311+
self._sync_output_partition(partition_id, key)
313312
return self._get_output_partition(partition_id, key, **kwargs)
314313

315314
@abc.abstractmethod
@@ -318,11 +317,11 @@ def _get_output_partition(
318317
) -> _T_partition_type:
319318
"""Get an output partition to the shuffle run"""
320319

321-
def get_unloaded_output_partition(
320+
def get_raw_output_partition(
322321
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
323-
) -> UnloadedPartition:
322+
) -> Any:
324323
self._sync_output_partition(partition_id, key)
325-
return UnloadedPartition(self, partition_id, key, **kwargs)
324+
return self._get_output_partition(partition_id, key, convert=False, **kwargs)
326325

327326
@abc.abstractmethod
328327
def read(self, path: Path) -> tuple[Any, int]:
@@ -467,88 +466,3 @@ def _mean_shard_size(shards: Iterable) -> int:
467466
if count == 10:
468467
break
469468
return size // count if count else 0
470-
471-
472-
class UnloadedPartition:
473-
"""Wrap unloaded shuffle output
474-
475-
The purpose of this class is to keep a shuffled partition
476-
on disk until it is needed by one of its dependent tasks.
477-
Otherwise, the in-memory partition may need to be spilled
478-
back to disk before the dependent task is executed anyway.
479-
480-
If the output tasks of a ``P2PShuffleLayer`` return objects
481-
of type ``UnloadedPartition``, that layer must be followed
482-
by an extra ``Blockwise`` call to ``load_output_partition``
483-
(to ensure the partitions are actually loaded). We want this
484-
extra layer to be ``Blockwise`` so that the loading can be
485-
fused into down-stream tasks. We do NOT want the original
486-
``shuffle_unpack`` tasks to be fused into dependent tasks,
487-
because this would prevent load balancing after the shuffle
488-
(long-running post-shuffle tasks may be pinned to specific
489-
workers, while others sit idle).
490-
491-
Note that serialization automatically loads the wrapped
492-
data, because the object may be moved to a worker that
493-
doesn't have access to the same local storage.
494-
"""
495-
496-
def __init__(
497-
self,
498-
shuffle_run: ShuffleRun,
499-
partition_id: _T_partition_id,
500-
key: Key,
501-
**kwargs: Any,
502-
):
503-
self.shuffle_run = shuffle_run
504-
self.partition_id = partition_id
505-
self.key = key
506-
self.kwargs = kwargs
507-
508-
def pre_serialize(self) -> Any:
509-
"""Make the unloaded partition serializable"""
510-
# TODO: Add mechanism to dispatch on meta.
511-
# Right now, serializing an UnloadedPartition object
512-
# will convert it to `type(self.shuffle_run.meta)`.
513-
# However, it may be beneficial to futher delay the
514-
# use of GPU memory for cudf/cupy-based data.
515-
return self.load()
516-
517-
def load(self) -> Any:
518-
"""Load the shuffle output partition into memory"""
519-
with handle_unpack_errors(self.shuffle_run.id):
520-
return self.shuffle_run.get_output_partition(
521-
self.partition_id,
522-
self.key,
523-
# We need sync=False, because `_sync_output_partition`
524-
# was already called for the current shuffle run
525-
sync=False,
526-
**self.kwargs,
527-
)
528-
529-
530-
@dask_serialize.register(UnloadedPartition)
531-
def _serialize_unloaded(obj):
532-
# Convert to LoadedPartition before serializing. Note that
533-
# we don't convert all the way to DataFrame, because this
534-
# adds unnecessary overhead and memory pressure for the
535-
# cudf backend (and minor overhead for pandas)
536-
return None, [pickle.dumps(obj.pre_serialize())]
537-
538-
539-
@dask_deserialize.register(UnloadedPartition)
540-
def _deserialize_unloaded(header, frames):
541-
return pickle.loads(frames[0])
542-
543-
544-
def load_output_partition(
545-
data: UnloadedPartition | _T_partition_type, barrier_key: int
546-
) -> _T_partition_type:
547-
# Used by rearrange_by_column_p2p to "unwrap"
548-
# UnloadedPartition/LoadedPartition data after
549-
# a P2PShuffleLayer
550-
assert barrier_key
551-
if isinstance(data, UnloadedPartition):
552-
data = data.load()
553-
assert not isinstance(data, UnloadedPartition)
554-
return data

distributed/shuffle/_shuffle.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
get_worker_plugin,
4040
handle_transfer_errors,
4141
handle_unpack_errors,
42-
load_output_partition,
4342
)
4443
from distributed.shuffle._limiter import ResourceLimiter
4544
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
@@ -91,12 +90,12 @@ def shuffle_unpack(
9190
)
9291

9392

94-
def delayed_shuffle_unpack(
93+
def shuffle_unpack_partial(
9594
id: ShuffleId, output_partition: int, barrier_run_id: int
9695
) -> pd.DataFrame:
9796
with handle_unpack_errors(id):
9897
return get_worker_plugin().get_output_partition(
99-
id, barrier_run_id, output_partition, load=False
98+
id, barrier_run_id, output_partition, convert=False
10099
)
101100

102101

@@ -151,10 +150,10 @@ def rearrange_by_column_p2p(
151150
meta,
152151
[None] * (npartitions + 1),
153152
).map_partitions(
154-
load_output_partition,
155-
layer._tokens[1],
153+
partial(_convert_output_partition, meta=meta),
156154
meta=meta,
157155
enforce_metadata=False,
156+
align_dataframes=False,
158157
)
159158

160159

@@ -305,7 +304,7 @@ def _construct_graph(self) -> _T_LowLevelGraph:
305304
name = self.name
306305
for part_out in self.parts_out:
307306
dsk[(name, part_out)] = (
308-
delayed_shuffle_unpack,
307+
shuffle_unpack_partial,
309308
token,
310309
part_out,
311310
_barrier_key,
@@ -519,13 +518,14 @@ def _get_output_partition(
519518
self,
520519
partition_id: int,
521520
key: Key,
521+
convert: bool = True,
522522
**kwargs: Any,
523523
) -> pd.DataFrame:
524524
try:
525525
data = self._read_from_disk((partition_id,))
526-
return convert_shards(data, self.meta)
526+
return convert_shards(data, self.meta) if convert else data
527527
except KeyError:
528-
return self.meta.copy()
528+
return self.meta.copy() if convert else None
529529

530530
def _get_assigned_worker(self, id: int) -> str:
531531
return self.worker_for[id]
@@ -580,3 +580,7 @@ def _get_worker_for_range_sharding(
580580
"""Get address of target worker for this output partition using range sharding"""
581581
i = len(workers) * output_partition // npartitions
582582
return workers[i]
583+
584+
585+
def _convert_output_partition(data: pa.Table, meta: Any = None) -> pd.DataFrame:
586+
return meta.copy() if data is None else convert_shards(data, meta)

distributed/shuffle/_worker_plugin.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def get_output_partition(
417417
run_id: int,
418418
partition_id: int | NDIndex,
419419
meta: pd.DataFrame | None = None,
420-
load: bool = True,
420+
convert: bool = True,
421421
) -> Any:
422422
"""
423423
Task: Retrieve a shuffled output partition from the ShuffleWorkerPlugin.
@@ -426,13 +426,13 @@ def get_output_partition(
426426
"""
427427
shuffle_run = self.get_shuffle_run(shuffle_id, run_id)
428428
key = thread_state.key
429-
if load:
429+
if convert:
430430
return shuffle_run.get_output_partition(
431431
partition_id=partition_id,
432432
key=key,
433433
meta=meta,
434434
)
435-
return shuffle_run.get_unloaded_output_partition(
435+
return shuffle_run.get_raw_output_partition(
436436
partition_id=partition_id,
437437
key=key,
438438
meta=meta,

0 commit comments

Comments
 (0)