24
24
25
25
from distributed .core import PooledRPCCall
26
26
from distributed .exceptions import Reschedule
27
- from distributed .protocol import dask_deserialize , dask_serialize , to_serialize
27
+ from distributed .protocol import to_serialize
28
28
from distributed .shuffle ._comms import CommShardsBuffer
29
29
from distributed .shuffle ._disk import DiskShardsBuffer
30
30
from distributed .shuffle ._exceptions import ShuffleClosedError
@@ -308,8 +308,7 @@ def _sync_output_partition(self, partition_id: _T_partition_id, key: Key) -> Non
308
308
def get_output_partition (
309
309
self , partition_id : _T_partition_id , key : Key , ** kwargs : Any
310
310
) -> _T_partition_type :
311
- if kwargs .pop ("sync" , True ):
312
- self ._sync_output_partition (partition_id , key )
311
+ self ._sync_output_partition (partition_id , key )
313
312
return self ._get_output_partition (partition_id , key , ** kwargs )
314
313
315
314
@abc .abstractmethod
@@ -318,11 +317,11 @@ def _get_output_partition(
318
317
) -> _T_partition_type :
319
318
"""Get an output partition to the shuffle run"""
320
319
321
- def get_unloaded_output_partition (
320
+ def get_raw_output_partition (
322
321
self , partition_id : _T_partition_id , key : Key , ** kwargs : Any
323
- ) -> UnloadedPartition :
322
+ ) -> Any :
324
323
self ._sync_output_partition (partition_id , key )
325
- return UnloadedPartition ( self , partition_id , key , ** kwargs )
324
+ return self . _get_output_partition ( partition_id , key , convert = False , ** kwargs )
326
325
327
326
@abc .abstractmethod
328
327
def read (self , path : Path ) -> tuple [Any , int ]:
@@ -467,88 +466,3 @@ def _mean_shard_size(shards: Iterable) -> int:
467
466
if count == 10 :
468
467
break
469
468
return size // count if count else 0
470
-
471
-
472
- class UnloadedPartition :
473
- """Wrap unloaded shuffle output
474
-
475
- The purpose of this class is to keep a shuffled partition
476
- on disk until it is needed by one of its dependent tasks.
477
- Otherwise, the in-memory partition may need to be spilled
478
- back to disk before the dependent task is executed anyway.
479
-
480
- If the output tasks of a ``P2PShuffleLayer`` return objects
481
- of type ``UnloadedPartition``, that layer must be followed
482
- by an extra ``Blockwise`` call to ``load_output_partition``
483
- (to ensure the partitions are actually loaded). We want this
484
- extra layer to be ``Blockwise`` so that the loading can be
485
- fused into down-stream tasks. We do NOT want the original
486
- ``shuffle_unpack`` tasks to be fused into dependent tasks,
487
- because this would prevent load balancing after the shuffle
488
- (long-running post-shuffle tasks may be pinned to specific
489
- workers, while others sit idle).
490
-
491
- Note that serialization automatically loads the wrapped
492
- data, because the object may be moved to a worker that
493
- doesn't have access to the same local storage.
494
- """
495
-
496
- def __init__ (
497
- self ,
498
- shuffle_run : ShuffleRun ,
499
- partition_id : _T_partition_id ,
500
- key : Key ,
501
- ** kwargs : Any ,
502
- ):
503
- self .shuffle_run = shuffle_run
504
- self .partition_id = partition_id
505
- self .key = key
506
- self .kwargs = kwargs
507
-
508
- def pre_serialize (self ) -> Any :
509
- """Make the unloaded partition serializable"""
510
- # TODO: Add mechanism to dispatch on meta.
511
- # Right now, serializing an UnloadedPartition object
512
- # will convert it to `type(self.shuffle_run.meta)`.
513
- # However, it may be beneficial to futher delay the
514
- # use of GPU memory for cudf/cupy-based data.
515
- return self .load ()
516
-
517
- def load (self ) -> Any :
518
- """Load the shuffle output partition into memory"""
519
- with handle_unpack_errors (self .shuffle_run .id ):
520
- return self .shuffle_run .get_output_partition (
521
- self .partition_id ,
522
- self .key ,
523
- # We need sync=False, because `_sync_output_partition`
524
- # was already called for the current shuffle run
525
- sync = False ,
526
- ** self .kwargs ,
527
- )
528
-
529
-
530
- @dask_serialize .register (UnloadedPartition )
531
- def _serialize_unloaded (obj ):
532
- # Convert to LoadedPartition before serializing. Note that
533
- # we don't convert all the way to DataFrame, because this
534
- # adds unnecessary overhead and memory pressure for the
535
- # cudf backend (and minor overhead for pandas)
536
- return None , [pickle .dumps (obj .pre_serialize ())]
537
-
538
-
539
- @dask_deserialize .register (UnloadedPartition )
540
- def _deserialize_unloaded (header , frames ):
541
- return pickle .loads (frames [0 ])
542
-
543
-
544
- def load_output_partition (
545
- data : UnloadedPartition | _T_partition_type , barrier_key : int
546
- ) -> _T_partition_type :
547
- # Used by rearrange_by_column_p2p to "unwrap"
548
- # UnloadedPartition/LoadedPartition data after
549
- # a P2PShuffleLayer
550
- assert barrier_key
551
- if isinstance (data , UnloadedPartition ):
552
- data = data .load ()
553
- assert not isinstance (data , UnloadedPartition )
554
- return data
0 commit comments