@@ -302,13 +302,14 @@ def _sync_output_partition(self, partition_id: _T_partition_id, key: Key) -> Non
302
302
self .raise_if_closed ()
303
303
sync (self ._loop , self ._ensure_output_worker , partition_id , key )
304
304
if not self .transferred :
305
- raise RuntimeError ("`get_output_partition ` called before barrier task" )
305
+ raise RuntimeError ("`_sync_output_partition ` called before barrier task" )
306
306
sync (self ._loop , self .flush_receive )
307
307
308
308
def get_output_partition (
309
309
self , partition_id : _T_partition_id , key : Key , ** kwargs : Any
310
310
) -> _T_partition_type :
311
- self ._sync_output_partition (partition_id , key )
311
+ if kwargs .pop ("sync" , True ):
312
+ self ._sync_output_partition (partition_id , key )
312
313
return self ._get_output_partition (partition_id , key , ** kwargs )
313
314
314
315
@abc .abstractmethod
@@ -504,24 +505,39 @@ def __init__(
504
505
self .key = key
505
506
self .kwargs = kwargs
506
507
508
+ def pre_serialize (self ) -> Any :
509
+ """Make the unloaded partition serializable"""
510
+ # TODO: Add mechanism to dispatch on meta.
511
+ # Right now, serializing an UnloadedPartition object
512
+ # will convert it to `type(self.shuffle_run.meta)`.
513
+ # However, it may be beneficial to futher delay the
514
+ # use of GPU memory for cudf/cupy-based data.
515
+ return self .load ()
516
+
507
517
def load (self ) -> Any :
518
+ """Load the shuffle output partition into memory"""
508
519
with handle_unpack_errors (self .shuffle_run .id ):
509
- return self .shuffle_run ._get_output_partition (
510
- self .partition_id , self .key , ** self .kwargs
520
+ return self .shuffle_run .get_output_partition (
521
+ self .partition_id ,
522
+ self .key ,
523
+ # We need sync=False, because `_sync_output_partition`
524
+ # was already called for the current shuffle run
525
+ sync = False ,
526
+ ** self .kwargs ,
511
527
)
512
528
513
529
514
530
@dask_serialize .register (UnloadedPartition )
515
- def _serialize_unloaded (obj : UnloadedPartition ) -> tuple [ None , list [ bytes ]] :
531
+ def _serialize_unloaded (obj ) :
516
532
# Convert to LoadedPartition before serializing. Note that
517
533
# we don't convert all the way to DataFrame, because this
518
534
# adds unnecessary overhead and memory pressure for the
519
535
# cudf backend (and minor overhead for pandas)
520
- return None , [pickle .dumps (obj .load ())]
536
+ return None , [pickle .dumps (obj .pre_serialize ())]
521
537
522
538
523
539
@dask_deserialize .register (UnloadedPartition )
524
- def _deserialize_unloaded (header : None , frames : list [ bytes ]) -> Any :
540
+ def _deserialize_unloaded (header , frames ) :
525
541
return pickle .loads (frames [0 ])
526
542
527
543
@@ -534,4 +550,5 @@ def load_output_partition(
534
550
assert barrier_key
535
551
if isinstance (data , UnloadedPartition ):
536
552
data = data .load ()
553
+ assert not isinstance (data , UnloadedPartition )
537
554
return data
0 commit comments