@@ -310,20 +310,25 @@ def split_by_worker(
310
310
constructor = df ._constructor_sliced
311
311
assert isinstance (constructor , type )
312
312
worker_for = constructor (worker_for )
313
- df = df .merge (
313
+ df [ "_worker" ] = df [[ column ]]. reset_index ( drop = True ) .merge (
314
314
right = worker_for .cat .codes .rename ("_worker" ),
315
315
left_on = column ,
316
316
right_index = True ,
317
- how = "inner" ,
318
- )
317
+ ).sort_index ()["_worker" ]
318
+ # df = df.merge(
319
+ # right=worker_for.cat.codes.rename("_worker"),
320
+ # left_on=column,
321
+ # right_index=True,
322
+ # how="inner",
323
+ # )
319
324
nrows = len (df )
320
325
if not nrows :
321
326
return {}
322
327
# assert len(df) == nrows # Not true if some outputs aren't wanted
323
328
# FIXME: If we do not preserve the index something is corrupting the
324
329
# bytestream such that it cannot be deserialized anymore
325
- t = to_pyarrow_table_dispatch (df , preserve_index = True )
326
- t = t .sort_by ("_worker" )
330
+ t = to_pyarrow_table_dispatch (df . sort_values ( "_worker" ) , preserve_index = True )
331
+ # t = t.sort_by("_worker")
327
332
codes = np .asarray (t ["_worker" ])
328
333
t = t .drop (["_worker" ])
329
334
del df
@@ -346,6 +351,94 @@ def split_by_worker(
346
351
return out
347
352
348
353
354
+ # def split_by_worker(
355
+ # df: pd.DataFrame,
356
+ # column: str,
357
+ # meta: pd.DataFrame,
358
+ # worker_for: pd.Series,
359
+ # ) -> dict[Any, pa.Table]:
360
+ # """
361
+ # Split data into many arrow batches, partitioned by destination worker
362
+ # """
363
+ # import numpy as np
364
+
365
+ # from dask.dataframe.dispatch import to_pyarrow_table_dispatch
366
+
367
+ # # (cudf support) Avoid pd.Series
368
+ # constructor = df._constructor_sliced
369
+ # assert isinstance(constructor, type)
370
+ # worker_for = constructor(worker_for)
371
+
372
+ # df["_worker"] = df[[column]].reset_index(drop=True).merge(
373
+ # right=worker_for.cat.codes.rename("_worker"),
374
+ # left_on=column,
375
+ # right_index=True,
376
+ # ).sort_index()["_worker"]
377
+ # # df = df.merge(
378
+ # # right=worker_for.cat.codes.rename("_worker"),
379
+ # # left_on=column,
380
+ # # right_index=True,
381
+ # # how="inner",
382
+ # # )
383
+ # nrows = len(df)
384
+ # if not nrows:
385
+ # return {}
386
+
387
+ # c = df["_worker"]
388
+ # k = len(worker_for.cat.categories)
389
+ # out = {
390
+ # worker_for.cat.categories[code] : to_pyarrow_table_dispatch(shard, preserve_index=True)
391
+ # for code, shard in enumerate(
392
+ # df.scatter_by_map(
393
+ # c.astype(np.int32, copy=False),
394
+ # map_size=k,
395
+ # keep_index=True,
396
+ # )
397
+ # )
398
+ # }
399
+ # assert sum(map(len, out.values())) == nrows
400
+ # return out
401
+
402
+ # # # (cudf support) Avoid pd.Series
403
+ # # constructor = df._constructor_sliced
404
+ # # assert isinstance(constructor, type)
405
+ # # worker_for = constructor(worker_for)
406
+ # # df = df.merge(
407
+ # # right=worker_for.cat.codes.rename("_worker"),
408
+ # # left_on=column,
409
+ # # right_index=True,
410
+ # # how="inner",
411
+ # # )
412
+ # # nrows = len(df)
413
+ # # if not nrows:
414
+ # # return {}
415
+ # # # assert len(df) == nrows # Not true if some outputs aren't wanted
416
+ # # # FIXME: If we do not preserve the index something is corrupting the
417
+ # # # bytestream such that it cannot be deserialized anymore
418
+ # # t = to_pyarrow_table_dispatch(df, preserve_index=True)
419
+ # # t = t.sort_by("_worker")
420
+ # # codes = np.asarray(t["_worker"])
421
+ # # t = t.drop(["_worker"])
422
+ # # del df
423
+
424
+ # # splits = np.where(codes[1:] != codes[:-1])[0] + 1
425
+ # # splits = np.concatenate([[0], splits])
426
+
427
+ # # shards = [
428
+ # # t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits)
429
+ # # ]
430
+ # # shards.append(t.slice(offset=splits[-1], length=None))
431
+
432
+ # # unique_codes = codes[splits]
433
+ # # out = {
434
+ # # # FIXME https://github.com/pandas-dev/pandas-stubs/issues/43
435
+ # # worker_for.cat.categories[code]: shard
436
+ # # for code, shard in zip(unique_codes, shards)
437
+ # # }
438
+ # # assert sum(map(len, out.values())) == nrows
439
+ # # return out
440
+
441
+
349
442
def split_by_partition (t : pa .Table , column : str ) -> dict [int , pa .Table ]:
350
443
"""
351
444
Split data into many arrow batches, partitioned by final partition
0 commit comments