Skip to content

Commit

Permalink
MINOR: Extract truncation logic out of partial concatenation in P2P r…
Browse files Browse the repository at this point in the history
…echunking (#8826)
  • Loading branch information
hendrikmakait authored Aug 12, 2024
1 parent e364f42 commit 86dc83c
Showing 1 changed file with 31 additions and 31 deletions.
62 changes: 31 additions & 31 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,20 +399,21 @@ def _construct_graph(self) -> _T_LowLevelGraph:
chunked_shape = tuple(len(axis) for axis in self.chunks)

for ndpartial in _split_partials(_old_to_new, chunked_shape):
output_count = np.sum(self.keepmap[ndpartial.new])
partial_keepmap = self.keepmap[ndpartial.new]
output_count = np.sum(partial_keepmap)
if output_count == 0:
continue
elif output_count == 1:
# Single output chunk
# TODO: Create new partial that contains ONLY the relevant chunk
ndindex = np.argwhere(partial_keepmap)[0]
ndpartial = _truncate_partial(ndindex, ndpartial, _old_to_new)

dsk.update(
partial_concatenate(
input_name=self.name_input,
input_chunks=self.chunks_input,
ndpartial=ndpartial,
token=self.token,
keepmap=self.keepmap,
old_to_new=_old_to_new,
)
)
else:
Expand Down Expand Up @@ -516,8 +517,6 @@ def partial_concatenate(
input_chunks: ChunkedAxes,
ndpartial: _NDPartial,
token: str,
keepmap: np.ndarray,
old_to_new: list[Any],
) -> dict[Key, Any]:
import numpy as np

Expand All @@ -528,31 +527,6 @@ def partial_concatenate(

slice_group = f"rechunk-slice-{token}"

partial_keepmap = keepmap[ndpartial.new]
assert np.sum(partial_keepmap) == 1

ndindex = np.argwhere(partial_keepmap)[0]

partial_per_axis = []
for axis_index, index in enumerate(ndindex):
slc = slice(
ndpartial.new[axis_index].start + index,
ndpartial.new[axis_index].start + index + 1,
)
first_old_chunk, first_old_slice = old_to_new[axis_index][slc.start][0]
last_old_chunk, last_old_slice = old_to_new[axis_index][slc.stop - 1][-1]
partial_per_axis.append(
_Partial(
old=slice(first_old_chunk, last_old_chunk + 1),
new=slc,
left_start=first_old_slice.start,
right_stop=last_old_slice.stop,
)
)

old, new, left_starts, right_stops = zip(*partial_per_axis)
ndpartial = _NDPartial(old, new, left_starts, right_stops, ndpartial.ix)

old_offset = tuple(slice_.start for slice_ in ndpartial.old)

shape = tuple(slice_.stop - slice_.start for slice_ in ndpartial.old)
Expand Down Expand Up @@ -588,6 +562,32 @@ def partial_concatenate(
return dsk


def _truncate_partial(
ndindex: NDIndex,
ndpartial: _NDPartial,
old_to_new: list[Any],
) -> _NDPartial:
partial_per_axis = []
for axis_index, index in enumerate(ndindex):
slc = slice(
ndpartial.new[axis_index].start + index,
ndpartial.new[axis_index].start + index + 1,
)
first_old_chunk, first_old_slice = old_to_new[axis_index][slc.start][0]
last_old_chunk, last_old_slice = old_to_new[axis_index][slc.stop - 1][-1]
partial_per_axis.append(
_Partial(
old=slice(first_old_chunk, last_old_chunk + 1),
new=slc,
left_start=first_old_slice.start,
right_stop=last_old_slice.stop,
)
)

old, new, left_starts, right_stops = zip(*partial_per_axis)
return _NDPartial(old, new, left_starts, right_stops, ndpartial.ix)


def _compute_partial_old_chunks(
partial: _NDPartial, chunks: ChunkedAxes
) -> ChunkedAxes:
Expand Down

0 comments on commit 86dc83c

Please sign in to comment.