Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into cudf-spilling-dashb…
Browse files Browse the repository at this point in the history
…oard
  • Loading branch information
charlesbluca committed Dec 12, 2023
2 parents f96b8a4 + 87576ae commit 936f0f6
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 47 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/[email protected]
- uses: actions/setup-python@v4
- uses: actions/setup-python@v5
with:
python-version: '3.9'
- uses: pre-commit/[email protected]
21 changes: 11 additions & 10 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pickle
import time
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
Expand Down Expand Up @@ -40,7 +40,6 @@
_P = ParamSpec("_P")

# circular dependencies
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin

ShuffleId = NewType("ShuffleId", str)
Expand Down Expand Up @@ -375,22 +374,24 @@ class ShuffleSpec(abc.ABC, Generic[_T_partition_id]):
id: ShuffleId
disk: bool

@abc.abstractproperty
def output_partitions(self) -> Generator[_T_partition_id, None, None]:
"""Output partitions"""
raise NotImplementedError

@abc.abstractmethod
def pick_worker(self, partition: _T_partition_id, workers: Sequence[str]) -> str:
"""Pick a worker for a partition"""

def create_new_run(
self,
plugin: ShuffleSchedulerPlugin,
worker_for: dict[_T_partition_id, str],
) -> SchedulerShuffleState:
worker_for = self._pin_output_workers(plugin)
return SchedulerShuffleState(
run_spec=ShuffleRunSpec(spec=self, worker_for=worker_for),
participating_workers=set(worker_for.values()),
)

@abc.abstractmethod
def _pin_output_workers(
self, plugin: ShuffleSchedulerPlugin
) -> dict[_T_partition_id, str]:
"""Pin output tasks to workers and return the mapping of partition ID to worker."""

