Skip to content

Commit 166eb51

Browse files
committed
play with split_by_worker
1 parent f16c51d commit 166eb51

File tree

1 file changed

+98
-5
lines changed

1 file changed

+98
-5
lines changed

distributed/shuffle/_shuffle.py

+98-5
Original file line numberDiff line numberDiff line change
@@ -310,20 +310,25 @@ def split_by_worker(
310310
constructor = df._constructor_sliced
311311
assert isinstance(constructor, type)
312312
worker_for = constructor(worker_for)
313-
df = df.merge(
313+
df["_worker"] = df[[column]].reset_index(drop=True).merge(
314314
right=worker_for.cat.codes.rename("_worker"),
315315
left_on=column,
316316
right_index=True,
317-
how="inner",
318-
)
317+
).sort_index()["_worker"]
318+
# df = df.merge(
319+
# right=worker_for.cat.codes.rename("_worker"),
320+
# left_on=column,
321+
# right_index=True,
322+
# how="inner",
323+
# )
319324
nrows = len(df)
320325
if not nrows:
321326
return {}
322327
# assert len(df) == nrows # Not true if some outputs aren't wanted
323328
# FIXME: If we do not preserve the index something is corrupting the
324329
# bytestream such that it cannot be deserialized anymore
325-
t = to_pyarrow_table_dispatch(df, preserve_index=True)
326-
t = t.sort_by("_worker")
330+
t = to_pyarrow_table_dispatch(df.sort_values("_worker"), preserve_index=True)
331+
#t = t.sort_by("_worker")
327332
codes = np.asarray(t["_worker"])
328333
t = t.drop(["_worker"])
329334
del df
@@ -346,6 +351,94 @@ def split_by_worker(
346351
return out
347352

348353

354+
# def split_by_worker(
355+
# df: pd.DataFrame,
356+
# column: str,
357+
# meta: pd.DataFrame,
358+
# worker_for: pd.Series,
359+
# ) -> dict[Any, pa.Table]:
360+
# """
361+
# Split data into many arrow batches, partitioned by destination worker
362+
# """
363+
# import numpy as np
364+
365+
# from dask.dataframe.dispatch import to_pyarrow_table_dispatch
366+
367+
# # (cudf support) Avoid pd.Series
368+
# constructor = df._constructor_sliced
369+
# assert isinstance(constructor, type)
370+
# worker_for = constructor(worker_for)
371+
372+
# df["_worker"] = df[[column]].reset_index(drop=True).merge(
373+
# right=worker_for.cat.codes.rename("_worker"),
374+
# left_on=column,
375+
# right_index=True,
376+
# ).sort_index()["_worker"]
377+
# # df = df.merge(
378+
# # right=worker_for.cat.codes.rename("_worker"),
379+
# # left_on=column,
380+
# # right_index=True,
381+
# # how="inner",
382+
# # )
383+
# nrows = len(df)
384+
# if not nrows:
385+
# return {}
386+
387+
# c = df["_worker"]
388+
# k = len(worker_for.cat.categories)
389+
# out = {
390+
# worker_for.cat.categories[code] : to_pyarrow_table_dispatch(shard, preserve_index=True)
391+
# for code, shard in enumerate(
392+
# df.scatter_by_map(
393+
# c.astype(np.int32, copy=False),
394+
# map_size=k,
395+
# keep_index=True,
396+
# )
397+
# )
398+
# }
399+
# assert sum(map(len, out.values())) == nrows
400+
# return out
401+
402+
# # # (cudf support) Avoid pd.Series
403+
# # constructor = df._constructor_sliced
404+
# # assert isinstance(constructor, type)
405+
# # worker_for = constructor(worker_for)
406+
# # df = df.merge(
407+
# # right=worker_for.cat.codes.rename("_worker"),
408+
# # left_on=column,
409+
# # right_index=True,
410+
# # how="inner",
411+
# # )
412+
# # nrows = len(df)
413+
# # if not nrows:
414+
# # return {}
415+
# # # assert len(df) == nrows # Not true if some outputs aren't wanted
416+
# # # FIXME: If we do not preserve the index something is corrupting the
417+
# # # bytestream such that it cannot be deserialized anymore
418+
# # t = to_pyarrow_table_dispatch(df, preserve_index=True)
419+
# # t = t.sort_by("_worker")
420+
# # codes = np.asarray(t["_worker"])
421+
# # t = t.drop(["_worker"])
422+
# # del df
423+
424+
# # splits = np.where(codes[1:] != codes[:-1])[0] + 1
425+
# # splits = np.concatenate([[0], splits])
426+
427+
# # shards = [
428+
# # t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits)
429+
# # ]
430+
# # shards.append(t.slice(offset=splits[-1], length=None))
431+
432+
# # unique_codes = codes[splits]
433+
# # out = {
434+
# # # FIXME https://github.com/pandas-dev/pandas-stubs/issues/43
435+
# # worker_for.cat.categories[code]: shard
436+
# # for code, shard in zip(unique_codes, shards)
437+
# # }
438+
# # assert sum(map(len, out.values())) == nrows
439+
# # return out
440+
441+
349442
def split_by_partition(t: pa.Table, column: str) -> dict[int, pa.Table]:
350443
"""
351444
Split data into many arrow batches, partitioned by final partition

0 commit comments

Comments
 (0)