Skip to content

Commit c730903

Browse files
committed
back to simple load, but leave space for future dispatching
1 parent 61a3eb2 commit c730903

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

distributed/shuffle/_core.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,14 @@ def _sync_output_partition(self, partition_id: _T_partition_id, key: Key) -> Non
302302
self.raise_if_closed()
303303
sync(self._loop, self._ensure_output_worker, partition_id, key)
304304
if not self.transferred:
305-
raise RuntimeError("`get_output_partition` called before barrier task")
305+
raise RuntimeError("`_sync_output_partition` called before barrier task")
306306
sync(self._loop, self.flush_receive)
307307

308308
def get_output_partition(
309309
self, partition_id: _T_partition_id, key: Key, **kwargs: Any
310310
) -> _T_partition_type:
311-
self._sync_output_partition(partition_id, key)
311+
if kwargs.pop("sync", True):
312+
self._sync_output_partition(partition_id, key)
312313
return self._get_output_partition(partition_id, key, **kwargs)
313314

314315
@abc.abstractmethod
@@ -504,24 +505,39 @@ def __init__(
504505
self.key = key
505506
self.kwargs = kwargs
506507

508+
def pre_serialize(self) -> Any:
509+
"""Make the unloaded partition serializable"""
510+
# TODO: Add mechanism to dispatch on meta.
511+
# Right now, serializing an UnloadedPartition object
512+
# will convert it to `type(self.shuffle_run.meta)`.
513+
# However, it may be beneficial to futher delay the
514+
# use of GPU memory for cudf/cupy-based data.
515+
return self.load()
516+
507517
def load(self) -> Any:
518+
"""Load the shuffle output partition into memory"""
508519
with handle_unpack_errors(self.shuffle_run.id):
509-
return self.shuffle_run._get_output_partition(
510-
self.partition_id, self.key, **self.kwargs
520+
return self.shuffle_run.get_output_partition(
521+
self.partition_id,
522+
self.key,
523+
# We need sync=False, because `_sync_output_partition`
524+
# was already called for the current shuffle run
525+
sync=False,
526+
**self.kwargs,
511527
)
512528

513529

514530
@dask_serialize.register(UnloadedPartition)
515-
def _serialize_unloaded(obj: UnloadedPartition) -> tuple[None, list[bytes]]:
531+
def _serialize_unloaded(obj):
516532
# Convert to LoadedPartition before serializing. Note that
517533
# we don't convert all the way to DataFrame, because this
518534
# adds unnecessary overhead and memory pressure for the
519535
# cudf backend (and minor overhead for pandas)
520-
return None, [pickle.dumps(obj.load())]
536+
return None, [pickle.dumps(obj.pre_serialize())]
521537

522538

523539
@dask_deserialize.register(UnloadedPartition)
524-
def _deserialize_unloaded(header: None, frames: list[bytes]) -> Any:
540+
def _deserialize_unloaded(header, frames):
525541
return pickle.loads(frames[0])
526542

527543

@@ -534,4 +550,5 @@ def load_output_partition(
534550
assert barrier_key
535551
if isinstance(data, UnloadedPartition):
536552
data = data.load()
553+
assert not isinstance(data, UnloadedPartition)
537554
return data

distributed/shuffle/_shuffle.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def shuffle_unpack(
9191
)
9292

9393

94-
def shuffle_unpack_unloaded(
94+
def delayed_shuffle_unpack(
9595
id: ShuffleId, output_partition: int, barrier_run_id: int
9696
) -> pd.DataFrame:
9797
with handle_unpack_errors(id):
@@ -305,7 +305,7 @@ def _construct_graph(self) -> _T_LowLevelGraph:
305305
name = self.name
306306
for part_out in self.parts_out:
307307
dsk[(name, part_out)] = (
308-
shuffle_unpack_unloaded,
308+
delayed_shuffle_unpack,
309309
token,
310310
part_out,
311311
_barrier_key,

0 commit comments

Comments
 (0)