@abc.abstractmethod
def create_run_on_worker(
self,
Expand Down
32 changes: 17 additions & 15 deletions distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
import mmap
import os
from collections import defaultdict
from collections.abc import Callable, Sequence
from collections.abc import Callable, Generator, Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from itertools import product
Expand All @@ -125,7 +125,6 @@
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._pickle import unpickle_bytestream
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
from distributed.shuffle._shuffle import barrier_key, shuffle_barrier
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.sizeof import sizeof
Expand Down Expand Up @@ -456,11 +455,22 @@ class ArrayRechunkSpec(ShuffleSpec[NDIndex]):
new: ChunkedAxes
old: ChunkedAxes

def _pin_output_workers(self, plugin: ShuffleSchedulerPlugin) -> dict[NDIndex, str]:
parts_out = product(*(range(len(c)) for c in self.new))
return plugin._pin_output_workers(
self.id, parts_out, _get_worker_for_hash_sharding
)
@property
def output_partitions(self) -> Generator[NDIndex, None, None]:
yield from product(*(range(len(c)) for c in self.new))

def pick_worker(self, partition: NDIndex, workers: Sequence[str]) -> str:
npartitions = 1
for c in self.new:
npartitions *= len(c)
ix = 0
for dim, pos in enumerate(partition):
if dim > 0:
ix += len(self.new[dim - 1]) * pos
else:
ix += pos
i = len(workers) * ix // npartitions
return workers[i]

def create_run_on_worker(
self,
Expand All @@ -487,11 +497,3 @@ def create_run_on_worker(
disk=self.disk,
loop=plugin.worker.loop,
)


def _get_worker_for_hash_sharding(
output_partition: NDIndex, workers: Sequence[str]
) -> str:
"""Get address of target worker for this output partition using hash sharding"""
i = hash(output_partition) % len(workers)
return workers[i]
18 changes: 7 additions & 11 deletions distributed/shuffle/_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import contextlib
import logging
from collections import defaultdict
from collections.abc import Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Any

from dask.typing import Key
Expand Down Expand Up @@ -145,7 +144,8 @@ def get_or_create(
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(spec.id)
self._raise_if_task_not_processing(key)
state = spec.create_new_run(self)
worker_for = self._calculate_worker_for(spec)
state = spec.create_new_run(worker_for)
self.active_shuffles[spec.id] = state
self._shuffles[spec.id].add(state)
state.participating_workers.add(worker)
Expand All @@ -167,12 +167,7 @@ def _raise_if_task_not_processing(self, key: Key) -> None:
if task.state != "processing":
raise RuntimeError(f"Expected {task} to be processing, is {task.state}.")

def _pin_output_workers(
self,
id: ShuffleId,
output_partitions: Iterable[Any],
pick: Callable[[Any, Sequence[str]], str],
) -> dict[Any, str]:
def _calculate_worker_for(self, spec: ShuffleSpec) -> dict[Any, str]:
"""Pin the outputs of a P2P shuffle to specific workers.
Parameters
Expand All @@ -186,15 +181,16 @@ def _pin_output_workers(
the same worker restrictions.
"""
mapping = {}
barrier = self.scheduler.tasks[barrier_key(id)]
shuffle_id = spec.id
barrier = self.scheduler.tasks[barrier_key(shuffle_id)]

if barrier.worker_restrictions:
workers = list(barrier.worker_restrictions)
else:
workers = list(self.scheduler.workers)

for partition in output_partitions:
worker = pick(partition, workers)
for partition in spec.output_partitions:
worker = spec.pick_worker(partition, workers)
mapping[partition] = worker
return mapping

Expand Down
20 changes: 14 additions & 6 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
import logging
import os
from collections import defaultdict
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
from collections.abc import (
Callable,
Collection,
Generator,
Iterable,
Iterator,
Sequence,
)
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -41,7 +47,6 @@
handle_unpack_errors,
)
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
from distributed.sizeof import sizeof

Expand Down Expand Up @@ -524,9 +529,12 @@ class DataFrameShuffleSpec(ShuffleSpec[int]):
meta: pd.DataFrame
parts_out: set[int]

def _pin_output_workers(self, plugin: ShuffleSchedulerPlugin) -> dict[int, str]:
pick_worker = partial(_get_worker_for_range_sharding, self.npartitions)
return plugin._pin_output_workers(self.id, self.parts_out, pick_worker)
@property
def output_partitions(self) -> Generator[int, None, None]:
yield from self.parts_out

def pick_worker(self, partition: int, workers: Sequence[str]) -> str:
return _get_worker_for_range_sharding(self.npartitions, partition, workers)

def create_run_on_worker(
self, run_id: int, worker_for: dict[int, str], plugin: ShuffleWorkerPlugin
Expand Down
24 changes: 20 additions & 4 deletions distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from distributed.shuffle._limiter import ResourceLimiter
from distributed.shuffle._rechunk import (
ArrayRechunkRun,
ArrayRechunkSpec,
Split,
_get_worker_for_hash_sharding,
split_axes,
)
from distributed.shuffle.tests.utils import AbstractShuffleTestPool
Expand Down Expand Up @@ -103,11 +103,12 @@ async def test_lowlevel_rechunk(tmp_path, n_workers, barrier_first_worker, disk)

worker_for_mapping = {}

spec = ArrayRechunkSpec(id=ShuffleId("foo"), disk=disk, new=new, old=old)
new_indices = list(product(*(range(len(dim)) for dim in new)))
for i, idx in enumerate(new_indices):
worker_for_mapping[idx] = _get_worker_for_hash_sharding(i, workers)

for idx in new_indices:
worker_for_mapping[idx] = spec.pick_worker(idx, workers)
assert len(set(worker_for_mapping.values())) == min(n_workers, len(new_indices))
# scheduler_state = spec.create_new_run(worker_for_mapping)

with ArrayRechunkTestPool() as local_shuffle_pool:
shuffles = []
Expand Down Expand Up @@ -1200,3 +1201,18 @@ def blocked(chunk, in_map, block_map):
buf_ids = {id(get_host_array(shard)) for shard in shards}
assert len(buf_ids) == len(shards)
await block_map.set()


@pytest.mark.parametrize("nworkers", [1, 2, 41, 50])
def test_worker_for_homogeneous_distribution(nworkers):
old = ((1, 2, 3, 4), (5,) * 6)
new = ((5, 5), (12, 18))
workers = [str(i) for i in range(nworkers)]
spec = ArrayRechunkSpec(ShuffleId("foo"), disk=False, new=new, old=old)
count = {w: 0 for w in workers}
for nidx in spec.output_partitions:
count[spec.pick_worker(nidx, workers)] += 1

assert sum(count.values()) > 0
assert sum(count.values()) == len(list(spec.output_partitions))
assert abs(max(count.values()) - min(count.values())) <= 1

0 comments on commit 936f0f6

Please sign in to comment.