From a4aa29f6c2d2251e4b9aaebc1a8d3c5fcbd463bb Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 17 May 2024 15:05:46 -0500 Subject: [PATCH 001/138] bump version to 2024.5.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 089e681561..1ef856d899 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ requires-python = ">=3.9" dependencies = [ "click >= 8.0", "cloudpickle >= 1.5.0", - "dask == 2024.5.0", + "dask == 2024.5.1", "jinja2 >= 2.10.3", "locket >= 1.0.0", "msgpack >= 1.0.0", From e0367577b6fefb4c18bc3252c869ee1d780febdb Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 20 May 2024 10:42:49 -0500 Subject: [PATCH 002/138] Avoid multiple ``WorkerState`` sphinx error (#8643) --- distributed/diagnostics/plugin.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index cd935140f7..7b497d511c 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -21,9 +21,12 @@ if TYPE_CHECKING: # circular imports + # Needed to avoid Sphinx WARNING: more than one target found for cross-reference + # 'WorkerState'" + # https://github.com/agronholm/sphinx-autodoc-typehints#dealing-with-circular-imports + from distributed import scheduler as scheduler_module from distributed.scheduler import Scheduler from distributed.scheduler import TaskStateState as SchedulerTaskStateState - from distributed.scheduler import WorkerState from distributed.worker import Worker from distributed.worker_state_machine import TaskStateState as WorkerTaskStateState @@ -207,8 +210,8 @@ def remove_client(self, scheduler: Scheduler, client: str) -> None: """Run when a client disconnects""" def valid_workers_downscaling( - self, scheduler: Scheduler, workers: list[WorkerState] - ) -> list[WorkerState]: + self, scheduler: Scheduler, workers: list[scheduler_module.WorkerState] + ) -> list[scheduler_module.WorkerState]: """Determine which workers can be removed from the cluster This method is called when the scheduler is about to downscale the cluster From 854a280dd8b6f485a0f682d4990b9f8c15012149 Mon Sep 17 00:00:00 2001 From: Ray Bell Date: Mon, 20 May 2024 11:51:52 -0400 Subject: [PATCH 003/138] Fix indent in code example in ``task-launch.rst`` (#8650) --- docs/source/task-launch.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/task-launch.rst b/docs/source/task-launch.rst index 80b0c1f8c8..2731a67ce1 100644 --- a/docs/source/task-launch.rst +++ b/docs/source/task-launch.rst @@ -204,17 +204,17 @@ worker. .. code-block:: python - from dask.distributed import worker_client + from dask.distributed import Client, worker_client def fib(n): if n < 2: return n - with worker_client() as client: - a_future = client.submit(fib, n - 1) - b_future = client.submit(fib, n - 2) - a, b = client.gather([a_future, b_future]) - return a + b + with worker_client() as client: + a_future = client.submit(fib, n - 1) + b_future = client.submit(fib, n - 2) + a, b = client.gather([a_future, b_future]) + return a + b if __name__ == "__main__": client = Client() From b3f8fcd0c8df9f131c5f4ab88ff95eb09e6e81df Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Wed, 22 May 2024 19:45:58 +0200 Subject: [PATCH 004/138] Submit collections metadata to scheduler (#8612) Co-authored-by: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Co-authored-by: Hendrik Makait --- distributed/client.py | 22 ++++++++++++++++++-- distributed/recreate_tasks.py | 2 +- distributed/scheduler.py | 9 ++++++-- distributed/spans.py | 32 +++++++++++++++++++++++++++-- distributed/tests/test_scheduler.py | 1 + distributed/tests/test_spans.py | 23 +++++++++++++++++++++ 6 files changed, 82 insertions(+), 7 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index df976d0e05..837360217b 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -92,6 +92,7 @@ from distributed.pubsub import PubSubClientExtension from distributed.security import Security from distributed.sizeof import sizeof +from distributed.spans import SpanMetadata from distributed.threadpoolexecutor import rejoin from distributed.utils import ( CancelledError, @@ -1946,7 +1947,6 @@ def submit( dsk = {key: (apply, func, list(args), kwargs)} else: dsk = {key: (func,) + tuple(args)} - futures = self._graph_to_futures( dsk, [key], @@ -1958,6 +1958,7 @@ def submit( retries=retries, fifo_timeout=fifo_timeout, actors=actor, + span_metadata=SpanMetadata(collections=[{"type": "Future"}]), ) logger.debug("Submit %s(...), %s", funcname(func), key) @@ -2164,6 +2165,7 @@ def map( user_priority=priority, fifo_timeout=fifo_timeout, actors=actor, + span_metadata=SpanMetadata(collections=[{"type": "Future"}]), ) logger.debug("map(%s, ...)", funcname(func)) @@ -3103,6 +3105,7 @@ def _graph_to_futures( self, dsk, keys, + span_metadata, workers=None, allow_other_workers=None, internal_priority=None, @@ -3179,6 +3182,7 @@ def _graph_to_futures( "actors": actors, "code": ToPickle(computations), "annotations": ToPickle(annotations), + "span_metadata": ToPickle(span_metadata), } ) return futures @@ -3266,6 +3270,7 @@ def get( retries=retries, user_priority=priority, actors=actors, + span_metadata=SpanMetadata(collections=[{"type": "low-level-graph"}]), ) packed = pack_data(keys, futures) if sync: @@ -3448,6 +3453,9 @@ def compute( ) variables = [a for a in collections if dask.is_dask_collection(a)] + metadata = SpanMetadata( + collections=[get_collections_metadata(v) for v in variables] + ) dsk = self.collections_to_dsk(variables, optimize_graph, **kwargs) names = ["finalize-%s" % tokenize(v) for v in variables] @@ -3481,6 +3489,7 @@ def compute( user_priority=priority, fifo_timeout=fifo_timeout, actors=actors, + span_metadata=metadata, ) i = 0 @@ -3572,7 +3581,9 @@ def persist( collections = [collections] assert all(map(dask.is_dask_collection, collections)) - + metadata = SpanMetadata( + collections=[get_collections_metadata(v) for v in collections] + ) dsk = self.collections_to_dsk(collections, optimize_graph, **kwargs) names = {k for c in collections for k in flatten(c.__dask_keys__())} @@ -3587,6 +3598,7 @@ def persist( user_priority=priority, fifo_timeout=fifo_timeout, actors=actors, + span_metadata=metadata, ) postpersists = [c.__dask_postpersist__() for c in collections] @@ -6154,4 +6166,10 @@ def _close_global_client(): c.close(timeout=3) +def get_collections_metadata(collection): + return { + "type": type(collection).__name__, + } + + atexit.register(_close_global_client) diff --git a/distributed/recreate_tasks.py b/distributed/recreate_tasks.py index d3de76f576..78adfda09f 100644 --- a/distributed/recreate_tasks.py +++ b/distributed/recreate_tasks.py @@ -93,7 +93,7 @@ async def _prepare_raw_components(self, raw_components): Take raw components and resolve future dependencies. """ function, args, kwargs, deps = raw_components - futures = self.client._graph_to_futures({}, deps) + futures = self.client._graph_to_futures({}, deps, span_metadata={}) data = await self.client._gather(futures) args = pack_data(args, data) kwargs = pack_data(kwargs, data) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index d25793c08f..86cf067f91 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -113,7 +113,7 @@ from distributed.security import Security from distributed.semaphore import SemaphoreExtension from distributed.shuffle import ShuffleSchedulerPlugin -from distributed.spans import SpansSchedulerExtension +from distributed.spans import SpanMetadata, SpansSchedulerExtension from distributed.stealing import WorkStealing from distributed.utils import ( All, @@ -4524,6 +4524,7 @@ def _create_taskstate_from_graph( global_annotations: dict | None, stimulus_id: str, submitting_task: Key | None, + span_metadata: SpanMetadata, user_priority: int | dict[Key, int] = 0, actors: bool | list[Key] | None = None, fifo_timeout: float = 0.0, @@ -4632,7 +4633,9 @@ def _create_taskstate_from_graph( # _generate_taskstates is not the only thing that calls new_task(). A # TaskState may have also been created by client_desires_keys or scatter, # and only later gained a run_spec. - span_annotations = spans_ext.observe_tasks(runnable, code=code) + span_annotations = spans_ext.observe_tasks( + runnable, span_metadata=span_metadata, code=code + ) # In case of TaskGroup collision, spans may have changed # FIXME: Is this used anywhere besides tests? if span_annotations: @@ -4667,6 +4670,7 @@ async def update_graph( graph_header: dict, graph_frames: list[bytes], keys: set[Key], + span_metadata: SpanMetadata, internal_priority: dict[Key, int] | None, submitting_task: Key | None, user_priority: int | dict[Key, int] = 0, @@ -4724,6 +4728,7 @@ async def update_graph( actors=actors, fifo_timeout=fifo_timeout, code=code, + span_metadata=span_metadata, annotations_by_type=annotations_by_type, # FIXME: This is just used to attach to Computation # objects. This should be removed diff --git a/distributed/spans.py b/distributed/spans.py index 37d5ad011b..354ef62b47 100644 --- a/distributed/spans.py +++ b/distributed/spans.py @@ -1,12 +1,13 @@ from __future__ import annotations +import copy import uuid import weakref from collections import defaultdict from collections.abc import Hashable, Iterable, Iterator from contextlib import contextmanager from itertools import islice -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict import dask.config from dask.typing import Key @@ -28,6 +29,10 @@ CONTEXTS_WITH_SPAN_ID = ("execute", "p2p") +class SpanMetadata(TypedDict): + collections: list[dict] + + @contextmanager def span(*tags: str) -> Iterator[str]: """Tag group of tasks to be part of a certain group, called a span. @@ -116,6 +121,7 @@ class Span: #: Source code snippets, if it was sent by the client. #: We're using a dict without values as an insertion-sorted set. _code: dict[tuple[SourceCode, ...], None] + _metadata: SpanMetadata | None _cumulative_worker_metrics: defaultdict[tuple[Hashable, ...], float] @@ -128,6 +134,7 @@ class Span: __weakref__: Any __slots__ = tuple(__annotations__) + _metadata_seen: set[int] = set() def __init__( self, @@ -143,6 +150,7 @@ def __init__( self.children = [] self.groups = set() self._code = {} + self._metadata = None # Don't cast int metrics to float self._cumulative_worker_metrics = defaultdict(int) @@ -162,6 +170,17 @@ def parent(self) -> Span | None: return out return None + def add_metadata(self, metadata: SpanMetadata) -> None: + """Add metadata to the span, e.g. code snippets""" + id_ = id(metadata) + if id_ in self._metadata_seen: + return + self._metadata_seen.add(id_) + if self._metadata is None: + self._metadata = copy.deepcopy(metadata) + else: + self._metadata["collections"].extend(metadata["collections"]) + @property def annotation(self) -> dict[str, tuple[str, ...]] | None: """Rebuild the dask graph annotation which contains the full id history @@ -241,6 +260,10 @@ def stop(self) -> float: # being perfectly monotonic return max(out, self.enqueued) + @property + def metadata(self) -> SpanMetadata | None: + return self._metadata + @property def states(self) -> dict[TaskStateState, int]: """The number of tasks currently in each state in this span tree; @@ -481,7 +504,10 @@ def __init__(self, scheduler: Scheduler): self.spans_search_by_tag = defaultdict(list) def observe_tasks( - self, tss: Iterable[scheduler_module.TaskState], code: tuple[SourceCode, ...] + self, + tss: Iterable[scheduler_module.TaskState], + code: tuple[SourceCode, ...], + span_metadata: SpanMetadata, ) -> dict[Key, dict]: """Acknowledge the existence of runnable tasks on the scheduler. These may either be new tasks, tasks that were previously unrunnable, or tasks that were @@ -520,6 +546,8 @@ def observe_tasks( if code: span._code[code] = None + if span_metadata: + span.add_metadata(span_metadata) # The span may be completely different from the one referenced by the # annotation, due to the TaskGroup collision issue explained above. diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 68ccfbfccb..916ab3a76a 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1439,6 +1439,7 @@ async def test_update_graph_culls(s, a, b): client="client", internal_priority={k: 0 for k in "xyz"}, submitting_task=None, + span_metadata={}, ) assert "z" not in s.tasks diff --git a/distributed/tests/test_spans.py b/distributed/tests/test_spans.py index 89772ed5ea..8a1d90eaed 100644 --- a/distributed/tests/test_spans.py +++ b/distributed/tests/test_spans.py @@ -859,3 +859,26 @@ async def test_span_on_persist(c, s, a, b): assert s.tasks["x"].group.span_id == x_id assert s.tasks["y"].group.span_id == y_id + + +@pytest.mark.filterwarnings("ignore:Dask annotations") +@gen_cluster(client=True) +async def test_collections_metadata(c, s, a, b): + pd = pytest.importorskip("pandas") + dd = pytest.importorskip("dask.dataframe") + np = pytest.importorskip("numpy") + df = pd.DataFrame( + {"x": np.random.random(1000), "y": np.random.random(1000)}, + index=np.arange(1000), + ) + ldf = dd.from_pandas(df, npartitions=10) + + with span("foo") as span_id: + await c.compute(ldf) + + ext = s.extensions["spans"] + span_ = ext.spans[span_id] + collections_meta = span_.metadata["collections"] + assert isinstance(collections_meta, list) + assert len(collections_meta) == 1 + assert collections_meta[0]["type"] == type(ldf).__name__ From 86e635a975110b7db1a6548b85e5b5cf593ce453 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 28 May 2024 10:46:36 -0500 Subject: [PATCH 005/138] Update gpuCI ``RAPIDS_VER`` to ``24.08`` (#8652) --- continuous_integration/gpuci/axis.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/continuous_integration/gpuci/axis.yaml b/continuous_integration/gpuci/axis.yaml index f7dff8116c..77e9b9b45c 100644 --- a/continuous_integration/gpuci/axis.yaml +++ b/continuous_integration/gpuci/axis.yaml @@ -10,6 +10,6 @@ LINUX_VER: - ubuntu20.04 RAPIDS_VER: -- "24.06" +- "24.08" excludes: From 9fae5dacf4d2cfd5c659e472b0a3ef307d695863 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 28 May 2024 18:35:07 +0200 Subject: [PATCH 006/138] Reduce task group count for partial P2P rechunks (#8655) --- distributed/shuffle/_rechunk.py | 61 +++++++++++++++-------- distributed/shuffle/tests/test_rechunk.py | 50 ++++++++++++++++--- 2 files changed, 83 insertions(+), 28 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 6bb6e3f069..be36f47f56 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -199,6 +199,14 @@ class _NDPartial(NamedTuple): #: to exclude from this partial. #: This corresponds to `left_start` of the subsequent partial. right_stops: NDIndex + #: Index of the partial among all partials. + #: This corresponds to the position of the partial in the n-dimensional grid of + #: partials representing the full rechunk. + ix: NDIndex + + +def rechunk_name(token: str) -> str: + return f"rechunk-p2p-{token}" def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array: @@ -210,14 +218,14 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array: dsk = {} token = tokenize(x, chunks) - name = f"rechunk-p2p-{token}" for ndpartial in _split_partials(x, chunks): if all(slc.stop == slc.start + 1 for slc in ndpartial.new): # Single output chunk - dsk.update(partial_concatenate(x, chunks, ndpartial, name)) + dsk.update(partial_concatenate(x, chunks, ndpartial, token)) else: - dsk.update(partial_rechunk(x, chunks, ndpartial, name)) + dsk.update(partial_rechunk(x, chunks, ndpartial, token)) layer = MaterializedLayer(dsk) + name = rechunk_name(token) graph = HighLevelGraph.from_collections(name, layer, dependencies=[x]) arr = da.Array(graph, name, chunks, meta=x) return arr @@ -228,9 +236,12 @@ def _split_partials( ) -> Generator[_NDPartial, None, None]: """Split the rechunking into partials that can be performed separately""" partials_per_axis = _split_partials_per_axis(x, chunks) - for partial_per_axis in product(*partials_per_axis): + indices_per_axis = (range(len(partials)) for partials in partials_per_axis) + for nindex, partial_per_axis in zip( + product(*indices_per_axis), product(*partials_per_axis) + ): old, new, left_starts, right_stops = zip(*partial_per_axis) - yield _NDPartial(old, new, left_starts, right_stops) + yield _NDPartial(old, new, left_starts, right_stops, nindex) def _split_partials_per_axis( @@ -311,7 +322,7 @@ def partial_concatenate( x: da.Array, chunks: ChunkedAxes, ndpartial: _NDPartial, - name: str, + token: str, ) -> dict[Key, Any]: import numpy as np @@ -320,8 +331,7 @@ def partial_concatenate( dsk: dict[Key, Any] = {} - partial_token = tokenize(x, chunks, ndpartial.new) - slice_name = f"rechunk-slice-{partial_token}" + slice_group = f"rechunk-slice-{token}" old_offset = tuple(slice_.start for slice_ in ndpartial.old) shape = tuple(slice_.stop - slice_.start for slice_ in ndpartial.old) @@ -340,8 +350,9 @@ def partial_concatenate( axis[index] for index, axis in zip(old_global_index, x.chunks) ) if _slicing_is_necessary(ndslice, original_shape): - rec_cat_arg[old_partial_index] = (slice_name,) + old_global_index - dsk[(slice_name,) + old_global_index] = ( + key = (slice_group,) + ndpartial.ix + old_global_index + rec_cat_arg[old_partial_index] = key + dsk[key] = ( getitem, (x.name,) + old_global_index, ndslice, @@ -349,7 +360,7 @@ def partial_concatenate( else: rec_cat_arg[old_partial_index] = (x.name,) + old_global_index global_index = tuple(int(slice_.start) for slice_ in ndpartial.new) - dsk[(name,) + global_index] = ( + dsk[(rechunk_name(token),) + global_index] = ( concatenate3, rec_cat_arg.tolist(), ) @@ -381,7 +392,7 @@ def partial_rechunk( x: da.Array, chunks: ChunkedAxes, ndpartial: _NDPartial, - name: str, + token: str, ) -> dict[Key, Any]: from dask.array.chunk import getitem @@ -389,10 +400,15 @@ def partial_rechunk( old_partial_offset = tuple(slice_.start for slice_ in ndpartial.old) - partial_token = tokenize(x, chunks, ndpartial.new) + partial_token = tokenize(token, ndpartial.ix) + # Use `token` to generate a canonical group for the entire rechunk + slice_group = f"rechunk-slice-{token}" + transfer_group = f"rechunk-transfer-{token}" + unpack_group = rechunk_name(token) + # We can use `partial_token` here because the barrier task share their + # group across all P2P shuffle-like operations + # FIXME: Make this group unique per individual P2P shuffle-like operation _barrier_key = barrier_key(ShuffleId(partial_token)) - slice_name = f"rechunk-slice-{partial_token}" - transfer_name = f"rechunk-transfer-{partial_token}" disk: bool = dask.config.get("distributed.p2p.disk") ndim = len(x.shape) @@ -414,19 +430,20 @@ def partial_rechunk( axis[index] for index, axis in zip(global_index, x.chunks) ) if _slicing_is_necessary(ndslice, original_shape): - input_task = (slice_name,) + global_index - dsk[(slice_name,) + global_index] = ( + input_key = (slice_group,) + ndpartial.ix + global_index + dsk[input_key] = ( getitem, (x.name,) + global_index, ndslice, ) else: - input_task = (x.name,) + global_index + input_key = (x.name,) + global_index - transfer_keys.append((transfer_name,) + global_index) - dsk[(transfer_name,) + global_index] = ( + key = (transfer_group,) + ndpartial.ix + global_index + transfer_keys.append(key) + dsk[key] = ( rechunk_transfer, - input_task, + input_key, partial_token, partial_index, partial_new, @@ -439,7 +456,7 @@ def partial_rechunk( new_partial_offset = tuple(axis.start for axis in ndpartial.new) for partial_index in _partial_ndindex(ndpartial.new): global_index = _global_index(partial_index, new_partial_offset) - dsk[(name,) + global_index] = ( + dsk[(unpack_group,) + global_index] = ( rechunk_unpack, partial_token, partial_index, diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 4744dfa6dc..0d282ec0ed 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -12,6 +12,7 @@ da = pytest.importorskip("dask.array") from concurrent.futures import ThreadPoolExecutor +from itertools import product from tornado.ioloop import IOLoop @@ -33,7 +34,7 @@ split_axes, ) from distributed.shuffle.tests.utils import AbstractShuffleTestPool -from distributed.utils_test import gen_cluster, gen_test +from distributed.utils_test import async_poll_for, gen_cluster, gen_test NUMPY_GE_124 = parse_version(np.__version__) >= parse_version("1.24") @@ -85,9 +86,6 @@ def new_shuffle( return s -from itertools import product - - @pytest.mark.parametrize("n_workers", [1, 10]) @pytest.mark.parametrize("barrier_first_worker", [True, False]) @pytest.mark.parametrize("disk", [True, False]) @@ -1259,8 +1257,6 @@ def test_pick_worker_homogeneous_distribution(nworkers): config={"distributed.scheduler.active-memory-manager.start": False}, ) async def test_partial_rechunk_homogeneous_distribution(c, s, *workers): - da = pytest.importorskip("dask.array") - # This rechunk operation can be split into 10 independent shuffles with 4 output # chunks each. This is less than the number of workers, so we are at risk of # choosing the same 4 output workers in each separate shuffle. @@ -1275,3 +1271,45 @@ async def test_partial_rechunk_homogeneous_distribution(c, s, *workers): nchunks = [len(w.data.keys() & out_keys) for w in workers] # There are 40 output chunks and 5 workers. Expect exactly 8 chunks per worker. assert nchunks == [8, 8, 8, 8, 8] + + +@gen_cluster(client=True, nthreads=[], config={"optimization.fuse.active": False}) +async def test_partial_rechunk_taskgroups(c, s): + """Regression test for https://github.com/dask/distributed/issues/8656""" + arr = da.random.random( + (10, 10, 10), + chunks=( + ( + 2, + 2, + 2, + 2, + 2, + ), + ) + * 3, + ) + arr = arr.rechunk( + ( + ( + 1, + 2, + 2, + 2, + 2, + 1, + ), + ) + * 3, + method="p2p", + ) + + _ = c.compute(arr) + await async_poll_for( + lambda: any( + isinstance(task, str) and task.startswith("shuffle-barrier") + for task in s.tasks + ), + timeout=5, + ) + assert len(s.task_groups) < 7 From 3bf0ea6542b00e9dcfa65b38bb864f6146f7b198 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 31 May 2024 16:11:18 -0500 Subject: [PATCH 007/138] bump version to 2024.5.2 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1ef856d899..b9752af0e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ requires-python = ">=3.9" dependencies = [ "click >= 8.0", "cloudpickle >= 1.5.0", - "dask == 2024.5.1", + "dask == 2024.5.2", "jinja2 >= 2.10.3", "locket >= 1.0.0", "msgpack >= 1.0.0", From cbc21dff47dbc40ecc15664dc47d4f0170b36600 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 3 Jun 2024 16:02:34 +0200 Subject: [PATCH 008/138] Reduce noise from erring tasks that are not supposed to be running (#8664) --- distributed/tests/test_worker.py | 22 ++++++++++++++++++++++ distributed/worker.py | 30 ++++++++++++++++-------------- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index af93668c6e..969a98ef87 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3777,3 +3777,25 @@ async def test_suppress_keyerror_for_cancelled_tasks(c, s, a, state): await async_poll_for(lambda: not b.state.tasks, timeout=5) assert not log.getvalue() + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_suppress_compute_failure_for_cancelled_tasks(c, s, a): + with captured_logger("distributed.worker", level=logging.WARNING) as log: + in_event = Event() + block_event = Event() + + def block_and_raise(in_event, block_event): + in_event.set() + block_event.wait() + return 1 / 0 + + x = c.submit(block_and_raise, in_event, block_event, key="x") + await in_event.wait() + del x + + await wait_for_state("x", "cancelled", a) + await block_event.set() + await async_poll_for(lambda: not a.state.tasks, timeout=5) + + assert not log.getvalue() diff --git a/distributed/worker.py b/distributed/worker.py index cd7f60efea..df5bdb0c59 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2340,19 +2340,21 @@ async def execute(self, key: Key, *, stimulus_id: str) -> StateMachineEvent: key=ts.key, stimulus_id=f"cancelled-by-worker-close-{time()}" ) - logger.warning( - "Compute Failed\n" - "Key: %s\n" - "Function: %s\n" - "args: %s\n" - "kwargs: %s\n" - "Exception: %r\n", - key, - str(funcname(function))[:1000], - convert_args_to_str(args2, max_len=1000), - convert_kwargs_to_str(kwargs2, max_len=1000), - result["exception_text"], - ) + if ts.state in ("executing", "long-running", "resumed"): + logger.warning( + "Compute Failed\n" + "Key: %s\n" + "Function: %s\n" + "args: %s\n" + "kwargs: %s\n" + "Exception: %r\n", + key, + str(funcname(function))[:1000], + convert_args_to_str(args2, max_len=1000), + convert_kwargs_to_str(kwargs2, max_len=1000), + result["exception_text"], + ) + return ExecuteFailureEvent.from_exception( result, key=key, @@ -2370,7 +2372,7 @@ async def execute(self, key: Key, *, stimulus_id: str) -> StateMachineEvent: # _prepare_args_for_execution() to raise KeyError; # - A dependency was unspilled but failed to deserialize due to a bug in # user-defined or third party classes. - if ts.state == "executing": + if ts.state in ("executing", "long-running"): logger.error( f"Exception during execution of task {key!r}", exc_info=True, From 25f0732065c393f8f6652aa1990ad8435032bde6 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 4 Jun 2024 12:55:28 +0200 Subject: [PATCH 009/138] Fix too strict assertion in shuffle code for pandas subclasses (#8667) --- distributed/shuffle/_shuffle.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 21e3e38836..d08d579778 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -322,8 +322,7 @@ def split_by_worker( # (cudf support) Avoid pd.Series constructor = df._constructor_sliced - assert isinstance(constructor, type) - worker_for = constructor(worker_for) + worker_for = constructor(worker_for) # type: ignore df = df.merge( right=worker_for.cat.codes.rename("_worker"), left_on=column, From c455e4d300e79f318194c62b3edaf3ba74528214 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 4 Jun 2024 15:38:58 +0200 Subject: [PATCH 010/138] Add Prometheus gauge for task groups (#8661) --- distributed/http/scheduler/prometheus/core.py | 6 ++ .../scheduler/tests/test_scheduler_http.py | 63 +++++++++++++++++++ docs/source/prometheus.rst | 2 + 3 files changed, 71 insertions(+) diff --git a/distributed/http/scheduler/prometheus/core.py b/distributed/http/scheduler/prometheus/core.py index b3ac06b6b2..0a5bbf3020 100644 --- a/distributed/http/scheduler/prometheus/core.py +++ b/distributed/http/scheduler/prometheus/core.py @@ -102,6 +102,12 @@ def collect(self) -> Iterator[GaugeMetricFamily | CounterMetricFamily]: tasks.add_metric([state], task_counter.get(state, 0.0)) yield tasks + yield GaugeMetricFamily( + self.build_name("task_groups"), + "Number of task groups known by scheduler", + value=len(self.server.task_groups), + ) + time_spent_compute_tasks = CounterMetricFamily( self.build_name("tasks_compute"), "Total amount of compute time spent in each prefix", diff --git a/distributed/http/scheduler/tests/test_scheduler_http.py b/distributed/http/scheduler/tests/test_scheduler_http.py index 25ede716fb..78a9744ce1 100644 --- a/distributed/http/scheduler/tests/test_scheduler_http.py +++ b/distributed/http/scheduler/tests/test_scheduler_http.py @@ -3,6 +3,7 @@ import asyncio import json import re +from typing import Any from unittest import mock import pytest @@ -121,6 +122,7 @@ async def test_prometheus(c, s, a, b): "dask_scheduler_tasks_output_bytes", "dask_scheduler_tasks_compute_seconds", "dask_scheduler_tasks_transfer_seconds", + "dask_scheduler_task_groups", "dask_scheduler_prefix_state_totals", "dask_scheduler_tick_count", "dask_scheduler_tick_duration_maximum_seconds", @@ -217,6 +219,67 @@ async def fetch_state_metrics(): assert sum(forgotten_tasks) == 0.0 +@gen_cluster(client=True) +async def test_prometheus_collect_task_groups(c, s, a, b): + pytest.importorskip("prometheus_client") + + async def fetch_task_groups_metric(): + families = await fetch_metrics(s.http_server.port, prefix="dask_scheduler_") + return families["dask_scheduler_task_groups"] + + assert not s.task_groups + metric = await fetch_task_groups_metric() + assert len(metric.samples) == 1 + assert metric.samples[0].value == 0 + + # submit a task which should show up in the prometheus scraping + def block(x: Any, in_event: Event, block_event: Event) -> Any: + in_event.set() + block_event.wait() + return x + + in_event = Event() + block_event = Event() + + future = c.submit(block, 1, in_event, block_event, key=("block-first", 1)) + + await in_event.wait() + assert len(s.task_groups) == 1 + metric = await fetch_task_groups_metric() + assert len(metric.samples) == 1 + assert metric.samples[0].value == 1 + + in_event_2 = Event() + block_event_2 = Event() + + future2 = c.submit(block, 2, in_event_2, block_event_2, key=("block-second", 2)) + + await in_event_2.wait() + assert len(s.task_groups) == 2 + metric = await fetch_task_groups_metric() + assert len(metric.samples) == 1 + assert metric.samples[0].value == 2 + + await block_event.set() + res = await c.gather(future) + assert res == 1 + + await block_event_2.set() + res2 = await c.gather(future2) + assert res2 == 2 + + future.release() + future2.release() + + while s.task_groups: + await asyncio.sleep(0.001) + + assert not s.task_groups + metric = await fetch_task_groups_metric() + assert len(metric.samples) == 1 + assert metric.samples[0].value == 0 + + @gen_cluster(client=True, clean_kwargs={"threads": False}) async def test_prometheus_collect_task_prefix_counts(c, s, a, b): pytest.importorskip("prometheus_client") diff --git a/docs/source/prometheus.rst b/docs/source/prometheus.rst index 8ca3165cd4..6317e57195 100644 --- a/docs/source/prometheus.rst +++ b/docs/source/prometheus.rst @@ -62,6 +62,8 @@ dask_scheduler_tasks_output_bytes Note that when a task output is transferred between worker, you'll typically end up with a duplicate, so this measure is going to be lower than the actual cluster-wide managed memory. See also ``dask_worker_memory_bytes``, which does count duplicates. +dask_scheduler_task_groups + Number of task groups known by scheduler dask_scheduler_prefix_state_totals_total Accumulated count of task prefix in each state dask_scheduler_tick_count_total From 45e8091b49e407c94e7332da57dcf66363c56d69 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 4 Jun 2024 15:40:15 +0200 Subject: [PATCH 011/138] Log task state in Compute Failed (#8668) --- distributed/tests/test_worker.py | 1 + distributed/worker.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 969a98ef87..33820c0e3a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -174,6 +174,7 @@ def reset(self): assert any("1 / 0" in line for line in pluck(3, traceback.extract_tb(tb)) if line) assert "Compute Failed" in hdlr.messages["warning"][0] assert y.key in hdlr.messages["warning"][0] + assert "executing" in hdlr.messages["warning"][0] logger.setLevel(old_level) # Now we check that both workers are still alive. diff --git a/distributed/worker.py b/distributed/worker.py index df5bdb0c59..b5f673633d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2344,11 +2344,13 @@ async def execute(self, key: Key, *, stimulus_id: str) -> StateMachineEvent: logger.warning( "Compute Failed\n" "Key: %s\n" + "State: %s\n" "Function: %s\n" "args: %s\n" "kwargs: %s\n" "Exception: %r\n", key, + ts.state, str(funcname(function))[:1000], convert_args_to_str(args2, max_len=1000), convert_kwargs_to_str(kwargs2, max_len=1000), From 7425ed0ab4d437d2675f5739ede20b4f013b6079 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 4 Jun 2024 15:47:15 +0200 Subject: [PATCH 012/138] Iterate over copy of `Server.digests_total_since_heartbeat` to avoid `RuntimeError` (#8670) --- distributed/spans.py | 13 +++++-------- distributed/worker.py | 13 ++++++------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/distributed/spans.py b/distributed/spans.py index 354ef62b47..1b7a5191c3 100644 --- a/distributed/spans.py +++ b/distributed/spans.py @@ -4,7 +4,7 @@ import uuid import weakref from collections import defaultdict -from collections.abc import Hashable, Iterable, Iterator +from collections.abc import Hashable, Iterable, Iterator, Mapping from contextlib import contextmanager from itertools import islice from typing import TYPE_CHECKING, Any, TypedDict @@ -655,18 +655,15 @@ def __init__(self, worker: Worker): self.worker = worker self.digests_total_since_heartbeat = {} - def collect_digests(self) -> None: - """Make a local copy of Worker.digests_total_since_heartbeat. We can't just - parse it directly in heartbeat() as the event loop may be yielded between its - call and `self.worker.digests_total_since_heartbeat.clear()`, causing the - scheduler to become misaligned with the workers. - """ + def collect_digests( + self, digests_total_since_heartbeat: Mapping[Hashable, float] + ) -> None: # Note: this method may be called spuriously by Worker._register_with_scheduler, # but when it does it's guaranteed not to find any metrics assert not self.digests_total_since_heartbeat self.digests_total_since_heartbeat = { k: v - for k, v in self.worker.digests_total_since_heartbeat.items() + for k, v in digests_total_since_heartbeat.items() if isinstance(k, tuple) and k[0] in CONTEXTS_WITH_SPAN_ID } diff --git a/distributed/worker.py b/distributed/worker.py index b5f673633d..5fe7a4da1b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1029,24 +1029,23 @@ async def get_metrics(self) -> dict: spilled_memory, spilled_disk = 0, 0 # Send Fine Performance Metrics - # Make sure we do not yield the event loop between the moment we parse - # self.digests_total_since_heartbeat to send it to the scheduler and the moment - # we clear it! + # Swap the dictionary to avoid updates while we iterate over it + digests_total_since_heartbeat = self.digests_total_since_heartbeat + self.digests_total_since_heartbeat = defaultdict(int) + spans_ext: SpansWorkerExtension | None = self.extensions.get("spans") if spans_ext: # Send metrics with disaggregated span_id - spans_ext.collect_digests() + spans_ext.collect_digests(digests_total_since_heartbeat) # Send metrics with squashed span_id # Don't cast int metrics to float digests: defaultdict[Hashable, float] = defaultdict(int) - for k, v in self.digests_total_since_heartbeat.items(): + for k, v in digests_total_since_heartbeat.items(): if isinstance(k, tuple) and k[0] in CONTEXTS_WITH_SPAN_ID: k = k[:1] + k[2:] digests[k] += v - self.digests_total_since_heartbeat.clear() - out: dict = dict( task_counts=self.state.task_counter.current_count(by_prefix=False), bandwidth={ From 366286e85f972b46f21f6468bff7e78eb7b7ed27 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 4 Jun 2024 15:55:26 +0200 Subject: [PATCH 013/138] Adjust P2P tests for dask-expr (#8662) --- distributed/shuffle/tests/test_shuffle.py | 50 +++++++++++++---------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index e0b7ca7542..663b559eaf 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -535,18 +535,21 @@ async def test_crashed_worker_during_transfer(c, s, a): async def test_restarting_does_not_deadlock(c, s): """Regression test for https://github.com/dask/distributed/issues/8088""" async with Worker(s.address) as a: + # Ensure that a holds the input tasks to the shuffle + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + df = await c.persist(df) + expected = await c.compute(df) + async with Nanny(s.address) as b: - # Ensure that a holds the input tasks to the shuffle - with dask.annotate(workers=[a.address]): - df = dask.datasets.timeseries( - start="2000-01-01", - end="2000-03-01", - dtypes={"x": float, "y": float}, - freq="10 s", - ) with dask.config.set({"dataframe.shuffle.method": "p2p"}): out = df.shuffle("x") - fut = c.compute(out.x.size) + assert not s.workers[b.worker_address].has_what + result = c.compute(out) await wait_until_worker_has_tasks( "shuffle-transfer", b.worker_address, 1, s ) @@ -560,7 +563,8 @@ async def test_restarting_does_not_deadlock(c, s): a.status = Status.running await async_poll_for(lambda: s.running, timeout=5) - await fut + result = await result + assert dd.assert_eq(result, expected) @gen_cluster(client=True, nthreads=[("", 1)] * 2) @@ -631,19 +635,23 @@ def mock_mock_get_worker_for_range_sharding( await check_scheduler_cleanup(s) -@pytest.mark.slow +# @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)] * 3) async def test_closed_bystanding_worker_during_shuffle(c, s, w1, w2, w3): - with dask.annotate(workers=[w1.address, w2.address], allow_other_workers=False): - df = dask.datasets.timeseries( - start="2000-01-01", - end="2000-02-01", - dtypes={"x": float, "y": float}, - freq="10 s", - ) - with dask.config.set({"dataframe.shuffle.method": "p2p"}): - shuffled = df.shuffle("x") - fut = c.compute([shuffled, df], sync=True) + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-02-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + with dask.config.set({"dataframe.shuffle.method": "p2p"}): + shuffled = df.shuffle("x") + fut = c.compute( + [shuffled, df], + sync=True, + workers=[w1.address, w2.address], + allow_other_workers=False, + ) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, w1) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, w2) await w3.close() From 7aea988722f285e9fc5a67a817f664bd14ef47a5 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 4 Jun 2024 18:02:27 +0200 Subject: [PATCH 014/138] Improved errors and reduced logging for P2P RPC calls (#8666) --- distributed/shuffle/_comms.py | 15 +- distributed/shuffle/_core.py | 31 +++-- distributed/shuffle/_exceptions.py | 10 +- distributed/shuffle/_scheduler_plugin.py | 129 +++++++++++------- distributed/shuffle/_worker_plugin.py | 44 +++--- distributed/shuffle/tests/test_comm_buffer.py | 2 + distributed/shuffle/tests/test_shuffle.py | 26 ++-- distributed/shuffle/tests/utils.py | 2 +- 8 files changed, 164 insertions(+), 95 deletions(-) diff --git a/distributed/shuffle/_comms.py b/distributed/shuffle/_comms.py index 50094afddd..8886a1d757 100644 --- a/distributed/shuffle/_comms.py +++ b/distributed/shuffle/_comms.py @@ -5,10 +5,10 @@ from dask.utils import parse_bytes +from distributed.core import ErrorMessage, OKMessage, clean_exception from distributed.metrics import context_meter from distributed.shuffle._disk import ShardsBuffer from distributed.shuffle._limiter import ResourceLimiter -from distributed.utils import log_errors class CommShardsBuffer(ShardsBuffer): @@ -53,7 +53,9 @@ class CommShardsBuffer(ShardsBuffer): def __init__( self, - send: Callable[[str, list[tuple[Any, Any]]], Awaitable[None]], + send: Callable[ + [str, list[tuple[Any, Any]]], Awaitable[OKMessage | ErrorMessage] + ], memory_limiter: ResourceLimiter, concurrency_limit: int = 10, ): @@ -64,9 +66,14 @@ def __init__( ) self.send = send - @log_errors async def _process(self, address: str, shards: list[tuple[Any, Any]]) -> None: """Send one message off to a neighboring worker""" # Consider boosting total_size a bit here to account for duplication with context_meter.meter("send"): - await self.send(address, shards) + response = await self.send(address, shards) + status = response["status"] + if status == "error": + _, exc, tb = clean_exception(**response) + assert exc + raise exc.with_traceback(tb) + assert status == "OK" diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index f4f69266cc..43badfbf5e 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -29,13 +29,14 @@ from dask.typing import Key from dask.utils import parse_timedelta -from distributed.core import PooledRPCCall +from distributed.core import ErrorMessage, OKMessage, PooledRPCCall, error_message from distributed.exceptions import Reschedule from distributed.metrics import context_meter, thread_time from distributed.protocol import to_serialize +from distributed.protocol.serialize import ToPickle from distributed.shuffle._comms import CommShardsBuffer from distributed.shuffle._disk import DiskShardsBuffer -from distributed.shuffle._exceptions import ShuffleClosedError +from distributed.shuffle._exceptions import P2PConsistencyError, ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._memory import MemoryShardsBuffer from distributed.utils import run_in_executor_with_context, sync @@ -59,6 +60,10 @@ _T = TypeVar("_T") +class RunSpecMessage(OKMessage): + run_spec: ShuffleRunSpec | ToPickle[ShuffleRunSpec] + + class ShuffleRun(Generic[_T_partition_id, _T_partition_type]): id: ShuffleId run_id: int @@ -199,7 +204,7 @@ async def barrier(self, run_ids: Sequence[int]) -> int: async def _send( self, address: str, shards: list[tuple[_T_partition_id, Any]] | bytes - ) -> None: + ) -> OKMessage | ErrorMessage: self.raise_if_closed() return await self.rpc(address).shuffle_receive( data=to_serialize(shards), @@ -209,7 +214,7 @@ async def _send( async def send( self, address: str, shards: list[tuple[_T_partition_id, Any]] - ) -> None: + ) -> OKMessage | ErrorMessage: if _mean_shard_size(shards) < 65536: # Don't send buffers individually over the tcp comms. # Instead, merge everything into an opaque bytes blob, send it all at once, @@ -220,7 +225,7 @@ async def send( else: shards_or_bytes = shards - def _send() -> Coroutine[Any, Any, None]: + def _send() -> Coroutine[Any, Any, OKMessage | ErrorMessage]: return self._send(address, shards_or_bytes) return await retry( @@ -302,11 +307,17 @@ def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing self.raise_if_closed() return self._disk_buffer.read("_".join(str(i) for i in id)) - async def receive(self, data: list[tuple[_T_partition_id, Any]] | bytes) -> None: - if isinstance(data, bytes): - # Unpack opaque blob. See send() - data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data)) - await self._receive(data) + async def receive( + self, data: list[tuple[_T_partition_id, Any]] | bytes + ) -> OKMessage | ErrorMessage: + try: + if isinstance(data, bytes): + # Unpack opaque blob. See send() + data = cast(list[tuple[_T_partition_id, Any]], pickle.loads(data)) + await self._receive(data) + return {"status": "OK"} + except P2PConsistencyError as e: + return error_message(e) async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None: assigned_worker = self._get_assigned_worker(i) diff --git a/distributed/shuffle/_exceptions.py b/distributed/shuffle/_exceptions.py index 8031b8f399..27bdb8dd56 100644 --- a/distributed/shuffle/_exceptions.py +++ b/distributed/shuffle/_exceptions.py @@ -1,7 +1,15 @@ from __future__ import annotations -class ShuffleClosedError(RuntimeError): +class P2PIllegalStateError(RuntimeError): + pass + + +class P2PConsistencyError(RuntimeError): + pass + + +class ShuffleClosedError(P2PConsistencyError): pass diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index ef646bcea0..132f34387c 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -8,11 +8,13 @@ from dask.typing import Key +from distributed.core import ErrorMessage, OKMessage, error_message from distributed.diagnostics.plugin import SchedulerPlugin from distributed.metrics import time from distributed.protocol.pickle import dumps from distributed.protocol.serialize import ToPickle from distributed.shuffle._core import ( + RunSpecMessage, SchedulerShuffleState, ShuffleId, ShuffleRunSpec, @@ -20,6 +22,7 @@ barrier_key, id_from_key, ) +from distributed.shuffle._exceptions import P2PConsistencyError, P2PIllegalStateError from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.utils import log_errors @@ -98,77 +101,97 @@ async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None: workers=list(shuffle.participating_workers), ) - def restrict_task(self, id: ShuffleId, run_id: int, key: Key, worker: str) -> dict: - shuffle = self.active_shuffles[id] - if shuffle.run_id > run_id: - return { - "status": "error", - "message": f"Request stale, expected {run_id=} for {shuffle}", - } - elif shuffle.run_id < run_id: - return { - "status": "error", - "message": f"Request invalid, expected {run_id=} for {shuffle}", - } - ts = self.scheduler.tasks[key] - self._set_restriction(ts, worker) - return {"status": "OK"} + def restrict_task( + self, id: ShuffleId, run_id: int, key: Key, worker: str + ) -> OKMessage | ErrorMessage: + try: + shuffle = self.active_shuffles[id] + if shuffle.run_id > run_id: + raise P2PConsistencyError( + f"Request stale, expected {run_id=} for {shuffle}" + ) + elif shuffle.run_id < run_id: + raise P2PConsistencyError( + f"Request invalid, expected {run_id=} for {shuffle}" + ) + ts = self.scheduler.tasks[key] + self._set_restriction(ts, worker) + return {"status": "OK"} + except P2PConsistencyError as e: + return error_message(e) def heartbeat(self, ws: WorkerState, data: dict) -> None: for shuffle_id, d in data.items(): if shuffle_id in self.shuffle_ids(): self.heartbeats[shuffle_id][ws.address].update(d) - def get(self, id: ShuffleId, worker: str) -> ToPickle[ShuffleRunSpec]: + def get(self, id: ShuffleId, worker: str) -> RunSpecMessage | ErrorMessage: + try: + try: + run_spec = self._get(id, worker) + return {"status": "OK", "run_spec": ToPickle(run_spec)} + except KeyError as e: + raise P2PConsistencyError( + f"No active shuffle with {id=!r} found" + ) from e + except P2PConsistencyError as e: + return error_message(e) + + def _get(self, id: ShuffleId, worker: str) -> ShuffleRunSpec: if worker not in self.scheduler.workers: # This should never happen - raise RuntimeError( + raise P2PConsistencyError( f"Scheduler is unaware of this worker {worker!r}" ) # pragma: nocover state = self.active_shuffles[id] state.participating_workers.add(worker) - return ToPickle(state.run_spec) + return state.run_spec + + def _create(self, spec: ShuffleSpec, key: Key, worker: str) -> ShuffleRunSpec: + # FIXME: The current implementation relies on the barrier task to be + # known by its name. If the name has been mangled, we cannot guarantee + # that the shuffle works as intended and should fail instead. + self._raise_if_barrier_unknown(spec.id) + self._raise_if_task_not_processing(key) + worker_for = self._calculate_worker_for(spec) + self._ensure_output_tasks_are_non_rootish(spec) + state = spec.create_new_run( + worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id + ) + self.active_shuffles[spec.id] = state + self._shuffles[spec.id].add(state) + state.participating_workers.add(worker) + logger.warning( + "Shuffle %s initialized by task %r executed on worker %s", + spec.id, + key, + worker, + ) + return state.run_spec def get_or_create( self, - # FIXME: This should never be ToPickle[ShuffleSpec] - spec: ShuffleSpec | ToPickle[ShuffleSpec], + spec: ShuffleSpec, key: Key, worker: str, - ) -> ToPickle[ShuffleRunSpec]: - # FIXME: Sometimes, this doesn't actually get pickled - if isinstance(spec, ToPickle): - spec = spec.data + ) -> RunSpecMessage | ErrorMessage: try: - return self.get(spec.id, worker) + run_spec = self._get(spec.id, worker) + except P2PConsistencyError as e: + return error_message(e) except KeyError: - # FIXME: The current implementation relies on the barrier task to be - # known by its name. If the name has been mangled, we cannot guarantee - # that the shuffle works as intended and should fail instead. - self._raise_if_barrier_unknown(spec.id) - self._raise_if_task_not_processing(key) - worker_for = self._calculate_worker_for(spec) - self._ensure_output_tasks_are_non_rootish(spec) - state = spec.create_new_run( - worker_for=worker_for, span_id=self.scheduler.tasks[key].group.span_id - ) - self.active_shuffles[spec.id] = state - self._shuffles[spec.id].add(state) - state.participating_workers.add(worker) - logger.warning( - "Shuffle %s initialized by task %r executed on worker %s", - spec.id, - key, - worker, - ) - return ToPickle(state.run_spec) + try: + run_spec = self._create(spec, key, worker) + except P2PConsistencyError as e: + return error_message(e) + return {"status": "OK", "run_spec": ToPickle(run_spec)} def _raise_if_barrier_unknown(self, id: ShuffleId) -> None: key = barrier_key(id) try: self.scheduler.tasks[key] except KeyError: - raise RuntimeError( + raise P2PConsistencyError( f"Barrier task with key {key!r} does not exist. This may be caused by " "task fusion during graph generation. Please let us know that you ran " "into this by leaving a comment at distributed#7816." @@ -177,7 +200,9 @@ def _raise_if_barrier_unknown(self, id: ShuffleId) -> None: def _raise_if_task_not_processing(self, key: Key) -> None: task = self.scheduler.tasks[key] if task.state != "processing": - raise RuntimeError(f"Expected {task} to be processing, is {task.state}.") + raise P2PConsistencyError( + f"Expected {task} to be processing, is {task.state}." + ) def _calculate_worker_for(self, spec: ShuffleSpec) -> dict[Any, str]: """Pin the outputs of a P2P shuffle to specific workers. @@ -235,7 +260,7 @@ def _calculate_worker_for(self, spec: ShuffleSpec) -> dict[Any, str]: if existing: # pragma: nocover for shared_key in existing.keys() & current_worker_for.keys(): if existing[shared_key] != current_worker_for[shared_key]: - raise RuntimeError( + raise P2PIllegalStateError( f"Failed to initialize shuffle {spec.id} because " "it cannot align output partition mappings between " f"existing shuffles {seen}. " @@ -316,7 +341,7 @@ def _restart_recommendations(self, id: ShuffleId) -> Recs: if barrier_task.state == "erred": # This should never happen, a dependent of the barrier should already # be `erred` - raise RuntimeError( + raise P2PIllegalStateError( f"Expected dependents of {barrier_task=} to be 'erred' if " "the barrier is." ) # pragma: no cover @@ -326,7 +351,7 @@ def _restart_recommendations(self, id: ShuffleId) -> Recs: if dt.state == "erred": # This should never happen, a dependent of the barrier should already # be `erred` - raise RuntimeError( + raise P2PIllegalStateError( f"Expected barrier and its dependents to be " f"'erred' if the barrier's dependency {dt} is." ) # pragma: no cover @@ -366,7 +391,9 @@ def remove_worker( shuffle_id, stimulus_id, ) - exception = RuntimeError(f"Worker {worker} left during active {shuffle}") + exception = P2PConsistencyError( + f"Worker {worker} left during active {shuffle}" + ) self._fail_on_workers(shuffle, str(exception)) self._clean_on_scheduler(shuffle_id, stimulus_id) diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 55414021ff..57d2cfe369 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -10,6 +10,7 @@ from dask.typing import Key from dask.utils import parse_bytes +from distributed.core import ErrorMessage, OKMessage, clean_exception, error_message from distributed.diagnostics.plugin import WorkerPlugin from distributed.protocol.serialize import ToPickle from distributed.shuffle._core import ( @@ -19,7 +20,7 @@ ShuffleRunSpec, ShuffleSpec, ) -from distributed.shuffle._exceptions import ShuffleClosedError +from distributed.shuffle._exceptions import P2PConsistencyError, ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter from distributed.utils import log_errors, sync @@ -67,7 +68,7 @@ def fail(self, shuffle_id: ShuffleId, run_id: int, message: str) -> None: if shuffle_run is None or shuffle_run.run_id != run_id: return self._active_runs.pop(shuffle_id) - exception = RuntimeError(message) + exception = P2PConsistencyError(message) shuffle_run.fail(exception) self._plugin.worker._ongoing_background_tasks.call_soon(self.close, shuffle_run) @@ -110,17 +111,19 @@ async def get_with_run_id(self, shuffle_id: ShuffleId, run_id: int) -> ShuffleRu ------ KeyError If the shuffle does not exist - RuntimeError + P2PConsistencyError If the run_id is stale + ShuffleClosedError + If the run manager has been closed """ shuffle_run = self._active_runs.get(shuffle_id, None) if shuffle_run is None or shuffle_run.run_id < run_id: shuffle_run = await self._refresh(shuffle_id=shuffle_id) if shuffle_run.run_id > run_id: - raise RuntimeError(f"{run_id=} stale, got {shuffle_run}") + raise P2PConsistencyError(f"{run_id=} stale, got {shuffle_run}") elif shuffle_run.run_id < run_id: - raise RuntimeError(f"{run_id=} invalid, got {shuffle_run}") + raise P2PConsistencyError(f"{run_id=} invalid, got {shuffle_run}") if self.closed: raise ShuffleClosedError(f"{self} has already been closed") @@ -172,7 +175,7 @@ async def get_most_recent( ------ KeyError If the shuffle does not exist - RuntimeError + P2PConsistencyError If the most recent run_id is stale """ return await self.get_with_run_id(shuffle_id=shuffle_id, run_id=max(run_ids)) @@ -183,22 +186,25 @@ async def _fetch( spec: ShuffleSpec | None = None, key: Key | None = None, ) -> ShuffleRunSpec: - # FIXME: This should never be ToPickle[ShuffleRunSpec] - result: ShuffleRunSpec | ToPickle[ShuffleRunSpec] if spec is None: - result = await self._plugin.worker.scheduler.shuffle_get( + response = await self._plugin.worker.scheduler.shuffle_get( id=shuffle_id, worker=self._plugin.worker.address, ) else: - result = await self._plugin.worker.scheduler.shuffle_get_or_create( + response = await self._plugin.worker.scheduler.shuffle_get_or_create( spec=ToPickle(spec), key=key, worker=self._plugin.worker.address, ) - if isinstance(result, ToPickle): - result = result.data - return result + + status = response["status"] + if status == "error": + _, exc, tb = clean_exception(**response) + assert exc + raise exc.with_traceback(tb) + assert status == "OK" + return response["run_spec"] @overload async def _refresh( @@ -236,7 +242,7 @@ async def _refresh( ) stale_run_id = self._stale_run_ids.get(shuffle_id, None) if stale_run_id is not None and stale_run_id >= result.run_id: - raise RuntimeError( + raise P2PConsistencyError( f"Received stale shuffle run with run_id={result.run_id};" f" expected run_id > {stale_run_id}" ) @@ -306,15 +312,17 @@ async def shuffle_receive( shuffle_id: ShuffleId, run_id: int, data: list[tuple[int, Any]] | bytes, - ) -> None: + ) -> OKMessage | ErrorMessage: """ Handler: Receive an incoming shard of data from a peer worker. Using an unknown ``shuffle_id`` is an error. """ - shuffle_run = await self._get_shuffle_run(shuffle_id, run_id) - await shuffle_run.receive(data) + try: + shuffle_run = await self._get_shuffle_run(shuffle_id, run_id) + return await shuffle_run.receive(data) + except P2PConsistencyError as e: + return error_message(e) - @log_errors async def shuffle_inputs_done(self, shuffle_id: ShuffleId, run_id: int) -> None: """ Handler: Inform the extension that all input partitions have been handed off to extensions. diff --git a/distributed/shuffle/tests/test_comm_buffer.py b/distributed/shuffle/tests/test_comm_buffer.py index 36896c547d..8cf58651af 100644 --- a/distributed/shuffle/tests/test_comm_buffer.py +++ b/distributed/shuffle/tests/test_comm_buffer.py @@ -62,6 +62,7 @@ async def send(address, shards): await block_send.wait() d[address].extend(shards) sending_first.set() + return {"status": "OK"} mc = CommShardsBuffer( send=send, concurrency_limit=1, memory_limiter=ResourceLimiter(None) @@ -131,6 +132,7 @@ async def send(address, shards): if counter == 5: raise OSError("error during send") d[address].extend(shards) + return {"status": "OK"} frac = 0.1 nshards = 10 diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 663b559eaf..3b6b75da4e 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -41,7 +41,7 @@ Scheduler, Worker, ) -from distributed.core import ConnectionPool +from distributed.core import ConnectionPool, ErrorMessage, OKMessage from distributed.scheduler import TaskState as SchedulerTaskState from distributed.shuffle._arrow import ( buffers_to_table, @@ -49,6 +49,7 @@ read_from_disk, serialize_table, ) +from distributed.shuffle._exceptions import P2PConsistencyError from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin from distributed.shuffle._shuffle import ( @@ -1919,7 +1920,7 @@ async def test_error_send(tmp_path, loop_in_thread): partitions_for_worker[w].append(part) class ErrorSend(DataFrameShuffleRun): - async def send(self, *args: Any, **kwargs: Any) -> None: + async def send(self, *args: Any, **kwargs: Any) -> OKMessage | ErrorMessage: raise RuntimeError("Error during send") with DataFrameShuffleTestPool() as local_shuffle_pool: @@ -1951,8 +1952,9 @@ async def send(self, *args: Any, **kwargs: Any) -> None: @pytest.mark.skipif(not pa, reason="Requires PyArrow") +@pytest.mark.parametrize("Error", [P2PConsistencyError, ValueError]) @gen_test() -async def test_error_receive(tmp_path, loop_in_thread): +async def test_error_receive(tmp_path, loop_in_thread, Error): dfs = [] rows_per_df = 10 n_input_partitions = 1 @@ -1976,7 +1978,7 @@ async def test_error_receive(tmp_path, loop_in_thread): class ErrorReceive(DataFrameShuffleRun): async def _receive(self, data: list[tuple[int, bytes]]) -> None: - raise RuntimeError("Error during receive") + raise Error("Error during receive") with DataFrameShuffleTestPool() as local_shuffle_pool: sA = local_shuffle_pool.new_shuffle( @@ -2000,7 +2002,7 @@ async def _receive(self, data: list[tuple[int, bytes]]) -> None: ) try: sB.add_partition(dfs[0], 0) - with pytest.raises(RuntimeError, match="Error during receive"): + with pytest.raises(Error, match="Error during receive"): await sB.barrier(run_ids=[sB.run_id]) finally: await asyncio.gather(*[s.close() for s in [sA, sB]]) @@ -2012,7 +2014,9 @@ def setup(self, worker: Worker) -> None: self.in_shuffle_receive = asyncio.Event() self.block_shuffle_receive = asyncio.Event() - async def shuffle_receive(self, *args: Any, **kwargs: Any) -> None: + async def shuffle_receive( + self, *args: Any, **kwargs: Any + ) -> OKMessage | ErrorMessage: self.in_shuffle_receive.set() await self.block_shuffle_receive.wait() return await super().shuffle_receive(*args, **kwargs) @@ -2144,7 +2148,7 @@ async def test_shuffle_run_consistency(c, s, a): out = out.persist() shuffle_id = await wait_until_new_shuffle_is_initialized(s) - spec = scheduler_ext.get(shuffle_id, a.worker_address).data + spec = scheduler_ext.get(shuffle_id, a.worker_address)["run_spec"].data # Shuffle run manager can fetch the current run assert await run_manager.get_with_run_id(shuffle_id, spec.run_id) @@ -2170,7 +2174,7 @@ async def test_shuffle_run_consistency(c, s, a): new_shuffle_id = await wait_until_new_shuffle_is_initialized(s) assert shuffle_id == new_shuffle_id - new_spec = scheduler_ext.get(shuffle_id, a.worker_address).data + new_spec = scheduler_ext.get(shuffle_id, a.worker_address)["run_spec"].data # Check invariant that the new run ID is larger than the previous assert spec.run_id < new_spec.run_id @@ -2196,7 +2200,9 @@ async def test_shuffle_run_consistency(c, s, a): independent_shuffle_id = await wait_until_new_shuffle_is_initialized(s) assert shuffle_id != independent_shuffle_id - independent_spec = scheduler_ext.get(independent_shuffle_id, a.worker_address).data + independent_spec = scheduler_ext.get(independent_shuffle_id, a.worker_address)[ + "run_spec" + ].data # Check invariant that the new run ID is larger than the previous # for independent shuffles @@ -2236,7 +2242,7 @@ async def test_fail_fetch_race(c, s, a): out = out.persist() shuffle_id = await wait_until_new_shuffle_is_initialized(s) - spec = scheduler_ext.get(shuffle_id, a.worker_address).data + spec = scheduler_ext.get(shuffle_id, a.worker_address)["run_spec"].data await worker_plugin.in_barrier.wait() # Pretend that the fail from the scheduler arrives first run_manager.fail(shuffle_id, spec.run_id, "error") diff --git a/distributed/shuffle/tests/utils.py b/distributed/shuffle/tests/utils.py index 1d4fb7319a..693a044373 100644 --- a/distributed/shuffle/tests/utils.py +++ b/distributed/shuffle/tests/utils.py @@ -32,7 +32,7 @@ async def _(**kwargs): # here. kwargs = _nested_deserialize(kwargs) meth = getattr(self.shuffle, method_name) - return await meth(**kwargs) + return _nested_deserialize(await meth(**kwargs)) return _ From 7cbfc4d6c2ee9c320d53b8328b38439715536691 Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Wed, 5 Jun 2024 13:16:41 +0100 Subject: [PATCH 015/138] Add safe keyword to `remove-worker` event (#8647) --- distributed/scheduler.py | 1 + distributed/tests/test_worker.py | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 86cf067f91..f752b47bf4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5299,6 +5299,7 @@ async def remove_worker( "lost-computed-tasks": recompute_keys, "lost-scattered-tasks": lost_keys, "stimulus_id": stimulus_id, + "safe": safe, # TODO change this to expected to be clearer } self.log_event(address, event_msg.copy()) event_msg["worker"] = address diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 33820c0e3a..85c56b6a1f 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2974,6 +2974,7 @@ async def test_worker_status_sync(s, a): "lost-computed-tasks": set(), "lost-scattered-tasks": set(), "processing-tasks": set(), + "safe": True, "stimulus_id": "retire-workers", }, {"action": "retired", "stimulus_id": "retire-workers"}, @@ -3052,6 +3053,7 @@ async def test_log_remove_worker(c, s, a, b): "lost-computed-tasks": set(), "lost-scattered-tasks": set(), "processing-tasks": {"y"}, + "safe": True, "stimulus_id": "graceful", }, {"action": "retired", "stimulus_id": "graceful"}, @@ -3075,6 +3077,7 @@ async def test_log_remove_worker(c, s, a, b): "lost-computed-tasks": {"x"}, "lost-scattered-tasks": {"z"}, "processing-tasks": {"y"}, + "safe": False, "stimulus_id": "ungraceful", }, {"action": "closing-worker", "reason": "scheduler-remove-worker"}, @@ -3085,6 +3088,7 @@ async def test_log_remove_worker(c, s, a, b): "lost-computed-tasks": set(), "lost-scattered-tasks": set(), "processing-tasks": {"y"}, + "safe": True, "stimulus_id": "graceful", "worker": a.address, }, @@ -3105,6 +3109,7 @@ async def test_log_remove_worker(c, s, a, b): "lost-computed-tasks": {"x"}, "lost-scattered-tasks": {"z"}, "processing-tasks": {"y"}, + "safe": False, "stimulus_id": "ungraceful", "worker": b.address, }, From 5708bdf37fe8daf697115cd6c90c07148b6dffb1 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Wed, 5 Jun 2024 17:00:02 +0200 Subject: [PATCH 016/138] Improve graph submission time for P2P rechunking by avoiding unpack recursion into indices (#8672) --- distributed/shuffle/_rechunk.py | 7 ++++--- distributed/tests/test_utils_comm.py | 21 +++++++++++++++++++++ distributed/utils_comm.py | 7 +++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index be36f47f56..b33e90730b 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -130,6 +130,7 @@ from distributed.shuffle._shuffle import barrier_key, shuffle_barrier from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.sizeof import sizeof +from distributed.utils_comm import DoNotUnpack if TYPE_CHECKING: import numpy as np @@ -445,9 +446,9 @@ def partial_rechunk( rechunk_transfer, input_key, partial_token, - partial_index, - partial_new, - partial_old, + DoNotUnpack(partial_index), + DoNotUnpack(partial_new), + DoNotUnpack(partial_old), disk, ) diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index ee0eaff089..94b1c33f3a 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -13,6 +13,7 @@ from distributed.config import get_loop_factory from distributed.core import ConnectionPool, Status from distributed.utils_comm import ( + DoNotUnpack, WrappedKey, gather_from_workers, pack_data, @@ -261,3 +262,23 @@ def assert_eq(keys1: set[WrappedKey], keys2: set[WrappedKey]) -> None: res, keys = unpack_remotedata(dsk) assert res == (sc, "arg1") # Notice, the first item (the SC) has NOT been changed assert_eq(keys, set()) + + +def test_unpack_remotedata_custom_tuple(): + # We don't want to recurse into custom tuples. This is used as a sentinel to + # avoid recursion for performance reasons if we know that there are no + # nested futures. This test case is not how this feature should be used in + # practice. + + akey = WrappedKey("a") + + ordinary_tuple = (1, 2, akey) + dont_recurse = DoNotUnpack(ordinary_tuple) + + res, keys = unpack_remotedata(ordinary_tuple) + assert res is not ordinary_tuple + assert any(left != right for left, right in zip(ordinary_tuple, res)) + assert keys == {akey} + res, keys = unpack_remotedata(dont_recurse) + assert not keys + assert res is dont_recurse diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index e0a9eda88b..7c10c25635 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -269,6 +269,13 @@ def _unpack_remotedata_inner( return o +class DoNotUnpack(tuple): + """A tuple sublass to indicate that we should not unpack its contents + + See also unpack_remotedata + """ + + def unpack_remotedata(o: Any, byte_keys: bool = False) -> tuple[Any, set]: """Unpack WrappedKey objects from collection From 490b69649b4eb54bcb06921eb42b8fe48d12d8c7 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 10 Jun 2024 11:01:24 +0200 Subject: [PATCH 017/138] Eagerly update aggregate statistics for `TaskPrefix` instead of calculating them on-demand (#8681) --- distributed/dashboard/components/scheduler.py | 18 +- .../diagnostics/tests/test_progress.py | 2 +- distributed/scheduler.py | 286 ++++++++++-------- distributed/tests/test_scheduler.py | 77 ++++- 4 files changed, 248 insertions(+), 135 deletions(-) diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 5dbc99c2fd..af669c8c23 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -3342,15 +3342,15 @@ def update(self): } for tp in self.scheduler.task_prefixes.values(): - active_states = tp.active_states - if any(active_states.get(s) for s in state.keys()): - state["memory"][tp.name] = active_states["memory"] - state["erred"][tp.name] = active_states["erred"] - state["released"][tp.name] = active_states["released"] - state["processing"][tp.name] = active_states["processing"] - state["waiting"][tp.name] = active_states["waiting"] - state["queued"][tp.name] = active_states["queued"] - state["no_worker"][tp.name] = active_states["no-worker"] + states = tp.states + if any(states.get(s) for s in state.keys()): + state["memory"][tp.name] = states["memory"] + state["erred"][tp.name] = states["erred"] + state["released"][tp.name] = states["released"] + state["processing"][tp.name] = states["processing"] + state["waiting"][tp.name] = states["waiting"] + state["queued"][tp.name] = states["queued"] + state["no_worker"][tp.name] = states["no-worker"] state["all"] = {k: sum(v[k] for v in state.values()) for k in state["memory"]} diff --git a/distributed/diagnostics/tests/test_progress.py b/distributed/diagnostics/tests/test_progress.py index 2d4310b992..7be880088c 100644 --- a/distributed/diagnostics/tests/test_progress.py +++ b/distributed/diagnostics/tests/test_progress.py @@ -257,7 +257,7 @@ async def test_group_timing(c, s, a, b): assert s.task_groups.keys() == p.compute.keys() assert all( [ - abs(s.task_groups[k].all_durations["compute"] - sum(v)) < 1.0e-12 + abs(s.task_groups[k].all_durations["compute"] - sum(v)) < 1.0e-6 * len(v) for k, v in p.compute.items() ] ) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f752b47bf4..84904c9270 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -57,6 +57,7 @@ from dask.core import get_deps, iskey, validate_key from dask.typing import Key, no_default from dask.utils import ( + _deprecated, ensure_dict, format_bytes, format_time, @@ -921,46 +922,112 @@ def _repr_html_(self) -> str: ) -class TaskPrefix: - """Collection tracking all tasks within a group - - Keys often have a structure like ``("x-123", 0)`` - A group takes the first section, like ``"x"`` +class TaskCollection: + """Abstract collection tracking all tasks See Also -------- TaskGroup + TaskPrefix """ - #: The name of a group of tasks. - #: For a task like ``("x-123", 0)`` this is the text ``"x"`` + #: The name of a collection of tasks. name: str + #: The total number of bytes that tasks belonging to this collection have produced + nbytes_total: int + + #: The number of tasks belonging to this collection in each state, + #: like ``{"memory": 10, "processing": 3, "released": 4, ...}`` + states: dict[TaskStateState, int] + + _all_durations_us: defaultdict[str, int] + + _duration_us: int + + _types: defaultdict[str, int] + + __slots__ = tuple(__annotations__) + + def __init__(self, name: str): + self.name = name + self._all_durations_us = defaultdict(int) + self._duration_us = 0 + self.nbytes_total = 0 + self.states = dict.fromkeys(ALL_TASK_STATES, 0) + self._types = defaultdict(int) + + def add(self, other: TaskState) -> None: + self.states[other.state] += 1 + + def add_duration(self, action: str, start: float, stop: float) -> None: + duration_us = self._calculate_duration_us(start, stop) + self._duration_us += duration_us + self._all_durations_us[action] += duration_us + + def add_type(self, typename: str) -> None: + self._types[typename] += 1 + + @property + def all_durations(self) -> defaultdict[str, float]: + """Cumulative duration of all completed actions of tasks belonging to this collection, by action""" + return defaultdict( + float, + { + action: duration_us / 1e6 + for action, duration_us in self._all_durations_us.items() + }, + ) + + @property + def duration(self) -> float: + """The total amount of time spent on all tasks belonging to this collection""" + return self._duration_us / 1e6 + + def transition(self, old: TaskStateState, new: TaskStateState) -> None: + self.states[old] -= 1 + self.states[new] += 1 + + @property + def types(self) -> Set[str]: + """The result types of this collection""" + return self._types.keys() + + def update_nbytes(self, diff: int) -> None: + self.nbytes_total += diff + + @staticmethod + def _calculate_duration_us(start: float, stop: float) -> int: + return max(round((stop - start) * 1e6), 0) + + +class TaskPrefix(TaskCollection): + """Collection tracking all tasks within a prefix + + See Also + -------- + TaskGroup + """ + #: An exponentially weighted moving average duration of all tasks with this prefix duration_average: float #: Numbers of times a task was marked as suspicious with this prefix suspicious: int - #: Store timings for each prefix-action - all_durations: defaultdict[str, float] - #: This measures the maximum recorded live execution time and can be used to #: detect outliers max_exec_time: float - #: Task groups associated to this prefix - groups: list[TaskGroup] - #: Accumulate count of number of tasks in each state state_counts: defaultdict[TaskStateState, int] + _groups: dict[TaskGroup, None] + __slots__ = tuple(__annotations__) def __init__(self, name: str): - self.name = name - self.groups = [] - self.all_durations = defaultdict(float) + super().__init__(name) self.state_counts = defaultdict(int) task_durations = dask.config.get("distributed.scheduler.default-task-durations") if self.name in task_durations: @@ -969,6 +1036,7 @@ def __init__(self, name: str): self.duration_average = -1 self.max_exec_time = -1 self.suspicious = 0 + self._groups = {} def add_exec_time(self, duration: float) -> None: self.max_exec_time = max(duration, self.max_exec_time) @@ -976,33 +1044,48 @@ def add_exec_time(self, duration: float) -> None: self.duration_average = -1 def add_duration(self, action: str, start: float, stop: float) -> None: - duration = stop - start - self.all_durations[action] += duration + super().add_duration(action, start, stop) + duration_s = self._calculate_duration_us(start, stop) / 1e6 if action == "compute": old = self.duration_average if old < 0: - self.duration_average = duration + self.duration_average = duration_s else: - self.duration_average = 0.5 * duration + 0.5 * old + self.duration_average = 0.5 * duration_s + 0.5 * old + + def transition(self, old: TaskStateState, new: TaskStateState) -> None: + super().transition(old, new) + self.state_counts[new] += 1 + + def add_group(self, tg: TaskGroup) -> None: + self._groups[tg] = None + + def remove_group(self, tg: TaskGroup) -> None: + # This is important, we need to adjust the stats + self._groups.pop(tg) + for state, count in tg.states.items(): + self.states[state] -= count + self._duration_us -= tg._duration_us + self.nbytes_total -= tg.nbytes_total + for typename, count in tg._types.items(): + self._types[typename] -= count + if self._types[typename] == 0: + del self._types[typename] @property - def states(self) -> dict[TaskStateState, int]: - """The number of tasks in each state, - like ``{"memory": 10, "processing": 3, "released": 4, ...}`` - """ - return merge_with(sum, [tg.states for tg in self.groups]) + @_deprecated(use_instead="groups") # type: ignore[misc] + def active(self) -> Set[TaskGroup]: + return self.groups @property - def active(self) -> list[TaskGroup]: - return [ - tg - for tg in self.groups - if any(k != "forgotten" and v != 0 for k, v in tg.states.items()) - ] + def groups(self) -> Set[TaskGroup]: + """Insertion-sorted set-like of groups associated to this prefix""" + return self._groups.keys() @property + @_deprecated(use_instead="states") # type: ignore[misc] def active_states(self) -> dict[TaskStateState, int]: - return merge_with(sum, [tg.states for tg in self.active]) + return self.states def __repr__(self) -> str: return ( @@ -1015,53 +1098,18 @@ def __repr__(self) -> str: + ">" ) - @property - def nbytes_total(self) -> int: - return sum(tg.nbytes_total for tg in self.groups) - - def __len__(self) -> int: - return sum(map(len, self.groups)) - - @property - def duration(self) -> float: - return sum(tg.duration for tg in self.groups) - @property - def types(self) -> set[str]: - return {typ for tg in self.groups for typ in tg.types} - - -class TaskGroup: +class TaskGroup(TaskCollection): """Collection tracking all tasks within a group - Keys often have a structure like ``("x-123", 0)`` - A group takes the first section, like ``"x-123"`` - See also -------- TaskPrefix """ - #: The name of a group of tasks. - #: For a task like ``("x-123", 0)`` this is the text ``"x-123"`` - name: str - - #: The number of tasks in each state, - #: like ``{"memory": 10, "processing": 3, "released": 4, ...}`` - states: dict[TaskStateState, int] - #: The other TaskGroups on which this one depends dependencies: set[TaskGroup] - #: The total number of bytes that this task group has produced - nbytes_total: int - - #: The total amount of time spent on all tasks in this TaskGroup - duration: float - - #: The result types of this TaskGroup - types: set[str] - #: The worker most recently assigned a task from this group, or None when the group #: is not identified to be root-like by `SchedulerState.decide_worker`. last_worker: WorkerState | None @@ -1070,7 +1118,7 @@ class TaskGroup: #: subsequent tasks until a new worker is chosen. last_worker_tasks_left: int - prefix: TaskPrefix | None + prefix: TaskPrefix #: Earliest time when a task belonging to this group started computing; #: 0 if no task has *finished* computing yet @@ -1085,9 +1133,6 @@ class TaskGroup: #: 0 if no task has finished computing yet stop: float - #: Cumulative duration of all completed actions, by action - all_durations: defaultdict[str, float] - #: Span ID (see ``distributed.spans``). #: Matches ``distributed.worker_state_machine.TaskState.span_id``. #: It is possible to end up in situation where different tasks of the same TaskGroup @@ -1097,37 +1142,43 @@ class TaskGroup: __slots__ = tuple(__annotations__) - def __init__(self, name: str): - self.name = name - self.prefix = None - self.states = dict.fromkeys(ALL_TASK_STATES, 0) + def __init__(self, name: str, prefix: TaskPrefix): + super().__init__(name) self.dependencies = set() - self.nbytes_total = 0 - self.duration = 0 - self.types = set() self.start = 0.0 self.stop = 0.0 - self.all_durations = defaultdict(float) self.last_worker = None self.last_worker_tasks_left = 0 self.span_id = None + self.prefix = prefix + prefix.add_group(self) def add_duration(self, action: str, start: float, stop: float) -> None: - duration = stop - start - self.duration += duration - self.all_durations[action] += duration + super().add_duration(action, start, stop) if action == "compute": if self.stop < stop: self.stop = stop if self.start == 0.0 or self.start > start: self.start = start - assert self.prefix is not None self.prefix.add_duration(action, start, stop) def add(self, other: TaskState) -> None: - self.states[other.state] += 1 + super().add(other) + self.prefix.add(other) other.group = self + def add_type(self, typename: str) -> None: + super().add_type(typename) + self.prefix.add_type(typename) + + def transition(self, old: TaskStateState, new: TaskStateState) -> None: + super().transition(old, new) + self.prefix.transition(old, new) + + def update_nbytes(self, diff: int) -> None: + super().update_nbytes(diff) + self.prefix.update_nbytes(diff) + def __repr__(self) -> str: return ( "<" @@ -1183,9 +1234,6 @@ class TaskState: #: ``'inc-ab31c010444977004d656610d2d421ec'``. key: Key - #: The broad class of tasks to which this task belongs like "inc" or "read_csv" - prefix: TaskPrefix - #: A specification of how to run the task. The type and meaning of this value is #: opaque to the scheduler, as it is only interpreted by the worker to which the #: task is sent for executing. @@ -1360,9 +1408,6 @@ class TaskState: #: The group of tasks to which this one belongs group: TaskGroup - #: Same as of group.name - group_key: str - #: Metadata related to task metadata: dict[str, Any] | None @@ -1404,6 +1449,7 @@ def __init__( key: Key, run_spec: T_runspec | None, state: TaskStateState, + group: TaskGroup, ): # Most of the attributes below are not initialized since there are not # always required for every tasks. Particularly for large graphs, these @@ -1436,15 +1482,14 @@ def __init__( self.resource_restrictions = None self.loose_restrictions = False self.actor = False - self.prefix = None # type: ignore self.type = None # type: ignore - self.group_key = key_split_group(key) - self.group = None # type: ignore self.metadata = None self.annotations = None self.erred_on = None self._rootish = None self.run_id = None + self.group = group + group.add(self) TaskState._instances.add(self) def __hash__(self) -> int: @@ -1464,10 +1509,8 @@ def state(self) -> TaskStateState: @state.setter def state(self, value: TaskStateState) -> None: - self.group.states[self._state] -= 1 - self.group.states[value] += 1 + self.group.transition(self._state, value) self._state = value - self.prefix.state_counts[value] += 1 def add_dependency(self, other: TaskState) -> None: """Add another task as a dependency of this task""" @@ -1483,7 +1526,7 @@ def set_nbytes(self, nbytes: int) -> None: old_nbytes = self.nbytes if old_nbytes >= 0: diff -= old_nbytes - self.group.nbytes_total += diff + self.group.update_nbytes(diff) for ws in self.who_has or (): ws.nbytes += diff self.nbytes = nbytes @@ -1540,6 +1583,15 @@ def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict[str, Any]: """ return recursive_to_dict(self, exclude=exclude, members=True) + @property + def prefix(self) -> TaskPrefix: + """The broad class of tasks to which this task belongs like "inc" or "read_csv" """ + return self.group.prefix + + @property + def group_key(self) -> str: + return self.group.name + class Transition(NamedTuple): """An entry in :attr:`SchedulerState.transition_log`""" @@ -1827,23 +1879,21 @@ def new_task( computation: Computation | None = None, ) -> TaskState: """Create a new task, and associated states""" - ts = TaskState(key, spec, state) - prefix_key = key_split(key) - tp = self.task_prefixes.get(prefix_key) - if tp is None: - self.task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key) - ts.prefix = tp + group_key = key_split_group(key) - group_key = ts.group_key tg = self.task_groups.get(group_key) if tg is None: - self.task_groups[group_key] = tg = TaskGroup(group_key) + tp = self.task_prefixes.get(prefix_key) + if tp is None: + self.task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key) + + self.task_groups[group_key] = tg = TaskGroup(group_key, tp) + if computation: computation.groups.add(tg) - tg.prefix = tp - tp.groups.append(tg) - tg.add(ts) + + ts = TaskState(key, spec, state, tg) self.tasks[key] = ts @@ -2031,7 +2081,7 @@ def _transition( if ts.state == "forgotten" and tg.name in self.task_groups: # Remove TaskGroup if all tasks are in the forgotten state if all(v == 0 or k == "forgotten" for k, v in tg.states.items()): - ts.prefix.groups.remove(tg) + ts.prefix.remove_group(tg) del self.task_groups[tg.name] return recommendations, client_msgs, worker_msgs @@ -3325,7 +3375,7 @@ def _add_to_memory( ts.state = "memory" ts.type = typename # type: ignore - ts.group.types.add(typename) # type: ignore + ts.group.add_type(typename) # type: ignore cs = self.clients["fire-and-forget"] if ts in cs.wants_what: @@ -7814,7 +7864,7 @@ def get_task_prefix_states(self) -> dict[str, dict[str, int]]: state = {} for tp in self.task_prefixes.values(): - active_states = tp.active_states + states = tp.states ss: list[TaskStateState] = [ "memory", "erred", @@ -7822,13 +7872,13 @@ def get_task_prefix_states(self) -> dict[str, dict[str, int]]: "processing", "waiting", ] - if any(active_states.get(s) for s in ss): + if any(states.get(s) for s in ss): state[tp.name] = { - "memory": active_states["memory"], - "erred": active_states["erred"], - "released": active_states["released"], - "processing": active_states["processing"], - "waiting": active_states["waiting"], + "memory": states["memory"], + "erred": states["erred"], + "released": states["released"], + "processing": states["processing"], + "waiting": states["waiting"], } return state diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 916ab3a76a..cf2e6345fc 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2794,7 +2794,7 @@ async def test_no_dangling_asyncio_tasks(): @gen_cluster(client=True, Worker=NoSchedulerDelayWorker, config=NO_AMM) -async def test_task_groups(c, s, a, b, no_time_resync): +async def test_task_group_and_prefix_statistics(c, s, a, b, no_time_resync): start = time() da = pytest.importorskip("dask.array") x = da.arange(100, chunks=(20,)) @@ -2808,20 +2808,44 @@ async def test_task_groups(c, s, a, b, no_time_resync): repr(tp) assert tg.states["memory"] == 0 assert tg.states["released"] == 5 - assert tp.states["memory"] == 0 - assert tp.states["released"] == 5 - assert tp.groups == [tg] + assert sum(tg.states.values()) == 5 + assert tg.nbytes_total == sum( + ts.get_nbytes() for ts in s.tasks.values() if ts.group is tg + ) + assert tg.prefix is tp + assert tp.groups == {tg} + with pytest.warns(FutureWarning, match="active"): + assert tp.groups == tp.active # these must be true since in this simple case there is a 1to1 mapping # between prefix and group + assert tg.states == tp.states + with pytest.warns(FutureWarning, match="active_states"): + assert tp.states == tp.active_states assert tg.duration == tp.duration + assert tg.all_durations == tp.all_durations assert tg.nbytes_total == tp.nbytes_total + assert tg.types == tp.types + # It should map down to individual tasks - assert tg.nbytes_total == sum( - ts.get_nbytes() for ts in s.tasks.values() if ts.group is tg - ) tg = s.task_groups[y.name] assert tg.states["memory"] == 5 + assert sum(tg.states.values()) == 5 + + tp = s.task_prefixes["add"] + assert tg.prefix is tp + assert tp.groups == {tg} + with pytest.warns(FutureWarning, match="active"): + assert tp.groups == tp.active + # these must be true since in this simple case there is a 1to1 mapping + # between prefix and group + assert tg.states == tp.states + with pytest.warns(FutureWarning, match="active_states"): + assert tp.states == tp.active_states + assert tg.duration == tp.duration + assert tg.all_durations == tp.all_durations + assert tg.nbytes_total == tp.nbytes_total + assert tg.types == tp.types assert s.task_groups[y.name].dependencies == {s.task_groups[x.name]} @@ -2830,17 +2854,56 @@ async def test_task_groups(c, s, a, b, no_time_resync): assert "array" in str(tg.types) assert "array" in str(tp.types) + z = y[:20].persist(optimize_graph=False) + z = await z del y + while len(s.tasks) > 3: + await asyncio.sleep(0.01) + + assert tg.prefix is tp + assert tp.groups == {tg} + with pytest.warns(FutureWarning, match="active"): + assert tp.groups == tp.active + assert tg.states["forgotten"] == 4 + assert tg.states["released"] == 1 + assert sum(tg.states.values()) == 5 + assert tg.states == tp.states + with pytest.warns(FutureWarning, match="active_states"): + assert tp.states == tp.active_states + assert tg.duration == tp.duration + assert tg.all_durations == tp.all_durations + assert tg.nbytes_total == tp.nbytes_total + assert tg.types == tp.types + + del z while s.tasks: await asyncio.sleep(0.01) + assert tg.states["forgotten"] == 5 + assert sum(tg.states.values()) == 5 + assert tg.states["forgotten"] == 5 assert tg.name not in s.task_groups assert tg.start > start assert tg.stop < stop assert "compute" in tg.all_durations + assert tg.prefix is tp + # all_durations is cumulative + assert tg.all_durations == tp.all_durations + # these must be zero because we remove fully-forgotten task groups + # from the prefixes + assert tp.groups == set() + with pytest.warns(FutureWarning, match="active"): + assert tp.groups == tp.active + assert all(count == 0 for count in tp.states.values()) + with pytest.warns(FutureWarning, match="active_states"): + assert tp.states == tp.active_states + assert tp.duration == 0 + assert tp.nbytes_total == 0 + assert tp.types == set() + @gen_cluster(client=True, nthreads=[("", 2)], Worker=NoSchedulerDelayWorker) async def test_task_groups_update_start_stop(c, s, a, no_time_resync): From 9672121ce115df9268b12ad74e108d71ec8104c0 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 10 Jun 2024 12:37:33 +0200 Subject: [PATCH 018/138] [P2P] Log event during failure (#8663) --- distributed/shuffle/_core.py | 1 + distributed/shuffle/_scheduler_plugin.py | 25 ++++++++++++++++++- distributed/shuffle/tests/test_shuffle.py | 29 +++++++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 43badfbf5e..48510dfd41 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -483,6 +483,7 @@ class SchedulerShuffleState(Generic[_T_partition_id]): run_spec: ShuffleRunSpec participating_workers: set[str] _archived_by: str | None = field(default=None, init=False) + _failed: bool = False @property def id(self) -> ShuffleId: diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index 132f34387c..c6fbbe210a 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -410,8 +410,31 @@ def transition( **kwargs: Any, ) -> None: """Clean up scheduler and worker state once a shuffle becomes inactive.""" - if finish not in ("released", "forgotten"): + if finish not in ("released", "erred", "forgotten"): return + + if finish == "erred": + ts = self.scheduler.tasks[key] + for active_shuffle in self.active_shuffles.values(): + if active_shuffle._failed: + continue + barrier = self.scheduler.tasks[barrier_key(active_shuffle.id)] + if ( + ts == barrier + or ts in barrier.dependents + or ts in barrier.dependencies + ): + active_shuffle._failed = True + self.scheduler.log_event( + "p2p", + { + "action": "p2p-failed", + "shuffle": active_shuffle.id, + "stimulus": stimulus_id, + }, + ) + return + shuffle_id = id_from_key(key) if not shuffle_id: return diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 3b6b75da4e..0d3c962687 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -435,6 +435,7 @@ async def test_restarting_during_transfer_raises_killed_worker(c, s, a, b): with pytest.raises(KilledWorker): await out + assert sum(event["action"] == "p2p-failed" for _, event in s.events["p2p"]) == 1 await c.close() await check_worker_cleanup(a) @@ -442,6 +443,32 @@ async def test_restarting_during_transfer_raises_killed_worker(c, s, a, b): await check_scheduler_cleanup(s) +@gen_cluster( + client=True, + nthreads=[("", 1)] * 2, + config={"distributed.scheduler.allowed-failures": 1}, +) +async def test_restarting_does_not_log_p2p_failed(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-03-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + with dask.config.set({"dataframe.shuffle.method": "p2p"}): + out = df.shuffle("x") + out = c.compute(out.x.size) + await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) + await b.close() + + await out + assert not s.events["p2p"] + await c.close() + await check_worker_cleanup(a) + await check_worker_cleanup(b, closed=True) + await check_scheduler_cleanup(s) + + class BlockedGetOrCreateShuffleRunManager(_ShuffleRunManager): def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -806,6 +833,7 @@ async def test_restarting_during_barrier_raises_killed_worker(c, s, a, b): with pytest.raises(KilledWorker): await out + assert sum(event["action"] == "p2p-failed" for _, event in s.events["p2p"]) == 1 alive_shuffle.block_inputs_done.set() @@ -968,6 +996,7 @@ async def test_restarting_during_unpack_raises_killed_worker(c, s, a, b): with pytest.raises(KilledWorker): await out + assert sum(event["action"] == "p2p-failed" for _, event in s.events["p2p"]) == 1 await c.close() await check_worker_cleanup(a) From 2482bd3439f99d4183776dd18e3234cff7de6e7a Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 12 Jun 2024 14:22:54 +0200 Subject: [PATCH 019/138] Rename `safe` to `expected` in `Scheduler.remove_worker`. (#8686) --- distributed/scheduler.py | 21 ++++++++++++--------- distributed/tests/test_cancelled_state.py | 4 ++-- distributed/tests/test_scheduler.py | 7 +++++++ distributed/tests/test_worker.py | 10 +++++----- 4 files changed, 26 insertions(+), 16 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 84904c9270..dc9afd1697 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -58,6 +58,7 @@ from dask.typing import Key, no_default from dask.utils import ( _deprecated, + _deprecated_kwarg, ensure_dict, format_bytes, format_time, @@ -2492,16 +2493,12 @@ def _transition_processing_memory( return recommendations, client_msgs, {} - def _transition_memory_released( - self, key: Key, stimulus_id: str, *, safe: bool = False - ) -> RecsMsgs: + def _transition_memory_released(self, key: Key, stimulus_id: str) -> RecsMsgs: ts = self.tasks[key] if self.validate: assert not ts.waiting_on assert not ts.processing_on - if safe: - assert not ts.waiters if ts.actor: for ws in ts.who_has or (): @@ -5230,9 +5227,15 @@ def close_worker(self, worker: str) -> None: self.log_event(worker, {"action": "close-worker"}) self.worker_send(worker, {"op": "close", "reason": "scheduler-close-worker"}) + @_deprecated_kwarg("safe", "expected") @log_errors async def remove_worker( - self, address: str, *, stimulus_id: str, safe: bool = False, close: bool = True + self, + address: str, + *, + stimulus_id: str, + expected: bool = False, + close: bool = True, ) -> Literal["OK", "already-removed"]: """Remove worker from cluster. @@ -5290,7 +5293,7 @@ async def remove_worker( for ts in list(ws.processing): k = ts.key recommendations[k] = "released" - if not safe: + if not expected: ts.suspicious += 1 ts.prefix.suspicious += 1 if ts.suspicious > self.allowed_failures: @@ -5349,7 +5352,7 @@ async def remove_worker( "lost-computed-tasks": recompute_keys, "lost-scattered-tasks": lost_keys, "stimulus_id": stimulus_id, - "safe": safe, # TODO change this to expected to be clearer + "expected": expected, } self.log_event(address, event_msg.copy()) event_msg["worker"] = address @@ -7501,7 +7504,7 @@ async def _track_retire_worker( if remove: await self.remove_worker( - ws.address, safe=True, close=close, stimulus_id=stimulus_id + ws.address, expected=True, close=close, stimulus_id=stimulus_id ) elif close: self.close_worker(ws.address) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 982947b0bd..ddfd4f7738 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -290,7 +290,7 @@ async def test_in_flight_lost_after_resumed(c, s, b): s.set_restrictions({fut1.key: [a.address, b.address]}) # It is removed, i.e. get_data is guaranteed to fail and f1 is scheduled # to be recomputed on B - await s.remove_worker(a.address, stimulus_id="foo", close=False, safe=True) + await s.remove_worker(a.address, stimulus_id="foo", close=False, expected=True) await wait_for_state(fut1.key, "resumed", b, interval=0) @@ -850,7 +850,7 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker( await s.remove_worker( address=x.address, - safe=True, + expected=True, close=close_worker, stimulus_id="remove-worker", ) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index cf2e6345fc..e73af206d7 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -845,6 +845,13 @@ async def test_remove_worker_from_scheduler(c, s, a, b): await c.gather(futs) +@gen_cluster(client=True) +async def test_remove_worker_from_scheduler_warns_on_safe(c, s, a, b): + with pytest.warns(FutureWarning, match="expected"): + await s.remove_worker(address=a.address, safe=True, stimulus_id="test") + assert a.address not in s.workers + + @gen_cluster() async def test_remove_worker_by_name_from_scheduler(s, a, b): assert a.address in s.stream_comms diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 85c56b6a1f..b9e51944a9 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2974,7 +2974,7 @@ async def test_worker_status_sync(s, a): "lost-computed-tasks": set(), "lost-scattered-tasks": set(), "processing-tasks": set(), - "safe": True, + "expected": True, "stimulus_id": "retire-workers", }, {"action": "retired", "stimulus_id": "retire-workers"}, @@ -3053,7 +3053,7 @@ async def test_log_remove_worker(c, s, a, b): "lost-computed-tasks": set(), "lost-scattered-tasks": set(), "processing-tasks": {"y"}, - "safe": True, + "expected": True, "stimulus_id": "graceful", }, {"action": "retired", "stimulus_id": "graceful"}, @@ -3077,7 +3077,7 @@ async def test_log_remove_worker(c, s, a, b): "lost-computed-tasks": {"x"}, "lost-scattered-tasks": {"z"}, "processing-tasks": {"y"}, - "safe": False, + "expected": False, "stimulus_id": "ungraceful", }, {"action": "closing-worker", "reason": "scheduler-remove-worker"}, @@ -3088,7 +3088,7 @@ async def test_log_remove_worker(c, s, a, b): "lost-computed-tasks": set(), "lost-scattered-tasks": set(), "processing-tasks": {"y"}, - "safe": True, + "expected": True, "stimulus_id": "graceful", "worker": a.address, }, @@ -3109,7 +3109,7 @@ async def test_log_remove_worker(c, s, a, b): "lost-computed-tasks": {"x"}, "lost-scattered-tasks": {"z"}, "processing-tasks": {"y"}, - "safe": False, + "expected": False, "stimulus_id": "ungraceful", "worker": b.address, }, From af237f0d816e811563d6bce70c3f7e48f89edd22 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 12 Jun 2024 16:20:00 +0200 Subject: [PATCH 020/138] Fix log event with multiple topics (#8691) --- .../tests/test_scheduler_plugin.py | 25 +++++++++++++++++-- distributed/scheduler.py | 17 ++++++------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index 2d35636b10..ee8c7f501c 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -442,8 +442,8 @@ async def start(self, scheduler: Scheduler) -> None: self.scheduler = scheduler self.scheduler._recorded_events = list() # type: ignore - def log_event(self, name, msg): - self.scheduler._recorded_events.append((name, msg)) + def log_event(self, topic, msg): + self.scheduler._recorded_events.append((topic, msg)) await c.register_plugin(EventPlugin()) @@ -455,6 +455,27 @@ def f(): assert ("foo", 123) in s._recorded_events +@gen_cluster(client=True) +async def test_log_event_plugin_multiple_topics(c, s, a, b): + class EventPlugin(SchedulerPlugin): + async def start(self, scheduler: Scheduler) -> None: + self.scheduler = scheduler + self.scheduler._recorded_events = list() # type: ignore + + def log_event(self, topic, msg): + self.scheduler._recorded_events.append((topic, msg)) + + await c.register_plugin(EventPlugin()) + + def f(): + get_worker().log_event(["foo", "bar"], 123) + + await c.submit(f) + + assert ("foo", 123) in s._recorded_events + assert ("bar", 123) in s._recorded_events + + @gen_cluster(client=True) async def test_register_plugin_on_scheduler(c, s, a, b): class MyPlugin(SchedulerPlugin): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index dc9afd1697..5ab1af0b9b 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -8413,19 +8413,16 @@ def log_event(self, topic: str | Collection[str], msg: Any) -> None: Client.log_event """ event = (time(), msg) - if not isinstance(topic, str): - for t in topic: - self.events[t].append(event) - self.event_counts[t] += 1 - self._report_event(t, event) - else: - self.events[topic].append(event) - self.event_counts[topic] += 1 - self._report_event(topic, event) + if isinstance(topic, str): + topic = [topic] + for t in topic: + self.events[t].append(event) + self.event_counts[t] += 1 + self._report_event(t, event) for plugin in list(self.plugins.values()): try: - plugin.log_event(topic, msg) + plugin.log_event(t, msg) except Exception: logger.info("Plugin failed with exception", exc_info=True) From d8dc8ad2172ff34113e4bc47d57ae55401cd6705 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Thu, 13 Jun 2024 13:55:09 +0100 Subject: [PATCH 021/138] Automate GitHub Releases when new tags are pushed (#8626) --- .github/release-drafter.yml | 13 +++++++++++ .github/workflows/release-drafter.yml | 32 ++++++++++++++++++++++++++ .github/workflows/release-publish.yml | 33 +++++++++++++++++++++++++++ 3 files changed, 78 insertions(+) create mode 100644 .github/release-drafter.yml create mode 100644 .github/workflows/release-drafter.yml create mode 100644 .github/workflows/release-publish.yml diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml new file mode 100644 index 0000000000..7d1d66b86a --- /dev/null +++ b/.github/release-drafter.yml @@ -0,0 +1,13 @@ +# These will be overridden by the publish workflow and set to the new tag +name-template: 'Next Release' +tag-template: 'next' + +change-template: '- $TITLE @$AUTHOR (#$NUMBER)' +change-title-escapes: '\<*_&' # You can add # and @ to disable mentions, and add ` to disable code blocks. + +template: | + ## Changes + + $CHANGES + + See the [Changelog](https://docs.dask.org/en/stable/changelog.html) for more information. diff --git a/.github/workflows/release-drafter.yml b/.github/workflows/release-drafter.yml new file mode 100644 index 0000000000..524ae21ebd --- /dev/null +++ b/.github/workflows/release-drafter.yml @@ -0,0 +1,32 @@ +name: Release Drafter + +on: + push: + branches: + - main + +permissions: + contents: read + +jobs: + update_release_draft: + if: github.repository == 'dask/distributed' + permissions: + # Write permission is required to create a GitHub release + contents: write + pull-requests: read + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Required to get tag history + - name: Check if release commit + id: check_release_commit + run: git describe --exact-match --tags $(git rev-parse HEAD) + continue-on-error: true + - uses: release-drafter/release-drafter@v6 + if: steps.check_release_commit.outcome != 'success' + with: + disable-autolabeler: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/release-publish.yml b/.github/workflows/release-publish.yml new file mode 100644 index 0000000000..5021e34393 --- /dev/null +++ b/.github/workflows/release-publish.yml @@ -0,0 +1,33 @@ +name: Release Publisher + +on: + push: + tags: + - "*.*.*" + +permissions: + contents: read + +jobs: + publish_release: + if: github.repository == 'dask/distributed' + permissions: + # Write permission is required to publish a GitHub release + contents: write + pull-requests: read + runs-on: ubuntu-latest + steps: + - name: Set version env + # Use a little bit of bash to extract the tag name from the GitHub ref + run: echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV + - uses: release-drafter/release-drafter@v5 + with: + disable-autolabeler: true + # Override the Release name/tag/version with the actual tag name + name: ${{ env.RELEASE_VERSION }} + tag: ${{ env.RELEASE_VERSION }} + version: ${{ env.RELEASE_VERSION }} + # Publish the Release + publish: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} From f7921a1a1675d3fb2b0525067100899a4b783b1e Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 13 Jun 2024 18:22:33 +0200 Subject: [PATCH 022/138] Log key collision count in `update_graph` log event (#8692) --- distributed/scheduler.py | 26 +++++++++++++++++++------- distributed/tests/test_scheduler.py | 19 +++++++++++++++++++ 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5ab1af0b9b..7d4e73ea63 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4590,11 +4590,6 @@ def _create_taskstate_from_graph( lost_keys = self._match_graph_with_tasks(dsk, dependencies, keys) - if len(dsk) > 1: - self.log_event( - ["all", client], {"action": "update_graph", "count": len(dsk)} - ) - if lost_keys: self.report({"op": "cancelled-keys", "keys": lost_keys}, client=client) self.client_releases_keys( @@ -4616,13 +4611,28 @@ def _create_taskstate_from_graph( computation.annotations.update(global_annotations) del global_annotations - runnable, touched_tasks, new_tasks = self._generate_taskstates( + ( + runnable, + touched_tasks, + new_tasks, + colliding_task_count, + ) = self._generate_taskstates( keys=keys, dsk=dsk, dependencies=dependencies, computation=computation, ) + if len(dsk) > 1 or colliding_task_count: + self.log_event( + ["all", client], + { + "action": "update_graph", + "count": len(dsk), + "key-collisions": colliding_task_count, + }, + ) + keys_with_annotations = self._apply_annotations( tasks=new_tasks, annotations_by_type=annotations_by_type, @@ -4815,6 +4825,7 @@ def _generate_taskstates( touched_keys = set() touched_tasks = [] tgs_with_bad_run_spec = set() + colliding_task_count = 0 while stack: k = stack.pop() if k in touched_keys: @@ -4860,6 +4871,7 @@ def _generate_taskstates( # dask/dask#9888. dependencies[k] = deps_lhs + colliding_task_count += 1 if ts.group not in tgs_with_bad_run_spec: tgs_with_bad_run_spec.add(ts.group) logger.warning( @@ -4912,7 +4924,7 @@ def _generate_taskstates( len(touched_tasks), len(keys), ) - return runnable, touched_tasks, new_tasks + return runnable, touched_tasks, new_tasks, colliding_task_count def _apply_annotations( self, diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index e73af206d7..9e25fba6c5 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -4789,6 +4789,23 @@ async def test_resubmit_different_task_same_key_before_previous_is_done(c, s, de For a real world example where this can trigger, see https://github.com/dask/dask/issues/9888 """ + seen = False + + def _match(event): + _, msg = event + return ( + isinstance(msg, dict) + and msg.get("action", None) == "update_graph" + and msg["key-collisions"] > 0 + ) + + def handler(ev): + if _match(ev): + nonlocal seen + seen = True + + c.subscribe_topic("all", handler) + x1 = c.submit(inc, 1, key="x1") y_old = c.submit(inc, x1, key="y") @@ -4803,6 +4820,8 @@ async def test_resubmit_different_task_same_key_before_previous_is_done(c, s, de assert "Detected different `run_spec` for key 'y'" in log.getvalue() + await async_poll_for(lambda: seen, timeout=5) + async with Worker(s.address): # Used old run_spec assert await y_old == 3 From 8cbddde57b7385d441ad4666c28d8c100a0659fe Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 14 Jun 2024 12:05:48 +0200 Subject: [PATCH 023/138] Avoid rounding error in `test_prometheus_collect_count_total_by_cost_multipliers` (#8687) --- distributed/http/scheduler/tests/test_stealing_http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/http/scheduler/tests/test_stealing_http.py b/distributed/http/scheduler/tests/test_stealing_http.py index aa8ffa67e2..1272076374 100644 --- a/distributed/http/scheduler/tests/test_stealing_http.py +++ b/distributed/http/scheduler/tests/test_stealing_http.py @@ -91,7 +91,7 @@ async def fetch_metrics_by_cost_multipliers(): for request in event[1] if event[0] == "request" ) - assert count == expected_cost + assert count == pytest.approx(expected_cost) @gen_cluster( From f1d230c91016ea7889275709e0b47a235ee7e78e Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 14 Jun 2024 15:38:55 -0500 Subject: [PATCH 024/138] bump version to 2024.6.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b9752af0e3..95543b0aac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ requires-python = ">=3.9" dependencies = [ "click >= 8.0", "cloudpickle >= 1.5.0", - "dask == 2024.5.2", + "dask == 2024.6.0", "jinja2 >= 2.10.3", "locket >= 1.0.0", "msgpack >= 1.0.0", From ac3854e4c463e80a0ff8b72a78c36972603ee391 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 18 Jun 2024 15:20:26 +0100 Subject: [PATCH 025/138] Bump release-drafter/release-drafter from 5 to 6 (#8699) --- .github/workflows/release-publish.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release-publish.yml b/.github/workflows/release-publish.yml index 5021e34393..d39098b854 100644 --- a/.github/workflows/release-publish.yml +++ b/.github/workflows/release-publish.yml @@ -20,7 +20,7 @@ jobs: - name: Set version env # Use a little bit of bash to extract the tag name from the GitHub ref run: echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - - uses: release-drafter/release-drafter@v5 + - uses: release-drafter/release-drafter@v6 with: disable-autolabeler: true # Override the Release name/tag/version with the actual tag name From 03035dafb939166d241fb0db09433c2bb3bc369c Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 19 Jun 2024 10:37:09 +0200 Subject: [PATCH 026/138] Fix deadlock (#8703) --- distributed/scheduler.py | 2 +- distributed/tests/test_scheduler.py | 37 +++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 7d4e73ea63..4cba667d9e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2540,7 +2540,7 @@ def _transition_memory_released(self, key: Key, stimulus_id: str) -> RecsMsgs: recommendations[key] = "waiting" for dts in ts.waiters or (): - if dts.state in ("no-worker", "processing"): + if dts.state in ("no-worker", "processing", "queued"): recommendations[dts.key] = "waiting" elif dts.state == "waiting": if not dts.waiting_on: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 9e25fba6c5..88eb3cbe11 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -4650,6 +4650,43 @@ def assert_rootish(): await c.gather(fut3) +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_deadlock_dependency_of_queued_released(c, s, a): + @delayed + def inc(input): + return input + 1 + + @delayed + def block_on_event(input, block, executing): + executing.set() + block.wait() + return input + + block = Event() + executing = Event() + + dep = inc(0) + futs = [ + block_on_event(dep, block, executing, dask_key_name=("rootish", i)) + for i in range(s.total_nthreads * 2 + 1) + ] + del dep + futs = c.compute(futs) + await executing.wait() + assert s.queued + await s.remove_worker(address=a.address, stimulus_id="test") + + s.validate_state() + + await block.set() + await executing.clear() + + async with Worker(s.address) as b: + s.validate_state() + await c.gather(*futs) + s.validate_state() + + @gen_cluster(client=True) async def test_submit_dependency_of_erred_task(c, s, a, b): x = c.submit(lambda: 1 / 0, key="x") From aeebb2dcc4d5d62de908ba3f364596c348ab530c Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 19 Jun 2024 19:16:40 +0200 Subject: [PATCH 027/138] bump version to 2024.6.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 95543b0aac..c2c0679ad9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ requires-python = ">=3.9" dependencies = [ "click >= 8.0", "cloudpickle >= 1.5.0", - "dask == 2024.6.0", + "dask == 2024.6.1", "jinja2 >= 2.10.3", "locket >= 1.0.0", "msgpack >= 1.0.0", From 850616570c640776251b6d008100d0e2c44256a8 Mon Sep 17 00:00:00 2001 From: Adam Williamson Date: Thu, 20 Jun 2024 02:28:05 -0700 Subject: [PATCH 028/138] profile._f_lineno: handle next_line being None in Python 3.13 (#8710) In Python 3.13, it's possible for `dis.findlinestarts` to return a line number as `None`. We don't ever want to return `None` as the line number, so guard against this. Signed-off-by: Adam Williamson --- distributed/profile.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/profile.py b/distributed/profile.py index 46420ca4eb..1ba4e2994d 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -90,7 +90,8 @@ def _f_lineno(frame: FrameType) -> int: for start, next_line in dis.findlinestarts(code): if f_lasti < start: return prev_line - prev_line = next_line + if next_line: + prev_line = next_line return prev_line except Exception: From adcb0452a2e515284e32172ca79f87cff0127a8e Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 20 Jun 2024 14:55:13 -0500 Subject: [PATCH 029/138] bump version to 2024.6.2 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c2c0679ad9..f9b43e28fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ requires-python = ">=3.9" dependencies = [ "click >= 8.0", "cloudpickle >= 1.5.0", - "dask == 2024.6.1", + "dask == 2024.6.2", "jinja2 >= 2.10.3", "locket >= 1.0.0", "msgpack >= 1.0.0", From e83246eda4b76f8fccda33abb342e3d8b7c3cfe9 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 24 Jun 2024 03:10:49 -0500 Subject: [PATCH 030/138] Add quotes to missing bokeh installation commands (#8717) --- distributed/http/scheduler/missing_bokeh.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/http/scheduler/missing_bokeh.py b/distributed/http/scheduler/missing_bokeh.py index 3bfa99d866..94fe218305 100644 --- a/distributed/http/scheduler/missing_bokeh.py +++ b/distributed/http/scheduler/missing_bokeh.py @@ -10,8 +10,8 @@ class MissingBokeh(RequestHandler): def get(self): self.write( f"

Dask needs {BOKEH_REQUIREMENT} for the dashboard.

" - f"

Install with conda: conda install {BOKEH_REQUIREMENT}

" - f"

Install with pip: pip install {BOKEH_REQUIREMENT}

" + f'

Install with conda: conda install "{BOKEH_REQUIREMENT}"

' + f'

Install with pip: pip install "{BOKEH_REQUIREMENT}"

' ) From 01e5ff3a83540020ef1b17b019af821fa97f02f6 Mon Sep 17 00:00:00 2001 From: Elliott Sales de Andrade Date: Mon, 24 Jun 2024 05:51:19 -0400 Subject: [PATCH 031/138] Fix cleanup iteration in save_sys_modules (#8713) When iterating `sys.path`, then if you delete index `i`, all elements >`i` will now be one less than what was put into `enumerate`. So we should remove elements in reverse to prevent indices getting out of sync. When iterating `sys.modules`, then modifying it _may_ break the `.keys()` iterator, so work on a copy of the necessary keys. --- distributed/utils_test.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 63a7488efe..78943226bf 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1628,12 +1628,11 @@ def save_sys_modules(): try: yield finally: - for i, elem in enumerate(sys.path): + for i, elem in reversed(list(enumerate(sys.path))): if elem not in old_path: del sys.path[i] - for elem in sys.modules.keys(): - if elem not in old_modules: - del sys.modules[elem] + for elem in sys.modules.keys() - old_modules.keys(): + del sys.modules[elem] @contextmanager From c3fb7f29412ad8e75f847e6f86289f22ed2c4068 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 24 Jun 2024 11:54:49 +0200 Subject: [PATCH 032/138] Fix test_quiet_client_close (#8722) --- distributed/tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 0d25f0bfe4..c838e078e8 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5223,7 +5223,7 @@ def test_quiet_client_close(loop): futures = c.map(slowinc, range(1000), delay=0.01) # Stop part-way s = c.cluster.scheduler - while sum(ts.state == "memory" for ts in s.tasks.values()) < 20: + while sum(ts.state == "memory" for ts in list(s.tasks.values())) < 20: sleep(0.01) sleep(0.1) # let things settle From 9b8aff08b4ef16ebd26123742795143210dd5071 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 24 Jun 2024 12:07:05 +0200 Subject: [PATCH 033/138] Skip test_deadlock_dependency_of_queued_released (#8723) --- distributed/tests/test_scheduler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 88eb3cbe11..d55e19fee1 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -4650,6 +4650,10 @@ def assert_rootish(): await c.gather(fut3) +@pytest.mark.skipif( + not QUEUING_ON_BY_DEFAULT, + reason="The situation handled in this test requires queueing.", +) @gen_cluster(client=True, nthreads=[("", 1)]) async def test_deadlock_dependency_of_queued_released(c, s, a): @delayed From 958ce7ea00aa0093a7c2379b01826401ed1f8b5f Mon Sep 17 00:00:00 2001 From: Elliott Sales de Andrade Date: Mon, 24 Jun 2024 06:11:04 -0400 Subject: [PATCH 034/138] TST: Fix wait condition on test_forget_errors (#8714) --- distributed/tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c838e078e8..c00e2c3938 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -2213,7 +2213,7 @@ async def test_forget_errors(c, s, a, b): x = c.submit(div, 1, 0) y = c.submit(inc, x) z = c.submit(inc, y) - await wait([y]) + await wait([z]) assert s.tasks[x.key].exception assert s.tasks[x.key].exception_blame From 7daddc561dbabb60ef9b9fdfb553702d661d3066 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 24 Jun 2024 13:07:49 +0200 Subject: [PATCH 035/138] Deprecate Pub and Sub (#8724) --- distributed/pubsub.py | 4 +++- distributed/tests/test_pubsub.py | 33 +++++++++++++++++++++----------- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 7a3dc0b574..3a678f0b66 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -6,7 +6,7 @@ import weakref from collections import defaultdict, deque -from dask.utils import parse_timedelta +from dask.utils import _deprecated, parse_timedelta from distributed.core import CommClosedError from distributed.metrics import time @@ -198,6 +198,7 @@ def cleanup(self): self.client.scheduler_comm.send(msg) +@_deprecated(use_instead="Client.log_event() or Worker.log_event()") class Pub: """Publish data with Publish-Subscribe pattern @@ -354,6 +355,7 @@ def __repr__(self): __str__ = __repr__ +@_deprecated(use_instead="Client.subscribe_topic()") class Sub: """Subscribe to a Publish/Subscribe topic diff --git a/distributed/tests/test_pubsub.py b/distributed/tests/test_pubsub.py index 664ea4f4ab..6deb6f798c 100644 --- a/distributed/tests/test_pubsub.py +++ b/distributed/tests/test_pubsub.py @@ -23,8 +23,10 @@ async def test_speed(c, s, a, b): """ def pingpong(a, b, start=False, n=1000, msg=1): - sub = Sub(a) - pub = Pub(b) + with pytest.warns(FutureWarning, match="deprecated"): + sub = Sub(a) + with pytest.warns(FutureWarning, match="deprecated"): + pub = Pub(b) while not pub.subscribers: sleep(0.01) @@ -56,8 +58,10 @@ def pingpong(a, b, start=False, n=1000, msg=1): async def test_client(c, s): with pytest.raises(ValueError, match="No worker found"): get_worker() - sub = Sub("a") - pub = Pub("a") + with pytest.warns(FutureWarning, match="deprecated"): + sub = Sub("a") + with pytest.warns(FutureWarning, match="deprecated"): + pub = Pub("a") sps = s.extensions["pubsub"] cps = c.extensions["pubsub"] @@ -75,10 +79,12 @@ async def test_client(c, s): @gen_cluster(client=True) async def test_client_worker(c, s, a, b): - sub = Sub("a", client=c, worker=None) + with pytest.warns(FutureWarning, match="deprecated"): + sub = Sub("a", client=c, worker=None) def f(x): - pub = Pub("a") + with pytest.warns(FutureWarning, match="deprecated"): + pub = Pub("a") pub.put(x) futures = c.map(f, range(10)) @@ -120,7 +126,8 @@ def f(x): @gen_cluster(client=True) async def test_timeouts(c, s, a, b): - sub = Sub("a", client=c, worker=None) + with pytest.warns(FutureWarning, match="deprecated"): + sub = Sub("a", client=c, worker=None) start = time() with pytest.raises(TimeoutError): await sub.get(timeout="100ms") @@ -132,8 +139,10 @@ async def test_timeouts(c, s, a, b): @gen_cluster(client=True) async def test_repr(c, s, a, b): - pub = Pub("my-topic") - sub = Sub("my-topic") + with pytest.warns(FutureWarning, match="deprecated"): + pub = Pub("my-topic") + with pytest.warns(FutureWarning, match="deprecated"): + sub = Sub("my-topic") assert "my-topic" in str(pub) assert "Pub" in str(pub) assert "my-topic" in str(sub) @@ -144,7 +153,8 @@ async def test_repr(c, s, a, b): @gen_cluster(client=True) async def test_basic(c, s, a, b): async def publish(): - pub = Pub("a") + with pytest.warns(FutureWarning, match="deprecated"): + pub = Pub("a") i = 0 while True: @@ -153,7 +163,8 @@ async def publish(): i += 1 def f(_): - sub = Sub("a") + with pytest.warns(FutureWarning, match="deprecated"): + sub = Sub("a") return list(toolz.take(5, sub)) asyncio.ensure_future(c.run(publish, workers=[a.address])) From f9257669d5076da7998f6bc6c31dd9d271e202b7 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 24 Jun 2024 13:08:29 +0200 Subject: [PATCH 036/138] Improve error on cancelled tasks due to disconnect (#8705) --- distributed/client.py | 165 ++++++++++++++++++++++++------- distributed/tests/test_client.py | 41 ++++++++ 2 files changed, 171 insertions(+), 35 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 837360217b..f9c66018c1 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -135,6 +135,68 @@ TOPIC_PREFIX_FORWARDED_LOG_RECORD = "forwarded-log-record" +class FutureCancelledError(CancelledError): + key: str + reason: str + msg: str | None + + def __init__(self, key: str, reason: str | None, msg: str | None = None): + self.key = key + self.reason = reason if reason else "unknown" + self.msg = msg + + def __str__(self) -> str: + result = f"{self.key} cancelled for reason: {self.reason}." + if self.msg: + result = "\n".join([result, self.msg]) + return result + + +class FuturesCancelledError(CancelledError): + error_groups: list[CancelledFuturesGroup] + + def __init__(self, error_groups: list[CancelledFuturesGroup]): + self.error_groups = sorted( + error_groups, key=lambda group: len(group.errors), reverse=True + ) + + def __str__(self): + count = sum(map(lambda group: len(group.errors), self.error_groups)) + result = f"{count} Future{'s' if count > 1 else ''} cancelled:" + return "\n".join( + [result, "Reasons:"] + [str(group) for group in self.error_groups] + ) + + +class CancelledFuturesGroup: + #: Errors of the cancelled futures + errors: list[FutureCancelledError] + + #: Reason for cancelling the futures + reason: str + + __slots__ = tuple(__annotations__) + + def __init__(self, errors: list[FutureCancelledError], reason: str): + self.errors = errors + self.reason = reason + + def __str__(self): + keys = [error.key for error in self.errors] + example_message = None + + for error in self.errors: + if error.msg: + example_message = error.msg + break + + return ( + f"{len(keys)} Future{'s' if len(keys) > 1 else ''} cancelled for reason: " + f"{self.reason}.\nMessage: {example_message}\n" + f"Future{'s' if len(keys) > 1 else ''}: {keys}" + ) + + class SourceCode(NamedTuple): code: str lineno_frame: int @@ -245,7 +307,7 @@ def _bind_late(self): if self.key in self._client.futures: self._state = self._client.futures[self.key] else: - self._state = self._client.futures[self.key] = FutureState() + self._state = self._client.futures[self.key] = FutureState(self.key) if self._inform: self._client._send_to_scheduler( @@ -337,8 +399,10 @@ async def _result(self, raiseit=True): raise exc.with_traceback(tb) else: return exc - elif self.status == "cancelled": - exception = CancelledError(self.key) + elif self.cancelled(): + assert self._state + exception = self._state.exception + assert isinstance(exception, CancelledError) if raiseit: raise exception else: @@ -414,7 +478,7 @@ def execute_callback(fut): done_callback, self, partial(cls._cb_executor.submit, execute_callback) ) - def cancel(self, **kwargs): + def cancel(self, reason=None, msg=None, **kwargs): """Cancel the request to run this future See Also @@ -422,7 +486,7 @@ def cancel(self, **kwargs): Client.cancel """ self._verify_initialized() - return self.client.cancel([self], **kwargs) + return self.client.cancel([self], reason=reason, msg=msg, **kwargs) def retry(self, **kwargs): """Retry this future if it has failed @@ -552,11 +616,14 @@ class FutureState: This is shared between all Futures with the same key and client. """ - __slots__ = ("_event", "status", "type", "exception", "traceback") + __slots__ = ("_event", "key", "status", "type", "exception", "traceback") - def __init__(self): + def __init__(self, key: str): self._event = None + self.key = key + self.exception = None self.status = "pending" + self.traceback = None self.type = None def _get_event(self): @@ -568,10 +635,10 @@ def _get_event(self): event = self._event = asyncio.Event() return event - def cancel(self): + def cancel(self, reason=None, msg=None): """Cancels the operation""" self.status = "cancelled" - self.exception = CancelledError() + self.exception = FutureCancelledError(key=self.key, reason=reason, msg=msg) self._get_event().set() def finish(self, type=None): @@ -1321,7 +1388,13 @@ async def _reconnect(self): self.scheduler_comm = None for st in self.futures.values(): - st.cancel() + st.cancel( + reason="scheduler-connection-lost", + msg=( + "Client lost the connection to the scheduler. " + "Please check your connection and re-run your work." + ), + ) self.futures.clear() timeout = self._timeout @@ -1640,7 +1713,10 @@ def _handle_task_erred(self, key=None, exception=None, traceback=None): def _handle_restart(self): logger.info("Receive restart signal from scheduler") for state in self.futures.values(): - state.cancel() + state.cancel( + reason="scheduler-restart", + msg="Scheduler has restarted. Please re-run your work.", + ) self.futures.clear() self.generation += 1 with self._refcount_lock: @@ -2220,19 +2296,15 @@ async def wait(k): exceptions = set() bad_keys = set() - for key in keys: - if key not in self.futures or self.futures[key].status in failed: + for future in future_set: + key = future.key + if key not in self.futures or future.status in failed: exceptions.add(key) if errors == "raise": - try: - st = self.futures[key] - exception = st.exception - traceback = st.traceback - except (KeyError, AttributeError): - exc = CancelledError(key) - else: - raise exception.with_traceback(traceback) - raise exc + st = future._state + exception = st.exception + traceback = st.traceback + raise exception.with_traceback(traceback) if errors == "skip": bad_keys.add(key) bad_data[key] = None @@ -2602,16 +2674,16 @@ def scatter( hash=hash, ) - async def _cancel(self, futures, force=False): + async def _cancel(self, futures, reason=None, msg=None, force=False): # FIXME: This method is asynchronous since interacting with the FutureState below requires an event loop. keys = list({f.key for f in futures_of(futures)}) self._send_to_scheduler({"op": "cancel-keys", "keys": keys, "force": force}) for k in keys: st = self.futures.pop(k, None) if st is not None: - st.cancel() + st.cancel(reason=reason, msg=msg) - def cancel(self, futures, asynchronous=None, force=False): + def cancel(self, futures, asynchronous=None, force=False, reason=None, msg=None): """ Cancel running futures This stops future tasks from being scheduled if they have not yet run @@ -2626,8 +2698,14 @@ def cancel(self, futures, asynchronous=None, force=False): If True the client is in asynchronous mode force : boolean (False) Cancel this future even if other clients desire it + reason: str + Reason for cancelling the futures + msg : str + Message that will be attached to the cancelled future """ - return self.sync(self._cancel, futures, asynchronous=asynchronous, force=force) + return self.sync( + self._cancel, futures, asynchronous=asynchronous, force=force, msg=msg + ) async def _retry(self, futures): keys = list({f.key for f in futures_of(futures)}) @@ -5445,9 +5523,19 @@ async def _wait(fs, timeout=None, return_when=ALL_COMPLETED): {fu for fu in fs if fu.status != "pending"}, {fu for fu in fs if fu.status == "pending"}, ) - cancelled = [f.key for f in done if f.status == "cancelled"] - if cancelled: - raise CancelledError(cancelled) + cancelled_errors = defaultdict(list) + for f in done: + if not f.cancelled(): + continue + exception = f._state.exception + assert isinstance(exception, FutureCancelledError) + cancelled_errors[exception.reason].append(exception) + if cancelled_errors: + groups = [ + CancelledFuturesGroup(errors=errors, reason=reason) + for reason, errors in cancelled_errors.items() + ] + raise FuturesCancelledError(groups) return DoneAndNotDoneFutures(done, not_done) @@ -5678,8 +5766,6 @@ def _get_and_raise(self): if self.raise_errors and future.status == "error": typ, exc, tb = result raise exc.with_traceback(tb) - elif future.status == "cancelled": - res = (res[0], CancelledError(future.key)) return res def __next__(self): @@ -5891,10 +5977,19 @@ def futures_of(o, client=None): stack.extend(x.__dask_graph__().values()) if client is not None: - bad = {f for f in futures if f.cancelled()} - if bad: - raise CancelledError(bad) - + cancelled_errors = defaultdict(list) + for f in futures: + if not f.cancelled(): + continue + exception = f._state.exception + assert isinstance(exception, FutureCancelledError) + cancelled_errors[exception.reason].append(exception) + if cancelled_errors: + groups = [ + CancelledFuturesGroup(errors=errors, reason=reason) + for reason, errors in cancelled_errors.items() + ] + raise FuturesCancelledError(groups) return futures[::-1] diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c00e2c3938..704155c630 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -63,6 +63,8 @@ from distributed.client import ( Client, Future, + FutureCancelledError, + FuturesCancelledError, _get_global_client, _global_clients, as_completed, @@ -8566,3 +8568,42 @@ async def test_gather_race_vs_AMM(c, s, a, direct): b.block_get_data.set() assert await fut == 3 # It's from a; it would be 2 if it were from b + + +@gen_cluster(client=True) +async def test_client_disconnect_exception_on_cancelled_futures(c, s, a, b): + fut = c.submit(inc, 1) + await wait(fut) + + await s.close() + + with pytest.raises(FutureCancelledError, match="connection to the scheduler"): + await fut.result() + + with pytest.raises(FuturesCancelledError, match="connection to the scheduler"): + await wait(fut) + + with pytest.raises(FutureCancelledError, match="connection to the scheduler"): + await fut + + with pytest.raises(FutureCancelledError, match="connection to the scheduler"): + await c.gather([fut]) + + with pytest.raises(FuturesCancelledError, match="connection to the scheduler"): + futures_of(fut, client=c) + + async for fut, res in as_completed([fut], with_results=True): + assert isinstance(res, FutureCancelledError) + assert "connection to the scheduler" in res.msg + + +@pytest.mark.slow +@gen_cluster(client=True, Worker=Nanny, timeout=60) +async def test_scheduler_restart_exception_on_cancelled_futures(c, s, a, b): + fut = c.submit(inc, 1) + await wait(fut) + + await s.restart(stimulus_id="test") + + with pytest.raises(CancelledError, match="Scheduler has restarted"): + await fut.result() From 58384678648f523050af5f4219386725eab93dd3 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Mon, 24 Jun 2024 15:20:43 +0200 Subject: [PATCH 037/138] More useful warning if a plugin type is provided instead of instance (#8689) Co-authored-by: Hendrik Makait --- distributed/client.py | 2 ++ distributed/tests/test_client.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/distributed/client.py b/distributed/client.py index f9c66018c1..5b35b69573 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4977,6 +4977,8 @@ def _register_plugin( name: str, idempotent: bool, ): + if isinstance(plugin, type): + raise TypeError("Please provide an instance of a plugin, not a type.") raise TypeError( "Registering duck-typed plugins is not allowed. Please inherit from " "NannyPlugin, WorkerPlugin, or SchedulerPlugin to create a plugin." diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 704155c630..6290d15a2e 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6998,6 +6998,15 @@ async def test_get_task_metadata_multiple(c, s, a, b): assert metadata2[f2.key] == s.tasks.get(f2.key).metadata +@gen_cluster(client=True) +async def test_register_worker_plugin_instance_required(c, s, a, b): + class MyPlugin(WorkerPlugin): + ... + + with pytest.raises(TypeError, match="instance"): + await c.register_plugin(MyPlugin) + + @gen_cluster(client=True) async def test_register_worker_plugin_exception(c, s, a, b): class MyPlugin(WorkerPlugin): From b77f38dc59a338cfefd995764786b292908912fb Mon Sep 17 00:00:00 2001 From: Sultan Orazbayev Date: Mon, 24 Jun 2024 09:22:48 -0400 Subject: [PATCH 038/138] Fix type in actor docs (#8711) --- docs/source/actors.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/actors.rst b/docs/source/actors.rst index 8efd027e42..30a1b1f4a9 100644 --- a/docs/source/actors.rst +++ b/docs/source/actors.rst @@ -228,6 +228,6 @@ Actors offer advanced capabilities, but with some cost: 3. **No Load balancing:** Actors are allocated onto workers evenly, without serious consideration given to avoiding communication. 4. **No dynamic clusters:** Actors cannot be migrated to other workers. - A worker holding an actor can be retired neither through + A worker holding an actor can't be retired, neither through :meth:`~distributed.Client.retire_workers` nor through :class:`~distributed.deploy.Adaptive`. From 33a281fff69637bf589043c691cfb426e1a9e25e Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Mon, 24 Jun 2024 15:41:45 +0200 Subject: [PATCH 039/138] More robust bokeh test_shuffling (#8727) --- distributed/dashboard/tests/test_scheduler_bokeh.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 3cdf8e70bf..a83bcf46bb 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -1350,8 +1350,8 @@ async def test_shuffling(c, s, a, b): start = time() while not ss.source.data["comm_written"]: ss.update() - await asyncio.sleep(0) - assert time() < start + 5 + await asyncio.sleep(0.05) + assert time() < start + 10 await df2 From 5b2e2b43d2e0caf5e25c4c491d270669010cd60d Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Mon, 24 Jun 2024 18:38:16 +0200 Subject: [PATCH 040/138] Cache URL encoding of worker addresses in dashboard (#8725) --- distributed/core.py | 2 ++ distributed/dashboard/components/nvml.py | 5 ++--- distributed/dashboard/components/rmm.py | 5 ++--- distributed/dashboard/components/scheduler.py | 13 ++++++------- .../http/scheduler/tests/test_scheduler_http.py | 3 +-- distributed/utils.py | 9 +++++++++ 6 files changed, 22 insertions(+), 15 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 66470ef224..d70ae3f383 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import functools import inspect import logging import math @@ -121,6 +122,7 @@ def _raise(*args, **kwargs): LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") +@functools.lru_cache def _expects_comm(func: Callable) -> bool: sig = inspect.signature(func) params = list(sig.parameters) diff --git a/distributed/dashboard/components/nvml.py b/distributed/dashboard/components/nvml.py index 20b0f6856b..6a381005c5 100644 --- a/distributed/dashboard/components/nvml.py +++ b/distributed/dashboard/components/nvml.py @@ -12,14 +12,13 @@ TapTool, ) from bokeh.plotting import figure -from tornado import escape from dask.utils import format_bytes from distributed.dashboard.components import DashboardComponent, add_periodic_callback from distributed.dashboard.components.scheduler import BOKEH_THEME, TICKS_1024, env from distributed.dashboard.utils import update -from distributed.utils import log_errors +from distributed.utils import log_errors, url_escape class GPUCurrentLoad(DashboardComponent): @@ -149,7 +148,7 @@ def update(self): "worker": worker, "gpu-index": gpu_index, "y": y, - "escaped_worker": [escape.url_escape(w) for w in worker], + "escaped_worker": [url_escape(w) for w in worker], } self.memory_figure.title.text = "GPU Memory: {} / {}".format( diff --git a/distributed/dashboard/components/rmm.py b/distributed/dashboard/components/rmm.py index 7376476570..f37380f68f 100644 --- a/distributed/dashboard/components/rmm.py +++ b/distributed/dashboard/components/rmm.py @@ -14,7 +14,6 @@ TapTool, ) from bokeh.plotting import figure -from tornado import escape from dask.utils import format_bytes @@ -26,7 +25,7 @@ MemoryColor, ) from distributed.dashboard.utils import update -from distributed.utils import log_errors +from distributed.utils import log_errors, url_escape T = TypeVar("T") @@ -191,7 +190,7 @@ def quadlist(i: Iterable[T]) -> list[T]: "color": color, "alpha": [1, 0.7, 0.4, 1] * len(workers), "worker": quadlist(ws.address for ws in workers), - "escaped_worker": quadlist(escape.url_escape(ws.address) for ws in workers), + "escaped_worker": quadlist(url_escape(ws.address) for ws in workers), "rmm_used": quadlist(rmm_used), "rmm_total": quadlist(rmm_total), "gpu_used": quadlist(gpu_used), diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index af669c8c23..7f5b391ca0 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -53,7 +53,6 @@ from jinja2 import Environment, FileSystemLoader from tlz import curry, pipe, second, valmap from tlz.curried import concat, groupby, map -from tornado import escape import dask from dask import config @@ -91,7 +90,7 @@ from distributed.metrics import time from distributed.scheduler import Scheduler from distributed.spans import SpansSchedulerExtension -from distributed.utils import Log, log_errors +from distributed.utils import Log, log_errors, url_escape if dask.config.get("distributed.dashboard.export-tool"): from distributed.dashboard.export_tool import ExportTool @@ -199,7 +198,7 @@ def update(self): "worker": [ws.address for ws in workers], "ms": ms, "color": color, - "escaped_worker": [escape.url_escape(ws.address) for ws in workers], + "escaped_worker": [url_escape(ws.address) for ws in workers], "x": x, "y": y, } @@ -581,7 +580,7 @@ def quadlist(i: Iterable[T]) -> list[T]: "color": color, "alpha": [1, 0.7, 0.4, 1] * len(workers), "worker": quadlist(ws.address for ws in workers), - "escaped_worker": quadlist(escape.url_escape(ws.address) for ws in workers), + "escaped_worker": quadlist(url_escape(ws.address) for ws in workers), "y": quadlist(range(len(workers))), "proc_memory": quadlist(procmemory), "managed": quadlist(managed), @@ -732,7 +731,7 @@ def update(self): ws.metrics["transfer"]["outgoing_bytes"] for ws in wss ] workers = [ws.address for ws in wss] - escaped_workers = [escape.url_escape(worker) for worker in workers] + escaped_workers = [url_escape(worker) for worker in workers] if wss: x_limit = max( @@ -1840,7 +1839,7 @@ def update(self): "nprocessing-half": [np / 2 for np in nprocessing], "nprocessing-color": nprocessing_color, "worker": [ws.address for ws in workers], - "escaped_worker": [escape.url_escape(ws.address) for ws in workers], + "escaped_worker": [url_escape(ws.address) for ws in workers], "y": list(range(len(workers))), } @@ -2381,7 +2380,7 @@ def add_new_nodes_edges(self, new, new_edges, update=False): continue xx = x[key] yy = y[key] - node_key.append(escape.url_escape(str(key))) + node_key.append(url_escape(str(key))) node_x.append(xx) node_y.append(yy) node_state.append(task.state) diff --git a/distributed/http/scheduler/tests/test_scheduler_http.py b/distributed/http/scheduler/tests/test_scheduler_http.py index 78a9744ce1..4b647dc6f2 100644 --- a/distributed/http/scheduler/tests/test_scheduler_http.py +++ b/distributed/http/scheduler/tests/test_scheduler_http.py @@ -7,7 +7,6 @@ from unittest import mock import pytest -from tornado.escape import url_escape from tornado.httpclient import AsyncHTTPClient, HTTPClientError import dask.config @@ -16,7 +15,7 @@ from distributed import Event, Lock, Scheduler from distributed.client import wait from distributed.core import Status -from distributed.utils import is_valid_xml +from distributed.utils import is_valid_xml, url_escape from distributed.utils_test import ( async_poll_for, div, diff --git a/distributed/utils.py b/distributed/utils.py index 6fbbf0f224..9658dffa8a 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -52,6 +52,7 @@ import click import psutil import tblib.pickling_support +from tornado import escape from distributed.compatibility import asyncio_run from distributed.config import get_loop_factory @@ -1994,3 +1995,11 @@ def __eq__(self, other): def __lt__(self, other): return self.obj < other.obj + + +@functools.lru_cache +def url_escape(url, *args, **kwargs): + """ + Escape a URL path segment. Cache results for better performance. + """ + return escape.url_escape(url, *args, **kwargs) From 97dbdaaf57937bfd44a5934273fe108e7fdd7e7f Mon Sep 17 00:00:00 2001 From: Elliott Sales de Andrade Date: Tue, 25 Jun 2024 04:04:01 -0400 Subject: [PATCH 041/138] TST: Use safer context for ProcessPoolExecutor (#8715) --- distributed/tests/test_worker.py | 8 ++++---- distributed/tests/test_worker_metrics.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index b9e51944a9..602ca85402 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -49,7 +49,7 @@ from distributed.metrics import time from distributed.protocol import pickle from distributed.scheduler import KilledWorker, Scheduler -from distributed.utils import wait_for +from distributed.utils import get_mp_context, wait_for from distributed.utils_test import ( NO_AMM, BlockedExecute, @@ -2199,7 +2199,7 @@ async def test_bad_executor_annotation(c, s, a, b): @gen_cluster(client=True) async def test_process_executor(c, s, a, b): - with ProcessPoolExecutor() as e: + with ProcessPoolExecutor(mp_context=get_mp_context()) as e: a.executors["processes"] = e b.executors["processes"] = e @@ -2231,7 +2231,7 @@ def kill_process(): @gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) async def test_process_executor_kills_process(c, s, a): - with ProcessPoolExecutor() as e: + with ProcessPoolExecutor(mp_context=get_mp_context()) as e: a.executors["processes"] = e with dask.annotate(executor="processes", retries=1): future = c.submit(kill_process) @@ -2254,7 +2254,7 @@ def raise_exc(): @gen_cluster(client=True) async def test_process_executor_raise_exception(c, s, a, b): - with ProcessPoolExecutor() as e: + with ProcessPoolExecutor(mp_context=get_mp_context()) as e: a.executors["processes"] = e b.executors["processes"] = e with dask.annotate(executor="processes", retries=1): diff --git a/distributed/tests/test_worker_metrics.py b/distributed/tests/test_worker_metrics.py index e12c4902b1..4623cce59d 100644 --- a/distributed/tests/test_worker_metrics.py +++ b/distributed/tests/test_worker_metrics.py @@ -141,7 +141,7 @@ async def test_custom_executor(c, s, a): """Don't try to acquire in-thread metrics when the executor is a ProcessPoolExecutor or a custom, arbitrary executor. """ - with ProcessPoolExecutor(1) as e: + with ProcessPoolExecutor(1, mp_context=distributed.utils.get_mp_context()) as e: # Warm up executor - this can take up to 2s in Windows and MacOSX e.submit(inc, 1).result() From 0c585450421d1af5cd9ebd2b403d5ae23475f898 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 26 Jun 2024 14:38:01 +0200 Subject: [PATCH 042/138] Extract tests related to event-logging into separate file (#8733) --- distributed/tests/test_client.py | 201 +---------------- distributed/tests/test_event_logging.py | 284 ++++++++++++++++++++++++ distributed/tests/test_nanny.py | 32 +-- distributed/tests/test_scheduler.py | 16 -- distributed/tests/test_worker.py | 34 --- 5 files changed, 288 insertions(+), 279 deletions(-) create mode 100644 distributed/tests/test_event_logging.py diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 6290d15a2e..56d535bc06 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -79,7 +79,7 @@ from distributed.cluster_dump import load_cluster_dump from distributed.comm import CommClosedError from distributed.compatibility import LINUX, MACOS, WINDOWS -from distributed.core import Status, error_message +from distributed.core import Status from distributed.diagnostics.plugin import WorkerPlugin from distributed.metrics import time from distributed.scheduler import CollectTaskMetaDataPlugin, KilledWorker, Scheduler @@ -7017,64 +7017,6 @@ def setup(self, worker=None): await c.register_plugin(MyPlugin()) -@gen_cluster(client=True, nthreads=[("", 1)]) -async def test_log_event(c, s, a): - # Log an event from inside a task - def foo(): - get_worker().log_event("topic1", {"foo": "bar"}) - - assert not await c.get_events("topic1") - await c.submit(foo) - events = await c.get_events("topic1") - assert len(events) == 1 - assert events[0][1] == {"foo": "bar", "worker": a.address} - - # Log an event while on the scheduler - def log_scheduler(dask_scheduler): - dask_scheduler.log_event("topic2", {"woo": "hoo"}) - - await c.run_on_scheduler(log_scheduler) - events = await c.get_events("topic2") - assert len(events) == 1 - assert events[0][1] == {"woo": "hoo"} - - # Log an event from the client process - await c.log_event("topic2", ("alice", "bob")) - events = await c.get_events("topic2") - assert len(events) == 2 - assert events[1][1] == ("alice", "bob") - - -@gen_cluster(client=True, nthreads=[]) -async def test_log_event_multiple_clients(c, s): - async with Client(s.address, asynchronous=True) as c2, Client( - s.address, asynchronous=True - ) as c3: - received_events = [] - - def get_event_handler(handler_id): - def handler(event): - received_events.append((handler_id, event)) - - return handler - - c.subscribe_topic("test-topic", get_event_handler(1)) - c2.subscribe_topic("test-topic", get_event_handler(2)) - - while len(s.event_subscriber["test-topic"]) != 2: - await asyncio.sleep(0.01) - - with captured_logger("distributed.client") as logger: - await c.log_event("test-topic", {}) - - while len(received_events) < 2: - await asyncio.sleep(0.01) - - assert len(received_events) == 2 - assert {handler_id for handler_id, _ in received_events} == {1, 2} - assert "ValueError" not in logger.getvalue() - - @gen_cluster(client=True) async def test_annotations_task_state(c, s, a, b): da = pytest.importorskip("dask.array") @@ -7739,119 +7681,8 @@ async def f(x, y): assert result == 12 -@gen_cluster(client=True, nthreads=[("", 1)]) -async def test_events_subscribe_topic(c, s, a): - log = [] - - def user_event_handler(event): - log.append(event) - - c.subscribe_topic("test-topic", user_event_handler) - - while not s.event_subscriber["test-topic"]: - await asyncio.sleep(0.01) - - a.log_event("test-topic", {"important": "event"}) - - while len(log) != 1: - await asyncio.sleep(0.01) - - time_, msg = log[0] - assert isinstance(time_, float) - assert msg == {"important": "event", "worker": a.address} - - c.unsubscribe_topic("test-topic") - - while s.event_subscriber["test-topic"]: - await asyncio.sleep(0.01) - - a.log_event("test-topic", {"forget": "me"}) - - while len(s.events["test-topic"]) == 1: - await asyncio.sleep(0.01) - - assert len(log) == 1 - - async def async_user_event_handler(event): - log.append(event) - await asyncio.sleep(0) - - c.subscribe_topic("test-topic", async_user_event_handler) - - while not s.event_subscriber["test-topic"]: - await asyncio.sleep(0.01) - - a.log_event("test-topic", {"async": "event"}) - - while len(log) == 1: - await asyncio.sleep(0.01) - - assert len(log) == 2 - time_, msg = log[1] - assert isinstance(time_, float) - assert msg == {"async": "event", "worker": a.address} - - # Even though the middle event was not subscribed to, the scheduler still - # knows about all and we can retrieve them - all_events = await c.get_events(topic="test-topic") - assert len(all_events) == 3 - - -@gen_cluster(client=True, nthreads=[("", 1)]) -async def test_events_subscribe_topic_cancelled(c, s, a): - event_handler_started = asyncio.Event() - exc_info = None - - async def user_event_handler(event): - nonlocal exc_info - c.unsubscribe_topic("test-topic") - event_handler_started.set() - with pytest.raises(asyncio.CancelledError) as exc_info: - await asyncio.sleep(0.5) - - c.subscribe_topic("test-topic", user_event_handler) - while not s.event_subscriber["test-topic"]: - await asyncio.sleep(0.01) - - a.log_event("test-topic", {}) - await event_handler_started.wait() - await c._close(fast=True) - assert exc_info is not None - - -@gen_cluster(client=True, nthreads=[("", 1)]) -async def test_events_all_servers_use_same_channel(c, s, a): - """Ensure that logs from all server types (scheduler, worker, nanny) - and the clients themselves arrive""" - - log = [] - - def user_event_handler(event): - log.append(event) - - c.subscribe_topic("test-topic", user_event_handler) - - while not s.event_subscriber["test-topic"]: - await asyncio.sleep(0.01) - - async with Nanny(s.address) as n: - a.log_event("test-topic", "worker") - n.log_event("test-topic", "nanny") - s.log_event("test-topic", "scheduler") - await c.log_event("test-topic", "client") - - while not len(log) == 4 == len(set(log)): - await asyncio.sleep(0.1) - - -@gen_cluster(client=True, nthreads=[]) -async def test_events_unsubscribe_raises_if_unknown(c, s): - with pytest.raises(ValueError, match="No event handler known for topic unknown"): - c.unsubscribe_topic("unknown") - - @gen_cluster(client=True) -async def test_log_event_warn(c, s, a, b): +async def test_warn_manual(c, s, a, b): def foo(): get_worker().log_event(["foo", "warn"], "Hello!") @@ -7874,34 +7705,8 @@ def no_category(): await c.submit(no_category) -@gen_cluster(client=True, nthreads=[]) -async def test_log_event_msgpack(c, s, a, b): - await c.log_event("test-topic", "foo") - with pytest.raises(TypeError, match="msgpack"): - - class C: - pass - - await c.log_event("test-topic", C()) - await c.log_event("test-topic", "bar") - await c.log_event("test-topic", error_message(Exception())) - - # assertion reversed for mock.ANY.__eq__(Serialized()) - assert [ - "foo", - "bar", - { - "status": "error", - "exception": mock.ANY, - "traceback": mock.ANY, - "exception_text": "Exception()", - "traceback_text": "", - }, - ] == [msg[1] for msg in s.get_events("test-topic")] - - @gen_cluster(client=True) -async def test_log_event_warn_dask_warns(c, s, a, b): +async def test_warn_remote(c, s, a, b): from dask.distributed import warn def warn_simple(): diff --git a/distributed/tests/test_event_logging.py b/distributed/tests/test_event_logging.py new file mode 100644 index 0000000000..2c74af3ca6 --- /dev/null +++ b/distributed/tests/test_event_logging.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +import asyncio +from unittest import mock + +import pytest + +from distributed import Client, Nanny, get_worker +from distributed.core import error_message +from distributed.utils_test import captured_logger, gen_cluster + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_log_event(c, s, a): + # Log an event from inside a task + def foo(): + get_worker().log_event("topic1", {"foo": "bar"}) + + assert not await c.get_events("topic1") + await c.submit(foo) + events = await c.get_events("topic1") + assert len(events) == 1 + assert events[0][1] == {"foo": "bar", "worker": a.address} + + # Log an event while on the scheduler + def log_scheduler(dask_scheduler): + dask_scheduler.log_event("topic2", {"woo": "hoo"}) + + await c.run_on_scheduler(log_scheduler) + events = await c.get_events("topic2") + assert len(events) == 1 + assert events[0][1] == {"woo": "hoo"} + + # Log an event from the client process + await c.log_event("topic2", ("alice", "bob")) + events = await c.get_events("topic2") + assert len(events) == 2 + assert events[1][1] == ("alice", "bob") + + +@gen_cluster(client=True, nthreads=[]) +async def test_log_event_multiple_clients(c, s): + async with Client(s.address, asynchronous=True) as c2, Client( + s.address, asynchronous=True + ) as c3: + received_events = [] + + def get_event_handler(handler_id): + def handler(event): + received_events.append((handler_id, event)) + + return handler + + c.subscribe_topic("test-topic", get_event_handler(1)) + c2.subscribe_topic("test-topic", get_event_handler(2)) + + while len(s.event_subscriber["test-topic"]) != 2: + await asyncio.sleep(0.01) + + with captured_logger("distributed.client") as logger: + await c.log_event("test-topic", {}) + + while len(received_events) < 2: + await asyncio.sleep(0.01) + + assert len(received_events) == 2 + assert {handler_id for handler_id, _ in received_events} == {1, 2} + assert "ValueError" not in logger.getvalue() + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_events_subscribe_topic(c, s, a): + log = [] + + def user_event_handler(event): + log.append(event) + + c.subscribe_topic("test-topic", user_event_handler) + + while not s.event_subscriber["test-topic"]: + await asyncio.sleep(0.01) + + a.log_event("test-topic", {"important": "event"}) + + while len(log) != 1: + await asyncio.sleep(0.01) + + time_, msg = log[0] + assert isinstance(time_, float) + assert msg == {"important": "event", "worker": a.address} + + c.unsubscribe_topic("test-topic") + + while s.event_subscriber["test-topic"]: + await asyncio.sleep(0.01) + + a.log_event("test-topic", {"forget": "me"}) + + while len(s.events["test-topic"]) == 1: + await asyncio.sleep(0.01) + + assert len(log) == 1 + + async def async_user_event_handler(event): + log.append(event) + await asyncio.sleep(0) + + c.subscribe_topic("test-topic", async_user_event_handler) + + while not s.event_subscriber["test-topic"]: + await asyncio.sleep(0.01) + + a.log_event("test-topic", {"async": "event"}) + + while len(log) == 1: + await asyncio.sleep(0.01) + + assert len(log) == 2 + time_, msg = log[1] + assert isinstance(time_, float) + assert msg == {"async": "event", "worker": a.address} + + # Even though the middle event was not subscribed to, the scheduler still + # knows about all and we can retrieve them + all_events = await c.get_events(topic="test-topic") + assert len(all_events) == 3 + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_events_subscribe_topic_cancelled(c, s, a): + event_handler_started = asyncio.Event() + exc_info = None + + async def user_event_handler(event): + nonlocal exc_info + c.unsubscribe_topic("test-topic") + event_handler_started.set() + with pytest.raises(asyncio.CancelledError) as exc_info: + await asyncio.sleep(0.5) + + c.subscribe_topic("test-topic", user_event_handler) + while not s.event_subscriber["test-topic"]: + await asyncio.sleep(0.01) + + a.log_event("test-topic", {}) + await event_handler_started.wait() + await c._close(fast=True) + assert exc_info is not None + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_events_all_servers_use_same_channel(c, s, a): + """Ensure that logs from all server types (scheduler, worker, nanny) + and the clients themselves arrive""" + + log = [] + + def user_event_handler(event): + log.append(event) + + c.subscribe_topic("test-topic", user_event_handler) + + while not s.event_subscriber["test-topic"]: + await asyncio.sleep(0.01) + + async with Nanny(s.address) as n: + a.log_event("test-topic", "worker") + n.log_event("test-topic", "nanny") + s.log_event("test-topic", "scheduler") + await c.log_event("test-topic", "client") + + while not len(log) == 4 == len(set(log)): + await asyncio.sleep(0.1) + + +@gen_cluster(client=True, nthreads=[]) +async def test_events_unsubscribe_raises_if_unknown(c, s): + with pytest.raises(ValueError, match="No event handler known for topic unknown"): + c.unsubscribe_topic("unknown") + + +@gen_cluster(client=True, nthreads=[]) +async def test_log_event_msgpack(c, s, a, b): + await c.log_event("test-topic", "foo") + with pytest.raises(TypeError, match="msgpack"): + + class C: + pass + + await c.log_event("test-topic", C()) + await c.log_event("test-topic", "bar") + await c.log_event("test-topic", error_message(Exception())) + + # assertion reversed for mock.ANY.__eq__(Serialized()) + assert [ + "foo", + "bar", + { + "status": "error", + "exception": mock.ANY, + "traceback": mock.ANY, + "exception_text": "Exception()", + "traceback_text": "", + }, + ] == [msg[1] for msg in s.get_events("test-topic")] + + +@gen_cluster(client=True, config={"distributed.admin.low-level-log-length": 3}) +async def test_configurable_events_log_length(c, s, a, b): + s.log_event("test", "dummy message 1") + assert len(s.events["test"]) == 1 + s.log_event("test", "dummy message 2") + s.log_event("test", "dummy message 3") + assert len(s.events["test"]) == 3 + + # adding a fourth message will drop the first one and length stays at 3 + s.log_event("test", "dummy message 4") + assert len(s.events["test"]) == 3 + assert s.events["test"][0][1] == "dummy message 2" + assert s.events["test"][1][1] == "dummy message 3" + assert s.events["test"][2][1] == "dummy message 4" + + +@gen_cluster(client=True, nthreads=[]) +async def test_log_event_on_nanny(c, s): + async with Nanny(s.address) as n: + n.log_event("test-topic1", "foo") + + class C: + pass + + with pytest.raises(TypeError, match="msgpack"): + n.log_event("test-topic2", C()) + n.log_event("test-topic3", "bar") + n.log_event("test-topic4", error_message(Exception())) + + # Worker unaffected + assert await c.submit(lambda x: x + 1, 1) == 2 + + assert [msg[1] for msg in s.get_events("test-topic1")] == ["foo"] + assert [msg[1] for msg in s.get_events("test-topic3")] == ["bar"] + # assertion reversed for mock.ANY.__eq__(Serialized()) + assert [ + { + "status": "error", + "exception": mock.ANY, + "traceback": mock.ANY, + "exception_text": "Exception()", + "traceback_text": "", + }, + ] == [msg[1] for msg in s.get_events("test-topic4")] + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_log_event_on_worker(c, s, a): + def log_event(msg): + w = get_worker() + w.log_event("test-topic", msg) + + await c.submit(log_event, "foo") + + class C: + pass + + with pytest.raises(TypeError, match="msgpack"): + await c.submit(log_event, C()) + + # Worker still works + await c.submit(log_event, "bar") + await c.submit(log_event, error_message(Exception())) + + # assertion reversed for mock.ANY.__eq__(Serialized()) + assert [ + "foo", + "bar", + { + "status": "error", + "exception": mock.ANY, + "traceback": mock.ANY, + "exception_text": "Exception()", + "traceback_text": "", + "worker": a.address, + }, + ] == [msg[1] for msg in s.get_events("test-topic")] diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index f5995d4526..3afe8ae344 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -22,7 +22,7 @@ from distributed import Nanny, Scheduler, Worker, profile, rpc, wait, worker from distributed.compatibility import LINUX, WINDOWS -from distributed.core import CommClosedError, Status, error_message +from distributed.core import CommClosedError, Status from distributed.diagnostics import SchedulerPlugin from distributed.diagnostics.plugin import NannyPlugin, WorkerPlugin from distributed.metrics import time @@ -798,36 +798,6 @@ async def test_worker_inherits_temp_config(c, s): assert out == 123 -@gen_cluster(client=True, nthreads=[]) -async def test_log_event(c, s): - async with Nanny(s.address) as n: - n.log_event("test-topic1", "foo") - - class C: - pass - - with pytest.raises(TypeError, match="msgpack"): - n.log_event("test-topic2", C()) - n.log_event("test-topic3", "bar") - n.log_event("test-topic4", error_message(Exception())) - - # Worker unaffected - assert await c.submit(lambda x: x + 1, 1) == 2 - - assert [msg[1] for msg in s.get_events("test-topic1")] == ["foo"] - assert [msg[1] for msg in s.get_events("test-topic3")] == ["bar"] - # assertion reversed for mock.ANY.__eq__(Serialized()) - assert [ - { - "status": "error", - "exception": mock.ANY, - "traceback": mock.ANY, - "exception_text": "Exception()", - "traceback_text": "", - }, - ] == [msg[1] for msg in s.get_events("test-topic4")] - - @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) async def test_nanny_plugin_simple(c, s, a): """A plugin should be registered to already existing workers but also to new ones.""" diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index d55e19fee1..3f6aa545bd 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -3302,22 +3302,6 @@ async def test_retire_state_change(c, s, a, b): await asyncio.gather(*coros) -@gen_cluster(client=True, config={"distributed.admin.low-level-log-length": 3}) -async def test_configurable_events_log_length(c, s, a, b): - s.log_event("test", "dummy message 1") - assert len(s.events["test"]) == 1 - s.log_event("test", "dummy message 2") - s.log_event("test", "dummy message 3") - assert len(s.events["test"]) == 3 - - # adding a fourth message will drop the first one and length stays at 3 - s.log_event("test", "dummy message 4") - assert len(s.events["test"]) == 3 - assert s.events["test"][0][1] == "dummy message 2" - assert s.events["test"][1][1] == "dummy message 3" - assert s.events["test"][2][1] == "dummy message 4" - - @gen_cluster() async def test_get_worker_monitor_info(s, a, b): res = await s.get_worker_monitor_info() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 602ca85402..067a025ae8 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -18,7 +18,6 @@ from numbers import Number from operator import add from time import sleep -from unittest import mock import psutil import pytest @@ -868,39 +867,6 @@ async def test_dont_overlap_communications_to_same_worker(c, s, a, b): assert l1["stop"] < l2["start"] -@gen_cluster(client=True, nthreads=[("", 1)]) -async def test_log_event(c, s, a): - def log_event(msg): - w = get_worker() - w.log_event("test-topic", msg) - - await c.submit(log_event, "foo") - - class C: - pass - - with pytest.raises(TypeError, match="msgpack"): - await c.submit(log_event, C()) - - # Worker still works - await c.submit(log_event, "bar") - await c.submit(log_event, error_message(Exception())) - - # assertion reversed for mock.ANY.__eq__(Serialized()) - assert [ - "foo", - "bar", - { - "status": "error", - "exception": mock.ANY, - "traceback": mock.ANY, - "exception_text": "Exception()", - "traceback_text": "", - "worker": a.address, - }, - ] == [msg[1] for msg in s.get_events("test-topic")] - - @gen_cluster(client=True) async def test_log_exception_on_failed_task(c, s, a, b): with captured_logger("distributed.worker") as logger: From 034b8fff435414e3c6d10a72b45d658142ebbca1 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 26 Jun 2024 16:02:01 +0200 Subject: [PATCH 043/138] Fix `mindeps`-testing on CI (#8728) --- .github/workflows/tests.yaml | 2 +- distributed/comm/tests/test_ws.py | 1 + distributed/shuffle/tests/test_merge.py | 5 +- .../tests/test_merge_column_and_index.py | 4 +- distributed/shuffle/tests/test_metrics.py | 10 +-- distributed/shuffle/tests/test_rechunk.py | 13 +++- distributed/shuffle/tests/test_shuffle.py | 7 +- .../tests/test_active_memory_manager.py | 1 + distributed/tests/test_client.py | 75 ++++++++++++++----- distributed/tests/test_computations.py | 1 + distributed/tests/test_nanny.py | 2 + distributed/tests/test_resources.py | 6 ++ distributed/tests/test_scheduler.py | 14 ++-- distributed/tests/test_spans.py | 4 +- distributed/tests/test_stress.py | 14 +++- distributed/tests/test_worker.py | 5 +- 16 files changed, 120 insertions(+), 44 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 57ea3ef6ba..cdd8150ec2 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -159,7 +159,7 @@ jobs: uses: actions/cache@v4 with: path: ${{ env.CONDA }}/envs - key: conda-${{ matrix.os }}-${{ steps.get-date.outputs.today }}-${{ hashFiles(env.CONDA_FILE) }}-${{ env.CACHE_NUMBER }} + key: conda-${{ matrix.os }}-${{ matrix.environment }}-${{ matrix.label }}-${{ steps.get-date.outputs.today }}-${{ hashFiles(env.CONDA_FILE) }}-${{ env.CACHE_NUMBER }} env: # Increase this value to reset cache if # continuous_integration/environment-${{ matrix.environment }}.yaml has not diff --git a/distributed/comm/tests/test_ws.py b/distributed/comm/tests/test_ws.py index 2b2c06ead5..83ff001c05 100644 --- a/distributed/comm/tests/test_ws.py +++ b/distributed/comm/tests/test_ws.py @@ -121,6 +121,7 @@ async def test_roundtrip(c, s, a, b): @gen_cluster(client=True, scheduler_kwargs={"protocol": "ws://"}) async def test_collections(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") x = da.random.random((1000, 1000), chunks=(100, 100)) x = x + x.T diff --git a/distributed/shuffle/tests/test_merge.py b/distributed/shuffle/tests/test_merge.py index 27786e963c..42d1f443a0 100644 --- a/distributed/shuffle/tests/test_merge.py +++ b/distributed/shuffle/tests/test_merge.py @@ -14,10 +14,9 @@ from distributed.shuffle._worker_plugin import ShuffleRun, _ShuffleRunManager from distributed.utils_test import gen_cluster -dd = pytest.importorskip("dask.dataframe") -import pandas as pd - +pd = pytest.importorskip("pandas") import dask +import dask.dataframe as dd from dask.dataframe._compat import PANDAS_GE_200, tm from dask.dataframe.utils import assert_eq diff --git a/distributed/shuffle/tests/test_merge_column_and_index.py b/distributed/shuffle/tests/test_merge_column_and_index.py index ead09eebd2..6bc16f0878 100644 --- a/distributed/shuffle/tests/test_merge_column_and_index.py +++ b/distributed/shuffle/tests/test_merge_column_and_index.py @@ -14,10 +14,10 @@ import pytest np = pytest.importorskip("numpy") -dd = pytest.importorskip("dask.dataframe") -import pandas as pd +pd = pytest.importorskip("pandas") import dask +import dask.dataframe as dd from dask.dataframe.utils import assert_eq from distributed.utils_test import gen_cluster diff --git a/distributed/shuffle/tests/test_metrics.py b/distributed/shuffle/tests/test_metrics.py index 6956b5f755..19d65d1698 100644 --- a/distributed/shuffle/tests/test_metrics.py +++ b/distributed/shuffle/tests/test_metrics.py @@ -5,11 +5,8 @@ import dask.datasets from distributed import Scheduler -from distributed.utils_test import gen_cluster - -da = pytest.importorskip("dask.array") -dd = pytest.importorskip("dask.dataframe") from distributed.shuffle.tests.utils import UNPACK_PREFIX +from distributed.utils_test import gen_cluster def assert_metrics(s: Scheduler, *keys: tuple[str, ...]) -> None: @@ -32,6 +29,9 @@ def assert_metrics(s: Scheduler, *keys: tuple[str, ...]) -> None: @gen_cluster(client=True, config={"optimization.fuse.active": False}) async def test_rechunk(c, s, a, b): + pytest.importorskip("numpy") + import dask.array as da + x = da.random.random((10, 10), chunks=(-1, 1)) x = x.rechunk((1, -1), method="p2p") await c.compute(x) @@ -72,7 +72,7 @@ async def test_dataframe(c, s, a, b): """Metrics are *almost* agnostic in dataframe shuffle vs. array rechunk. The only exception is the 'p2p-shards' metric, which is implemented separately. """ - dd = pytest.importorskip("dask.dataframe") + pytest.importorskip("pandas") df = dask.datasets.timeseries( start="2000-01-01", diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 0d282ec0ed..89d9791f43 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -531,6 +531,7 @@ async def test_rechunk_same_fully_unknown(c, s, *ws): -------- dask.array.tests.test_rechunk.test_rechunk_same_fully_unknown """ + pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") x = da.ones(shape=(10, 10), chunks=(5, 10)) y = dd.from_array(x).values @@ -549,6 +550,8 @@ async def test_rechunk_same_fully_unknown_floats(c, s, *ws): -------- dask.array.tests.test_rechunk.test_rechunk_same_fully_unknown_floats """ + + pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") x = da.ones(shape=(10, 10), chunks=(5, 10)) y = dd.from_array(x).values @@ -564,6 +567,7 @@ async def test_rechunk_same_partially_unknown(c, s, *ws): -------- dask.array.tests.test_rechunk.test_rechunk_same_partially_unknown """ + pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") x = da.ones(shape=(10, 10), chunks=(5, 10)) y = dd.from_array(x).values @@ -621,8 +625,8 @@ async def test_rechunk_unknown_from_pandas(c, s, *ws): -------- dask.array.tests.test_rechunk.test_rechunk_unknown_from_pandas """ - dd = pytest.importorskip("dask.dataframe") pd = pytest.importorskip("pandas") + dd = pytest.importorskip("dask.dataframe") arr = np.random.default_rng().standard_normal((50, 10)) x = dd.from_pandas(pd.DataFrame(arr), 2).values @@ -643,6 +647,7 @@ async def test_rechunk_unknown_from_array(c, s, *ws): -------- dask.array.tests.test_rechunk.test_rechunk_unknown_from_array """ + pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") x = dd.from_array(da.ones(shape=(4, 4), chunks=(2, 2))).values result = x.rechunk((None, 4), method="p2p") @@ -676,6 +681,7 @@ async def test_rechunk_with_fully_unknown_dimension(c, s, *ws, x, chunks): -------- dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension """ + pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") y = dd.from_array(x).values result = y.rechunk(chunks, method="p2p") @@ -718,6 +724,7 @@ async def test_rechunk_with_partially_unknown_dimension(c, s, *ws, x, chunks): -------- dask.array.tests.test_rechunk.test_rechunk_with_partially_unknown_dimension """ + pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") y = dd.from_array(x).values z = da.concatenate([x, y]) @@ -743,6 +750,7 @@ async def test_rechunk_with_fully_unknown_dimension_explicit(c, s, *ws, new_chun -------- dask.array.tests.test_rechunk.test_rechunk_with_fully_unknown_dimension_explicit """ + pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") x = da.ones(shape=(10, 10), chunks=(5, 2)) y = dd.from_array(x).values @@ -764,6 +772,7 @@ async def test_rechunk_unknown_raises(c, s, *ws): -------- dask.array.tests.test_rechunk.test_rechunk_unknown_raises """ + pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") x = da.ones(shape=(10, 10), chunks=(5, 5)) @@ -786,8 +795,6 @@ async def test_rechunk_zero_dim(c, s, *ws): -------- dask.array.tests.test_rechunk.test_rechunk_zero_dim """ - da = pytest.importorskip("dask.array") - x = da.ones((0, 10, 100), chunks=(0, 10, 10)).rechunk((0, 10, 50), method="p2p") assert len(await c.compute(x)) == 0 diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 0d3c962687..3d94ef81be 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -24,11 +24,10 @@ from distributed.shuffle._core import ShuffleId, ShuffleRun, barrier_key from distributed.worker import Status -dd = pytest.importorskip("dask.dataframe") - -import numpy as np -import pandas as pd +np = pytest.importorskip("numpy") +pd = pytest.importorskip("pandas") +import dask.dataframe as dd from dask.dataframe._compat import PANDAS_GE_150, PANDAS_GE_200 from dask.typing import Key diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index eb13356836..cb573c3672 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -1278,6 +1278,7 @@ def run(self): async def tensordot_stress(c, s): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") rng = da.random.RandomState(0) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 56d535bc06..f589105e67 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -371,6 +371,7 @@ async def test_persist_retries_annotations(c, s, a, b): @gen_cluster(client=True) async def test_retries_dask_array(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") x = da.ones((10, 10), chunks=(3, 3)) future = c.compute(x.sum(), retries=2) @@ -398,6 +399,7 @@ async def test_future_repr(c, s, a, b): @gen_cluster(client=True) async def test_future_tuple_repr(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") y = da.arange(10, chunks=(5,)).persist() f = futures_of(y)[0] @@ -2538,9 +2540,9 @@ async def test_async_persist(c, s, a, b): @gen_cluster(client=True) -async def test__persist(c, s, a, b): - pytest.importorskip("dask.array") - import dask.array as da +async def test_persist_async(c, s, a, b): + pytest.importorskip("numpy") + da = pytest.importorskip("dask.array") x = da.ones((10, 10), chunks=(5, 10)) y = 2 * (x + 1) @@ -2559,8 +2561,8 @@ async def test__persist(c, s, a, b): def test_persist(c): - pytest.importorskip("dask.array") - import dask.array as da + pytest.importorskip("numpy") + da = pytest.importorskip("dask.array") x = da.ones((10, 10), chunks=(5, 10)) y = 2 * (x + 1) @@ -2623,7 +2625,9 @@ async def test_futures_of_get(c, s, a, b): def test_futures_of_class(): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + assert futures_of([da.Array]) == [] @@ -3357,6 +3361,7 @@ async def test_scheduler_saturates_cores_random(c, s, a, b): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) async def test_cancel_clears_processing(c, s, *workers): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") x = c.submit(slowinc, 1, delay=0.2) while not s.tasks: @@ -4406,8 +4411,8 @@ async def test_compute_workers(e, s, a, b, c): @gen_cluster(client=True) async def test_compute_nested_containers(c, s, a, b): - da = pytest.importorskip("dask.array") np = pytest.importorskip("numpy") + da = pytest.importorskip("dask.array") x = da.ones(10, chunks=(5,)) + 1 future = c.compute({"x": [x], "y": 123}) @@ -4562,6 +4567,7 @@ async def test_normalize_collection(c, s, a, b): @gen_cluster(client=True) async def test_normalize_collection_dask_array(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") x = da.ones(10, chunks=(5,)) @@ -4586,6 +4592,7 @@ async def test_normalize_collection_dask_array(c, s, a, b): @pytest.mark.slow def test_normalize_collection_with_released_futures(c): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") x = da.arange(2**20, chunks=2**10) @@ -4814,6 +4821,9 @@ async def test_recreate_error_futures(c, s, a, b): @gen_cluster(client=True) async def test_recreate_error_collection(c, s, a, b): + pd = pytest.importorskip("pandas") + dd = pytest.importorskip("dask.dataframe") + b = db.range(10, npartitions=4) b = b.map(lambda x: 1 / x) b = b.persist() @@ -4824,9 +4834,6 @@ async def test_recreate_error_collection(c, s, a, b): with pytest.raises(ZeroDivisionError): function(*args, **kwargs) - dd = pytest.importorskip("dask.dataframe") - import pandas as pd - df = dd.from_pandas(pd.DataFrame({"a": [0, 1, 2, 3, 4]}), chunksize=2) def make_err(x): @@ -4852,6 +4859,7 @@ def make_err(x): @gen_cluster(client=True) async def test_recreate_error_array(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") pytest.importorskip("scipy") z = (da.linalg.inv(da.zeros((10, 10), chunks=10)) + 1).sum() @@ -4919,6 +4927,9 @@ async def test_recreate_task_futures(c, s, a, b): @gen_cluster(client=True) async def test_recreate_task_collection(c, s, a, b): + pd = pytest.importorskip("pandas") + dd = pytest.importorskip("dask.dataframe") + b = db.range(10, npartitions=4) b = b.map(lambda x: int(3628800 / (x + 1))) b = b.persist() @@ -4938,9 +4949,6 @@ async def test_recreate_task_collection(c, s, a, b): 362880, ] - dd = pytest.importorskip("dask.dataframe") - import pandas as pd - df = dd.from_pandas(pd.DataFrame({"a": [0, 1, 2, 3, 4]}), chunksize=2) df2 = df.a.map(inc, meta=df.a) @@ -4963,6 +4971,7 @@ async def test_recreate_task_collection(c, s, a, b): @gen_cluster(client=True) async def test_recreate_task_array(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") z = (da.zeros((10, 10), chunks=10) + 1).sum() f = c.compute(z) @@ -4999,6 +5008,9 @@ def setup(self, worker): @pytest.mark.slow @gen_cluster(client=True, Worker=Nanny, worker_kwargs={"plugins": [WorkerStartTime()]}) async def test_restart_workers(c, s, a, b): + pytest.importorskip("numpy") + da = pytest.importorskip("dask.array") + # Get initial worker start times results = await c.run(lambda dask_worker: dask_worker.start_time) a_start_time = results[a.worker_address] @@ -5006,7 +5018,6 @@ async def test_restart_workers(c, s, a, b): assert set(s.workers) == {a.worker_address, b.worker_address} # Persist futures and perform a computation - da = pytest.importorskip("dask.array") size = 100 x = da.ones(size, chunks=10) x = x.persist() @@ -5289,7 +5300,9 @@ def f(_): @pytest.mark.slow def test_threadsafe_get(c): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + x = da.arange(100, chunks=(10,)) def f(_): @@ -5308,7 +5321,9 @@ def f(_): @pytest.mark.slow def test_threadsafe_compute(c): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + x = da.arange(100, chunks=(10,)) def f(_): @@ -5372,7 +5387,9 @@ def test_get_client_no_cluster(): @gen_cluster(client=True) async def test_serialize_collections(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + x = da.arange(10, chunks=(5,)).persist() def f(x): @@ -5712,7 +5729,9 @@ async def test_call_stack_all(c, s, a, b): @gen_cluster([("127.0.0.1", 4)] * 2, client=True) async def test_call_stack_collections(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + x = da.random.random(100, chunks=(10,)).map_blocks(slowinc, delay=0.5).persist() while not a.state.executing_count and not b.state.executing_count: await asyncio.sleep(0.001) @@ -5722,7 +5741,9 @@ async def test_call_stack_collections(c, s, a, b): @gen_cluster([("127.0.0.1", 4)] * 2, client=True) async def test_call_stack_collections_all(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + x = da.random.random(100, chunks=(10,)).map_blocks(slowinc, delay=0.5).persist() while not a.state.executing_count and not b.state.executing_count: await asyncio.sleep(0.001) @@ -6423,8 +6444,8 @@ async def test_get_mix_futures_and_SubgraphCallable(c, s, a, b): @gen_cluster(client=True) async def test_get_mix_futures_and_SubgraphCallable_dask_dataframe(c, s, a, b): + pd = pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") - import pandas as pd df = pd.DataFrame({"x": range(1, 11)}) ddf = dd.from_pandas(df, npartitions=2).persist() @@ -6811,6 +6832,7 @@ async def f(dask_worker): @gen_cluster(client=True, nthreads=[("127.0.0.1", 2)] * 2) async def test_performance_report(c, s, a, b, local): pytest.importorskip("bokeh") + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") async def f(stacklevel, mode=None): @@ -6902,6 +6924,7 @@ def test_client_connectionpool_semaphore_loop(s, a, b, loop): @pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost") async def test_mixed_compression(c, s): pytest.importorskip("lz4") + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") async with Nanny( @@ -6928,9 +6951,8 @@ async def test_mixed_compression(c, s): def test_futures_in_subgraphs(loop_in_thread): """Regression test of """ - - dd = pytest.importorskip("dask.dataframe") pd = pytest.importorskip("pandas") + dd = pytest.importorskip("dask.dataframe") with cluster() as (s, [a, b]), Client(s["address"], loop=loop_in_thread) as c: ddf = dd.from_pandas( pd.DataFrame( @@ -7019,6 +7041,7 @@ def setup(self, worker=None): @gen_cluster(client=True) async def test_annotations_task_state(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") with dask.annotate(qux="bar", priority=100): @@ -7035,7 +7058,9 @@ async def test_annotations_task_state(c, s, a, b): @pytest.mark.parametrize("fn", ["compute", "persist"]) @gen_cluster(client=True) async def test_annotations_compute_time(c, s, a, b, fn): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + x = da.ones(10, chunks=(5,)) with dask.annotate(foo="bar"): @@ -7052,6 +7077,7 @@ async def test_annotations_compute_time(c, s, a, b, fn): @pytest.mark.xfail(reason="https://github.com/dask/dask/issues/7036") @gen_cluster(client=True) async def test_annotations_survive_optimization(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") with dask.annotate(foo="bar"): @@ -7070,6 +7096,7 @@ async def test_annotations_survive_optimization(c, s, a, b): @gen_cluster(client=True) async def test_annotations_priorities(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") with dask.annotate(priority=15): @@ -7085,6 +7112,7 @@ async def test_annotations_priorities(c, s, a, b): @gen_cluster(client=True) async def test_annotations_workers(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") with dask.annotate(workers=[a.address]): @@ -7103,6 +7131,7 @@ async def test_annotations_workers(c, s, a, b): @gen_cluster(client=True) async def test_annotations_retries(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") with dask.annotate(retries=2): @@ -7118,8 +7147,8 @@ async def test_annotations_retries(c, s, a, b): @gen_cluster(client=True) async def test_annotations_blockwise_unpack(c, s, a, b): - da = pytest.importorskip("dask.array") np = pytest.importorskip("numpy") + da = pytest.importorskip("dask.array") from dask.array.utils import assert_eq # A flaky doubling function -- need extra args because it is called before @@ -7155,6 +7184,7 @@ def reliable_double(x): ], ) async def test_annotations_resources(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") with dask.annotate(resources={"GPU": 1}): @@ -7176,6 +7206,7 @@ async def test_annotations_resources(c, s, a, b): ], ) async def test_annotations_resources_culled(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") x = da.ones((2, 2, 2), chunks=1) @@ -7191,6 +7222,7 @@ async def test_annotations_resources_culled(c, s, a, b): @gen_cluster(client=True) async def test_annotations_loose_restrictions(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") # Eventually fails if allow_other_workers=False @@ -7244,6 +7276,7 @@ async def test_annotations_global_vs_local(c, s, a, b): @gen_cluster(client=True) async def test_workers_collection_restriction(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") future = c.compute(da.arange(10), workers=a.address) @@ -7413,7 +7446,9 @@ async def test_computation_store_annotations(c, s, a): def test_computation_object_code_dask_compute(client): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + with dask.config.set({"distributed.diagnostics.computations.nframes": 2}): x = da.ones((10, 10), chunks=(3, 3)) x.sum().compute() @@ -7433,7 +7468,9 @@ def fetch_comp_code(dask_scheduler): def test_computation_object_code_dask_compute_no_frames_default(client): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + x = da.ones((10, 10), chunks=(3, 3)) x.sum().compute() @@ -7469,7 +7506,9 @@ def fetch_comp_code(dask_scheduler): @gen_cluster(client=True, config={"distributed.diagnostics.computations.nframes": 2}) async def test_computation_object_code_dask_persist(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + x = da.ones((10, 10), chunks=(3, 3)) future = x.sum().persist() await future @@ -7576,7 +7615,9 @@ def func(x): @gen_cluster(client=True, config={"distributed.diagnostics.computations.nframes": 2}) async def test_computation_object_code_client_compute(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + x = da.ones((10, 10), chunks=(3, 3)) future = c.compute(x.sum(), retries=2) y = await future diff --git a/distributed/tests/test_computations.py b/distributed/tests/test_computations.py index e28a7875a7..70fd7633e8 100644 --- a/distributed/tests/test_computations.py +++ b/distributed/tests/test_computations.py @@ -9,6 +9,7 @@ @gen_cluster(client=True) async def test_computations(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") x = da.ones(100, chunks=(10,)) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 3afe8ae344..1314ea8e27 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -746,7 +746,9 @@ async def test_malloc_trim_threshold(c, s, a): This test may start failing in a future Python version if CPython switches to using mimalloc by default. If it does, a thorough benchmarking exercise is needed. """ + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + arr = da.random.random(2**29 // 8, chunks="512 kiB") # 0.5 GiB arr = arr.persist() await wait(arr) diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index 5b1c2d5c35..c7001c0b00 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -403,7 +403,9 @@ async def test_set_resources(c, s, a): ], ) async def test_persist_collections(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + x = da.arange(10, chunks=(5,)) with dask.annotate(resources={"A": 1}): y = x.map_blocks(lambda x: x + 1) @@ -426,7 +428,9 @@ async def test_persist_collections(c, s, a, b): ], ) async def test_dont_optimize_out(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + x = da.arange(10, chunks=(5,)) y = x.map_blocks(lambda x: x + 1) z = y.map_blocks(lambda x: 2 * x) @@ -447,6 +451,7 @@ async def test_dont_optimize_out(c, s, a, b): ], ) async def test_full_collections(c, s, a, b): + pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") df = dd.demo.make_timeseries( freq="60s", partition_freq="1d", start="2000-01-01", end="2000-01-31" @@ -471,6 +476,7 @@ async def test_full_collections(c, s, a, b): ], ) def test_collections_get(client, optimize_graph, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") async def f(dask_worker): diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 3f6aa545bd..7c2cc391ad 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -198,8 +198,8 @@ async def test_decide_worker_coschedule_order_neighbors_(c, s, *workers): generally, only one worker holds each row of the array, that the `random-` tasks are never transferred, and that there are few transfers overall. """ - da = pytest.importorskip("dask.array") np = pytest.importorskip("numpy") + da = pytest.importorskip("dask.array") if ndeps == 0: x = da.random.random((100, 100), chunks=(10, 10)) @@ -1647,7 +1647,9 @@ async def test_retire_workers_no_suspicious_tasks(c, s, a, b): @gen_cluster(client=True, nthreads=[], timeout=120) async def test_file_descriptors(c, s): await asyncio.sleep(0.1) + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + proc = psutil.Process() num_fds_1 = proc.num_fds() @@ -2802,8 +2804,10 @@ async def test_no_dangling_asyncio_tasks(): @gen_cluster(client=True, Worker=NoSchedulerDelayWorker, config=NO_AMM) async def test_task_group_and_prefix_statistics(c, s, a, b, no_time_resync): - start = time() + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + + start = time() x = da.arange(100, chunks=(20,)) y = (x + 1).persist(optimize_graph=False) y = await y @@ -3033,6 +3037,7 @@ async def test_task_group_not_done_processing(c, s, a, b): @gen_cluster(client=True) async def test_task_prefix(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") x = da.arange(100, chunks=(20,)) y = (x + 1).sum().persist() @@ -3063,8 +3068,8 @@ async def test_failing_task_increments_suspicious(client, s, a, b): @gen_cluster(client=True) async def test_task_group_non_tuple_key(c, s, a, b): - da = pytest.importorskip("dask.array") np = pytest.importorskip("numpy") + da = pytest.importorskip("dask.array") x = da.arange(100, chunks=(20,)) y = (x + 1).sum().persist() y = await y @@ -4458,9 +4463,8 @@ async def test_scheduler_close_fast_deprecated(s, w): def test_runspec_regression_sync(loop): # https://github.com/dask/distributed/issues/6624 - - da = pytest.importorskip("dask.array") np = pytest.importorskip("numpy") + da = pytest.importorskip("dask.array") with Client(loop=loop): v = da.random.random((20, 20), chunks=(5, 5)) diff --git a/distributed/tests/test_spans.py b/distributed/tests/test_spans.py index 8a1d90eaed..2f54b4d3bf 100644 --- a/distributed/tests/test_spans.py +++ b/distributed/tests/test_spans.py @@ -214,7 +214,9 @@ async def test_no_extension(c, s, a, b): config={"optimization.fuse.active": False}, ) async def test_task_groups(c, s, a, b, release, no_time_resync): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + t0 = await padded_time(before=0) with span("wf"): @@ -864,9 +866,9 @@ async def test_span_on_persist(c, s, a, b): @pytest.mark.filterwarnings("ignore:Dask annotations") @gen_cluster(client=True) async def test_collections_metadata(c, s, a, b): + np = pytest.importorskip("numpy") pd = pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") - np = pytest.importorskip("numpy") df = pd.DataFrame( {"x": np.random.random(1000), "y": np.random.random(1000)}, index=np.arange(1000), diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index 53801943cc..565f3d9792 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -63,7 +63,9 @@ def test_stress_gc(loop, func, n): @pytest.mark.skipif(WINDOWS, reason="test can leave dangling RPC objects") @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 8) async def test_cancel_stress(c, s, *workers): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + x = da.random.random((50, 50), chunks=(2, 2)) x = c.persist(x) await wait([x]) @@ -80,7 +82,9 @@ async def test_cancel_stress(c, s, *workers): def test_cancel_stress_sync(loop): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + x = da.random.random((50, 50), chunks=(2, 2)) with cluster(active_rpc_timeout=10) as (s, [a, b]): with Client(s["address"], loop=loop) as c: @@ -101,6 +105,7 @@ def test_cancel_stress_sync(loop): ) async def test_stress_creation_and_deletion(c, s): # Assertions are handled by the validate mechanism in the scheduler + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") rng = da.random.RandomState(0) @@ -182,7 +187,9 @@ def vsum(*args): }, ) async def test_stress_communication(c, s, *workers): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + # Test consumes many file descriptors and can hang if the limit is too low resource = pytest.importorskip("resource") bump_rlimit(resource.RLIMIT_NOFILE, 8192) @@ -234,7 +241,9 @@ async def test_stress_steal(c, s, *workers): async def test_close_connections(c, s, *workers): # Schedule 600 slowinc's interleaved by worker-to-worker data transfers # The minimum time to compute this is (600 * 0.1 / 10 threads) = 6s + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + x = da.random.random(size=(100, 100), chunks=(-1, 1)) for _ in range(3): x = x.rechunk((1, -1)) @@ -306,14 +315,15 @@ async def test_no_delay_during_large_transfer(c, s, w): worker_kwargs={"transition_counter_max": 500_000}, ) async def test_chaos_rechunk(c, s, *workers): + pytest.importorskip("numpy") + da = pytest.importorskip("dask.array") + s.allowed_failures = 10000 plugin = KillWorker(delay="4 s", mode="sys.exit") await c.register_plugin(plugin, name="kill") - da = pytest.importorskip("dask.array") - x = da.random.random((10000, 10000)) y = x.rechunk((10000, 20)).rechunk((20, 10000)).sum() z = c.compute(y) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 067a025ae8..3ea579487c 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1183,7 +1183,9 @@ async def test_statistical_profiling(c, s, a, b): }, ) async def test_statistical_profiling_2(c, s, a, b): + pytest.importorskip("numpy") da = pytest.importorskip("dask.array") + while True: x = da.random.random(1000000, chunks=(10000,)) y = (x + x * 2) - x.sum().persist() @@ -3421,7 +3423,8 @@ def get_data(self, comm, **kwargs): @pytest.mark.slow @gen_cluster(client=True, Worker=BreakingWorker) async def test_broken_comm(c, s, a, b): - pytest.importorskip("dask.dataframe") + pytest.importorskip("pandas") + dd = pytest.importorskip("dask.dataframe") df = dask.datasets.timeseries( start="2000-01-01", From c73aaa6bf6855603a53fc21b919bb349524271d8 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 26 Jun 2024 18:10:50 +0200 Subject: [PATCH 044/138] Remove `FutureWarning` in `test_task_state_instance_are_garbage_collected` (#8734) --- distributed/tests/test_worker_state_machine.py | 2 +- distributed/worker_state_machine.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 361b930028..aa232ab6ec 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -922,7 +922,7 @@ async def test_task_state_instance_are_garbage_collected(c, s, a, b): f2 = c.submit(inc, red, pure=False) async def check(dask_worker): - while dask_worker.tasks: + while dask_worker.state.tasks: await asyncio.sleep(0.01) with profile.lock: gc.collect() diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index fd3c63e4cb..e88a026b48 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -336,7 +336,7 @@ def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict: ----- This class uses ``_to_dict_no_nest`` instead of ``_to_dict``. When a task references another task, just print the task repr. All tasks - should neatly appear under Worker.tasks. This also prevents a RecursionError + should neatly appear under Worker.state.tasks. This also prevents a RecursionError during particularly heavy loads, which have been observed to happen whenever there's an acyclic dependency chain of ~200+ tasks. """ From a27a7b5656cb63c3bbb7d4f08f7186c867eb2ccb Mon Sep 17 00:00:00 2001 From: Adam Williamson Date: Thu, 27 Jun 2024 02:16:00 -0700 Subject: [PATCH 045/138] get_ip: handle getting 0.0.0.0 (#2554) (#8712) --- distributed/utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index 9658dffa8a..10585430ce 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -171,12 +171,20 @@ def get_fileno_limit(): @toolz.memoize def _get_ip(host, port, family): + def hostname_fallback(): + addr_info = socket.getaddrinfo( + socket.gethostname(), port, family, socket.SOCK_DGRAM, socket.IPPROTO_UDP + )[0] + return addr_info[4][0] + # By using a UDP socket, we don't actually try to connect but # simply select the local address through which *host* is reachable. sock = socket.socket(family, socket.SOCK_DGRAM) try: sock.connect((host, port)) ip = sock.getsockname()[0] + if ip == "0.0.0.0": + return hostname_fallback() return ip except OSError as e: warnings.warn( @@ -184,10 +192,7 @@ def _get_ip(host, port, family): "reaching %r, defaulting to hostname: %s" % (host, e), RuntimeWarning, ) - addr_info = socket.getaddrinfo( - socket.gethostname(), port, family, socket.SOCK_DGRAM, socket.IPPROTO_UDP - )[0] - return addr_info[4][0] + return hostname_fallback() finally: sock.close() From 920bd4678e73a8f4f4a92ab5a325109d8d71ef51 Mon Sep 17 00:00:00 2001 From: Benjamin Zaitlen Date: Thu, 27 Jun 2024 10:59:43 -0400 Subject: [PATCH 046/138] Fix pynvml handles (#8693) --- distributed/dashboard/components/nvml.py | 2 +- distributed/diagnostics/nvml.py | 75 ++++++++++++++++------ distributed/diagnostics/tests/test_nvml.py | 72 +++++++++++++++++---- 3 files changed, 119 insertions(+), 30 deletions(-) diff --git a/distributed/dashboard/components/nvml.py b/distributed/dashboard/components/nvml.py index 6a381005c5..92e4353354 100644 --- a/distributed/dashboard/components/nvml.py +++ b/distributed/dashboard/components/nvml.py @@ -131,7 +131,7 @@ def update(self): continue memory_max = max(memory_max, mem_total) memory_total += mem_total - utilization.append(int(u)) + utilization.append(int(u) if u else 0) memory.append(mem_used) worker.append(ws.address) gpu_index.append(idx) diff --git a/distributed/diagnostics/nvml.py b/distributed/diagnostics/nvml.py index 6548db6a0c..f621a7eb76 100644 --- a/distributed/diagnostics/nvml.py +++ b/distributed/diagnostics/nvml.py @@ -134,29 +134,64 @@ def _pynvml_handles(): count = device_get_count() if NVML_STATE == NVMLState.DISABLED_PYNVML_NOT_AVAILABLE: raise RuntimeError("NVML monitoring requires PyNVML and NVML to be installed") - elif NVML_STATE == NVMLState.DISABLED_LIBRARY_NOT_FOUND: + if NVML_STATE == NVMLState.DISABLED_LIBRARY_NOT_FOUND: raise RuntimeError("PyNVML is installed, but NVML is not") - elif NVML_STATE == NVMLState.DISABLED_WSL_INSUFFICIENT_DRIVER: + if NVML_STATE == NVMLState.DISABLED_WSL_INSUFFICIENT_DRIVER: raise RuntimeError( "Outdated NVIDIA drivers for WSL, please upgrade to " f"{MINIMUM_WSL_VERSION} or newer" ) - elif NVML_STATE == NVMLState.DISABLED_CONFIG: + if NVML_STATE == NVMLState.DISABLED_CONFIG: raise RuntimeError( "PyNVML monitoring disabled by 'distributed.diagnostics.nvml' " "config setting" ) - elif count == 0: + if count == 0: raise RuntimeError("No GPUs available") - else: - try: - gpu_idx = next( - map(int, os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",")) + + device = 0 + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if cuda_visible_devices: + device = _parse_cuda_visible_device(cuda_visible_devices.split(",")[0]) + return _get_handle(device) + + +# Port from https://github.com/rapidsai/dask-cuda/blob/0f34116c4f3cdf5dfc0df0dbfeba92655f686716/dask_cuda/utils.py#L403-L437 +def _parse_cuda_visible_device(dev): + """Parses a single CUDA device identifier + + A device identifier must either be an integer, a string containing an + integer or a string containing the device's UUID, beginning with prefix + 'GPU-' or 'MIG-'. + + >>> parse_cuda_visible_device(2) + 2 + >>> parse_cuda_visible_device('2') + 2 + >>> parse_cuda_visible_device('GPU-9baca7f5-0f2f-01ac-6b05-8da14d6e9005') + 'GPU-9baca7f5-0f2f-01ac-6b05-8da14d6e9005' + >>> parse_cuda_visible_device('Foo') + Traceback (most recent call last): + ... + ValueError: Devices in CUDA_VISIBLE_DEVICES must be comma-separated integers or + strings beginning with 'GPU-' or 'MIG-' prefixes. + """ + try: + return int(dev) + except ValueError: + if any( + dev.startswith(prefix) + for prefix in [ + "GPU-", + "MIG-", + ] + ): + return dev + else: + raise ValueError( + "Devices in CUDA_VISIBLE_DEVICES must be comma-separated integers " + "or strings beginning with 'GPU-' or 'MIG-' prefixes." ) - except ValueError: - # CUDA_VISIBLE_DEVICES is not set, take first device - gpu_idx = 0 - return pynvml.nvmlDeviceGetHandleByIndex(gpu_idx) def _running_process_matches(handle): @@ -281,18 +316,22 @@ def get_device_mig_mode(device): A ``list`` with two integers ``[current_mode, pending_mode]``. """ init_once() - try: - device_index = int(device) - handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) - except ValueError: - uuid = device if isinstance(device, bytes) else bytes(device, "utf-8") - handle = pynvml.nvmlDeviceGetHandleByUUID(uuid) + handle = _get_handle(device) try: return pynvml.nvmlDeviceGetMigMode(handle) except pynvml.NVMLError_NotSupported: return [0, 0] +def _get_handle(device): + try: + device_index = int(device) + return pynvml.nvmlDeviceGetHandleByIndex(device_index) + except ValueError: + uuid = device if isinstance(device, bytes) else bytes(device, "utf-8") + return pynvml.nvmlDeviceGetHandleByUUID(uuid) + + def _get_utilization(h): try: return pynvml.nvmlDeviceGetUtilizationRates(h).gpu diff --git a/distributed/diagnostics/tests/test_nvml.py b/distributed/diagnostics/tests/test_nvml.py index 7dd63d62fd..d9a95486d4 100644 --- a/distributed/diagnostics/tests/test_nvml.py +++ b/distributed/diagnostics/tests/test_nvml.py @@ -2,6 +2,7 @@ import multiprocessing as mp import os +from unittest import mock import pytest @@ -95,10 +96,10 @@ def test_1_visible_devices(): if nvml.device_get_count() < 1: pytest.skip("No GPUs available") - os.environ["CUDA_VISIBLE_DEVICES"] = "0" - output = nvml.one_time() - h = nvml._pynvml_handles() - assert output["memory-total"] == pynvml.nvmlDeviceGetMemoryInfo(h).total + with mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}): + output = nvml.one_time() + h = nvml._pynvml_handles() + assert output["memory-total"] == pynvml.nvmlDeviceGetMemoryInfo(h).total @pytest.mark.parametrize("CVD", ["1,0", "0,1"]) @@ -106,16 +107,65 @@ def test_2_visible_devices(CVD): if nvml.device_get_count() < 2: pytest.skip("Less than two GPUs available") - os.environ["CUDA_VISIBLE_DEVICES"] = CVD - idx = int(CVD.split(",")[0]) + with mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": CVD}): + idx = int(CVD.split(",")[0]) - h = nvml._pynvml_handles() - h2 = pynvml.nvmlDeviceGetHandleByIndex(idx) + h = nvml._pynvml_handles() + h2 = pynvml.nvmlDeviceGetHandleByIndex(idx) + + s = pynvml.nvmlDeviceGetSerial(h) + s2 = pynvml.nvmlDeviceGetSerial(h2) + + assert s == s2 + + +def test_visible_devices_uuid(): + if nvml.device_get_count() < 1: + pytest.skip("No GPUs available") + + info = nvml.get_device_index_and_uuid(0) + assert info.uuid + + with mock.patch.dict( + os.environ, {"CUDA_VISIBLE_DEVICES": info.uuid.decode("utf-8")} + ): + h = nvml._pynvml_handles() + h_expected = pynvml.nvmlDeviceGetHandleByIndex(0) + + s = pynvml.nvmlDeviceGetSerial(h) + s_expected = pynvml.nvmlDeviceGetSerial(h_expected) - s = pynvml.nvmlDeviceGetSerial(h) - s2 = pynvml.nvmlDeviceGetSerial(h2) + assert s == s_expected + + +@pytest.mark.parametrize("index", [0, 1]) +def test_visible_devices_uuid_2(index): + if nvml.device_get_count() < 2: + pytest.skip("Less than two GPUs available") + + info = nvml.get_device_index_and_uuid(index) + assert info.uuid + + with mock.patch.dict( + os.environ, {"CUDA_VISIBLE_DEVICES": info.uuid.decode("utf-8")} + ): + h = nvml._pynvml_handles() + h_expected = pynvml.nvmlDeviceGetHandleByIndex(index) + + s = pynvml.nvmlDeviceGetSerial(h) + s_expected = pynvml.nvmlDeviceGetSerial(h_expected) + + assert s == s_expected + + +def test_visible_devices_bad_uuid(): + if nvml.device_get_count() < 1: + pytest.skip("No GPUs available") - assert s == s2 + with mock.patch.dict( + os.environ, {"CUDA_VISIBLE_DEVICES": "NOT-A-GPU-UUID"} + ), pytest.raises(ValueError, match="Devices in CUDA_VISIBLE_DEVICES"): + nvml._pynvml_handles() @gen_cluster() From a336586330055d891b5247e57cd50f2798c311c1 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 28 Jun 2024 12:20:54 +0200 Subject: [PATCH 047/138] Fix floating-point inaccuracy (#8736) --- distributed/tests/test_deadline.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_deadline.py b/distributed/tests/test_deadline.py index f816766526..97de85bed7 100644 --- a/distributed/tests/test_deadline.py +++ b/distributed/tests/test_deadline.py @@ -3,6 +3,8 @@ import asyncio from time import sleep +import pytest + from distributed.metrics import monotonic from distributed.utils import Deadline from distributed.utils_test import gen_test @@ -11,10 +13,10 @@ def test_deadline(): deadline = Deadline.after(5) - assert deadline.duration == 5 + assert deadline.duration == pytest.approx(5) assert deadline.expired is False assert deadline.expires is True - assert deadline.expires_at_mono - deadline.started_at_mono == 5 + assert deadline.expires_at_mono - deadline.started_at_mono == pytest.approx(5) assert 4 < deadline.expires_at - deadline.started_at < 6 assert 0 <= deadline.elapsed <= 1 assert 4 <= deadline.remaining <= 5 From 50700f381075c8d2be1b17dacc1b5f48dcc93a88 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 28 Jun 2024 14:09:27 +0200 Subject: [PATCH 048/138] Fix `test_task_state_instance_are_garbage_collected` (#8735) --- distributed/tests/test_worker_state_machine.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index aa232ab6ec..5508355929 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -935,9 +935,14 @@ async def check(dask_worker): async def check(dask_scheduler): while dask_scheduler.tasks: await asyncio.sleep(0.01) - with profile.lock: - gc.collect() - assert not SchedulerTaskState._instances + + gc.collect() + # Gargabe collection might already be running in which case gc.collect()'s behavior is undefined. + # Try again and hope for the best. + while SchedulerTaskState._instances: + await asyncio.sleep(0.01) + with profile.lock: + gc.collect() await c.run_on_scheduler(check) From 147c505bcfe6baf823ac31f508a79603a1e467b2 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 1 Jul 2024 15:05:56 +0200 Subject: [PATCH 049/138] Remove `is_python_shutting_down` (#8492) --- distributed/__init__.py | 21 --------------------- distributed/client.py | 11 +++++++---- distributed/comm/inproc.py | 11 ++++++++--- distributed/core.py | 7 ++++--- distributed/scheduler.py | 3 +-- distributed/utils.py | 13 ------------- distributed/worker.py | 7 ++++--- 7 files changed, 24 insertions(+), 49 deletions(-) diff --git a/distributed/__init__.py b/distributed/__init__.py index b8cf71ef24..29fa40e7fb 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -97,27 +97,6 @@ def __getattr__(name): raise AttributeError(f"module {__name__!r} has no attribute {name!r}") -_python_shutting_down = False - - -@atexit.register -def _(): - """Set a global when Python shuts down. - - Note - ---- - This function must be registered with atexit *after* any class that invokes - ``distributed.utils.is_python_shutting_down`` has been defined. This way it - will be called before the ``__del__`` method of those classes. - - See Also - -------- - distributed.utils.is_python_shutting_down - """ - global _python_shutting_down - _python_shutting_down = True - - __all__ = [ "Actor", "ActorFuture", diff --git a/distributed/client.py b/distributed/client.py index 5b35b69573..3cac5b3fc2 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -103,7 +103,6 @@ format_dashboard_link, has_keyword, import_term, - is_python_shutting_down, log_errors, nbytes, sync, @@ -272,6 +271,8 @@ class Future(WrappedKey): Client: Creates futures """ + _is_finalizing: staticmethod[[], bool] = staticmethod(sys.is_finalizing) + _cb_executor = None _cb_executor_pid = None _counter = itertools.count() @@ -586,7 +587,7 @@ def __del__(self): except AttributeError: # Occasionally we see this error when shutting down the client # https://github.com/dask/distributed/issues/4305 - if not is_python_shutting_down(): + if not self._is_finalizing(): raise except RuntimeError: # closed event loop pass @@ -900,6 +901,8 @@ class Client(SyncMethodMixin): distributed.LocalCluster: """ + _is_finalizing: staticmethod[[], bool] = staticmethod(sys.is_finalizing) + _instances: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet() _default_event_handlers = {"print": _handle_print, "warn": _handle_warn} @@ -1628,7 +1631,7 @@ async def _handle_report(self): try: msgs = await self.scheduler_comm.comm.read() except CommClosedError: - if is_python_shutting_down(): + if self._is_finalizing(): return if self.status == "running": if self.cluster and self.cluster.status in ( @@ -1852,7 +1855,7 @@ def close(self, timeout=no_default): sync(self.loop, self._close, fast=True, callback_timeout=timeout) assert self.status == "closed" - if not is_python_shutting_down(): + if not self._is_finalizing(): self._loop_runner.stop() async def _shutdown(self): diff --git a/distributed/comm/inproc.py b/distributed/comm/inproc.py index dad7350568..b6479eb3fa 100644 --- a/distributed/comm/inproc.py +++ b/distributed/comm/inproc.py @@ -4,6 +4,7 @@ import itertools import logging import os +import sys import threading import weakref from collections import deque, namedtuple @@ -14,7 +15,7 @@ from distributed.comm.core import BaseListener, Comm, CommClosedError, Connector from distributed.comm.registry import Backend, backends from distributed.protocol.serialize import _nested_deserialize -from distributed.utils import get_ip, is_python_shutting_down +from distributed.utils import get_ip logger = logging.getLogger(__name__) @@ -187,9 +188,13 @@ def _get_finalizer(self): r = repr(self) def finalize( - read_q=self._read_q, write_q=self._write_q, write_loop=self._write_loop, r=r + read_q=self._read_q, + write_q=self._write_q, + write_loop=self._write_loop, + is_finalizing=sys.is_finalizing, + r=r, ): - if read_q.peek(None) is _EOF or is_python_shutting_down(): + if read_q.peek(None) is _EOF or is_finalizing(): return logger.warning(f"Closing dangling queue in {r}") write_loop.add_callback(write_q.put_nowait, _EOF) diff --git a/distributed/core.py b/distributed/core.py index d70ae3f383..90705e8051 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -55,7 +55,6 @@ get_traceback, has_keyword, import_file, - is_python_shutting_down, iscoroutinefunction, offload, recursive_to_dict, @@ -321,6 +320,8 @@ class Server: """ + _is_finalizing: staticmethod[[], bool] = staticmethod(sys.is_finalizing) + default_ip: ClassVar[str] = "" default_port: ClassVar[int] = 0 @@ -902,7 +903,7 @@ async def _handle_comm(self, comm: Comm) -> None: msg = await comm.read() logger.debug("Message from %r: %s", address, msg) except OSError as e: - if not is_python_shutting_down(): + if not self._is_finalizing(): logger.debug( "Lost connection to %r while reading message: %s." " Last operation: %s", @@ -1006,7 +1007,7 @@ async def _handle_comm(self, comm: Comm) -> None: finally: del self._comms[comm] - if not is_python_shutting_down() and not comm.closed(): + if not self._is_finalizing() and not comm.closed(): try: comm.abort() except Exception as e: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4cba667d9e..69ca80826e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -123,7 +123,6 @@ TimeoutError, format_dashboard_link, get_fileno_limit, - is_python_shutting_down, key_split_group, log_errors, offload, @@ -5786,7 +5785,7 @@ async def add_client( if not comm.closed(): self.client_comms[client].send({"op": "stream-closed"}) try: - if not is_python_shutting_down(): + if not self._is_finalizing(): await self.client_comms[client].close() del self.client_comms[client] if self.status == Status.running: diff --git a/distributed/utils.py b/distributed/utils.py index 10585430ce..3f659bff02 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1831,19 +1831,6 @@ def recursive_to_dict( tok.var.reset(tok) -def is_python_shutting_down() -> bool: - """Is the interpreter shutting down now? - - This is a variant of ``sys.is_finalizing`` which can return True inside the ``__del__`` - method of classes defined inside the distributed package. - """ - # This import must remain local for the global variable to be - # properly evaluated - from distributed import _python_shutting_down - - return _python_shutting_down - - class Deadline: """Utility class tracking a deadline and the progress toward it""" diff --git a/distributed/worker.py b/distributed/worker.py index 5fe7a4da1b..27be73e30a 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -101,7 +101,6 @@ get_ip, has_arg, in_async_call, - is_python_shutting_down, iscoroutinefunction, json_load_robust, log_errors, @@ -1634,7 +1633,9 @@ def _close(executor, wait): # weird deadlocks particularly if the task that is executing in # the thread is waiting for a server reply, e.g. when using # worker clients, semaphores, etc. - if is_python_shutting_down(): + + # Are we shutting down the process? + if self._is_finalizing() or not threading.main_thread().is_alive(): # If we're shutting down there is no need to wait for daemon # threads to finish _close(executor=executor, wait=False) @@ -1643,7 +1644,7 @@ def _close(executor, wait): await asyncio.to_thread( _close, executor=executor, wait=executor_wait ) - except RuntimeError: # Are we shutting down the process? + except RuntimeError: logger.error( "Could not close executor %r by dispatching to thread. Trying synchronously.", executor, From f997f21c3c8be7149b76b8dde650ea22d64dc1a5 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 1 Jul 2024 21:35:15 +0200 Subject: [PATCH 050/138] Drop support for pandas 1.X (#8741) --- distributed/shuffle/tests/test_merge.py | 10 ++-- distributed/shuffle/tests/test_shuffle.py | 69 +++++++++++------------ 2 files changed, 37 insertions(+), 42 deletions(-) diff --git a/distributed/shuffle/tests/test_merge.py b/distributed/shuffle/tests/test_merge.py index 42d1f443a0..8c196af3c0 100644 --- a/distributed/shuffle/tests/test_merge.py +++ b/distributed/shuffle/tests/test_merge.py @@ -17,7 +17,7 @@ pd = pytest.importorskip("pandas") import dask import dask.dataframe as dd -from dask.dataframe._compat import PANDAS_GE_200, tm +from dask.dataframe._compat import tm from dask.dataframe.utils import assert_eq from distributed import get_client @@ -293,7 +293,7 @@ async def test_merge_by_multiple_columns(c, s, a, b, how): # FIXME: There's an discrepancy with an empty index for # pandas=2.0 (xref https://github.com/dask/dask/issues/9957). # Temporarily avoid index check until the discrepancy is fixed. - check_index=not (PANDAS_GE_200 and expected.index.empty), + check_index=not expected.index.empty, ) expected = pdr.join(pdl, how=how) @@ -303,7 +303,7 @@ async def test_merge_by_multiple_columns(c, s, a, b, how): # FIXME: There's an discrepancy with an empty index for # pandas=2.0 (xref https://github.com/dask/dask/issues/9957). # Temporarily avoid index check until the discrepancy is fixed. - check_index=not (PANDAS_GE_200 and expected.index.empty), + check_index=not expected.index.empty, ) expected = pd.merge( @@ -323,7 +323,7 @@ async def test_merge_by_multiple_columns(c, s, a, b, how): # FIXME: There's an discrepancy with an empty index for # pandas=2.0 (xref https://github.com/dask/dask/issues/9957). # Temporarily avoid index check until the discrepancy is fixed. - check_index=not (PANDAS_GE_200 and expected.index.empty), + check_index=not expected.index.empty, ) expected = pd.merge( @@ -343,7 +343,7 @@ async def test_merge_by_multiple_columns(c, s, a, b, how): # FIXME: There's an discrepancy with an empty index for # pandas=2.0 (xref https://github.com/dask/dask/issues/9957). # Temporarily avoid index check until the discrepancy is fixed. - check_index=not (PANDAS_GE_200 and expected.index.empty), + check_index=not expected.index.empty, ) # hash join diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 3d94ef81be..24e199d1b6 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -28,7 +28,6 @@ pd = pytest.importorskip("pandas") import dask.dataframe as dd -from dask.dataframe._compat import PANDAS_GE_150, PANDAS_GE_200 from dask.typing import Key from distributed import ( @@ -1145,41 +1144,38 @@ def __init__(self, value: int) -> None: } ) - if PANDAS_GE_150: - columns.update( - { - # PyArrow dtypes - f"col{next(counter)}": pd.array( - [True, False] * 50, dtype="bool[pyarrow]" - ), - f"col{next(counter)}": pd.array(range(100), dtype="int8[pyarrow]"), - f"col{next(counter)}": pd.array(range(100), dtype="int16[pyarrow]"), - f"col{next(counter)}": pd.array(range(100), dtype="int32[pyarrow]"), - f"col{next(counter)}": pd.array(range(100), dtype="int64[pyarrow]"), - f"col{next(counter)}": pd.array(range(100), dtype="uint8[pyarrow]"), - f"col{next(counter)}": pd.array(range(100), dtype="uint16[pyarrow]"), - f"col{next(counter)}": pd.array(range(100), dtype="uint32[pyarrow]"), - f"col{next(counter)}": pd.array(range(100), dtype="uint64[pyarrow]"), - f"col{next(counter)}": pd.array(range(100), dtype="float32[pyarrow]"), - f"col{next(counter)}": pd.array(range(100), dtype="float64[pyarrow]"), - f"col{next(counter)}": pd.array( - [pd.Timestamp.fromtimestamp(1641034800 + i) for i in range(100)], - dtype=pd.ArrowDtype(pa.timestamp("ms")), - ), - f"col{next(counter)}": pd.array( - ["lorem ipsum"] * 100, - dtype="string[pyarrow]", - ), - f"col{next(counter)}": pd.array( - ["lorem ipsum"] * 100, - dtype=pd.StringDtype("pyarrow"), - ), - f"col{next(counter)}": pd.array( - ["lorem ipsum"] * 100, - dtype="string[python]", - ), - } - ) + columns.update( + { + # PyArrow dtypes + f"col{next(counter)}": pd.array([True, False] * 50, dtype="bool[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="int8[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="int16[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="int32[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="int64[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="uint8[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="uint16[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="uint32[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="uint64[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="float32[pyarrow]"), + f"col{next(counter)}": pd.array(range(100), dtype="float64[pyarrow]"), + f"col{next(counter)}": pd.array( + [pd.Timestamp.fromtimestamp(1641034800 + i) for i in range(100)], + dtype=pd.ArrowDtype(pa.timestamp("ms")), + ), + f"col{next(counter)}": pd.array( + ["lorem ipsum"] * 100, + dtype="string[pyarrow]", + ), + f"col{next(counter)}": pd.array( + ["lorem ipsum"] * 100, + dtype=pd.StringDtype("pyarrow"), + ), + f"col{next(counter)}": pd.array( + ["lorem ipsum"] * 100, + dtype="string[python]", + ), + } + ) df = pd.DataFrame(columns) df["_partitions"] = df.col4 % npartitions @@ -2399,7 +2395,6 @@ async def test_replace_stale_shuffle(c, s, a, b): await check_scheduler_cleanup(s) -@pytest.mark.skipif(not PANDAS_GE_200, reason="requires pandas >=2.0") @gen_cluster(client=True) async def test_handle_null_partitions(c, s, a, b): data = [ From 63e5108973f5528210254b2888fd3f192deca7ef Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 1 Jul 2024 21:39:10 +0200 Subject: [PATCH 051/138] Refactor event logging functionality into broker (#8731) --- distributed/broker.py | 106 ++++++++++++ distributed/dashboard/components/scheduler.py | 19 +- .../tests/test_scheduler_plugin.py | 43 +---- .../scheduler/tests/test_stealing_http.py | 4 +- distributed/scheduler.py | 59 +++---- distributed/shuffle/tests/test_shuffle.py | 8 +- distributed/stealing.py | 4 +- distributed/tests/test_client.py | 4 +- distributed/tests/test_event_logging.py | 163 ++++++++++++++++-- distributed/tests/test_nanny.py | 2 +- distributed/tests/test_scheduler.py | 24 +-- distributed/tests/test_steal.py | 6 +- distributed/tests/test_utils_test.py | 4 +- distributed/tests/test_worker.py | 6 +- distributed/utils_test.py | 15 +- 15 files changed, 327 insertions(+), 140 deletions(-) create mode 100644 distributed/broker.py diff --git a/distributed/broker.py b/distributed/broker.py new file mode 100644 index 0000000000..298225eb85 --- /dev/null +++ b/distributed/broker.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import logging +from collections import defaultdict, deque +from collections.abc import Collection +from functools import partial +from typing import TYPE_CHECKING, Any, overload + +from distributed.metrics import time + +if TYPE_CHECKING: + from distributed import Scheduler + +logger = logging.getLogger(__name__) + + +class Topic: + events: deque + count: int + subscribers: set + + def __init__(self, maxlen: int): + self.events = deque(maxlen=maxlen) + self.count = 0 + self.subscribers = set() + + def subscribe(self, subscriber: str) -> None: + self.subscribers.add(subscriber) + + def unsubscribe(self, subscriber: str) -> None: + self.subscribers.discard(subscriber) + + def publish(self, event: Any) -> None: + self.events.append(event) + self.count += 1 + + def truncate(self) -> None: + self.events.clear() + + +class Broker: + _scheduler: Scheduler + _topics: defaultdict[str, Topic] + + def __init__(self, maxlen: int, scheduler: Scheduler) -> None: + self._scheduler = scheduler + self._topics = defaultdict(partial(Topic, maxlen=maxlen)) + + def subscribe(self, topic: str, subscriber: str) -> None: + self._topics[topic].subscribe(subscriber) + + def unsubscribe(self, topic: str, subscriber: str) -> None: + self._topics[topic].unsubscribe(subscriber) + + def publish(self, topics: str | Collection[str], msg: Any) -> None: + event = (time(), msg) + if isinstance(topics, str): + topics = [topics] + for name in topics: + topic = self._topics[name] + topic.publish(event) + self._send_to_subscribers(name, event) + + for plugin in list(self._scheduler.plugins.values()): + try: + plugin.log_event(name, msg) + except Exception: + logger.info("Plugin failed with exception", exc_info=True) + + def truncate(self, topic: str | None = None) -> None: + if topic is None: + for _topic in self._topics.values(): + _topic.truncate() + elif topic in self._topics: + self._topics[topic].truncate() + + def _send_to_subscribers(self, topic: str, event: Any) -> None: + msg = { + "op": "event", + "topic": topic, + "event": event, + } + client_msgs = {client: [msg] for client in self._topics[topic].subscribers} + self._scheduler.send_all(client_msgs, worker_msgs={}) + + @overload + def get_events(self, topic: str) -> tuple[tuple[float, Any], ...]: + ... + + @overload + def get_events( + self, topic: None = None + ) -> dict[str, tuple[tuple[float, Any], ...]]: + ... + + def get_events( + self, topic: str | None = None + ) -> tuple[tuple[float, Any], ...] | dict[str, tuple[tuple[float, Any], ...]]: + if topic is not None: + return tuple(self._topics[topic].events) + else: + return { + name: tuple(topic.events) + for name, topic in self._topics.items() + if topic.events + } diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 7f5b391ca0..982d234c9a 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -1971,12 +1971,12 @@ def convert(self, msgs): @without_property_validation @log_errors def update(self): - log = self.scheduler.get_events(topic="stealing") - current = len(self.scheduler.events["stealing"]) - n = current - self.last - - log = [log[-i][1][1] for i in range(1, n + 1) if log[-i][1][0] == "request"] - self.last = current + topic = self.scheduler._broker._topics["stealing"] + log = log = topic.events + n = min(topic.count - self.last, len(log)) + if log: + log = [log[-i][1][1] for i in range(1, n + 1) if log[-i][1][0] == "request"] + self.last = topic.count if log: new = pipe( @@ -2041,11 +2041,12 @@ def __init__(self, scheduler, name, height=150, **kwargs): @without_property_validation @log_errors def update(self): - log = self.scheduler.events[self.name] - n = self.scheduler.event_counts[self.name] - self.last + topic = self.scheduler._broker._topics[self.name] + log = topic.events + n = min(topic.count - self.last, len(log)) if log: log = [log[-i] for i in range(1, n + 1)] - self.last = self.scheduler.event_counts[self.name] + self.last = topic.count if log: actions = [] diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index ee8c7f501c..c49bd96510 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -4,7 +4,7 @@ import pytest -from distributed import Nanny, Scheduler, SchedulerPlugin, Worker, get_worker +from distributed import Nanny, Scheduler, SchedulerPlugin, Worker from distributed.protocol.pickle import dumps from distributed.utils_test import captured_logger, gen_cluster, gen_test, inc @@ -435,47 +435,6 @@ class Plugin(SchedulerPlugin): await c.unregister_scheduler_plugin(name="plugin") -@gen_cluster(client=True) -async def test_log_event_plugin(c, s, a, b): - class EventPlugin(SchedulerPlugin): - async def start(self, scheduler: Scheduler) -> None: - self.scheduler = scheduler - self.scheduler._recorded_events = list() # type: ignore - - def log_event(self, topic, msg): - self.scheduler._recorded_events.append((topic, msg)) - - await c.register_plugin(EventPlugin()) - - def f(): - get_worker().log_event("foo", 123) - - await c.submit(f) - - assert ("foo", 123) in s._recorded_events - - -@gen_cluster(client=True) -async def test_log_event_plugin_multiple_topics(c, s, a, b): - class EventPlugin(SchedulerPlugin): - async def start(self, scheduler: Scheduler) -> None: - self.scheduler = scheduler - self.scheduler._recorded_events = list() # type: ignore - - def log_event(self, topic, msg): - self.scheduler._recorded_events.append((topic, msg)) - - await c.register_plugin(EventPlugin()) - - def f(): - get_worker().log_event(["foo", "bar"], 123) - - await c.submit(f) - - assert ("foo", 123) in s._recorded_events - assert ("bar", 123) in s._recorded_events - - @gen_cluster(client=True) async def test_register_plugin_on_scheduler(c, s, a, b): class MyPlugin(SchedulerPlugin): diff --git a/distributed/http/scheduler/tests/test_stealing_http.py b/distributed/http/scheduler/tests/test_stealing_http.py index 1272076374..efd5589612 100644 --- a/distributed/http/scheduler/tests/test_stealing_http.py +++ b/distributed/http/scheduler/tests/test_stealing_http.py @@ -53,7 +53,7 @@ async def fetch_metrics_by_cost_multipliers(): count = sum(active_metrics.values()) assert count > 0 expected_count = sum( - len(event[1]) for _, event in s.events["stealing"] if event[0] == "request" + len(event[1]) for _, event in s.get_events("stealing") if event[0] == "request" ) assert count == expected_count @@ -87,7 +87,7 @@ async def fetch_metrics_by_cost_multipliers(): assert count > 0 expected_cost = sum( request[3] - for _, event in s.events["stealing"] + for _, event in s.get_events("stealing") for request in event[1] if event[0] == "request" ) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 69ca80826e..0273d333da 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -75,6 +75,7 @@ from distributed._stories import scheduler_story from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker from distributed.batched import BatchedSend +from distributed.broker import Broker from distributed.client import SourceCode from distributed.collections import HeapSet from distributed.comm import ( @@ -3804,9 +3805,7 @@ async def post(self): ] maxlen = dask.config.get("distributed.admin.low-level-log-length") - self.events = defaultdict(partial(deque, maxlen=maxlen)) - self.event_counts = defaultdict(int) - self.event_subscriber = defaultdict(set) + self._broker = Broker(maxlen, self) self.worker_plugins = {} self.nanny_plugins = {} self._starting_nannies = set() @@ -4002,7 +4001,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: "workers": self.workers, "clients": self.clients, "memory": self.memory, - "events": self.events, + "events": self._broker._topics, "extensions": self.extensions, } extra = {k: v for k, v in extra.items() if k not in exclude} @@ -5406,8 +5405,8 @@ async def remove_worker( async def remove_worker_from_events() -> None: # If the worker isn't registered anymore after the delay, remove from events - if address not in self.workers and address in self.events: - del self.events[address] + if address not in self.workers: + self._broker.truncate(address) cleanup_delay = parse_timedelta( dask.config.get("distributed.scheduler.events-cleanup-delay") @@ -5820,8 +5819,8 @@ def remove_client(self, client: str, stimulus_id: str | None = None) -> None: async def remove_client_from_events() -> None: # If the client isn't registered anymore after the delay, remove from events - if client not in self.clients and client in self.events: - del self.events[client] + if client not in self.clients: + self._broker.truncate(client) cleanup_delay = parse_timedelta( dask.config.get("distributed.scheduler.events-cleanup-delay") @@ -8423,40 +8422,26 @@ def log_event(self, topic: str | Collection[str], msg: Any) -> None: -------- Client.log_event """ - event = (time(), msg) - if isinstance(topic, str): - topic = [topic] - for t in topic: - self.events[t].append(event) - self.event_counts[t] += 1 - self._report_event(t, event) + self._broker.publish(topic, msg) - for plugin in list(self.plugins.values()): - try: - plugin.log_event(t, msg) - except Exception: - logger.info("Plugin failed with exception", exc_info=True) + def subscribe_topic(self, topic: str, client: str) -> None: + self._broker.subscribe(topic, client) - def _report_event(self, name, event): - msg = { - "op": "event", - "topic": name, - "event": event, - } - client_msgs = {client: [msg] for client in self.event_subscriber[name]} - self.send_all(client_msgs, worker_msgs={}) + def unsubscribe_topic(self, topic: str, client: str) -> None: + self._broker.unsubscribe(topic, client) - def subscribe_topic(self, topic, client): - self.event_subscriber[topic].add(client) + @overload + def get_events(self, topic: str) -> tuple[tuple[float, Any], ...]: + ... - def unsubscribe_topic(self, topic, client): - self.event_subscriber[topic].discard(client) + @overload + def get_events(self) -> dict[str, tuple[tuple[float, Any], ...]]: + ... - def get_events(self, topic=None): - if topic is not None: - return tuple(self.events[topic]) - else: - return valmap(tuple, self.events) + def get_events( + self, topic: str | None = None + ) -> tuple[tuple[float, Any], ...] | dict[str, tuple[tuple[float, Any], ...]]: + return self._broker.get_events(topic) async def get_worker_monitor_info(self, recent=False, starts=None): if starts is None: diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 24e199d1b6..310c927fd2 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -433,7 +433,7 @@ async def test_restarting_during_transfer_raises_killed_worker(c, s, a, b): with pytest.raises(KilledWorker): await out - assert sum(event["action"] == "p2p-failed" for _, event in s.events["p2p"]) == 1 + assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1 await c.close() await check_worker_cleanup(a) @@ -460,7 +460,7 @@ async def test_restarting_does_not_log_p2p_failed(c, s, a, b): await b.close() await out - assert not s.events["p2p"] + assert not s.get_events("p2p") await c.close() await check_worker_cleanup(a) await check_worker_cleanup(b, closed=True) @@ -831,7 +831,7 @@ async def test_restarting_during_barrier_raises_killed_worker(c, s, a, b): with pytest.raises(KilledWorker): await out - assert sum(event["action"] == "p2p-failed" for _, event in s.events["p2p"]) == 1 + assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1 alive_shuffle.block_inputs_done.set() @@ -994,7 +994,7 @@ async def test_restarting_during_unpack_raises_killed_worker(c, s, a, b): with pytest.raises(KilledWorker): await out - assert sum(event["action"] == "p2p-failed" for _, event in s.events["p2p"]) == 1 + assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1 await c.close() await check_worker_cleanup(a) diff --git a/distributed/stealing.py b/distributed/stealing.py index 952aa55b1e..1d72e58a22 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -2,7 +2,7 @@ import asyncio import logging -from collections import defaultdict, deque +from collections import defaultdict from collections.abc import Container from functools import partial from math import log2 @@ -106,8 +106,6 @@ def __init__(self, scheduler: Scheduler): ) # `callback_time` is in milliseconds self.scheduler.add_plugin(self) - maxlen = dask.config.get("distributed.admin.low-level-log-length") - self.scheduler.events["stealing"] = deque(maxlen=maxlen) self.count = 0 self.in_flight = {} self.in_flight_occupancy = defaultdict(int) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index f589105e67..f35de46bf9 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6461,7 +6461,9 @@ def test_direct_to_workers(s, loop): with Client(s["address"], loop=loop, direct_to_workers=True) as client: future = client.scatter(1) future.result() - resp = client.run_on_scheduler(lambda dask_scheduler: dask_scheduler.events) + resp = client.run_on_scheduler( + lambda dask_scheduler: dask_scheduler.get_events() + ) assert "gather" not in str(resp) diff --git a/distributed/tests/test_event_logging.py b/distributed/tests/test_event_logging.py index 2c74af3ca6..43effcb12f 100644 --- a/distributed/tests/test_event_logging.py +++ b/distributed/tests/test_event_logging.py @@ -1,17 +1,52 @@ from __future__ import annotations import asyncio +from functools import partial from unittest import mock import pytest -from distributed import Client, Nanny, get_worker +from distributed import Client, Nanny, Scheduler, get_worker from distributed.core import error_message +from distributed.diagnostics import SchedulerPlugin +from distributed.metrics import time from distributed.utils_test import captured_logger, gen_cluster +@gen_cluster(nthreads=[]) +async def test_log_event(s): + before = time() + s.log_event("foo", {"action": "test", "value": 1}) + after = time() + assert len(s.get_events("foo")) == 1 + timestamp, event = s.get_events("foo")[0] + assert before <= timestamp <= after + assert event == {"action": "test", "value": 1} + + +@gen_cluster(nthreads=[]) +async def test_log_events(s): + s.log_event("foo", {"action": "test", "value": 1}) + s.log_event(["foo", "bar"], {"action": "test", "value": 2}) + + actual = [event for _, event in s.get_events("foo")] + assert actual == [{"action": "test", "value": 1}, {"action": "test", "value": 2}] + + actual = [event for _, event in s.get_events("bar")] + assert actual == [{"action": "test", "value": 2}] + + actual = { + topic: [event for _, event in events] + for topic, events in s.get_events().items() + } + assert actual == { + "foo": [{"action": "test", "value": 1}, {"action": "test", "value": 2}], + "bar": [{"action": "test", "value": 2}], + } + + @gen_cluster(client=True, nthreads=[("", 1)]) -async def test_log_event(c, s, a): +async def test_log_event_e2e(c, s, a): # Log an event from inside a task def foo(): get_worker().log_event("topic1", {"foo": "bar"}) @@ -54,7 +89,7 @@ def handler(event): c.subscribe_topic("test-topic", get_event_handler(1)) c2.subscribe_topic("test-topic", get_event_handler(2)) - while len(s.event_subscriber["test-topic"]) != 2: + while len(s._broker._topics["test-topic"].subscribers) != 2: await asyncio.sleep(0.01) with captured_logger("distributed.client") as logger: @@ -77,7 +112,7 @@ def user_event_handler(event): c.subscribe_topic("test-topic", user_event_handler) - while not s.event_subscriber["test-topic"]: + while not s._broker._topics["test-topic"].subscribers: await asyncio.sleep(0.01) a.log_event("test-topic", {"important": "event"}) @@ -91,12 +126,12 @@ def user_event_handler(event): c.unsubscribe_topic("test-topic") - while s.event_subscriber["test-topic"]: + while s._broker._topics["test-topic"].subscribers: await asyncio.sleep(0.01) a.log_event("test-topic", {"forget": "me"}) - while len(s.events["test-topic"]) == 1: + while len(s.get_events("test-topic")) == 1: await asyncio.sleep(0.01) assert len(log) == 1 @@ -107,7 +142,7 @@ async def async_user_event_handler(event): c.subscribe_topic("test-topic", async_user_event_handler) - while not s.event_subscriber["test-topic"]: + while not s._broker._topics["test-topic"].subscribers: await asyncio.sleep(0.01) a.log_event("test-topic", {"async": "event"}) @@ -139,7 +174,7 @@ async def user_event_handler(event): await asyncio.sleep(0.5) c.subscribe_topic("test-topic", user_event_handler) - while not s.event_subscriber["test-topic"]: + while not s._broker._topics["test-topic"].subscribers: await asyncio.sleep(0.01) a.log_event("test-topic", {}) @@ -148,6 +183,62 @@ async def user_event_handler(event): assert exc_info is not None +@gen_cluster(nthreads=[]) +async def test_topic_subscribe_unsubscribe(s): + async with Client(s.address, asynchronous=True) as c1, Client( + s.address, asynchronous=True + ) as c2: + + def event_handler(recorded_events, event): + _, msg = event + recorded_events.append(msg) + + c1_events = [] + c1.subscribe_topic("foo", partial(event_handler, c1_events)) + while not s._broker._topics["foo"].subscribers: + await asyncio.sleep(0.01) + s.log_event("foo", {"value": 1}) + + c2_events = [] + c2.subscribe_topic("foo", partial(event_handler, c2_events)) + c2.subscribe_topic("bar", partial(event_handler, c2_events)) + + while ( + not s._broker._topics["bar"].subscribers + and len(s._broker._topics["foo"].subscribers) < 2 + ): + await asyncio.sleep(0.01) + + s.log_event("foo", {"value": 2}) + s.log_event("bar", {"value": 3}) + + c2.unsubscribe_topic("foo") + + while len(s._broker._topics["foo"].subscribers) > 1: + await asyncio.sleep(0.01) + + s.log_event("foo", {"value": 4}) + s.log_event("bar", {"value": 5}) + + c1.unsubscribe_topic("foo") + + while s._broker._topics["foo"].subscribers: + await asyncio.sleep(0.01) + + s.log_event("foo", {"value": 6}) + s.log_event("bar", {"value": 7}) + + c2.unsubscribe_topic("bar") + + while s._broker._topics["bar"].subscribers: + await asyncio.sleep(0.01) + + s.log_event("bar", {"value": 8}) + + assert c1_events == [{"value": 1}, {"value": 2}, {"value": 4}] + assert c2_events == [{"value": 2}, {"value": 3}, {"value": 5}, {"value": 7}] + + @gen_cluster(client=True, nthreads=[("", 1)]) async def test_events_all_servers_use_same_channel(c, s, a): """Ensure that logs from all server types (scheduler, worker, nanny) @@ -160,7 +251,7 @@ def user_event_handler(event): c.subscribe_topic("test-topic", user_event_handler) - while not s.event_subscriber["test-topic"]: + while not s._broker._topics["test-topic"].subscribers: await asyncio.sleep(0.01) async with Nanny(s.address) as n: @@ -208,17 +299,18 @@ class C: @gen_cluster(client=True, config={"distributed.admin.low-level-log-length": 3}) async def test_configurable_events_log_length(c, s, a, b): s.log_event("test", "dummy message 1") - assert len(s.events["test"]) == 1 + assert len(s.get_events("test")) == 1 s.log_event("test", "dummy message 2") s.log_event("test", "dummy message 3") - assert len(s.events["test"]) == 3 + assert len(s.get_events("test")) == 3 + assert s._broker._topics["test"].count == 3 # adding a fourth message will drop the first one and length stays at 3 s.log_event("test", "dummy message 4") - assert len(s.events["test"]) == 3 - assert s.events["test"][0][1] == "dummy message 2" - assert s.events["test"][1][1] == "dummy message 3" - assert s.events["test"][2][1] == "dummy message 4" + assert len(s.get_events("test")) == 3 + assert s._broker._topics["test"].count == 4 + events = [event for _, event in s.get_events("test")] + assert events == ["dummy message 2", "dummy message 3", "dummy message 4"] @gen_cluster(client=True, nthreads=[]) @@ -282,3 +374,44 @@ class C: "worker": a.address, }, ] == [msg[1] for msg in s.get_events("test-topic")] + + +@gen_cluster(client=True) +async def test_log_event_plugin(c, s, a, b): + class EventPlugin(SchedulerPlugin): + async def start(self, scheduler: Scheduler) -> None: + self.scheduler = scheduler + self.scheduler._recorded_events = list() # type: ignore + + def log_event(self, topic, msg): + self.scheduler._recorded_events.append((topic, msg)) + + await c.register_plugin(EventPlugin()) + + def f(): + get_worker().log_event("foo", 123) + + await c.submit(f) + + assert ("foo", 123) in s._recorded_events + + +@gen_cluster(client=True) +async def test_log_event_plugin_multiple_topics(c, s, a, b): + class EventPlugin(SchedulerPlugin): + async def start(self, scheduler: Scheduler) -> None: + self.scheduler = scheduler + self.scheduler._recorded_events = list() # type: ignore + + def log_event(self, topic, msg): + self.scheduler._recorded_events.append((topic, msg)) + + await c.register_plugin(EventPlugin()) + + def f(): + get_worker().log_event(["foo", "bar"], 123) + + await c.submit(f) + + assert ("foo", 123) in s._recorded_events + assert ("bar", 123) in s._recorded_events diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 1314ea8e27..52073c13d2 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -554,7 +554,7 @@ async def test_nanny_closed_by_keyboard_interrupt(ucx_loop, protocol): ) as n: await n.process.stopped.wait() # Check that the scheduler has been notified about the closed worker - assert "remove-worker" in str(s.events) + assert "remove-worker" in str(s.get_events()) class BrokenWorker(worker.Worker): diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 7c2cc391ad..9aaea1288a 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -865,39 +865,39 @@ async def test_remove_worker_by_name_from_scheduler(s, a, b): @gen_cluster(config={"distributed.scheduler.events-cleanup-delay": "500 ms"}) async def test_clear_events_worker_removal(s, a, b): - assert a.address in s.events + assert a.address in s._broker._topics assert a.address in s.workers - assert b.address in s.events + assert b.address in s._broker._topics assert b.address in s.workers await s.remove_worker(address=a.address, stimulus_id="test") # Shortly after removal, the events should still be there - assert a.address in s.events + assert s.get_events(a.address) assert a.address not in s.workers s.validate_state() start = time() - while a.address in s.events: + while s.get_events(a.address): await asyncio.sleep(0.01) assert time() < start + 2 - assert b.address in s.events + assert b.address in s._broker._topics @gen_cluster( config={"distributed.scheduler.events-cleanup-delay": "10 ms"}, client=True ) async def test_clear_events_client_removal(c, s, a, b): - assert c.id in s.events + assert s.get_events(c.id) s.remove_client(c.id) - assert c.id in s.events + assert s.get_events(c.id) assert c.id not in s.clients assert c not in s.clients s.remove_client(c.id) # If it doesn't reconnect after a given time, the events log should be cleared start = time() - while c.id in s.events: + while s.get_events(c.id): await asyncio.sleep(0.01) assert time() < start + 2 @@ -2108,14 +2108,16 @@ def g(_, ev1, ev2): await ev2.set() -@pytest.mark.slow +# @pytest.mark.slow @gen_cluster( client=True, Worker=Nanny, clean_kwargs={"processes": False, "threads": False} ) async def test_log_tasks_during_restart(c, s, a, b): future = c.submit(sys.exit, 0) await wait(future) - assert "exit" in str(s.events) + assert "exit" in str( + {name: topic.events for name, topic in s._broker._topics.items()} + ) @gen_cluster(client=True) @@ -4436,7 +4438,7 @@ def block(x, event): await event.set() await c.gather(futs) - assert "TaskState" not in str(s.events) + assert not any("TaskState" in str(event) for event in s.get_events()) @gen_cluster(nthreads=[("", 1)]) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 539923dd1c..3976857c9e 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1178,7 +1178,7 @@ async def test_steal_worker_dies_same_ip(c, s, w0, w1): wsB = s.workers[w1.address] steal.move_task_request(victim_ts, wsA, wsB) - len_before = len(s.events["stealing"]) + len_before = len(s.get_events("stealing")) with freeze_batched_send(w0.batched_stream): while not any( isinstance(event, StealRequestEvent) for event in w0.state.stimulus_log @@ -1208,7 +1208,7 @@ async def test_steal_worker_dies_same_ip(c, s, w0, w1): assert hash(wsB2) != hash(wsB) # Wait for the steal response to arrive - while len_before == len(s.events["stealing"]): + while len_before == len(s.get_events("stealing")): await asyncio.sleep(0.1) assert victim_ts.processing_on != wsB @@ -1875,5 +1875,5 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers): results = [dask.delayed(lambda *args: None)(root, i) for i in range(1000)] futs = c.compute(results) await c.gather(futs) - events = s.events["stealing"] + events = s.get_events("stealing") assert len(events) == 0 diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 7530b82ce4..c750e3faa5 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -712,7 +712,7 @@ async def test_log_invalid_transitions(c, s, a): with pytest.raises(InvalidTransition): a.handle_stimulus(ev) - while not s.events["invalid-worker-transition"]: + while not s.get_events("invalid-worker-transition"): await asyncio.sleep(0.01) with pytest.raises(Exception) as info: @@ -737,7 +737,7 @@ async def test_log_invalid_worker_task_state(c, s, a): with pytest.raises(InvalidTaskState): a.validate_state() - while not s.events["invalid-worker-task-state"]: + while not s.get_events("invalid-worker-task-state"): await asyncio.sleep(0.01) with pytest.raises(Exception) as info: diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 3ea579487c..117ad00406 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2906,7 +2906,7 @@ async def test_worker_status_sync(s, a): while ws.status != Status.closed: await asyncio.sleep(0.01) - events = [ev for _, ev in s.events[ws.address] if ev["action"] != "heartbeat"] + events = [ev for _, ev in s.get_events(ws.address) if ev["action"] != "heartbeat"] for ev in events: if "stimulus_id" in ev: # Strip timestamp ev["stimulus_id"] = ev["stimulus_id"].rsplit("-", 1)[0] @@ -2963,7 +2963,7 @@ async def test_log_remove_worker(c, s, a, b): # Scattered task z = await c.scatter({"z": 3}, workers=a.address) - s.events.clear() + s._broker.truncate() with captured_logger("distributed.scheduler", level=logging.INFO) as log: # Successful graceful shutdown @@ -2999,7 +2999,7 @@ async def test_log_remove_worker(c, s, a, b): "Lost all workers", ] - events = {topic: [ev for _, ev in evs] for topic, evs in s.events.items()} + events = {topic: [ev for _, ev in evs] for topic, evs in s.get_events().items()} for evs in events.values(): for ev in evs: if ev["action"] == "retire-workers": diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 78943226bf..1fd59b5525 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -806,24 +806,25 @@ async def start_cluster( def check_invalid_worker_transitions(s: Scheduler) -> None: - if not s.events.get("invalid-worker-transition"): + if not s.get_events("invalid-worker-transition"): return - for _, msg in s.events["invalid-worker-transition"]: + for _, msg in s.get_events("invalid-worker-transition"): worker = msg.pop("worker") print("Worker:", worker) print(InvalidTransition(**msg)) raise ValueError( - "Invalid worker transitions found", len(s.events["invalid-worker-transition"]) + "Invalid worker transitions found", + len(s.get_events("invalid-worker-transition")), ) def check_invalid_task_states(s: Scheduler) -> None: - if not s.events.get("invalid-worker-task-state"): + if not s.get_events("invalid-worker-task-state"): return - for _, msg in s.events["invalid-worker-task-state"]: + for _, msg in s.get_events("invalid-worker-task-state"): print("Worker:", msg["worker"]) print("State:", msg["state"]) for line in msg["story"]: @@ -833,10 +834,10 @@ def check_invalid_task_states(s: Scheduler) -> None: def check_worker_fail_hard(s: Scheduler) -> None: - if not s.events.get("worker-fail-hard"): + if not s.get_events("worker-fail-hard"): return - for _, msg in s.events["worker-fail-hard"]: + for _, msg in s.get_events("worker-fail-hard"): msg = msg.copy() worker = msg.pop("worker") msg["exception"] = deserialize(msg["exception"].header, msg["exception"].frames) From d36fc3cca66a137d307a8b896e119db89169a2ec Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 3 Jul 2024 12:06:52 -0500 Subject: [PATCH 052/138] Bump ``pandas`` to 2.0 in mindeps build (#8743) --- .github/workflows/tests.yaml | 4 ++-- distributed/protocol/tests/test_highlevelgraph.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index cdd8150ec2..49c4531eb2 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -83,12 +83,12 @@ jobs: - os: ubuntu-latest environment: mindeps label: pandas - extra_packages: [numpy=1.21, pandas=1.3, pyarrow=7, pyarrow-hotfix] + extra_packages: [numpy=1.21, pandas=2.0, pyarrow=7, pyarrow-hotfix] partition: "ci1" - os: ubuntu-latest environment: mindeps label: pandas - extra_packages: [numpy=1.21, pandas=1.3, pyarrow=7, pyarrow-hotfix] + extra_packages: [numpy=1.21, pandas=2.0, pyarrow=7, pyarrow-hotfix] partition: "not ci1" - os: ubuntu-latest diff --git a/distributed/protocol/tests/test_highlevelgraph.py b/distributed/protocol/tests/test_highlevelgraph.py index 33de277a53..ba1f6478a8 100644 --- a/distributed/protocol/tests/test_highlevelgraph.py +++ b/distributed/protocol/tests/test_highlevelgraph.py @@ -6,12 +6,12 @@ np = pytest.importorskip("numpy") pd = pytest.importorskip("pandas") +dd = pytest.importorskip("dask.dataframe") +da = pytest.importorskip("dask.array") from numpy.testing import assert_array_equal import dask -import dask.array as da -import dask.dataframe as dd from distributed.diagnostics import SchedulerPlugin from distributed.utils_test import gen_cluster From 9d362285b6f7f78cc6e6c36bbb3e3e726ef55863 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 3 Jul 2024 16:52:34 -0400 Subject: [PATCH 053/138] Update system monitor when polling Prometheus metrics (#8745) --- distributed/http/scheduler/prometheus/core.py | 2 ++ distributed/http/worker/prometheus/core.py | 1 + 2 files changed, 3 insertions(+) diff --git a/distributed/http/scheduler/prometheus/core.py b/distributed/http/scheduler/prometheus/core.py index 0a5bbf3020..24034bee9d 100644 --- a/distributed/http/scheduler/prometheus/core.py +++ b/distributed/http/scheduler/prometheus/core.py @@ -22,6 +22,8 @@ def __init__(self, server: Scheduler): self.subsystem = "scheduler" def collect(self) -> Iterator[GaugeMetricFamily | CounterMetricFamily]: + self.server.monitor.update() + yield GaugeMetricFamily( self.build_name("clients"), "Number of clients connected", diff --git a/distributed/http/worker/prometheus/core.py b/distributed/http/worker/prometheus/core.py index 5354ee521c..dcb3d77d3d 100644 --- a/distributed/http/worker/prometheus/core.py +++ b/distributed/http/worker/prometheus/core.py @@ -32,6 +32,7 @@ def __init__(self, server: Worker): ) def collect(self) -> Iterator[Metric]: + self.server.monitor.update() ws = self.server.state tasks = GaugeMetricFamily( From 670fbfbcb3a6360da5dbd426a4fd83ea4634e1d4 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 4 Jul 2024 09:27:43 -0400 Subject: [PATCH 054/138] Log traceback upon task error (#8746) --- distributed/tests/test_worker.py | 1 + distributed/worker.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 117ad00406..fa78641828 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -878,6 +878,7 @@ async def test_log_exception_on_failed_task(c, s, a, b): text = logger.getvalue() assert "ZeroDivisionError" in text assert "Exception" in text + assert "Traceback" in text @gen_cluster(client=True) diff --git a/distributed/worker.py b/distributed/worker.py index 27be73e30a..17d715507e 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2348,13 +2348,15 @@ async def execute(self, key: Key, *, stimulus_id: str) -> StateMachineEvent: "Function: %s\n" "args: %s\n" "kwargs: %s\n" - "Exception: %r\n", + "Exception: %r\n" + "Traceback: %r\n", key, ts.state, str(funcname(function))[:1000], convert_args_to_str(args2, max_len=1000), convert_kwargs_to_str(kwargs2, max_len=1000), result["exception_text"], + result["traceback_text"], ) return ExecuteFailureEvent.from_exception( From 86ef1b92dd864c4e642e9c3d93c0a4a761745538 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 5 Jul 2024 12:48:27 -0500 Subject: [PATCH 055/138] Fix ``assert_eq`` import from ``cudf`` (#8747) --- distributed/comm/tests/test_ucx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index f2dd3ecb2e..fe2e06fc63 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -198,7 +198,7 @@ async def test_ping_pong_cudf(ucx_loop, g): # *** ImportError: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `CXXABI_1.3.11' # not found (required by python3.7/site-packages/pyarrow/../../../libarrow.so.12) cudf = pytest.importorskip("cudf") - from cudf.testing._utils import assert_eq + from cudf.testing import assert_eq cudf_obj = g(cudf) From a57ab42e2600d0589342fe596e443c0e40112976 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 5 Jul 2024 14:18:10 -0500 Subject: [PATCH 056/138] bump version to 2024.7.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f9b43e28fd..323475ed6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ requires-python = ">=3.9" dependencies = [ "click >= 8.0", "cloudpickle >= 1.5.0", - "dask == 2024.6.2", + "dask == 2024.7.0", "jinja2 >= 2.10.3", "locket >= 1.0.0", "msgpack >= 1.0.0", From 8564dc79a1b9902eb0320d51b034e0623a2afe8b Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 10 Jul 2024 05:10:48 -0500 Subject: [PATCH 057/138] Add close worker button to worker info page (#8742) --- distributed/http/scheduler/info.py | 8 +++++ distributed/http/templates/worker-table.html | 37 ++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/distributed/http/scheduler/info.py b/distributed/http/scheduler/info.py index 6ca6f5d3b6..a84a6a362f 100644 --- a/distributed/http/scheduler/info.py +++ b/distributed/http/scheduler/info.py @@ -11,6 +11,7 @@ from tornado import escape from tornado.websocket import WebSocketHandler +import dask from dask.typing import Key from dask.utils import format_bytes, format_time @@ -32,6 +33,10 @@ logger = logging.getLogger(__name__) +API_ENABLED = "distributed.http.scheduler.api" in dask.config.get( + "distributed.scheduler.http.routes" +) + class Workers(RequestHandler): @log_errors @@ -40,6 +45,7 @@ def get(self): "workers.html", title="Workers", scheduler=self.server, + api_enabled=API_ENABLED, **merge( self.server.__dict__, self.server.__pdict__, @@ -62,6 +68,7 @@ def get(self, worker): "worker.html", title="Worker: " + worker, scheduler=self.server, + api_enabled=API_ENABLED, Worker=worker, **merge( self.server.__dict__, @@ -134,6 +141,7 @@ def get(self, task: str) -> None: title=f"Task: {key!r}", Task=key, scheduler=self.server, + api_enabled=API_ENABLED, **merge( self.server.__dict__, self.server.__pdict__, diff --git a/distributed/http/templates/worker-table.html b/distributed/http/templates/worker-table.html index 1f7168c338..2ab60951fa 100644 --- a/distributed/http/templates/worker-table.html +++ b/distributed/http/templates/worker-table.html @@ -12,6 +12,10 @@ Services Logs Last seen + {% if api_enabled %} + + {% else %} + {% end %} @@ -33,7 +37,40 @@ {% end %} logs {{ format_time(time() - ws.last_seen) }} + {% if api_enabled %} + +
+ +
+ + {% else %} + {% end %} {% end %} + + From 4a6ca5f1e5a58ab51eb57cbd1b686368d55515cc Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Thu, 11 Jul 2024 10:44:28 +0100 Subject: [PATCH 058/138] Adding HLG to MAP (#8740) --- distributed/client.py | 226 +++++++++++++++++++++++-------- distributed/tests/test_client.py | 104 ++++++++------ 2 files changed, 231 insertions(+), 99 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 3cac5b3fc2..b5c6eb3789 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -17,7 +17,7 @@ import warnings import weakref from collections import defaultdict -from collections.abc import Collection, Coroutine, Iterator, Sequence +from collections.abc import Collection, Coroutine, Iterable, Iterator, Sequence from concurrent.futures import ThreadPoolExecutor from concurrent.futures._base import DoneAndNotDoneFutures from contextlib import asynccontextmanager, contextmanager, suppress @@ -26,7 +26,19 @@ from importlib.metadata import PackageNotFoundError, version from numbers import Number from queue import Queue as pyQueue -from typing import Any, Callable, ClassVar, Literal, NamedTuple, TypedDict, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Literal, + NamedTuple, + TypedDict, + cast, +) + +if TYPE_CHECKING: + from typing_extensions import TypeAlias from packaging.version import parse as parse_version from tlz import first, groupby, merge, partition_all, valmap @@ -35,8 +47,9 @@ from dask.base import collections_to_dsk, tokenize from dask.core import flatten, validate_key from dask.highlevelgraph import HighLevelGraph +from dask.layers import Layer from dask.optimization import SubgraphCallable -from dask.typing import NoDefault, no_default +from dask.typing import Key, NoDefault, no_default from dask.utils import ( apply, ensure_dict, @@ -807,6 +820,135 @@ class VersionsDict(TypedDict): client: dict[str, dict[str, Any]] +_T_LowLevelGraph: TypeAlias = dict[Key, tuple] + + +def _is_nested(iterable): + for item in iterable: + if ( + isinstance(item, Iterable) + and not isinstance(item, str) + and not isinstance(item, bytes) + ): + return True + return False + + +class _MapLayer(Layer): + func: Callable + iterables: Iterable[Any] + key: str | Iterable[str] | None + pure: bool + annotations: dict[str, Any] | None + + def __init__( + self, + func: Callable, + iterables: Iterable[Any], + key: str | Iterable[str] | None = None, + pure: bool = True, + annotations: dict[str, Any] | None = None, + **kwargs, + ): + self.func: Callable = func + self.iterables: Iterable[Any] = ( + list(zip(*zip(*iterables))) if _is_nested(iterables) else [iterables] + ) + self.key: str | Iterable[str] | None = key + self.pure: bool = pure + self.kwargs = kwargs + super().__init__(annotations=annotations) + + def __repr__(self) -> str: + return f"{type(self).__name__} " + + @property + def _dict(self) -> _T_LowLevelGraph: + self._cached_dict: _T_LowLevelGraph + dsk: _T_LowLevelGraph + + if hasattr(self, "_cached_dict"): + return self._cached_dict + else: + dsk = self._construct_graph() + self._cached_dict = dsk + return self._cached_dict + + @property + def _keys(self) -> Iterable[Key]: + if hasattr(self, "_cached_keys"): + return self._cached_keys + else: + if isinstance(self.key, Iterable) and not isinstance(self.key, str): + self._cached_keys: Iterable[Key] = self.key + return self.key + + else: + if self.pure: + keys = [ + self.key + "-" + tokenize(self.func, self.kwargs, args) # type: ignore + for args in zip(*self.iterables) + ] + else: + uid = str(uuid.uuid4()) + keys = ( + [ + f"{self.key}-{uid}-{i}" + for i in range(min(map(len, self.iterables))) + ] + if self.iterables + else [] + ) + self._cached_keys = keys + return keys + + def get_output_keys(self) -> set[Key]: + return set(self._keys) + + def get_ordered_keys(self): + return list(self._keys) + + def is_materialized(self) -> bool: + return hasattr(self, "_cached_dict") + + def __getitem__(self, key: Key) -> tuple: + return self._dict[key] + + def __iter__(self) -> Iterator[Key]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) + + def _construct_graph(self) -> _T_LowLevelGraph: + dsk: _T_LowLevelGraph = {} + + if not self.kwargs: + dsk = { + key: (self.func,) + args + for key, args in zip(self._keys, zip(*self.iterables)) + } + + else: + kwargs2 = {} + dsk = {} + for k, v in self.kwargs.items(): + if sizeof(v) > 1e5: + vv = dask.delayed(v) + kwargs2[k] = vv._key + dsk.update(vv.dask) + else: + kwargs2[k] = v + + dsk.update( + { + key: (apply, self.func, (tuple, list(args)), kwargs2) + for key, args in zip(self._keys, zip(*self.iterables)) + } + ) + return dsk + + class Client(SyncMethodMixin): """Connect to and submit computation to a Dask cluster @@ -2046,18 +2188,18 @@ def submit( def map( self, - func, - *iterables, - key=None, - workers=None, - retries=None, - resources=None, - priority=0, - allow_other_workers=False, - fifo_timeout="100 ms", - actor=False, - actors=False, - pure=True, + func: Callable, + *iterables: Collection, + key: str | list | None = None, + workers: str | Iterable[str] | None = None, + retries: int | None = None, + resources: dict[str, Any] | None = None, + priority: int = 0, + allow_other_workers: bool = False, + fifo_timeout: str = "100 ms", + actor: bool = False, + actors: bool = False, + pure: bool = True, batch_size=None, **kwargs, ): @@ -2148,11 +2290,11 @@ def map( "Consider using a normal for loop and Client.submit" ) total_length = sum(len(x) for x in iterables) - if batch_size and batch_size > 1 and total_length > batch_size: batches = list( zip(*(partition_all(batch_size, iterable) for iterable in iterables)) ) + keys: list[list[Any]] | list[Any] if isinstance(key, list): keys = [list(element) for element in partition_all(batch_size, key)] else: @@ -2187,45 +2329,14 @@ def map( if allow_other_workers and workers is None: raise ValueError("Only use allow_other_workers= if using workers=") - iterables = list(zip(*zip(*iterables))) - if isinstance(key, list): - keys = key - else: - if pure: - keys = [ - key + "-" + tokenize(func, kwargs, *args) - for args in zip(*iterables) - ] - else: - uid = str(uuid.uuid4()) - keys = ( - [ - key + "-" + uid + "-" + str(i) - for i in range(min(map(len, iterables))) - ] - if iterables - else [] - ) - - if not kwargs: - dsk = {key: (func,) + args for key, args in zip(keys, zip(*iterables))} - else: - kwargs2 = {} - dsk = {} - for k, v in kwargs.items(): - if sizeof(v) > 1e5: - vv = dask.delayed(v) - kwargs2[k] = vv._key - dsk.update(vv.dask) - else: - kwargs2[k] = v - dsk.update( - { - key: (apply, func, (tuple, list(args)), kwargs2) - for key, args in zip(keys, zip(*iterables)) - } - ) - + dsk = _MapLayer( + func, + iterables, + key=key, + pure=pure, + **kwargs, + ) + keys = dsk.get_ordered_keys() if isinstance(workers, (str, Number)): workers = [workers] if workers is not None and not isinstance(workers, (list, set)): @@ -2246,8 +2357,10 @@ def map( actors=actor, span_metadata=SpanMetadata(collections=[{"type": "Future"}]), ) - logger.debug("map(%s, ...)", funcname(func)) + # make sure the graph is not materialized + assert not dsk.is_materialized(), "Graph must be non-materialized" + logger.debug("map(%s, ...)", funcname(func)) return [futures[k] for k in keys] async def _gather(self, futures, errors="raise", direct=None, local_worker=None): @@ -3199,7 +3312,6 @@ def _graph_to_futures( with self._refcount_lock: if actors is not None and actors is not True and actors is not False: actors = list(self._expand_key(actors)) - # Make sure `dsk` is a high level graph if not isinstance(dsk, HighLevelGraph): dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=()) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index f35de46bf9..28479ca512 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -155,6 +155,7 @@ async def test_map(c, s, a, b): assert all(isinstance(x, Future) for x in L1) result = await L1[0] + assert result == inc(0) assert len(s.tasks) == 5 @@ -2128,9 +2129,12 @@ async def test_repr_no_memory_limit(c, s, a, b): @gen_test() async def test_repr_localcluster(): - async with LocalCluster( - processes=False, dashboard_address=":0", asynchronous=True - ) as cluster, Client(cluster, asynchronous=True) as client: + async with ( + LocalCluster( + processes=False, dashboard_address=":0", asynchronous=True + ) as cluster, + Client(cluster, asynchronous=True) as client, + ): text = client._repr_html_() assert cluster.scheduler.address in text assert is_valid_xml(client._repr_html_()) @@ -2308,9 +2312,10 @@ async def test_cleanup_after_broken_client_connection(s, a, b): @gen_cluster() async def test_multi_garbage_collection(s, a, b): - async with Client(s.address, asynchronous=True) as c, Client( - s.address, asynchronous=True - ) as f: + async with ( + Client(s.address, asynchronous=True) as c, + Client(s.address, asynchronous=True) as f, + ): x = c.submit(inc, 1) y = f.submit(inc, 2) y2 = c.submit(inc, 2) @@ -3774,9 +3779,10 @@ async def test_reconnect(): stack = ExitStack() proc = popen(["dask", "scheduler", "--no-dashboard", f"--port={port}"]) stack.enter_context(proc) - async with Client(f"127.0.0.1:{port}", asynchronous=True) as c, Worker( - f"127.0.0.1:{port}" - ) as w: + async with ( + Client(f"127.0.0.1:{port}", asynchronous=True) as c, + Worker(f"127.0.0.1:{port}") as w, + ): await c.wait_for_workers(1, timeout=10) x = c.submit(inc, 1) assert (await x) == 2 @@ -3855,9 +3861,10 @@ def _(loop: object, context: dict[str, Any]) -> None: @gen_cluster(client=True, nthreads=[], client_kwargs={"timeout": 0.5}) async def test_reconnect_timeout(c, s): - with catch_unhandled_exceptions(), captured_logger( - logging.getLogger("distributed.client") - ) as logger: + with ( + catch_unhandled_exceptions(), + captured_logger(logging.getLogger("distributed.client")) as logger, + ): await s.close() while c.status != "closed": await asyncio.sleep(0.05) @@ -3939,9 +3946,10 @@ async def start_worker(sleep, duration, repeat=1): @gen_cluster() async def test_idempotence(s, a, b): - async with Client(s.address, asynchronous=True) as c, Client( - s.address, asynchronous=True - ) as f: + async with ( + Client(s.address, asynchronous=True) as c, + Client(s.address, asynchronous=True) as f, + ): # Submit x = c.submit(inc, 1) await x @@ -4154,9 +4162,10 @@ async def test_scatter_compute_store_lose_processing(c, s, a, b): @gen_cluster() async def test_serialize_future(s, a, b): - async with Client(s.address, asynchronous=True) as c1, Client( - s.address, asynchronous=True - ) as c2: + async with ( + Client(s.address, asynchronous=True) as c1, + Client(s.address, asynchronous=True) as c2, + ): future = c1.submit(lambda: 1) result = await future @@ -4198,9 +4207,10 @@ def do_stuff(): @gen_cluster() async def test_temp_default_client(s, a, b): - async with Client(s.address, asynchronous=True) as c1, Client( - s.address, asynchronous=True - ) as c2: + async with ( + Client(s.address, asynchronous=True) as c1, + Client(s.address, asynchronous=True) as c2, + ): with temp_default_client(c1): assert default_client() is c1 assert default_client(c2) is c2 @@ -4212,9 +4222,10 @@ async def test_temp_default_client(s, a, b): @gen_cluster(client=True) async def test_as_current(c, s, a, b): - async with Client(s.address, asynchronous=True) as c1, Client( - s.address, asynchronous=True - ) as c2: + async with ( + Client(s.address, asynchronous=True) as c1, + Client(s.address, asynchronous=True) as c2, + ): with temp_default_client(c): assert Client.current() is c assert Client.current(allow_global=False) is c @@ -6284,9 +6295,10 @@ def f(): @gen_cluster() async def test_mixing_clients_same_scheduler(s, a, b): - async with Client(s.address, asynchronous=True) as c1, Client( - s.address, asynchronous=True - ) as c2: + async with ( + Client(s.address, asynchronous=True) as c1, + Client(s.address, asynchronous=True) as c2, + ): future = c1.submit(inc, 1) assert await c2.submit(inc, future) == 3 assert not s.tasks @@ -6294,9 +6306,12 @@ async def test_mixing_clients_same_scheduler(s, a, b): @gen_cluster() async def test_mixing_clients_different_scheduler(s, a, b): - async with Scheduler(port=open_port()) as s2, Worker(s2.address) as w1, Client( - s.address, asynchronous=True - ) as c1, Client(s2.address, asynchronous=True) as c2: + async with ( + Scheduler(port=open_port()) as s2, + Worker(s2.address) as w1, + Client(s.address, asynchronous=True) as c1, + Client(s2.address, asynchronous=True) as c2, + ): future = c1.submit(inc, 1) with pytest.raises(CancelledError): await c2.submit(inc, future) @@ -6505,8 +6520,10 @@ async def test_file_descriptors_dont_leak(Worker): proc = psutil.Process() before = proc.num_fds() async with Scheduler(dashboard_address=":0") as s: - async with Worker(s.address), Worker(s.address), Client( - s.address, asynchronous=True + async with ( + Worker(s.address), + Worker(s.address), + Client(s.address, asynchronous=True), ): assert proc.num_fds() > before await df.sum().persist() @@ -6929,16 +6946,19 @@ async def test_mixed_compression(c, s): pytest.importorskip("numpy") da = pytest.importorskip("dask.array") - async with Nanny( - s.address, - host="127.0.0.2", - nthreads=1, - config={"distributed.comm.compression": "lz4"}, - ), Nanny( - s.address, - host="127.0.0.3", - nthreads=1, - config={"distributed.comm.compression": "zlib"}, + async with ( + Nanny( + s.address, + host="127.0.0.2", + nthreads=1, + config={"distributed.comm.compression": "lz4"}, + ), + Nanny( + s.address, + host="127.0.0.3", + nthreads=1, + config={"distributed.comm.compression": "zlib"}, + ), ): await c.wait_for_workers(2) await c.get_versions() From b8022373324758dbe9a19aab03a12284a7439246 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Thu, 11 Jul 2024 12:37:59 +0200 Subject: [PATCH 059/138] Robuster deeply nested structures (#8730) --- distributed/protocol/serialize.py | 25 ++++++++++++--------- distributed/protocol/tests/test_protocol.py | 20 +++++++++++------ 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index fe2d6879a4..a2e593e06a 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -209,17 +209,20 @@ def register_serialization_family(name, dumps, loads): def check_dask_serializable(x): - if type(x) in (list, set, tuple) and len(x): - return check_dask_serializable(next(iter(x))) - elif type(x) is dict and len(x): - return check_dask_serializable(next(iter(x.items()))[1]) - else: - try: - dask_serialize.dispatch(type(x)) - return True - except TypeError: - pass - return False + try: + if type(x) in (list, set, tuple) and len(x): + return check_dask_serializable(next(iter(x))) + elif type(x) is dict and len(x): + return check_dask_serializable(next(iter(x.items()))[1]) + else: + try: + dask_serialize.dispatch(type(x)) + return True + except TypeError: + pass + return False + except RecursionError: + return False def serialize( # type: ignore[no-untyped-def] diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index c9db26fb79..abe7a63e04 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import sys import pytest @@ -161,19 +162,24 @@ def test_sizeof_serialize(Wrapper, Wrapped): @pytest.mark.skipif(WINDOWS, reason="On windows this is triggering a stackoverflow") def test_deeply_nested_structures(): # These kind of deeply nested structures are generated in our profiling code - def gen_deeply_nested(depth): - msg = {} - d = msg + def gen_deeply_nested(depth, msg=None): + d = msg or {} while depth: depth -= 1 - d["children"] = d = {} - return msg + d = {"children": d} + return d + + msg = {} + for _ in range(10): + msg = gen_deeply_nested(sys.getrecursionlimit() // 2, msg=msg) + + with pytest.raises(RecursionError): + copy.deepcopy(msg) - msg = gen_deeply_nested(sys.getrecursionlimit() - 100) with pytest.raises(TypeError, match="Could not serialize object"): serialize(msg, on_error="raise") - msg = gen_deeply_nested(sys.getrecursionlimit() // 4) + msg = gen_deeply_nested(sys.getrecursionlimit() // 2) assert isinstance(serialize(msg), tuple) From 782050a3a4cf2abd450caa8adfaa912c22829e78 Mon Sep 17 00:00:00 2001 From: Jonas Dedden Date: Thu, 11 Jul 2024 19:03:26 +0200 Subject: [PATCH 060/138] Use functools.cache instead of functools.lru_cache for extremely often called functions (#8762) --- distributed/core.py | 2 +- distributed/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 90705e8051..f5c876f324 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -121,7 +121,7 @@ def _raise(*args, **kwargs): LOG_PDB = dask.config.get("distributed.admin.pdb-on-err") -@functools.lru_cache +@functools.cache def _expects_comm(func: Callable) -> bool: sig = inspect.signature(func) params = list(sig.parameters) diff --git a/distributed/utils.py b/distributed/utils.py index 3f659bff02..320068fa6e 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1989,7 +1989,7 @@ def __lt__(self, other): return self.obj < other.obj -@functools.lru_cache +@functools.cache def url_escape(url, *args, **kwargs): """ Escape a URL path segment. Cache results for better performance. From f5ad974bb932c03193cbd4e8215ab023fdd370d8 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 12 Jul 2024 01:44:10 +0200 Subject: [PATCH 061/138] Don't sort keys lexicographically in worker table (#8753) --- distributed/http/templates/worker-table.html | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/http/templates/worker-table.html b/distributed/http/templates/worker-table.html index 2ab60951fa..925473f7b2 100644 --- a/distributed/http/templates/worker-table.html +++ b/distributed/http/templates/worker-table.html @@ -24,10 +24,10 @@ {{ws.address}} {{ ws.name if ws.name is not None else "" }} {{ ws.nthreads }} - {{ format_bytes(ws.memory_limit) if ws.memory_limit is not None else "" }} + {{ format_bytes(ws.memory_limit) if ws.memory_limit is not None else "" }} - {{ format_time(ws.occupancy) }} + {{ format_time(ws.occupancy) }} {{ len(ws.processing) }} {{ len(ws.has_what) }} {% if 'dashboard' in ws.services %} @@ -36,7 +36,7 @@ {% end %} logs - {{ format_time(time() - ws.last_seen) }} + {{ format_time(time() - ws.last_seen) }} {% if api_enabled %}
From eab58be01567bf8889f8edf598170d72c58fd818 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 12 Jul 2024 13:28:01 +0200 Subject: [PATCH 062/138] Factor out async taskgroup (#8756) --- distributed/_async_taskgroup.py | 159 +++++++++++++++++++++ distributed/core.py | 146 +------------------ distributed/nanny.py | 2 +- distributed/tests/test_async_task_group.py | 159 +++++++++++++++++++++ distributed/tests/test_core.py | 153 -------------------- 5 files changed, 320 insertions(+), 299 deletions(-) create mode 100644 distributed/_async_taskgroup.py create mode 100644 distributed/tests/test_async_task_group.py diff --git a/distributed/_async_taskgroup.py b/distributed/_async_taskgroup.py new file mode 100644 index 0000000000..a048491d30 --- /dev/null +++ b/distributed/_async_taskgroup.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import asyncio +import threading +from collections.abc import Callable, Coroutine +from typing import TYPE_CHECKING, Any, TypeVar + +if TYPE_CHECKING: + from typing_extensions import ParamSpec + + P = ParamSpec("P") + R = TypeVar("R") + T = TypeVar("T") + Coro = Coroutine[Any, Any, T] + + +class _LoopBoundMixin: + """Backport of the private asyncio.mixins._LoopBoundMixin from 3.11""" + + _global_lock = threading.Lock() + + _loop = None + + def _get_loop(self): + loop = asyncio.get_running_loop() + + if self._loop is None: + with self._global_lock: + if self._loop is None: + self._loop = loop + if loop is not self._loop: + raise RuntimeError(f"{self!r} is bound to a different event loop") + return loop + + +class AsyncTaskGroupClosedError(RuntimeError): + pass + + +def _delayed(corofunc: Callable[P, Coro[T]], delay: float) -> Callable[P, Coro[T]]: + """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" + + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + await asyncio.sleep(delay) + return await corofunc(*args, **kwargs) + + return wrapper + + +class AsyncTaskGroup(_LoopBoundMixin): + """Collection tracking all currently running asynchronous tasks within a group""" + + #: If True, the group is closed and does not allow adding new tasks. + closed: bool + + def __init__(self) -> None: + self.closed = False + self._ongoing_tasks: set[asyncio.Task[None]] = set() + + def call_soon( + self, afunc: Callable[P, Coro[None]], /, *args: P.args, **kwargs: P.kwargs + ) -> None: + """Schedule a coroutine function to be executed as an `asyncio.Task`. + + The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments + as an `asyncio.Task`. + + Parameters + ---------- + afunc + Coroutine function to schedule. + *args + Arguments to be passed to `afunc`. + **kwargs + Keyword arguments to be passed to `afunc` + + Returns + ------- + None + + Raises + ------ + AsyncTaskGroupClosedError + If the task group is closed. + """ + if self.closed: # Avoid creating a coroutine + raise AsyncTaskGroupClosedError( + "Cannot schedule a new coroutine function as the group is already closed." + ) + task = self._get_loop().create_task(afunc(*args, **kwargs)) + task.add_done_callback(self._ongoing_tasks.remove) + self._ongoing_tasks.add(task) + return None + + def call_later( + self, + delay: float, + afunc: Callable[P, Coro[None]], + /, + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + """Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`. + + The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments + as an `asyncio.Task` that is executed after `delay` seconds. + + Parameters + ---------- + delay + Delay in seconds. + afunc + Coroutine function to schedule. + *args + Arguments to be passed to `afunc`. + **kwargs + Keyword arguments to be passed to `afunc` + + Returns + ------- + The None + + Raises + ------ + AsyncTaskGroupClosedError + If the task group is closed. + """ + self.call_soon(_delayed(afunc, delay), *args, **kwargs) + + def close(self) -> None: + """Closes the task group so that no new tasks can be scheduled. + + Existing tasks continue to run. + """ + self.closed = True + + async def stop(self) -> None: + """Close the group and stop all currently running tasks. + + Closes the task group and cancels all tasks. All tasks are cancelled + an additional time for each time this task is cancelled. + """ + self.close() + + current_task = asyncio.current_task(self._get_loop()) + err = None + while tasks_to_stop := (self._ongoing_tasks - {current_task}): + for task in tasks_to_stop: + task.cancel() + try: + await asyncio.wait(tasks_to_stop) + except asyncio.CancelledError as e: + err = e + + if err is not None: + raise err + + def __len__(self): + return len(self._ongoing_tasks) diff --git a/distributed/core.py b/distributed/core.py index f5c876f324..28467594b5 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -35,6 +35,7 @@ from dask.utils import parse_timedelta from distributed import profile, protocol +from distributed._async_taskgroup import AsyncTaskGroup, AsyncTaskGroupClosedError from distributed.comm import ( Comm, CommClosedError, @@ -138,151 +139,6 @@ def _expects_comm(func: Callable) -> bool: return False -class _LoopBoundMixin: - """Backport of the private asyncio.mixins._LoopBoundMixin from 3.11""" - - _global_lock = threading.Lock() - - _loop = None - - def _get_loop(self): - loop = asyncio.get_running_loop() - - if self._loop is None: - with self._global_lock: - if self._loop is None: - self._loop = loop - if loop is not self._loop: - raise RuntimeError(f"{self!r} is bound to a different event loop") - return loop - - -class AsyncTaskGroupClosedError(RuntimeError): - pass - - -def _delayed(corofunc: Callable[P, Coro[T]], delay: float) -> Callable[P, Coro[T]]: - """Decorator to delay the evaluation of a coroutine function by the given delay in seconds.""" - - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - await asyncio.sleep(delay) - return await corofunc(*args, **kwargs) - - return wrapper - - -class AsyncTaskGroup(_LoopBoundMixin): - """Collection tracking all currently running asynchronous tasks within a group""" - - #: If True, the group is closed and does not allow adding new tasks. - closed: bool - - def __init__(self) -> None: - self.closed = False - self._ongoing_tasks: set[asyncio.Task[None]] = set() - - def call_soon( - self, afunc: Callable[P, Coro[None]], /, *args: P.args, **kwargs: P.kwargs - ) -> None: - """Schedule a coroutine function to be executed as an `asyncio.Task`. - - The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments - as an `asyncio.Task`. - - Parameters - ---------- - afunc - Coroutine function to schedule. - *args - Arguments to be passed to `afunc`. - **kwargs - Keyword arguments to be passed to `afunc` - - Returns - ------- - None - - Raises - ------ - AsyncTaskGroupClosedError - If the task group is closed. - """ - if self.closed: # Avoid creating a coroutine - raise AsyncTaskGroupClosedError( - "Cannot schedule a new coroutine function as the group is already closed." - ) - task = self._get_loop().create_task(afunc(*args, **kwargs)) - task.add_done_callback(self._ongoing_tasks.remove) - self._ongoing_tasks.add(task) - return None - - def call_later( - self, - delay: float, - afunc: Callable[P, Coro[None]], - /, - *args: P.args, - **kwargs: P.kwargs, - ) -> None: - """Schedule a coroutine function to be executed after `delay` seconds as an `asyncio.Task`. - - The coroutine function `afunc` is scheduled with `args` arguments and `kwargs` keyword arguments - as an `asyncio.Task` that is executed after `delay` seconds. - - Parameters - ---------- - delay - Delay in seconds. - afunc - Coroutine function to schedule. - *args - Arguments to be passed to `afunc`. - **kwargs - Keyword arguments to be passed to `afunc` - - Returns - ------- - The None - - Raises - ------ - AsyncTaskGroupClosedError - If the task group is closed. - """ - self.call_soon(_delayed(afunc, delay), *args, **kwargs) - - def close(self) -> None: - """Closes the task group so that no new tasks can be scheduled. - - Existing tasks continue to run. - """ - self.closed = True - - async def stop(self) -> None: - """Close the group and stop all currently running tasks. - - Closes the task group and cancels all tasks. All tasks are cancelled - an additional time for each time this task is cancelled. - """ - self.close() - - current_task = asyncio.current_task(self._get_loop()) - err = None - while tasks_to_stop := (self._ongoing_tasks - {current_task}): - for task in tasks_to_stop: - task.cancel() - try: - await asyncio.wait(tasks_to_stop) - except asyncio.CancelledError as e: - err = e - - if err is not None: - raise err - - def __len__(self): - return len(self._ongoing_tasks) - - class Server: """Dask Distributed Server diff --git a/distributed/nanny.py b/distributed/nanny.py index 52e4ad5b36..af0d9a62ad 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -25,12 +25,12 @@ from dask.utils import parse_timedelta from distributed import preloading +from distributed._async_taskgroup import AsyncTaskGroupClosedError from distributed.comm import get_address_host from distributed.comm.addressing import address_from_user_args from distributed.compatibility import asyncio_run from distributed.config import get_loop_factory from distributed.core import ( - AsyncTaskGroupClosedError, CommClosedError, ErrorMessage, OKMessage, diff --git a/distributed/tests/test_async_task_group.py b/distributed/tests/test_async_task_group.py new file mode 100644 index 0000000000..12a019521c --- /dev/null +++ b/distributed/tests/test_async_task_group.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import asyncio +import time as timemod + +import pytest + +from distributed._async_taskgroup import AsyncTaskGroup, AsyncTaskGroupClosedError +from distributed.utils_test import gen_test + + +async def _wait_for_n_loop_cycles(n): + for _ in range(n): + await asyncio.sleep(0) + + +def test_async_task_group_initialization(): + group = AsyncTaskGroup() + assert not group.closed + assert len(group) == 0 + + +@gen_test() +async def test_async_task_group_call_soon_executes_task_in_background(): + group = AsyncTaskGroup() + ev = asyncio.Event() + flag = False + + async def set_flag(): + nonlocal flag + await ev.wait() + flag = True + + assert group.call_soon(set_flag) is None + assert len(group) == 1 + ev.set() + await _wait_for_n_loop_cycles(2) + assert len(group) == 0 + assert flag + + +@gen_test() +async def test_async_task_group_call_later_executes_delayed_task_in_background(): + group = AsyncTaskGroup() + ev = asyncio.Event() + + start = timemod.monotonic() + assert group.call_later(1, ev.set) is None + assert len(group) == 1 + await ev.wait() + end = timemod.monotonic() + # the task must be removed in exactly 1 event loop cycle + await _wait_for_n_loop_cycles(2) + assert len(group) == 0 + assert end - start > 1 - timemod.get_clock_info("monotonic").resolution + + +def test_async_task_group_close_closes(): + group = AsyncTaskGroup() + group.close() + assert group.closed + + # Test idempotency + group.close() + assert group.closed + + +@gen_test() +async def test_async_task_group_close_does_not_cancel_existing_tasks(): + group = AsyncTaskGroup() + + ev = asyncio.Event() + flag = False + + async def set_flag(): + nonlocal flag + await ev.wait() + flag = True + return None + + assert group.call_soon(set_flag) is None + + group.close() + + assert len(group) == 1 + + ev.set() + await _wait_for_n_loop_cycles(2) + assert len(group) == 0 + + +@gen_test() +async def test_async_task_group_close_prohibits_new_tasks(): + group = AsyncTaskGroup() + group.close() + + ev = asyncio.Event() + flag = False + + async def set_flag(): + nonlocal flag + await ev.wait() + flag = True + return True + + with pytest.raises(AsyncTaskGroupClosedError): + group.call_soon(set_flag) + assert len(group) == 0 + + with pytest.raises(AsyncTaskGroupClosedError): + group.call_later(1, set_flag) + assert len(group) == 0 + + await asyncio.sleep(0.01) + assert not flag + + +@gen_test() +async def test_async_task_group_stop_disallows_shutdown(): + group = AsyncTaskGroup() + + task = None + + async def set_flag(): + nonlocal task + task = asyncio.current_task() + + assert group.call_soon(set_flag) is None + assert len(group) == 1 + # tasks are not given a grace period, and are not even allowed to start + # if the group is closed immediately + await group.stop() + assert task is None + + +@gen_test() +async def test_async_task_group_stop_cancels_long_running(): + group = AsyncTaskGroup() + + task = None + flag = False + started = asyncio.Event() + + async def set_flag(): + nonlocal task + task = asyncio.current_task() + started.set() + await asyncio.sleep(10) + nonlocal flag + flag = True + return True + + assert group.call_soon(set_flag) is None + assert len(group) == 1 + await started.wait() + await group.stop() + assert task + assert task.cancelled() + assert not flag diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 93bb18f16d..af25353bf2 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -8,7 +8,6 @@ import socket import sys import threading -import time as timemod import weakref from unittest import mock @@ -22,8 +21,6 @@ from distributed.comm.registry import backends from distributed.comm.tcp import TCPBackend, TCPListener from distributed.core import ( - AsyncTaskGroup, - AsyncTaskGroupClosedError, ConnectionPool, Server, Status, @@ -84,156 +81,6 @@ def echo_no_serialize(comm, x): return {"result": x} -def test_async_task_group_initialization(): - group = AsyncTaskGroup() - assert not group.closed - assert len(group) == 0 - - -async def _wait_for_n_loop_cycles(n): - for _ in range(n): - await asyncio.sleep(0) - - -@gen_test() -async def test_async_task_group_call_soon_executes_task_in_background(): - group = AsyncTaskGroup() - ev = asyncio.Event() - flag = False - - async def set_flag(): - nonlocal flag - await ev.wait() - flag = True - - assert group.call_soon(set_flag) is None - assert len(group) == 1 - ev.set() - await _wait_for_n_loop_cycles(2) - assert len(group) == 0 - assert flag - - -@gen_test() -async def test_async_task_group_call_later_executes_delayed_task_in_background(): - group = AsyncTaskGroup() - ev = asyncio.Event() - - start = timemod.monotonic() - assert group.call_later(1, ev.set) is None - assert len(group) == 1 - await ev.wait() - end = timemod.monotonic() - # the task must be removed in exactly 1 event loop cycle - await _wait_for_n_loop_cycles(2) - assert len(group) == 0 - assert end - start > 1 - timemod.get_clock_info("monotonic").resolution - - -def test_async_task_group_close_closes(): - group = AsyncTaskGroup() - group.close() - assert group.closed - - # Test idempotency - group.close() - assert group.closed - - -@gen_test() -async def test_async_task_group_close_does_not_cancel_existing_tasks(): - group = AsyncTaskGroup() - - ev = asyncio.Event() - flag = False - - async def set_flag(): - nonlocal flag - await ev.wait() - flag = True - return None - - assert group.call_soon(set_flag) is None - - group.close() - - assert len(group) == 1 - - ev.set() - await _wait_for_n_loop_cycles(2) - assert len(group) == 0 - - -@gen_test() -async def test_async_task_group_close_prohibits_new_tasks(): - group = AsyncTaskGroup() - group.close() - - ev = asyncio.Event() - flag = False - - async def set_flag(): - nonlocal flag - await ev.wait() - flag = True - return True - - with pytest.raises(AsyncTaskGroupClosedError): - group.call_soon(set_flag) - assert len(group) == 0 - - with pytest.raises(AsyncTaskGroupClosedError): - group.call_later(1, set_flag) - assert len(group) == 0 - - await asyncio.sleep(0.01) - assert not flag - - -@gen_test() -async def test_async_task_group_stop_disallows_shutdown(): - group = AsyncTaskGroup() - - task = None - - async def set_flag(): - nonlocal task - task = asyncio.current_task() - - assert group.call_soon(set_flag) is None - assert len(group) == 1 - # tasks are not given a grace period, and are not even allowed to start - # if the group is closed immediately - await group.stop() - assert task is None - - -@gen_test() -async def test_async_task_group_stop_cancels_long_running(): - group = AsyncTaskGroup() - - task = None - flag = False - started = asyncio.Event() - - async def set_flag(): - nonlocal task - task = asyncio.current_task() - started.set() - await asyncio.sleep(10) - nonlocal flag - flag = True - return True - - assert group.call_soon(set_flag) is None - assert len(group) == 1 - await started.wait() - await group.stop() - assert task - assert task.cancelled() - assert not flag - - @gen_test() async def test_server_status_is_always_enum(): """Assignments with strings is forbidden""" From 767163adb0026019e8638d0b6c294037c5f482a6 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 12 Jul 2024 13:56:02 +0200 Subject: [PATCH 063/138] increase timeouts for pubsub::test_client_worker (#8765) --- distributed/tests/test_pubsub.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_pubsub.py b/distributed/tests/test_pubsub.py index 6deb6f798c..b83e009380 100644 --- a/distributed/tests/test_pubsub.py +++ b/distributed/tests/test_pubsub.py @@ -110,7 +110,7 @@ def f(x): or len(sps.client_subscribers["a"]) != 1 ): await asyncio.sleep(0.01) - assert time() < start + 3 + assert time() < start + 10 del sub @@ -121,7 +121,7 @@ def f(x): or any(bps.publish_to_scheduler.values()) ): await asyncio.sleep(0.01) - assert time() < start + 3 + assert time() < start + 10 @gen_cluster(client=True) From 0a8d8c9ebb742ebfeb1f01e4cc37ab5e4a35946f Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 12 Jul 2024 17:20:00 +0200 Subject: [PATCH 064/138] fix scheduler_bokeh::test_shuffling (#8766) --- distributed/dashboard/tests/test_scheduler_bokeh.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index a83bcf46bb..cb3f37b1f6 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -1349,10 +1349,10 @@ async def test_shuffling(c, s, a, b): df2 = df.shuffle("x").persist() start = time() while not ss.source.data["comm_written"]: + await asyncio.gather(*[a.heartbeat(), b.heartbeat()]) ss.update() - await asyncio.sleep(0.05) + await asyncio.sleep(0.01) assert time() < start + 10 - await df2 @gen_cluster(client=True, scheduler_kwargs={"dashboard": True}, timeout=60) From 48eefeef5a20066e6560c71f30ca7b9ec61f6617 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 12 Jul 2024 17:21:30 +0200 Subject: [PATCH 065/138] Robuster faster tests memory sampler (#8758) --- distributed/diagnostics/memory_sampler.py | 14 ++- .../diagnostics/tests/test_memory_sampler.py | 107 ++++++++++-------- distributed/utils.py | 5 +- 3 files changed, 77 insertions(+), 49 deletions(-) diff --git a/distributed/diagnostics/memory_sampler.py b/distributed/diagnostics/memory_sampler.py index a805bdc9c1..bf1d3c6704 100644 --- a/distributed/diagnostics/memory_sampler.py +++ b/distributed/diagnostics/memory_sampler.py @@ -143,7 +143,7 @@ def to_pandas(self, *, align: bool = False) -> pd.DataFrame: if align: # convert datetime to timedelta from the first sample s.index -= s.index[0] - ss[label] = s + ss[label] = s[~s.index.duplicated()] # type: ignore[attr-defined] df = pd.DataFrame(ss) @@ -169,8 +169,16 @@ def plot(self, *, align: bool = False, **kwargs: Any) -> Any: ======= Output of :meth:`pandas.DataFrame.plot` """ - df = self.to_pandas(align=align).resample("1s").nearest() / 2**30 - return df.plot( + df = self.to_pandas(align=align) + resampled = df.resample("1s").nearest() / 2**30 + # If resampling collapses data onto one point, we'll run into + # https://stackoverflow.com/questions/58322744/matplotlib-userwarning-attempting-to-set-identical-left-right-737342-0 + # This should only happen in tests since users typically sample for more + # than a second + if len(resampled) == 1: + resampled = df.resample("1ms").nearest() / 2**30 + + return resampled.plot( xlabel="time", ylabel="Cluster memory (GiB)", **kwargs, diff --git a/distributed/diagnostics/tests/test_memory_sampler.py b/distributed/diagnostics/tests/test_memory_sampler.py index 2c591cac30..e5ec7bc560 100644 --- a/distributed/diagnostics/tests/test_memory_sampler.py +++ b/distributed/diagnostics/tests/test_memory_sampler.py @@ -5,34 +5,47 @@ import pytest +import dask + +from distributed import Client from distributed.diagnostics import MemorySampler -from distributed.utils_test import gen_cluster +from distributed.utils_test import SizeOf, cluster, gen_cluster, gen_test -@gen_cluster(client=True) -async def test_async(c, s, a, b): +@pytest.fixture(scope="module") +@gen_cluster(client=True, config={"distributed.admin.system-monitor.interval": "1ms"}) +async def some_sample(c, s, *workers): ms = MemorySampler() - async with ms.sample("foo", measure="managed", interval=0.1): + name = "foo" + async with ms.sample(name, measure="managed", interval=0.001): f = c.submit(lambda: 1) await f - await asyncio.sleep(0.5) + await asyncio.sleep(0.1) + f.release() + await asyncio.sleep(0.1) - assert ms.samples["foo"][0][1] == 0 - assert ms.samples["foo"][-1][1] > 0 + assert ms.samples[name][0][1] == 0 + assert sum([s[1] for s in ms.samples[name]]) > 0 # Test that there is no server-side memory leak assert not s.extensions["memory_sampler"].samples + return name, ms -def test_sync(client): - ms = MemorySampler() - with ms.sample("foo", measure="managed", interval=0.1): - f = client.submit(lambda: 1) - f.result() - time.sleep(0.5) +def test_sync(loop): + with ( + dask.config.set({"distributed.admin.system-monitor.interval": "1ms"}), + cluster() as (scheduler, _), + Client(scheduler["address"], loop=loop) as client, + ): + ms = MemorySampler() + with ms.sample("foo", measure="managed", interval=0.001): + f = client.submit(lambda: 1) + f.result() + time.sleep(0.1) - assert ms.samples["foo"][0][1] == 0 - assert ms.samples["foo"][-1][1] > 0 + assert ms.samples["foo"][0][1] == 0 + assert sum([s[1] for s in ms.samples["foo"]]) > 0 @gen_cluster(client=True) # MemorySampler internally fetches the client @@ -46,53 +59,55 @@ async def test_at_least_one_sample(c, s, a, b): assert len(next(iter(ms.samples.values()))) == 1 -@pytest.mark.slow -@gen_cluster(client=True) -async def test_multi_sample(c, s, a, b): +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_multi_sample(c, s, a): + expected_process_memory = 20 * 1024**2 + + def mock_process_memory(): + return expected_process_memory if a.data.fast else 0 + + a.monitor.get_process_memory = mock_process_memory + a.monitor.update() + await a.heartbeat() + ms = MemorySampler() - s1 = ms.sample("managed", measure="managed", interval=0.15) - s2 = ms.sample("process", interval=0.2) + s1 = ms.sample("managed", measure="managed", interval=0.001) + s2 = ms.sample("process", interval=0.001) + + expected_managed_memory = 100 * 1024 + payload = SizeOf(expected_managed_memory) + async with s1, s2: - idle_mem = s.memory.process - f = c.submit(lambda: "x" * 100 * 2**20) # 100 MiB + f = c.submit(lambda: payload) await f - while s.memory.process < idle_mem + 80 * 2**20: - # Wait for heartbeat - await asyncio.sleep(0.01) - await asyncio.sleep(0.6) + a.monitor.update() + await a.heartbeat() + await asyncio.sleep(0.01) m = ms.samples["managed"] p = ms.samples["process"] + assert len(m) >= 2 assert m[0][1] == 0 - assert m[-1][1] >= 100 * 2**20 + assert m[-1][1] == expected_managed_memory assert len(p) >= 2 - assert p[0][1] > 2**20 # Assume > 1 MiB for idle process - assert p[-1][1] > p[0][1] + 80 * 2**20 - assert m[-1][1] < p[-1][1] + assert p[0][1] == 0 + assert p[-1][1] == expected_process_memory -@gen_cluster(client=True) +@gen_test() @pytest.mark.parametrize("align", [False, True]) -async def test_pandas(c, s, a, b, align): +async def test_pandas(some_sample, align): + name, ms = some_sample pd = pytest.importorskip("pandas") pytest.importorskip("matplotlib") - ms = MemorySampler() - async with ms.sample("foo", measure="managed", interval=0.15): - f = c.submit(lambda: 1) - await f - await asyncio.sleep(1.5) - - assert ms.samples["foo"][0][1] == 0 - assert ms.samples["foo"][-1][1] > 0 - df = ms.to_pandas(align=align) assert isinstance(df, pd.DataFrame) if align: assert isinstance(df.index, pd.TimedeltaIndex) - assert df["foo"].iloc[0] == 0 - assert df["foo"].iloc[-1] > 0 + assert df[name].iloc[0] == 0 + assert df[name].sum() > 1 assert df.index[0] == pd.Timedelta(0, unit="s") assert pd.Timedelta(0, unit="s") < df.index[1] assert df.index[1] < pd.Timedelta(1.5, unit="s") @@ -105,7 +120,6 @@ async def test_pandas(c, s, a, b, align): assert plt -@pytest.mark.slow @gen_cluster(client=True) @pytest.mark.parametrize("align", [False, True]) async def test_pandas_multiseries(c, s, a, b, align): @@ -113,7 +127,10 @@ async def test_pandas_multiseries(c, s, a, b, align): pd = pytest.importorskip("pandas") ms = MemorySampler() - for label, interval, final_sleep in (("foo", 0.15, 1.0), ("bar", 0.2, 0.6)): + for label, interval, final_sleep in ( + ("foo", 0.001, 0.2), + ("bar", 0.002, 0.01), + ): async with ms.sample(label, measure="managed", interval=interval): x = c.submit(lambda: 1, key="x") await x diff --git a/distributed/utils.py b/distributed/utils.py index 320068fa6e..6f23799ebd 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -49,7 +49,6 @@ from typing import Any as AnyType from typing import ClassVar, TypeVar, overload -import click import psutil import tblib.pickling_support from tornado import escape @@ -1274,6 +1273,10 @@ def has_keyword(func, keyword): @functools.lru_cache(1000) def command_has_keyword(cmd, k): + # Click is a relatively expensive import + # That hurts startup time a little + import click + if cmd is not None: if isinstance(cmd, str): try: From 110eac1c16cd0720b946886c84bf8a7af14cfe5d Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 15 Jul 2024 14:35:00 -0400 Subject: [PATCH 066/138] Re-raise `P2PConsistencyError` from failed P2P tasks. (#8748) --- distributed/shuffle/_core.py | 4 ++++ distributed/shuffle/_shuffle.py | 4 +++- distributed/shuffle/tests/test_shuffle.py | 4 +--- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 48510dfd41..b0c4fc17e1 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -506,6 +506,8 @@ def handle_transfer_errors(id: ShuffleId) -> Iterator[None]: yield except ShuffleClosedError: raise Reschedule() + except P2PConsistencyError: + raise except Exception as e: raise RuntimeError(f"P2P shuffling {id} failed during transfer phase") from e @@ -518,6 +520,8 @@ def handle_unpack_errors(id: ShuffleId) -> Iterator[None]: raise e except ShuffleClosedError: raise Reschedule() + except P2PConsistencyError: + raise except Exception as e: raise RuntimeError(f"P2P shuffling {id} failed during unpack phase") from e diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index d08d579778..508e2f4823 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -48,7 +48,7 @@ handle_transfer_errors, handle_unpack_errors, ) -from distributed.shuffle._exceptions import DataUnavailable +from distributed.shuffle._exceptions import DataUnavailable, P2PConsistencyError from distributed.shuffle._limiter import ResourceLimiter from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin from distributed.sizeof import sizeof @@ -105,6 +105,8 @@ def shuffle_barrier(id: ShuffleId, run_ids: list[int]) -> int: return get_worker_plugin().barrier(id, run_ids) except Reschedule as e: raise e + except P2PConsistencyError: + raise except Exception as e: raise RuntimeError(f"shuffle_barrier failed during shuffle {id}") from e diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 310c927fd2..311b8f7939 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -254,9 +254,7 @@ async def test_shuffle_with_array_conversion(c, s, a, b, npartitions): if npartitions == 1: # FIXME: distributed#7816 - with raises_with_cause( - RuntimeError, "failed during transfer", RuntimeError, "Barrier task" - ): + with pytest.raises(P2PConsistencyError, match="Barrier task"): await c.compute(out) else: await c.compute(out) From ea46f7245e0269c0a7a566291cf29b1923cfab47 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Tue, 16 Jul 2024 14:12:53 +0200 Subject: [PATCH 067/138] Raise an error if compute on persisted collection with released futures (#8764) --- distributed/scheduler.py | 4 ++-- distributed/tests/test_client.py | 39 ++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 0273d333da..85a4c62235 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4508,12 +4508,12 @@ def _match_graph_with_tasks( dependencies: dict[Key, set[Key]], keys: set[Key], ) -> set[Key]: - n = 0 + n = -1 lost_keys = set() while len(dsk) != n: # walk through new tasks, cancel any bad deps n = len(dsk) for k, deps in list(dependencies.items()): - if any( + if (k not in self.tasks and k not in dsk) or any( dep not in self.tasks and dep not in dsk for dep in deps ): # bad key lost_keys.add(k) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 28479ca512..c3c4051929 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -8484,3 +8484,42 @@ async def test_scheduler_restart_exception_on_cancelled_futures(c, s, a, b): with pytest.raises(CancelledError, match="Scheduler has restarted"): await fut.result() + + +def _release_persisted(obj): + return len([f.release() for f in futures_of(obj)]) + + +@gen_cluster(client=True) +async def test_release_persisted_collection(c, s, a, b): + np = pytest.importorskip("numpy") + da = pytest.importorskip("dask.array") + + arr = c.persist(da.random.random((10,), chunks=(10,))) + + await wait(arr) + + _release_persisted(arr) + while s.tasks: + await asyncio.sleep(0.01) + + with pytest.raises(CancelledError): + await c.compute(arr) + + +def test_release_persisted_collection_sync(c): + np = pytest.importorskip("numpy") + da = pytest.importorskip("dask.array") + arr = da.random.random((10,), chunks=(10,)).persist() + + wait(arr) + _release_persisted(arr) + + while c.run_on_scheduler(lambda dask_scheduler: len(dask_scheduler.tasks)) > 0: + sleep(0.01) + + with pytest.raises(CancelledError): + # Note: dask.compute is actually calling client.get, i.e. what we are + # submitting to the scheduler is different to what we are in + # client.compute + arr.compute() From d68a5d9c57628732f0d020e8d46d424d5c67397c Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 16 Jul 2024 16:52:57 +0200 Subject: [PATCH 068/138] Add another test for a possible deadlock scenario caused by #8703 (#8769) --- distributed/tests/test_scheduler.py | 61 +++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 9aaea1288a..069ca0f8f2 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -4644,8 +4644,11 @@ def assert_rootish(): not QUEUING_ON_BY_DEFAULT, reason="The situation handled in this test requires queueing.", ) +@pytest.mark.parametrize("validate", [True, False]) @gen_cluster(client=True, nthreads=[("", 1)]) -async def test_deadlock_dependency_of_queued_released(c, s, a): +async def test_deadlock_dependency_of_queued_released_when_worker_replaced( + c, s, a, validate +): @delayed def inc(input): return input + 1 @@ -4670,14 +4673,66 @@ def block_on_event(input, block, executing): assert s.queued await s.remove_worker(address=a.address, stimulus_id="test") - s.validate_state() + if validate: + s.validate_state() await block.set() await executing.clear() async with Worker(s.address) as b: + if validate: + s.validate_state() + await c.gather(futs) + if validate: + s.validate_state() + + +@pytest.mark.skipif( + not QUEUING_ON_BY_DEFAULT, + reason="The situation handled in this test requires queueing.", +) +@pytest.mark.parametrize("validate", [True, False]) +@gen_cluster(client=True) +async def test_deadlock_dependency_of_queued_released_when_worker_removed( + c, s, a, b, validate +): + @delayed + def inc(input): + return input + 1 + + @delayed + def block_on_event(input, block): + block.wait() + return input + + block = Event() + + with dask.annotate(workers=a.address, allow_other_workers=True): + dep = inc(0) + futs = [ + block_on_event(dep, block, dask_key_name=("rootish", i)) + for i in range(s.total_nthreads * 2 + 1) + ] + dep.release() + futs = c.compute(futs) + with freeze_batched_send(b.batched_stream): + await async_poll_for( + lambda: b.state.tasks.get(dep.key) is not None + and b.state.tasks.get(dep.key).state == "memory", + timeout=5, + ) + assert s.queued + await s.remove_worker(address=a.address, stimulus_id="test") + + if validate: + s.validate_state() + + await block.set() + + if validate: s.validate_state() - await c.gather(*futs) + await c.gather(futs) + if validate: s.validate_state() From ffcb2717f9a8653bb34c8406ec25839ce9ee1431 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Wed, 17 Jul 2024 16:05:58 +0200 Subject: [PATCH 069/138] Ensure Lock will not lock up in case of worker failures (#8770) --- distributed/lock.py | 171 ++++++++-------------------- distributed/scheduler.py | 2 - distributed/semaphore.py | 94 ++++++++------- distributed/tests/test_locks.py | 43 +++---- distributed/tests/test_semaphore.py | 24 ++-- 5 files changed, 130 insertions(+), 204 deletions(-) diff --git a/distributed/lock.py b/distributed/lock.py index 99ec34cd6f..4c79303bd2 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -1,79 +1,33 @@ from __future__ import annotations -import asyncio import logging -import uuid -from collections import defaultdict, deque -from dask.utils import parse_timedelta - -from distributed.utils import TimeoutError, log_errors, wait_for -from distributed.worker import get_client +from distributed.semaphore import Semaphore logger = logging.getLogger(__name__) +_no_value = object() -class LockExtension: - """An extension for the scheduler to manage Locks - This adds the following routes to the scheduler +class Lock(Semaphore): + """Distributed Centralized Lock - * lock_acquire - * lock_release - """ + .. warning:: - def __init__(self, scheduler): - self.scheduler = scheduler - self.events = defaultdict(deque) - self.ids = dict() + This is using the ``distributed.Semaphore`` as a backend, which is + susceptible to lease overbooking. For the Lock this means that if a + lease is timing out, two or more instances could acquire the lock at the + same time. To disable lease timeouts, set + ``distributed.scheduler.locks.lease-timeout`` to `inf`, e.g. - self.scheduler.handlers.update( - {"lock_acquire": self.acquire, "lock_release": self.release} - ) + .. code-block:: python - @log_errors - async def acquire(self, name=None, id=None, timeout=None): - if isinstance(name, list): - name = tuple(name) - if name not in self.ids: - result = True - else: - while name in self.ids: - event = asyncio.Event() - self.events[name].append(event) - future = event.wait() - if timeout is not None: - future = wait_for(future, timeout) - try: - await future - except TimeoutError: - result = False - break - else: - result = True - finally: - event2 = self.events[name].popleft() - assert event is event2 - if result: - assert name not in self.ids - self.ids[name] = id - return result - - @log_errors - def release(self, name=None, id=None): - if isinstance(name, list): - name = tuple(name) - if self.ids.get(name) != id: - raise ValueError("This lock has not yet been acquired") - del self.ids[name] - if self.events[name]: - self.scheduler.loop.add_callback(self.events[name][0].set) - else: - del self.events[name] - - -class Lock: - """Distributed Centralized Lock + with dask.config.set({"distributed.scheduler.locks.lease-timeout": "inf"}): + lock = Lock("x") + ... + + Note, that without lease timeouts, the Lock may deadlock in case of + cluster downscaling or worker failures. Parameters ---------- @@ -93,29 +47,31 @@ class Lock: >>> lock.release() # doctest: +SKIP """ - def __init__(self, name=None, client=None): - self._client = client - self.name = name or "lock-" + uuid.uuid4().hex - self.id = uuid.uuid4().hex - self._locked = False - - @property - def client(self): - if not self._client: - try: - self._client = get_client() - except ValueError: - pass - return self._client - - def _verify_running(self): - if not self.client: - raise RuntimeError( - f"{type(self)} object not properly initialized. This can happen" - " if the object is being deserialized outside of the context of" - " a Client or Worker." + def __init__( + self, + name=None, + client=_no_value, + register=True, + scheduler_rpc=None, + loop=None, + ): + if client is not _no_value: + import warnings + + warnings.warn( + "The `client` parameter is deprecated. It is no longer necessary to pass a client to Lock.", + DeprecationWarning, + stacklevel=2, ) + super().__init__( + max_leases=1, + name=name, + register=register, + scheduler_rpc=scheduler_rpc, + loop=loop, + ) + def acquire(self, blocking=True, timeout=None): """Acquire the lock @@ -139,50 +95,21 @@ def acquire(self, blocking=True, timeout=None): ------- True or False whether or not it successfully acquired the lock """ - self._verify_running() - timeout = parse_timedelta(timeout) - if not blocking: if timeout is not None: raise ValueError("can't specify a timeout for a non-blocking call") timeout = 0 + return super().acquire(timeout=timeout) - result = self.client.sync( - self.client.scheduler.lock_acquire, - name=self.name, - id=self.id, - timeout=timeout, - ) - self._locked = True - return result - - def release(self): - """Release the lock if already acquired""" - self._verify_running() - if not self.locked(): - raise ValueError("Lock is not yet acquired") - result = self.client.sync( - self.client.scheduler.lock_release, name=self.name, id=self.id - ) - self._locked = False - return result + async def _locked(self): + val = await self.scheduler.semaphore_value(name=self.name) + return val == 1 def locked(self): - return self._locked - - def __enter__(self): - self.acquire() - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.release() - - async def __aenter__(self): - await self.acquire() - return self + return self.sync(self._locked) - async def __aexit__(self, exc_type, exc_value, traceback): - await self.release() + def __getstate__(self): + return self.name - def __reduce__(self): - return (Lock, (self.name,)) + def __setstate__(self, state): + self.__init__(name=state, register=False) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 85a4c62235..edd2f1b9e1 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -101,7 +101,6 @@ from distributed.diagnostics.plugin import SchedulerPlugin, _get_plugin_name from distributed.event import EventExtension from distributed.http import get_handlers -from distributed.lock import LockExtension from distributed.metrics import time from distributed.multi_lock import MultiLockExtension from distributed.node import ServerNode @@ -179,7 +178,6 @@ STIMULUS_ID_UNSET = "" DEFAULT_EXTENSIONS = { - "locks": LockExtension, "multi_locks": MultiLockExtension, "publish": PublishExtension, "replay-tasks": ReplayTaskScheduler, diff --git a/distributed/semaphore.py b/distributed/semaphore.py index 59d6951d7b..fd971d46e8 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -40,7 +40,7 @@ def __init__(self, scheduler): self.max_leases = dict() # {semaphore_name: {lease_id: lease_last_seen_timestamp}} self.leases = defaultdict(dict) - + self.lease_timeouts = dict() self.scheduler.handlers.update( { "semaphore_register": self.create, @@ -70,20 +70,18 @@ def __init__(self, scheduler): self._check_lease_timeout, validation_callback_time * 1000 ) pc.start() - self.lease_timeout = parse_timedelta( - dask.config.get("distributed.scheduler.locks.lease-timeout"), default="s" - ) def get_value(self, name=None): return len(self.leases[name]) # `comm` here is required by the handler interface - def create(self, name=None, max_leases=None): + def create(self, name, max_leases, lease_timeout): # We use `self.max_leases` as the point of truth to find out if a semaphore with a specific # `name` has been created. if name not in self.max_leases: assert isinstance(max_leases, int), max_leases self.max_leases[name] = max_leases + self.lease_timeouts[name] = lease_timeout else: if max_leases != self.max_leases[name]: raise ValueError( @@ -128,7 +126,7 @@ def _semaphore_exists(self, name): @log_errors async def acquire(self, name=None, timeout=None, lease_id=None): if not self._semaphore_exists(name): - raise RuntimeError(f"Semaphore `{name}` not known or already closed.") + raise RuntimeError(f"Semaphore or Lock `{name}` not known.") if isinstance(name, list): name = tuple(name) @@ -176,7 +174,7 @@ async def acquire(self, name=None, timeout=None, lease_id=None): def release(self, name=None, lease_id=None): if not self._semaphore_exists(name): logger.warning( - f"Tried to release semaphore `{name}` but it is not known or already closed." + f"Tried to release Lock or Semaphore `{name}` but it is not known." ) return if isinstance(name, list): @@ -185,9 +183,9 @@ def release(self, name=None, lease_id=None): self._release_value(name, lease_id) else: logger.warning( - "Tried to release semaphore but it was already released: " + f"Tried to release Lock or Semaphore but it was already released: " f"{name=}, {lease_id=}. " - "This can happen if the semaphore timed out before." + f"This can happen if the Lock or Semaphore timed out before." ) def _release_value(self, name, lease_id): @@ -201,23 +199,24 @@ def _check_lease_timeout(self): now = time() semaphore_names = list(self.leases.keys()) for name in semaphore_names: - ids = list(self.leases[name]) - logger.debug( - "Validating leases for %s at time %s. Currently known %s", - name, - now, - self.leases[name], - ) - for _id in ids: - time_since_refresh = now - self.leases[name][_id] - if time_since_refresh > self.lease_timeout: - logger.debug( - "Lease %s for %s timed out after %ss.", - _id, - name, - time_since_refresh, - ) - self._release_value(name=name, lease_id=_id) + if lease_timeout := self.lease_timeouts.get(name): + ids = list(self.leases[name]) + logger.debug( + "Validating leases for %s at time %s. Currently known %s", + name, + now, + self.leases[name], + ) + for _id in ids: + time_since_refresh = now - self.leases[name][_id] + if time_since_refresh > lease_timeout: + logger.debug( + "Lease %s for %s timed out after %ss.", + _id, + name, + time_since_refresh, + ) + self._release_value(name=name, lease_id=_id) @log_errors def close(self, name=None): @@ -226,6 +225,7 @@ def close(self, name=None): return del self.max_leases[name] + del self.lease_timeouts[name] if name in self.events: del self.events[name] if name in self.leases: @@ -320,14 +320,6 @@ class Semaphore(SyncMethodMixin): ----- If a client attempts to release the semaphore but doesn't have a lease acquired, this will raise an exception. - - When a semaphore is closed, if, for that closed semaphore, a client attempts to: - - - Acquire a lease: an exception will be raised. - - Release: a warning will be logged. - - Close: nothing will happen. - - dask executes functions by default assuming they are pure, when using semaphore acquire/releases inside such a function, it must be noted that there *are* in fact side-effects, thus, the function can no longer be considered pure. If this is not taken into account, this may lead to unexpected behavior. @@ -352,19 +344,23 @@ def __init__( self.refresh_leases = True - self._registered = None + self._do_register = None if register: - self._registered = self.register() + self._do_register = register # this should give ample time to refresh without introducing another # config parameter since this *must* be smaller than the timeout anyhow - refresh_leases_interval = ( - parse_timedelta( - dask.config.get("distributed.scheduler.locks.lease-timeout"), - default="s", - ) - / 5 + lease_timeout = dask.config.get("distributed.scheduler.locks.lease-timeout") + if lease_timeout == "inf": + return + + ## Below is all code for the lease timout validation + + lease_timeout = parse_timedelta( + dask.config.get("distributed.scheduler.locks.lease-timeout"), + default="s", ) + refresh_leases_interval = lease_timeout / 5 pc = PeriodicCallback( self._refresh_leases, callback_time=refresh_leases_interval * 1000 ) @@ -407,10 +403,17 @@ def _verify_running(self): ) async def _register(self): + lease_timeout = dask.config.get("distributed.scheduler.locks.lease-timeout") + + if lease_timeout == "inf": + lease_timeout = None + else: + lease_timeout = parse_timedelta(lease_timeout, "s") await retry_operation( self.scheduler.semaphore_register, name=self.name, max_leases=self.max_leases, + lease_timeout=lease_timeout, operation=f"semaphore register id={self.id} name={self.name}", ) @@ -419,8 +422,8 @@ def register(self, **kwargs): def __await__(self): async def create_semaphore(): - if self._registered: - await self._registered + if self._do_register: + await self._register() return self return create_semaphore().__await__() @@ -442,6 +445,7 @@ async def _refresh_leases(self): ) async def _acquire(self, timeout=None): + await self lease_id = uuid.uuid4().hex logger.debug( "%s requests lease for %s with ID %s", self.id, self.name, lease_id @@ -527,6 +531,7 @@ def get_value(self): return self.sync(self.scheduler.semaphore_value, name=self.name) def __enter__(self): + self.register() self._verify_running() self.acquire() return self @@ -535,6 +540,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.release() async def __aenter__(self): + await self self._verify_running() await self.acquire() return self diff --git a/distributed/tests/test_locks.py b/distributed/tests/test_locks.py index 7477f35460..d13d14577c 100644 --- a/distributed/tests/test_locks.py +++ b/distributed/tests/test_locks.py @@ -6,37 +6,35 @@ import pytest +import dask + from distributed import Lock, get_client from distributed.metrics import time from distributed.utils_test import gen_cluster -@gen_cluster(client=True, nthreads=[("127.0.0.1", 8)] * 2) +@gen_cluster(client=True, nthreads=[("", 8)] * 2) async def test_lock(c, s, a, b): await c.set_metadata("locked", False) def f(x): client = get_client() - with Lock("x") as lock: + with Lock("x"): assert client.get_metadata("locked") is False client.set_metadata("locked", True) - sleep(0.05) + sleep(0.01) assert client.get_metadata("locked") is True client.set_metadata("locked", False) futures = c.map(f, range(20)) await c.gather(futures) - assert not s.extensions["locks"].events - assert not s.extensions["locks"].ids @gen_cluster(client=True) async def test_timeout(c, s, a, b): - locks = s.extensions["locks"] lock = Lock("x") result = await lock.acquire() assert result is True - assert locks.ids["x"] == lock.id lock2 = Lock("x") assert lock.id != lock2.id @@ -46,9 +44,6 @@ async def test_timeout(c, s, a, b): stop = time() assert stop - start < 0.3 assert result is False - assert locks.ids["x"] == lock.id - assert not locks.events["x"] - await lock.release() @@ -56,7 +51,7 @@ async def test_timeout(c, s, a, b): async def test_acquires_with_zero_timeout(c, s, a, b): lock = Lock("x") await lock.acquire(timeout=0) - assert lock.locked() + assert await lock.locked() await lock.release() await lock.acquire(timeout="1s") @@ -69,12 +64,12 @@ async def test_acquires_with_zero_timeout(c, s, a, b): async def test_acquires_blocking(c, s, a, b): lock = Lock("x") await lock.acquire(blocking=False) - assert lock.locked() + assert await lock.locked() await lock.release() - assert not lock.locked() + assert not await lock.locked() with pytest.raises(ValueError): - lock.acquire(blocking=False, timeout=1) + lock.acquire(blocking=False, timeout=0.1) def test_timeout_sync(client): @@ -85,7 +80,7 @@ def test_timeout_sync(client): @gen_cluster(client=True) async def test_errors(c, s, a, b): lock = Lock("x") - with pytest.raises(ValueError): + with pytest.raises(RuntimeError): await lock.release() @@ -95,7 +90,7 @@ def f(x): client = get_client() assert client.get_metadata("locked") is False client.set_metadata("locked", True) - sleep(0.05) + sleep(0.01) assert client.get_metadata("locked") is True client.set_metadata("locked", False) @@ -113,8 +108,6 @@ async def test_lock_types(c, s, a, b): await lock.acquire() await lock.release() - assert not s.extensions["locks"].events - @gen_cluster(client=True) async def test_serializable(c, s, a, b): @@ -129,13 +122,21 @@ def f(x, lock=None): lock2 = pickle.loads(pickle.dumps(lock)) assert lock2.name == lock.name - assert lock2.client is lock.client @gen_cluster(client=True, nthreads=[]) async def test_locks(c, s): async with Lock("x") as l1: l2 = Lock("x") - assert l1.client is c - assert l2.client is c assert await l2.acquire(timeout=0.01) is False + + +@gen_cluster(client=True, nthreads=[]) +async def test_locks_inf_lease_timeout(c, s): + sem_ext = s.extensions["semaphores"] + async with Lock("x"): + assert sem_ext.lease_timeouts["x"] + + with dask.config.set({"distributed.scheduler.locks.lease-timeout": "inf"}): + async with Lock("y"): + assert sem_ext.lease_timeouts.get("y") is None diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index af759832a0..16a07362b5 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -100,8 +100,8 @@ def test_timeout_sync(client): @gen_cluster( client=True, config={ - "distributed.scheduler.locks.lease-validation-interval": "200ms", - "distributed.scheduler.locks.lease-timeout": "200ms", + "distributed.scheduler.locks.lease-validation-interval": "100ms", + "distributed.scheduler.locks.lease-timeout": "100ms", }, ) async def test_release_semaphore_after_timeout(c, s, a, b): @@ -199,14 +199,16 @@ async def test_close_async(c, s, a): assert await sem.acquire() with pytest.warns( RuntimeWarning, - match="Closing semaphore .* but there remain unreleased leases .*", + match="Closing semaphore test but there remain unreleased leases .*", ): await sem.close() - - with pytest.raises( - RuntimeError, match="Semaphore `test` not known or already closed." + # After close, the semaphore is reset + await sem.acquire() + with pytest.warns( + RuntimeWarning, + match="Closing semaphore test but there remain unreleased leases .*", ): - await sem.acquire() + await sem.close() sem2 = await Semaphore(name="t2", max_leases=1) assert await sem2.acquire() @@ -231,14 +233,6 @@ def f(sem_): assert not metric_dict -def test_close_sync(client): - sem = Semaphore() - sem.close() - - with pytest.raises(RuntimeError, match="Semaphore .* not known or already closed."): - sem.acquire() - - @gen_cluster(client=True) async def test_release_once_too_many(c, s, a, b): sem = await Semaphore(name="x") From 651ba5a97937da7f37c88ee3c9f58420db304b4e Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 18 Jul 2024 16:02:39 +0200 Subject: [PATCH 070/138] Implement HLG layer for P2P rechunking (#8751) --- distributed/shuffle/_rechunk.py | 320 ++++++++++++++++++---- distributed/shuffle/tests/test_rechunk.py | 30 ++ 2 files changed, 302 insertions(+), 48 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index b33e90730b..1ff0daa97c 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -99,19 +99,30 @@ import mmap import os from collections import defaultdict -from collections.abc import Callable, Generator, Hashable, Sequence +from collections.abc import ( + Callable, + Collection, + Generator, + Hashable, + Iterable, + Iterator, + Mapping, + Sequence, +) from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from itertools import product from pathlib import Path -from typing import TYPE_CHECKING, Any, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple, cast import toolz from tornado.ioloop import IOLoop import dask +import dask.config from dask.base import tokenize -from dask.highlevelgraph import HighLevelGraph, MaterializedLayer +from dask.highlevelgraph import HighLevelGraph +from dask.layers import Layer from dask.typing import Key from distributed.core import PooledRPCCall @@ -138,6 +149,7 @@ import dask.array as da +_T_LowLevelGraph: TypeAlias = dict[Key, tuple] ChunkedAxis: TypeAlias = tuple[float, ...] # chunks must either be an int or NaN ChunkedAxes: TypeAlias = tuple[ChunkedAxis, ...] NDSlice: TypeAlias = tuple[slice, ...] @@ -216,27 +228,214 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array: if x.size == 0: # Special case for empty array, as the algorithm below does not behave correctly return da.empty(x.shape, chunks=chunks, dtype=x.dtype) + from dask.array.core import new_da_object - dsk = {} token = tokenize(x, chunks) - for ndpartial in _split_partials(x, chunks): - if all(slc.stop == slc.start + 1 for slc in ndpartial.new): - # Single output chunk - dsk.update(partial_concatenate(x, chunks, ndpartial, token)) - else: - dsk.update(partial_rechunk(x, chunks, ndpartial, token)) - layer = MaterializedLayer(dsk) name = rechunk_name(token) - graph = HighLevelGraph.from_collections(name, layer, dependencies=[x]) - arr = da.Array(graph, name, chunks, meta=x) - return arr + disk: bool = dask.config.get("distributed.p2p.disk") + + layer = P2PRechunkLayer( + name=name, + token=token, + chunks=chunks, + chunks_input=x.chunks, + name_input=x.name, + disk=disk, + ) + return new_da_object( + HighLevelGraph.from_collections(name, layer, [x]), + name, + chunks, + meta=x._meta, + dtype=x.dtype, + ) + + +class P2PRechunkLayer(Layer): + name: str + token: str + chunks: ChunkedAxes + chunks_input: ChunkedAxes + name_input: str + disk: bool + keepmap: np.ndarray + + _cached_dict: _T_LowLevelGraph | None + + def __init__( + self, + name: str, + token: str, + chunks: ChunkedAxes, + chunks_input: ChunkedAxes, + name_input: str, + disk: bool, + keepmap: np.ndarray | None = None, + annotations: Mapping[str, Any] | None = None, + ): + import numpy as np + + self.name = name + self.token = token + self.chunks = chunks + self.chunks_input = chunks_input + self.name_input = name_input + self.disk = disk + if keepmap is not None: + self.keepmap = keepmap + else: + shape = tuple(len(axis) for axis in chunks) + self.keepmap = np.ones(shape, dtype=bool) + self._cached_dict = None + super().__init__(annotations=annotations) + + def __repr__(self) -> str: + return f"{type(self).__name__}" + + def get_output_keys(self) -> set[Key]: + import numpy as np + + return { + (self.name,) + nindex + for nindex in np.ndindex(tuple(len(axis) for axis in self.chunks)) + if self.keepmap[nindex] + } + + def is_materialized(self) -> bool: + return self._cached_dict is not None + + @property + def _dict(self) -> _T_LowLevelGraph: + """Materialize full dict representation""" + dsk: _T_LowLevelGraph + if self._cached_dict is not None: + return self._cached_dict + else: + dsk = self._construct_graph() + self._cached_dict = dsk + return self._cached_dict + + def __getitem__(self, key: Key) -> tuple: + return self._dict[key] + + def __iter__(self) -> Iterator[Key]: + return iter(self._dict) + + def __len__(self) -> int: + return len(self._dict) + + def _cull(self, keepmap: np.ndarray) -> P2PRechunkLayer: + return P2PRechunkLayer( + name=self.name, + token=self.token, + chunks=self.chunks, + chunks_input=self.chunks_input, + name_input=self.name_input, + disk=self.disk, + keepmap=keepmap, + annotations=self.annotations, + ) + + def _keys_to_indices(self, keys: Iterable[Key]) -> set[tuple[int, ...]]: + """Simple utility to convert keys to chunk indices.""" + chunks = set() + for key in keys: + if not isinstance(key, tuple) or len(key) < 2 or key[0] != self.name: + continue + chunk = cast(tuple[int, ...], key[1:]) + assert all(isinstance(index, int) for index in chunk) + chunks.add(chunk) + return chunks + + def cull( + self, keys: set[Key], all_keys: Collection[Key] + ) -> tuple[P2PRechunkLayer, dict]: + """Cull a P2PRechunkLayer HighLevelGraph layer. + + The underlying graph will only include the necessary + tasks to produce the keys (indices) included in `keepmap`. + Therefore, "culling" the layer only requires us to reset this + parameter. + """ + import numpy as np + + from dask.array.rechunk import old_to_new + + keepmap = np.zeros_like(self.keepmap, dtype=bool) + indices_to_keep = self._keys_to_indices(keys) + _old_to_new = old_to_new(self.chunks_input, self.chunks) + + culled_deps: defaultdict[Key, set[Key]] = defaultdict(set) + for nindex in indices_to_keep: + old_indices_per_axis = [] + keepmap[nindex] = True + for index, new_axis in zip(nindex, _old_to_new): + old_indices_per_axis.append( + [old_chunk_index for old_chunk_index, _ in new_axis[index]] + ) + for old_nindex in product(*old_indices_per_axis): + culled_deps[(self.name,) + nindex].add((self.name_input,) + old_nindex) + + # Protect against mutations later on with frozenset + frozen_deps = { + output_task: frozenset(input_tasks) + for output_task, input_tasks in culled_deps.items() + } + + if np.array_equal(keepmap, self.keepmap): + return self, frozen_deps + else: + culled_layer = self._cull(keepmap) + return culled_layer, frozen_deps + + def _construct_graph(self) -> _T_LowLevelGraph: + import numpy as np + + from dask.array.rechunk import old_to_new + + dsk: _T_LowLevelGraph = {} + + _old_to_new = old_to_new(self.chunks_input, self.chunks) + chunked_shape = tuple(len(axis) for axis in self.chunks) + + for ndpartial in _split_partials(_old_to_new, chunked_shape): + output_count = np.sum(self.keepmap[ndpartial.new]) + if output_count == 0: + continue + elif output_count == 1: + # Single output chunk + # TODO: Create new partial that contains ONLY the relevant chunk + dsk.update( + partial_concatenate( + input_name=self.name_input, + input_chunks=self.chunks_input, + ndpartial=ndpartial, + token=self.token, + keepmap=self.keepmap, + old_to_new=_old_to_new, + ) + ) + else: + dsk.update( + partial_rechunk( + input_name=self.name_input, + input_chunks=self.chunks_input, + chunks=self.chunks, + ndpartial=ndpartial, + token=self.token, + disk=self.disk, + keepmap=self.keepmap, + ) + ) + return dsk def _split_partials( - x: da.Array, chunks: ChunkedAxes + old_to_new: list[Any], + chunked_shape: tuple[int, ...], ) -> Generator[_NDPartial, None, None]: """Split the rechunking into partials that can be performed separately""" - partials_per_axis = _split_partials_per_axis(x, chunks) + partials_per_axis = _split_partials_per_axis(old_to_new, chunked_shape) indices_per_axis = (range(len(partials)) for partials in partials_per_axis) for nindex, partial_per_axis in zip( product(*indices_per_axis), product(*partials_per_axis) @@ -246,26 +445,19 @@ def _split_partials( def _split_partials_per_axis( - x: da.Array, chunks: ChunkedAxes + old_to_new: list[Any], chunked_shape: tuple[int, ...] ) -> tuple[tuple[_Partial, ...], ...]: """Split the rechunking into partials that can be performed separately on each axis""" - from dask.array.rechunk import old_to_new - - chunked_shape = tuple(len(axis) for axis in chunks) - _old_to_new = old_to_new(x.chunks, chunks) - - sliced_axes = _partial_slices(_old_to_new, chunked_shape) + sliced_axes = _partial_slices(old_to_new, chunked_shape) partial_axes = [] for axis_index, slices in enumerate(sliced_axes): partials = [] for slice_ in slices: last_old_chunk: int - first_old_chunk, first_old_slice = _old_to_new[axis_index][slice_.start][0] - last_old_chunk, last_old_slice = _old_to_new[axis_index][slice_.stop - 1][ - -1 - ] + first_old_chunk, first_old_slice = old_to_new[axis_index][slice_.start][0] + last_old_chunk, last_old_slice = old_to_new[axis_index][slice_.stop - 1][-1] partials.append( _Partial( old=slice(first_old_chunk, last_old_chunk + 1), @@ -320,10 +512,12 @@ def _global_index(partial_index: NDIndex, partial_offset: NDIndex) -> NDIndex: def partial_concatenate( - x: da.Array, - chunks: ChunkedAxes, + input_name: str, + input_chunks: ChunkedAxes, ndpartial: _NDPartial, token: str, + keepmap: np.ndarray, + old_to_new: list[Any], ) -> dict[Key, Any]: import numpy as np @@ -333,12 +527,38 @@ def partial_concatenate( dsk: dict[Key, Any] = {} slice_group = f"rechunk-slice-{token}" + + partial_keepmap = keepmap[ndpartial.new] + assert np.sum(partial_keepmap) == 1 + + ndindex = np.argwhere(partial_keepmap)[0] + + partial_per_axis = [] + for axis_index, index in enumerate(ndindex): + slc = slice( + ndpartial.new[axis_index].start + index, + ndpartial.new[axis_index].start + index + 1, + ) + first_old_chunk, first_old_slice = old_to_new[axis_index][slc.start][0] + last_old_chunk, last_old_slice = old_to_new[axis_index][slc.stop - 1][-1] + partial_per_axis.append( + _Partial( + old=slice(first_old_chunk, last_old_chunk + 1), + new=slc, + left_start=first_old_slice.start, + right_stop=last_old_slice.stop, + ) + ) + + old, new, left_starts, right_stops = zip(*partial_per_axis) + ndpartial = _NDPartial(old, new, left_starts, right_stops, ndpartial.ix) + old_offset = tuple(slice_.start for slice_ in ndpartial.old) shape = tuple(slice_.stop - slice_.start for slice_ in ndpartial.old) rec_cat_arg = np.empty(shape, dtype="O") - partial_old = _compute_partial_old_chunks(ndpartial, x.chunks) + partial_old = _compute_partial_old_chunks(ndpartial, input_chunks) for old_partial_index in _partial_ndindex(ndpartial.old): old_global_index = _global_index(old_partial_index, old_offset) @@ -348,18 +568,18 @@ def partial_concatenate( ) original_shape = tuple( - axis[index] for index, axis in zip(old_global_index, x.chunks) + axis[index] for index, axis in zip(old_global_index, input_chunks) ) - if _slicing_is_necessary(ndslice, original_shape): + if _slicing_is_necessary(ndslice, original_shape): # type: ignore key = (slice_group,) + ndpartial.ix + old_global_index rec_cat_arg[old_partial_index] = key dsk[key] = ( getitem, - (x.name,) + old_global_index, + (input_name,) + old_global_index, ndslice, ) else: - rec_cat_arg[old_partial_index] = (x.name,) + old_global_index + rec_cat_arg[old_partial_index] = (input_name,) + old_global_index global_index = tuple(int(slice_.start) for slice_ in ndpartial.new) dsk[(rechunk_name(token),) + global_index] = ( concatenate3, @@ -390,10 +610,13 @@ def _slicing_is_necessary(slice: NDSlice, shape: tuple[int | None, ...]) -> bool def partial_rechunk( - x: da.Array, + input_name: str, + input_chunks: ChunkedAxes, chunks: ChunkedAxes, ndpartial: _NDPartial, token: str, + disk: bool, + keepmap: np.ndarray, ) -> dict[Key, Any]: from dask.array.chunk import getitem @@ -410,17 +633,17 @@ def partial_rechunk( # group across all P2P shuffle-like operations # FIXME: Make this group unique per individual P2P shuffle-like operation _barrier_key = barrier_key(ShuffleId(partial_token)) - disk: bool = dask.config.get("distributed.p2p.disk") - ndim = len(x.shape) + ndim = len(input_chunks) - partial_old = _compute_partial_old_chunks(ndpartial, x.chunks) + partial_old = _compute_partial_old_chunks(ndpartial, input_chunks) partial_new: ChunkedAxes = tuple( chunks[axis_index][ndpartial.new[axis_index]] for axis_index in range(ndim) ) transfer_keys = [] for partial_index in _partial_ndindex(ndpartial.old): + # FIXME: Do not shuffle data for output chunks that we culled ndslice = ndslice_for( partial_index, partial_old, ndpartial.left_starts, ndpartial.right_stops ) @@ -428,17 +651,17 @@ def partial_rechunk( global_index = _global_index(partial_index, old_partial_offset) original_shape = tuple( - axis[index] for index, axis in zip(global_index, x.chunks) + axis[index] for index, axis in zip(global_index, input_chunks) ) - if _slicing_is_necessary(ndslice, original_shape): + if _slicing_is_necessary(ndslice, original_shape): # type: ignore input_key = (slice_group,) + ndpartial.ix + global_index dsk[input_key] = ( getitem, - (x.name,) + global_index, + (input_name,) + global_index, ndslice, ) else: - input_key = (x.name,) + global_index + input_key = (input_name,) + global_index key = (transfer_group,) + ndpartial.ix + global_index transfer_keys.append(key) @@ -457,12 +680,13 @@ def partial_rechunk( new_partial_offset = tuple(axis.start for axis in ndpartial.new) for partial_index in _partial_ndindex(ndpartial.new): global_index = _global_index(partial_index, new_partial_offset) - dsk[(unpack_group,) + global_index] = ( - rechunk_unpack, - partial_token, - partial_index, - _barrier_key, - ) + if keepmap[global_index]: + dsk[(unpack_group,) + global_index] = ( + rechunk_unpack, + partial_token, + partial_index, + _barrier_key, + ) return dsk diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 89d9791f43..e6d3b558fe 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -232,6 +232,36 @@ async def test_cull_p2p_rechunk_independent_partitions(c, s, *ws): assert np.all(await c.compute(culled) == a[:5, :2]) +@gen_cluster(client=True) +async def test_cull_p2p_rechunking_single_chunk(c, s, *ws): + a = np.random.default_rng().uniform(0, 1, 1000).reshape((10, 10, 10)) + x = da.from_array(a, chunks=(1, 5, 1)) + new = (5, 1, -1) + rechunked = rechunk(x, chunks=new, method="p2p") + (dsk,) = dask.optimize(rechunked) + culled = rechunked[:5, 1:2] + (dsk_culled,) = dask.optimize(culled) + + # The culled graph requires only 1/2 of the input tasks + n_inputs = len( + [1 for key in dsk.dask.get_all_dependencies() if key[0].startswith("array-")] + ) + assert n_inputs > 0 + + n_culled_inputs = len( + [ + 1 + for key in dsk_culled.dask.get_all_dependencies() + if key[0].startswith("array-") + ] + ) + assert n_culled_inputs == n_inputs / 4 + # The culled graph should also have less than 1/4 the tasks + assert len(dsk_culled.dask) < len(dsk.dask) / 4 + + assert np.all(await c.compute(culled) == a[:5, 1:2]) + + @gen_cluster(client=True) async def test_cull_p2p_rechunk_overlapping_partitions(c, s, *ws): a = np.random.default_rng().uniform(0, 1, 500).reshape((10, 10, 5)) From 5d9d96e217d1379f69e1f9e191231ce9c2784fdd Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Thu, 18 Jul 2024 15:19:11 +0100 Subject: [PATCH 071/138] Creating transitions-failures log event (#8776) Co-authored-by: Hendrik Makait --- distributed/scheduler.py | 15 ++++++++++++++ distributed/tests/test_scheduler.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index edd2f1b9e1..7754581b74 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -17,6 +17,7 @@ import uuid import warnings import weakref +from abc import abstractmethod from collections import defaultdict, deque from collections.abc import ( Callable, @@ -1842,6 +1843,10 @@ def __init__( + repr(self.WORKER_SATURATION) ) + @abstractmethod + def log_event(self, topic: str | Collection[str], msg: Any) -> None: + ... + @property def memory(self) -> MemoryState: return MemoryState.sum(*(w.memory for w in self.workers.values())) @@ -2086,6 +2091,16 @@ def _transition( return recommendations, client_msgs, worker_msgs except Exception: logger.exception("Error transitioning %r from %r to %r", key, start, finish) + self.log_event( + "transitions", + { + "action": "scheduler-transition-failed", + "key": key, + "start": start, + "finish": finish, + "transistion_log": list(self.transition_log), + }, + ) if LOG_PDB: import pdb diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 069ca0f8f2..2ba6a35580 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -4640,6 +4640,37 @@ def assert_rootish(): await c.gather(fut3) +@gen_test() +async def test_transition_failure_triggers_log_event(): + def block_on_event(input, block, executing): + executing.set() + block.wait() + return input + + # Manually spin up cluster to avoid state validation on cluster shutdown in gen_cluster + async with Scheduler(dashboard_address=":0") as s, Worker(s.address) as w, Client( + s.address, asynchronous=True + ) as c: + block = Event() + executing = Event() + + fut = c.submit(block_on_event, 0, block, executing) + await executing.wait() + + # Manually corrupt the state of the processing task + s.tasks[fut.key].processing_on = None + + await block.set() + await async_poll_for( + lambda: sum( + event["action"] == "scheduler-transition-failed" + for _, event in s.get_events("transitions") + ) + == 1, + timeout=5, + ) + + @pytest.mark.skipif( not QUEUING_ON_BY_DEFAULT, reason="The situation handled in this test requires queueing.", From 30c0d293df69ec263aea6b83d365d845577e929b Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Thu, 18 Jul 2024 21:45:47 +0200 Subject: [PATCH 072/138] Expose paused and retired workers separately in prometheus (#8613) --- distributed/http/scheduler/prometheus/core.py | 9 +++++++- .../scheduler/tests/test_scheduler_http.py | 23 +++++++++++++++---- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/distributed/http/scheduler/prometheus/core.py b/distributed/http/scheduler/prometheus/core.py index 24034bee9d..628f4c2e54 100644 --- a/distributed/http/scheduler/prometheus/core.py +++ b/distributed/http/scheduler/prometheus/core.py @@ -7,6 +7,7 @@ import toolz from prometheus_client.core import CounterMetricFamily, GaugeMetricFamily +from distributed.core import Status from distributed.http.prometheus import PrometheusCollector from distributed.http.scheduler.prometheus.semaphore import SemaphoreMetricCollector from distributed.http.scheduler.prometheus.stealing import WorkStealingMetricCollector @@ -49,9 +50,15 @@ def collect(self) -> Iterator[GaugeMetricFamily | CounterMetricFamily]: - len(self.server.saturated), ) worker_states.add_metric(["saturated"], len(self.server.saturated)) + paused_workers = len( + [w for w in self.server.workers.values() if w.status == Status.paused] + ) + worker_states.add_metric(["paused"], paused_workers) worker_states.add_metric( - ["paused_or_retiring"], len(self.server.workers) - len(self.server.running) + ["retiring"], + len(self.server.workers) - paused_workers - len(self.server.running), ) + yield worker_states if self.server.monitor.monitor_gil_contention: diff --git a/distributed/http/scheduler/tests/test_scheduler_http.py b/distributed/http/scheduler/tests/test_scheduler_http.py index 4b647dc6f2..7be4365ffc 100644 --- a/distributed/http/scheduler/tests/test_scheduler_http.py +++ b/distributed/http/scheduler/tests/test_scheduler_http.py @@ -343,7 +343,8 @@ async def fetch_metrics(): "idle": 2, "partially_saturated": 0, "saturated": 0, - "paused_or_retiring": 0, + "paused": 0, + "retiring": 0, } ev = Event() @@ -353,7 +354,8 @@ async def fetch_metrics(): "idle": 1, "partially_saturated": 1, "saturated": 0, - "paused_or_retiring": 0, + "paused": 0, + "retiring": 0, } y = c.submit(lambda ev: ev.wait(), ev, key="y", workers=[a.address]) @@ -365,7 +367,8 @@ async def fetch_metrics(): "idle": 1, "partially_saturated": 0, "saturated": 1, - "paused_or_retiring": 0, + "paused": 0, + "retiring": 0, } a.monitor.get_process_memory = lambda: 2**40 @@ -375,9 +378,21 @@ async def fetch_metrics(): "idle": 1, "partially_saturated": 0, "saturated": 0, - "paused_or_retiring": 1, + "paused": 1, + "retiring": 0, } + sa.status = Status.stopping + while sa.status != Status.stopping: + await asyncio.sleep(0.01) + + assert await fetch_metrics() == { + "idle": 1, + "partially_saturated": 0, + "saturated": 0, + "paused": 0, + "retiring": 1, + } await ev.set() From f211adabe599d43c2613855a4243f0ebbce2456b Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 19 Jul 2024 15:19:24 +0200 Subject: [PATCH 073/138] Avoid false positives for `p2p-failed` log event (#8777) Co-authored-by: Florian Jetter --- distributed/shuffle/_scheduler_plugin.py | 4 +++ distributed/shuffle/tests/test_rechunk.py | 3 +- distributed/shuffle/tests/test_shuffle.py | 36 +++++++++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index c6fbbe210a..5f474c0cfd 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -416,8 +416,12 @@ def transition( if finish == "erred": ts = self.scheduler.tasks[key] for active_shuffle in self.active_shuffles.values(): + # Log once per active shuffle if active_shuffle._failed: continue + # Log IFF a P2P task is the root cause + if ts.exception_blame != ts: + continue barrier = self.scheduler.tasks[barrier_key(active_shuffle.id)] if ( ts == barrier diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index e6d3b558fe..69438b4473 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -246,8 +246,9 @@ async def test_cull_p2p_rechunking_single_chunk(c, s, *ws): n_inputs = len( [1 for key in dsk.dask.get_all_dependencies() if key[0].startswith("array-")] ) + assert n_inputs > 0 - + n_culled_inputs = len( [ 1 diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 311b8f7939..39f53528fd 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -37,6 +37,7 @@ LocalCluster, Nanny, Scheduler, + Semaphore, Worker, ) from distributed.core import ConnectionPool, ErrorMessage, OKMessage @@ -439,6 +440,41 @@ async def test_restarting_during_transfer_raises_killed_worker(c, s, a, b): await check_scheduler_cleanup(s) +@gen_cluster( + client=True, + nthreads=[("", 1), ("", 1)], + config={"distributed.scheduler.allowed-failures": 0}, +) +async def test_erred_task_before_p2p_does_not_log_event(c, s, a, b): + def block_and_fail_eventually(df, semaphore, event): + acquired = semaphore.acquire(timeout=0) + if acquired: + return df + event.wait() + raise RuntimeError("test error") + + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-02-01", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + semaphore = await Semaphore(max_leases=s.total_nthreads * 2 + 1) + event = Event() + + df = df.map_partitions(block_and_fail_eventually, semaphore, event, meta=df._meta) + with dask.config.set({"dataframe.shuffle.method": "p2p"}): + out = df.shuffle("x") + shuffle_ext = s.plugins["shuffle"] + out = c.compute(out) + await async_poll_for(lambda: shuffle_ext.active_shuffles, timeout=5) + await event.set() + with pytest.raises(RuntimeError, match="test error"): + await out + + assert all(event["action"] != "p2p-failed" for _, event in s.get_events("p2p")) + + @gen_cluster( client=True, nthreads=[("", 1)] * 2, From 9c3da6ab1481d2a37bffc41c8c80aa65a72504b5 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 19 Jul 2024 17:26:07 +0200 Subject: [PATCH 074/138] Restore len() on TaskPrefix (#8783) --- distributed/scheduler.py | 6 +++--- distributed/tests/test_scheduler.py | 9 +++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 7754581b74..a8f7d54630 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1000,6 +1000,9 @@ def update_nbytes(self, diff: int) -> None: def _calculate_duration_us(start: float, stop: float) -> int: return max(round((stop - start) * 1e6), 0) + def __len__(self) -> int: + return sum(self.states.values()) + class TaskPrefix(TaskCollection): """Collection tracking all tasks within a prefix @@ -1190,9 +1193,6 @@ def __repr__(self) -> str: + ">" ) - def __len__(self) -> int: - return sum(self.states.values()) - def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict[str, Any]: """Dictionary representation for debugging purposes. Not type stable and not intended for roundtrips. diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 2ba6a35580..0a63557106 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2822,6 +2822,8 @@ async def test_task_group_and_prefix_statistics(c, s, a, b, no_time_resync): assert tg.states["memory"] == 0 assert tg.states["released"] == 5 assert sum(tg.states.values()) == 5 + assert len(tg) == 5 + assert len(tp) == 5 assert tg.nbytes_total == sum( ts.get_nbytes() for ts in s.tasks.values() if ts.group is tg ) @@ -2844,6 +2846,8 @@ async def test_task_group_and_prefix_statistics(c, s, a, b, no_time_resync): tg = s.task_groups[y.name] assert tg.states["memory"] == 5 assert sum(tg.states.values()) == 5 + assert len(tg) == 5 + assert len(tp) == 5 tp = s.task_prefixes["add"] assert tg.prefix is tp @@ -2881,6 +2885,9 @@ async def test_task_group_and_prefix_statistics(c, s, a, b, no_time_resync): assert tg.states["forgotten"] == 4 assert tg.states["released"] == 1 assert sum(tg.states.values()) == 5 + assert len(tg) == 5 + assert len(tp) == 5 + assert tg.states == tp.states with pytest.warns(FutureWarning, match="active_states"): assert tp.states == tp.active_states @@ -2895,6 +2902,7 @@ async def test_task_group_and_prefix_statistics(c, s, a, b, no_time_resync): assert tg.states["forgotten"] == 5 assert sum(tg.states.values()) == 5 + assert len(tg) == 5 assert tg.states["forgotten"] == 5 assert tg.name not in s.task_groups @@ -2913,6 +2921,7 @@ async def test_task_group_and_prefix_statistics(c, s, a, b, no_time_resync): assert all(count == 0 for count in tp.states.values()) with pytest.warns(FutureWarning, match="active_states"): assert tp.states == tp.active_states + assert len(tp) == 0 assert tp.duration == 0 assert tp.nbytes_total == 0 assert tp.types == set() From 31cb89cf3b4e327e90fa7bde34d41004630ff336 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 19 Jul 2024 15:36:54 -0500 Subject: [PATCH 075/138] Temporarily pin ``setuptools < 71`` (#8785) --- continuous_integration/environment-3.10.yaml | 2 ++ continuous_integration/environment-3.11.yaml | 2 ++ continuous_integration/environment-3.12.yaml | 2 ++ continuous_integration/environment-3.9.yaml | 2 ++ continuous_integration/environment-mindeps.yaml | 2 ++ 5 files changed, 10 insertions(+) diff --git a/continuous_integration/environment-3.10.yaml b/continuous_integration/environment-3.10.yaml index 14d0abb5f8..be2da72095 100644 --- a/continuous_integration/environment-3.10.yaml +++ b/continuous_integration/environment-3.10.yaml @@ -44,6 +44,8 @@ dependencies: - tornado - zict # overridden by git tip below - zstandard + # Temporary fix for https://github.com/pypa/setuptools/issues/4496 + - setuptools < 71 - pip: - git+https://github.com/dask/dask - git+https://github.com/dask-contrib/dask-expr diff --git a/continuous_integration/environment-3.11.yaml b/continuous_integration/environment-3.11.yaml index 005132a024..dfc2a56be7 100644 --- a/continuous_integration/environment-3.11.yaml +++ b/continuous_integration/environment-3.11.yaml @@ -44,6 +44,8 @@ dependencies: - tornado - zict # overridden by git tip below - zstandard + # Temporary fix for https://github.com/pypa/setuptools/issues/4496 + - setuptools < 71 - pip: - git+https://github.com/dask/dask - git+https://github.com/dask-contrib/dask-expr diff --git a/continuous_integration/environment-3.12.yaml b/continuous_integration/environment-3.12.yaml index a0b23b6e92..39a313d5ef 100644 --- a/continuous_integration/environment-3.12.yaml +++ b/continuous_integration/environment-3.12.yaml @@ -44,6 +44,8 @@ dependencies: - tornado - zict # overridden by git tip below - zstandard + # Temporary fix for https://github.com/pypa/setuptools/issues/4496 + - setuptools < 71 - pip: - git+https://github.com/dask/dask - git+https://github.com/dask-contrib/dask-expr diff --git a/continuous_integration/environment-3.9.yaml b/continuous_integration/environment-3.9.yaml index 97a2cdd375..47502ee3c2 100644 --- a/continuous_integration/environment-3.9.yaml +++ b/continuous_integration/environment-3.9.yaml @@ -51,6 +51,8 @@ dependencies: - tornado - zict - zstandard + # Temporary fix for https://github.com/pypa/setuptools/issues/4496 + - setuptools < 71 - pip: - git+https://github.com/dask/dask - git+https://github.com/dask-contrib/dask-expr diff --git a/continuous_integration/environment-mindeps.yaml b/continuous_integration/environment-mindeps.yaml index d77cdfb0fe..ce2c2c980c 100644 --- a/continuous_integration/environment-mindeps.yaml +++ b/continuous_integration/environment-mindeps.yaml @@ -19,6 +19,8 @@ dependencies: - tornado=6.0.4 - urllib3=1.24.3 - zict=3.0.0 + # Temporary fix for https://github.com/pypa/setuptools/issues/4496 + - setuptools < 71 # Distributed depends on the latest version of Dask - pip - pip: From 70ae414b60ec5f711a2dce4d3b7c08bba5cf8c49 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 19 Jul 2024 22:37:13 +0200 Subject: [PATCH 076/138] Ensure Locks always register with scheduler (#8781) --- distributed/lock.py | 6 +++--- distributed/semaphore.py | 12 +++++------- distributed/tests/test_locks.py | 18 ++++++++++++++++++ distributed/tests/test_semaphore.py | 9 ++------- 4 files changed, 28 insertions(+), 17 deletions(-) diff --git a/distributed/lock.py b/distributed/lock.py index 4c79303bd2..5e10a228a2 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import uuid from distributed.semaphore import Semaphore @@ -51,7 +52,6 @@ def __init__( self, name=None, client=_no_value, - register=True, scheduler_rpc=None, loop=None, ): @@ -64,10 +64,10 @@ def __init__( stacklevel=2, ) + self.name = name or "lock-" + uuid.uuid4().hex super().__init__( max_leases=1, name=name, - register=register, scheduler_rpc=scheduler_rpc, loop=loop, ) @@ -112,4 +112,4 @@ def __getstate__(self): return self.name def __setstate__(self, state): - self.__init__(name=state, register=False) + self.__init__(name=state) diff --git a/distributed/semaphore.py b/distributed/semaphore.py index fd971d46e8..b4d8d900c0 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -330,7 +330,6 @@ def __init__( self, max_leases=1, name=None, - register=True, scheduler_rpc=None, loop=None, ): @@ -344,9 +343,7 @@ def __init__( self.refresh_leases = True - self._do_register = None - if register: - self._do_register = register + self._registered = False # this should give ample time to refresh without introducing another # config parameter since this *must* be smaller than the timeout anyhow @@ -403,6 +400,8 @@ def _verify_running(self): ) async def _register(self): + if self._registered: + return lease_timeout = dask.config.get("distributed.scheduler.locks.lease-timeout") if lease_timeout == "inf": @@ -416,14 +415,14 @@ async def _register(self): lease_timeout=lease_timeout, operation=f"semaphore register id={self.id} name={self.name}", ) + self._registered = True def register(self, **kwargs): return self.sync(self._register) def __await__(self): async def create_semaphore(): - if self._do_register: - await self._register() + await self._register() return self return create_semaphore().__await__() @@ -558,7 +557,6 @@ def __setstate__(self, state): self.__init__( name=name, max_leases=max_leases, - register=False, ) def close(self): diff --git a/distributed/tests/test_locks.py b/distributed/tests/test_locks.py index d13d14577c..44b218b4d6 100644 --- a/distributed/tests/test_locks.py +++ b/distributed/tests/test_locks.py @@ -124,6 +124,24 @@ def f(x, lock=None): assert lock2.name == lock.name +@gen_cluster(client=True) +async def test_serializable_no_ctx(c, s, a, b): + def f(x, lock=None): + lock.acquire() + try: + assert lock.name == "x" + return x + 1 + finally: + lock.release() + + lock = Lock("x") + futures = c.map(f, range(10), lock=lock) + await c.gather(futures) + + lock2 = pickle.loads(pickle.dumps(lock)) + assert lock2.name == lock.name + + @gen_cluster(client=True, nthreads=[]) async def test_locks(c, s): async with Lock("x") as l1: diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index 16a07362b5..05cbb46e02 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -202,13 +202,8 @@ async def test_close_async(c, s, a): match="Closing semaphore test but there remain unreleased leases .*", ): await sem.close() - # After close, the semaphore is reset - await sem.acquire() - with pytest.warns( - RuntimeWarning, - match="Closing semaphore test but there remain unreleased leases .*", - ): - await sem.close() + with pytest.raises(RuntimeError, match="not known"): + await sem.acquire() sem2 = await Semaphore(name="t2", max_leases=1) assert await sem2.acquire() From b42145c6617f39ace32464969cd1b5da61f81d85 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 19 Jul 2024 19:40:11 -0500 Subject: [PATCH 077/138] bump version to 2024.7.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 323475ed6f..44a79310f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ requires-python = ">=3.9" dependencies = [ "click >= 8.0", "cloudpickle >= 1.5.0", - "dask == 2024.7.0", + "dask == 2024.7.1", "jinja2 >= 2.10.3", "locket >= 1.0.0", "msgpack >= 1.0.0", From 4adf564c293bf12abedceaf99ea0b26306c1bd95 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Mon, 22 Jul 2024 14:17:42 +0200 Subject: [PATCH 078/138] leave a warning about future instantiation (#8782) --- distributed/client.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/distributed/client.py b/distributed/client.py index b5c6eb3789..ad283a352a 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -253,6 +253,11 @@ class Future(WrappedKey): manages future objects in the local Python process to determine what happens in the larger cluster. + .. note:: + + Users should not instantiate futures manually. This can lead to state + corruption and deadlocking clusters. + Parameters ---------- key: str, or tuple From 222e04769886740fc7d0ce4102cc20049265f221 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 23 Jul 2024 12:05:39 +0200 Subject: [PATCH 079/138] Make stealing more robust (#8788) --- distributed/scheduler.py | 4 +++- distributed/stealing.py | 24 +++++++++++++----------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a8f7d54630..09866e117f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -569,7 +569,9 @@ def __hash__(self) -> int: return self._hash def __eq__(self, other: object) -> bool: - return isinstance(other, WorkerState) and other.server_id == self.server_id + return self is other or ( + isinstance(other, WorkerState) and other.server_id == self.server_id + ) @property def has_what(self) -> Set[TaskState]: diff --git a/distributed/stealing.py b/distributed/stealing.py index 1d72e58a22..7e3711b8e2 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -178,13 +178,19 @@ def transition( *args: Any, **kwargs: Any, ) -> None: - if finish == "processing": - ts = self.scheduler.tasks[key] - self.put_key_in_stealable(ts) - elif start == "processing": + # By first checking whether we've started in processing + # and then checking whether we've finished in processing, + # this logic also handles transitions that end up in the same state. + # Since finish is the actual end state of the task, not the desired one, + # this could occur if a transaction decides against moving the task to the + # desired state. + if start == "processing": ts = self.scheduler.tasks[key] self.remove_key_from_stealable(ts) self._remove_from_in_flight(ts) + if finish == "processing": + ts = self.scheduler.tasks[key] + self.put_key_in_stealable(ts) def _add_to_in_flight(self, ts: TaskState, info: InFlightInfo) -> None: self.in_flight[ts] = info @@ -231,10 +237,7 @@ def remove_key_from_stealable(self, ts: TaskState) -> None: return worker, level = result - try: - self.stealable[worker][level].remove(ts) - except KeyError: - pass + self.stealable[worker][level].discard(ts) def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, None]: """The compute to communication time ratio of a key @@ -328,7 +331,7 @@ def move_task_request( pdb.set_trace() raise - async def move_task_confirm( + def move_task_confirm( self, *, key: str, state: str, stimulus_id: str, worker: str | None = None ) -> None: try: @@ -350,8 +353,7 @@ async def move_task_confirm( victim = info["victim"] logger.debug("Confirm move %s, %s -> %s. State: %s", key, victim, thief, state) - if self.scheduler.validate: - assert ts.processing_on == victim + assert ts.processing_on == victim try: _log_msg = [key, state, victim.address, thief.address, stimulus_id] From 7013e2e71bcf0a02905488447b98400ddad72d32 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 23 Jul 2024 17:09:31 +0200 Subject: [PATCH 080/138] Fix PackageInstall plugin (#8794) --- distributed/diagnostics/plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 7b497d511c..7e2c8eea75 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -518,7 +518,6 @@ async def setup(self, nanny): await Semaphore( max_leases=1, name=socket.gethostname(), - register=True, scheduler_rpc=nanny.scheduler, loop=nanny.loop, ) From 7cf5a3600b715908d563766fe2e2f5affc63fa31 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 25 Jul 2024 10:31:48 +0200 Subject: [PATCH 081/138] Add Prometheus metrics for dask_client_connections_{added|removed}_total (#8799) --- distributed/http/scheduler/prometheus/core.py | 12 ++++ .../scheduler/tests/test_scheduler_http.py | 69 ++++++++++++++++++- distributed/scheduler.py | 9 ++- docs/source/prometheus.rst | 12 ++++ 4 files changed, 100 insertions(+), 2 deletions(-) diff --git a/distributed/http/scheduler/prometheus/core.py b/distributed/http/scheduler/prometheus/core.py index 628f4c2e54..0657471968 100644 --- a/distributed/http/scheduler/prometheus/core.py +++ b/distributed/http/scheduler/prometheus/core.py @@ -31,6 +31,18 @@ def collect(self) -> Iterator[GaugeMetricFamily | CounterMetricFamily]: value=len([k for k in self.server.clients if k != "fire-and-forget"]), ) + yield CounterMetricFamily( + self.build_name("client_connections_added"), + "Total number of client connections added", + value=self.server._client_connections_added_total, + ) + + yield CounterMetricFamily( + self.build_name("client_connections_removed"), + "Total number of client connections removed", + value=self.server._client_connections_removed_total, + ) + yield GaugeMetricFamily( self.build_name("desired_workers"), "Number of workers scheduler needs for task graph", diff --git a/distributed/http/scheduler/tests/test_scheduler_http.py b/distributed/http/scheduler/tests/test_scheduler_http.py index 7be4365ffc..c35e9d5111 100644 --- a/distributed/http/scheduler/tests/test_scheduler_http.py +++ b/distributed/http/scheduler/tests/test_scheduler_http.py @@ -13,7 +13,7 @@ from dask.sizeof import sizeof from distributed import Event, Lock, Scheduler -from distributed.client import wait +from distributed.client import Client, wait from distributed.core import Status from distributed.utils import is_valid_xml, url_escape from distributed.utils_test import ( @@ -112,6 +112,8 @@ async def test_prometheus(c, s, a, b): expected_metrics = { "dask_scheduler_clients", + "dask_scheduler_client_connections_added", + "dask_scheduler_client_connections_removed", "dask_scheduler_desired_workers", "dask_scheduler_workers", "dask_scheduler_last_time", @@ -159,6 +161,71 @@ async def test_metrics_when_prometheus_client_not_installed(prometheus_not_avail assert "Prometheus metrics are not available" in body +@gen_cluster( + nthreads=[], +) +async def test_prometheus_collect_client_connections_totals(s): + pytest.importorskip("prometheus_client") + from prometheus_client.parser import text_string_to_metric_families + + http_client = AsyncHTTPClient() + + async def fetch_metrics(): + port = s.http_server.port + response = await http_client.fetch(f"http://localhost:{port}/metrics") + txt = response.body.decode("utf8") + families = { + family.name: family + for family in text_string_to_metric_families(txt) + if family.name + in ( + "dask_scheduler_client_connections_added", + "dask_scheduler_client_connections_removed", + ) + } + return { + name: [sample.value for sample in family.samples] + for name, family in families.items() + } + + assert await fetch_metrics() == { + "dask_scheduler_client_connections_added": [0], + "dask_scheduler_client_connections_removed": [0], + } + async with Client(s.address, asynchronous=True): + assert await fetch_metrics() == { + "dask_scheduler_client_connections_added": [1], + "dask_scheduler_client_connections_removed": [0], + } + + async with Client(s.address, asynchronous=True): + assert await fetch_metrics() == { + "dask_scheduler_client_connections_added": [2], + "dask_scheduler_client_connections_removed": [0], + } + + assert await fetch_metrics() == { + "dask_scheduler_client_connections_added": [2], + "dask_scheduler_client_connections_removed": [1], + } + + async with Client(s.address, asynchronous=True): + assert await fetch_metrics() == { + "dask_scheduler_client_connections_added": [3], + "dask_scheduler_client_connections_removed": [1], + } + + assert await fetch_metrics() == { + "dask_scheduler_client_connections_added": [3], + "dask_scheduler_client_connections_removed": [2], + } + + assert await fetch_metrics() == { + "dask_scheduler_client_connections_added": [3], + "dask_scheduler_client_connections_removed": [3], + } + + @gen_cluster(client=True, clean_kwargs={"threads": False}) async def test_prometheus_collect_task_states(c, s, a, b): pytest.importorskip("prometheus_client") diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 09866e117f..14115cc749 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3598,6 +3598,9 @@ class Scheduler(SchedulerState, ServerNode): _no_workers_since: float | None # Note: not None iff there are pending tasks no_workers_timeout: float | None + _client_connections_added_total: int + _client_connections_removed_total: int + def __init__( self, loop=None, @@ -3962,6 +3965,9 @@ async def post(self): Scheduler._instances.add(self) self.rpc.allow_offload = False + self._client_connections_added_total = 0 + self._client_connections_removed_total = 0 + ################## # Administration # ################## @@ -5770,6 +5776,7 @@ async def add_client( logger.info("Receive client connection: %s", client) self.log_event(["all", client], {"action": "add-client", "client": client}) self.clients[client] = ClientState(client, versions=versions) + self._client_connections_added_total += 1 for plugin in list(self.plugins.values()): try: @@ -5825,7 +5832,7 @@ def remove_client(self, client: str, stimulus_id: str | None = None) -> None: stimulus_id=stimulus_id, ) del self.clients[client] - + self._client_connections_removed_total += 1 for plugin in list(self.plugins.values()): try: plugin.remove_client(scheduler=self, client=client) diff --git a/docs/source/prometheus.rst b/docs/source/prometheus.rst index 6317e57195..330cf9242e 100644 --- a/docs/source/prometheus.rst +++ b/docs/source/prometheus.rst @@ -24,6 +24,18 @@ The scheduler exposes the following metrics about itself: dask_scheduler_clients Number of clients connected +dask_scheduler_client_connections_added_total + Total number of client connections added to the scheduler + + .. note:: + This metric does *not* count distinct clients. If a client disconnects + and reconnects later on, it will be counted twice. +dask_scheduler_client_connections_removed_total + Total number of client connections removed from the scheduler + + .. note:: + This metric does *not* count distinct clients. If a client disconnects, + then reconnects and disconnects again, it will be counted twice. dask_scheduler_desired_workers Number of workers scheduler needs for task graph dask_scheduler_gil_contention_seconds_total From cd82d0430a460523f08bffbc21944d2e9b571192 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 25 Jul 2024 12:32:13 +0200 Subject: [PATCH 082/138] Add log event for `worker-ttl-timed-out` (#8800) --- distributed/scheduler.py | 8 ++++++++ distributed/tests/test_failed_workers.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 14115cc749..6ea8644543 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -8498,6 +8498,14 @@ async def check_worker_ttl(self) -> None: ) if to_restart: + self.log_event( + "scheduler", + { + "action": "worker-ttl-timed-out", + "workers": to_restart.copy(), + "ttl": ttl, + }, + ) await self.restart_workers( to_restart, wait_for_workers=False, diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index b62e40b8e4..35401f7f5f 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -456,6 +456,10 @@ async def test_worker_time_to_live(c, s, a, b): # Note that this value is ignored because is less than 10x heartbeat_interval assert s.worker_ttl == 0.5 assert set(s.workers) == {a.address, b.address} + assert all( + event["action"] != "worker-ttl-timed-out" + for _, event in s.get_events("scheduler") + ) a.periodic_callbacks["heartbeat"].stop() @@ -463,6 +467,18 @@ async def test_worker_time_to_live(c, s, a, b): while set(s.workers) == {a.address, b.address}: await asyncio.sleep(0.01) assert set(s.workers) == {b.address} + events = [ + event + for _, event in s.get_events("scheduler") + if event["action"] == "worker-ttl-timed-out" + ] + assert len(events) == 1 + # This event includes the actual TTL that we applied, i.e, 10 * heartbeat. + assert events[0] == { + "action": "worker-ttl-timed-out", + "workers": [a.address], + "ttl": 5.0, + } # Worker removal is triggered after 10 * heartbeat # This is 10 * 0.5s at the moment of writing. From 30d8139aefa873e4c7e853aefb094a4f2e71f12d Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 25 Jul 2024 12:32:47 +0200 Subject: [PATCH 083/138] Add Prometheus metrics for `dask_worker_{added|removed}_total` (#8798) --- distributed/http/scheduler/prometheus/core.py | 12 ++++ .../scheduler/tests/test_scheduler_http.py | 66 +++++++++++++++++++ distributed/scheduler.py | 6 ++ docs/source/prometheus.rst | 4 ++ 4 files changed, 88 insertions(+) diff --git a/distributed/http/scheduler/prometheus/core.py b/distributed/http/scheduler/prometheus/core.py index 0657471968..1770a0aafe 100644 --- a/distributed/http/scheduler/prometheus/core.py +++ b/distributed/http/scheduler/prometheus/core.py @@ -73,6 +73,18 @@ def collect(self) -> Iterator[GaugeMetricFamily | CounterMetricFamily]: yield worker_states + yield CounterMetricFamily( + self.build_name("workers_added"), + "Total number of workers added", + value=self.server._workers_added_total, + ) + + yield CounterMetricFamily( + self.build_name("workers_removed"), + "Total number of workers removed", + value=self.server._workers_removed_total, + ) + if self.server.monitor.monitor_gil_contention: yield CounterMetricFamily( self.build_name("gil_contention"), diff --git a/distributed/http/scheduler/tests/test_scheduler_http.py b/distributed/http/scheduler/tests/test_scheduler_http.py index c35e9d5111..630c63d292 100644 --- a/distributed/http/scheduler/tests/test_scheduler_http.py +++ b/distributed/http/scheduler/tests/test_scheduler_http.py @@ -28,6 +28,7 @@ slowinc, wait_for_state, ) +from distributed.worker import Worker DEFAULT_ROUTES = dask.config.get("distributed.scheduler.http.routes") @@ -116,6 +117,8 @@ async def test_prometheus(c, s, a, b): "dask_scheduler_client_connections_removed", "dask_scheduler_desired_workers", "dask_scheduler_workers", + "dask_scheduler_workers_added", + "dask_scheduler_workers_removed", "dask_scheduler_last_time", "dask_scheduler_tasks", "dask_scheduler_tasks_suspicious", @@ -383,6 +386,69 @@ async def fetch_metrics(): assert prefix_state_counts.get(("div", "erred")) == 1 +@gen_cluster( + client=True, + nthreads=[], +) +async def test_prometheus_collect_worker_totals(c, s): + pytest.importorskip("prometheus_client") + from prometheus_client.parser import text_string_to_metric_families + + http_client = AsyncHTTPClient() + + async def fetch_metrics(): + port = s.http_server.port + response = await http_client.fetch(f"http://localhost:{port}/metrics") + txt = response.body.decode("utf8") + families = { + family.name: family + for family in text_string_to_metric_families(txt) + if family.name + in ("dask_scheduler_workers_added", "dask_scheduler_workers_removed") + } + return { + name: [sample.value for sample in family.samples] + for name, family in families.items() + } + + assert await fetch_metrics() == { + "dask_scheduler_workers_added": [0], + "dask_scheduler_workers_removed": [0], + } + async with Worker(s.address): + assert await fetch_metrics() == { + "dask_scheduler_workers_added": [1], + "dask_scheduler_workers_removed": [0], + } + + async with Worker(s.address): + assert await fetch_metrics() == { + "dask_scheduler_workers_added": [2], + "dask_scheduler_workers_removed": [0], + } + + assert await fetch_metrics() == { + "dask_scheduler_workers_added": [2], + "dask_scheduler_workers_removed": [1], + } + + async with Worker(s.address): + assert await fetch_metrics() == { + "dask_scheduler_workers_added": [3], + "dask_scheduler_workers_removed": [1], + } + + assert await fetch_metrics() == { + "dask_scheduler_workers_added": [3], + "dask_scheduler_workers_removed": [2], + } + + assert await fetch_metrics() == { + "dask_scheduler_workers_added": [3], + "dask_scheduler_workers_removed": [3], + } + + @gen_cluster( client=True, config={"distributed.worker.memory.monitor-interval": "10ms"}, diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6ea8644543..f88d0d0a61 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3600,6 +3600,8 @@ class Scheduler(SchedulerState, ServerNode): _client_connections_added_total: int _client_connections_removed_total: int + _workers_added_total: int + _workers_removed_total: int def __init__( self, @@ -3967,6 +3969,8 @@ async def post(self): self._client_connections_added_total = 0 self._client_connections_removed_total = 0 + self._workers_added_total = 0 + self._workers_removed_total = 0 ################## # Administration # @@ -4433,6 +4437,7 @@ async def add_worker( server_id=server_id, scheduler=self, ) + self._workers_added_total += 1 if ws.status == Status.running: self.running.add(ws) @@ -5315,6 +5320,7 @@ async def remove_worker( self.idle_task_count.discard(ws) self.saturated.discard(ws) del self.workers[address] + self._workers_removed_total += 1 ws.status = Status.closed self.running.discard(ws) diff --git a/docs/source/prometheus.rst b/docs/source/prometheus.rst index 330cf9242e..d1332456ce 100644 --- a/docs/source/prometheus.rst +++ b/docs/source/prometheus.rst @@ -51,6 +51,10 @@ dask_scheduler_gil_contention_seconds_total dask_scheduler_workers Number of workers known by scheduler +dask_scheduler_workers_added_total + Total numbers of workers added to the scheduler +dask_scheduler_workers_removed_total + Total number of workers removed from the scheduler dask_scheduler_last_time_total Cumulative SystemMonitor time dask_scheduler_tasks From 386e5fea1cde4aefaf821e319405188266b41832 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 29 Jul 2024 09:23:15 +0200 Subject: [PATCH 084/138] Add Prometheus metric for time spent on GC (#8803) --- distributed/http/scheduler/prometheus/core.py | 8 ++++ .../scheduler/tests/test_scheduler_http.py | 1 + distributed/http/worker/prometheus/core.py | 8 ++++ .../http/worker/tests/test_worker_http.py | 1 + distributed/tests/test_utils_perf.py | 6 +++ distributed/utils_perf.py | 38 ++++++++++++++++--- docs/source/prometheus.rst | 14 +++++++ 7 files changed, 70 insertions(+), 6 deletions(-) diff --git a/distributed/http/scheduler/prometheus/core.py b/distributed/http/scheduler/prometheus/core.py index 1770a0aafe..9240c98d7c 100644 --- a/distributed/http/scheduler/prometheus/core.py +++ b/distributed/http/scheduler/prometheus/core.py @@ -13,6 +13,7 @@ from distributed.http.scheduler.prometheus.stealing import WorkStealingMetricCollector from distributed.http.utils import RequestHandler from distributed.scheduler import ALL_TASK_STATES, Scheduler +from distributed.utils_perf import gc_collect_duration class SchedulerMetricCollector(PrometheusCollector): @@ -93,6 +94,13 @@ def collect(self) -> Iterator[GaugeMetricFamily | CounterMetricFamily]: unit="seconds", ) + yield CounterMetricFamily( + self.build_name("gc_collection"), + "Total time spent on garbage collection", + value=gc_collect_duration(), + unit="seconds", + ) + yield CounterMetricFamily( self.build_name("last_time"), "SystemMonitor last time", diff --git a/distributed/http/scheduler/tests/test_scheduler_http.py b/distributed/http/scheduler/tests/test_scheduler_http.py index 630c63d292..14793cae46 100644 --- a/distributed/http/scheduler/tests/test_scheduler_http.py +++ b/distributed/http/scheduler/tests/test_scheduler_http.py @@ -130,6 +130,7 @@ async def test_prometheus(c, s, a, b): "dask_scheduler_prefix_state_totals", "dask_scheduler_tick_count", "dask_scheduler_tick_duration_maximum_seconds", + "dask_scheduler_gc_collection_seconds", } try: diff --git a/distributed/http/worker/prometheus/core.py b/distributed/http/worker/prometheus/core.py index dcb3d77d3d..4f0c2cb100 100644 --- a/distributed/http/worker/prometheus/core.py +++ b/distributed/http/worker/prometheus/core.py @@ -10,6 +10,7 @@ from distributed.http.prometheus import PrometheusCollector from distributed.http.utils import RequestHandler +from distributed.utils_perf import gc_collect_duration from distributed.worker import Worker logger = logging.getLogger("distributed.prometheus.worker") @@ -68,6 +69,13 @@ def collect(self) -> Iterator[Metric]: unit="seconds", ) + yield CounterMetricFamily( + self.build_name("gc_collection"), + "Total time spent on garbage collection", + value=gc_collect_duration(), + unit="seconds", + ) + yield GaugeMetricFamily( self.build_name("threads"), "Number of worker threads", diff --git a/distributed/http/worker/tests/test_worker_http.py b/distributed/http/worker/tests/test_worker_http.py index 8fc56662bf..2d9903af98 100644 --- a/distributed/http/worker/tests/test_worker_http.py +++ b/distributed/http/worker/tests/test_worker_http.py @@ -52,6 +52,7 @@ async def test_prometheus(c, s, a): "dask_worker_transfer_outgoing_count", "dask_worker_transfer_outgoing_count_total", "dask_worker_transfer_outgoing_bytes_total", + "dask_worker_gc_collection_seconds_total", } try: diff --git a/distributed/tests/test_utils_perf.py b/distributed/tests/test_utils_perf.py index daf83714ed..2e66141b2a 100644 --- a/distributed/tests/test_utils_perf.py +++ b/distributed/tests/test_utils_perf.py @@ -48,14 +48,20 @@ def check_fraction(timer, ft): timer = RandomTimer() ft = FractionalTimer(n_samples=N, timer=timer) + assert ft.duration_total == 0 for _ in range(N): ft.start_timing() ft.stop_timing() + expected_total = sum(ft._durations) + assert ft.duration_total == pytest.approx(expected_total / ft.MULT) assert len(timer.timings) == N * 2 assert ft.running_fraction is None + assert ft.duration_total > 0 ft.start_timing() ft.stop_timing() + expected_total += ft._durations[-1] + assert ft.duration_total == pytest.approx(expected_total / ft.MULT) assert len(timer.timings) == (N + 1) * 2 assert ft.running_fraction is not None check_fraction(timer, ft) diff --git a/distributed/utils_perf.py b/distributed/utils_perf.py index d1fc6cc46f..ad406d6135 100644 --- a/distributed/utils_perf.py +++ b/distributed/utils_perf.py @@ -4,6 +4,7 @@ import logging import threading from collections import deque +from typing import Callable, Final import psutil @@ -76,9 +77,18 @@ class FractionalTimer: elapsed time. """ - MULT = 1e9 # convert to nanoseconds + MULT: Final[float] = 1e9 # convert to nanoseconds - def __init__(self, n_samples, timer=thread_time): + _timer: Callable[[], float] + _n_samples: int + _start_stops: deque[tuple[float, float]] + _durations: deque[int] + _cur_start: float | None + _running_sum: int | None + _running_fraction: float | None + _duration_total: int + + def __init__(self, n_samples: int, timer: Callable[[], float] = thread_time): self._timer = timer self._n_samples = n_samples self._start_stops = deque() @@ -86,8 +96,9 @@ def __init__(self, n_samples, timer=thread_time): self._cur_start = None self._running_sum = None self._running_fraction = None + self._duration_total = 0 - def _add_measurement(self, start, stop): + def _add_measurement(self, start: float, stop: float) -> None: start_stops = self._start_stops durations = self._durations if stop < start or (start_stops and start < start_stops[-1][1]): @@ -98,6 +109,7 @@ def _add_measurement(self, start, stop): duration = int((stop - start) * self.MULT) start_stops.append((start, stop)) durations.append(duration) + self._duration_total += duration n = len(durations) assert n == len(start_stops) @@ -114,11 +126,11 @@ def _add_measurement(self, start, stop): self._running_sum / (stop - old_stop) / self.MULT ) - def start_timing(self): + def start_timing(self) -> None: assert self._cur_start is None self._cur_start = self._timer() - def stop_timing(self): + def stop_timing(self) -> None: stop = self._timer() start = self._cur_start self._cur_start = None @@ -126,7 +138,14 @@ def stop_timing(self): self._add_measurement(start, stop) @property - def running_fraction(self): + def duration_total(self) -> float: + current_duration = 0.0 + if self._cur_start is not None: + current_duration = self._timer() - self._cur_start + return self._duration_total / self.MULT + current_duration + + @property + def running_fraction(self) -> float | None: return self._running_fraction @@ -145,6 +164,7 @@ def __init__(self, warn_over_frac=0.1, info_over_rss_win=10 * 1e6): self._warn_over_frac = warn_over_frac self._info_over_rss_win = info_over_rss_win self._enabled = False + self._fractional_timer = None def enable(self): assert not self._enabled @@ -244,3 +264,9 @@ def disable_gc_diagnosis(force=False): _gc_diagnosis_users = 0 else: assert _gc_diagnosis.enabled + + +def gc_collect_duration() -> float: + if _gc_diagnosis._fractional_timer is None: + return 0 + return _gc_diagnosis._fractional_timer.duration_total diff --git a/docs/source/prometheus.rst b/docs/source/prometheus.rst index d1332456ce..a38c85f022 100644 --- a/docs/source/prometheus.rst +++ b/docs/source/prometheus.rst @@ -49,6 +49,13 @@ dask_scheduler_gil_contention_seconds_total ``distributed.admin.system-monitor.gil.enabled`` configuration to be set. +dask_scheduler_gc_collection_seconds_total + Total time spent on garbage dask_scheduler_gc_collection_seconds_total + + .. note:: + Due to measurement overhead, this metric only measures + time spent on garbage collection for generation=2 + dask_scheduler_workers Number of workers known by scheduler dask_scheduler_workers_added_total @@ -159,6 +166,13 @@ dask_worker_gil_contention_seconds_total ``distributed.admin.system-monitor.gil.enabled`` configuration to be set. +dask_worker_gc_collection_seconds_total + Total time spent on garbage dask_scheduler_gc_collection_seconds_total + + .. note:: + Due to measurement overhead, this metric only measures + time spent on garbage collection for generation=2 + dask_worker_latency_seconds Latency of worker connection dask_worker_memory_bytes From 5e1d3657e683adbe01a003b1c656e318ebd0ee1f Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 30 Jul 2024 11:44:15 +0200 Subject: [PATCH 085/138] Change log level for Compute Failed log message (#8802) --- distributed/tests/test_worker.py | 6 +++--- distributed/worker.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index fa78641828..2c02b79bfb 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -171,9 +171,9 @@ def reset(self): tb = await y._traceback() assert any("1 / 0" in line for line in pluck(3, traceback.extract_tb(tb)) if line) - assert "Compute Failed" in hdlr.messages["warning"][0] - assert y.key in hdlr.messages["warning"][0] - assert "executing" in hdlr.messages["warning"][0] + assert "Compute Failed" in hdlr.messages["error"][0] + assert y.key in hdlr.messages["error"][0] + assert "executing" in hdlr.messages["error"][0] logger.setLevel(old_level) # Now we check that both workers are still alive. diff --git a/distributed/worker.py b/distributed/worker.py index 17d715507e..ac4231a2a6 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2341,7 +2341,7 @@ async def execute(self, key: Key, *, stimulus_id: str) -> StateMachineEvent: ) if ts.state in ("executing", "long-running", "resumed"): - logger.warning( + logger.error( "Compute Failed\n" "Key: %s\n" "State: %s\n" From 40fcd65e991382a956c3b879e438be1b100dff97 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 30 Jul 2024 13:17:09 +0200 Subject: [PATCH 086/138] Remove unused `delete_interval` and `synchronize_worker_interval` from `Scheduler` (#8801) --- distributed/scheduler.py | 6 ------ distributed/tests/test_client.py | 19 ------------------- 2 files changed, 25 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f88d0d0a61..4f31ed44aa 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3606,8 +3606,6 @@ class Scheduler(SchedulerState, ServerNode): def __init__( self, loop=None, - delete_interval="500ms", - synchronize_worker_interval="60s", services=None, service_kwargs=None, allowed_failures=None, @@ -3656,10 +3654,6 @@ def __init__( if validate is None: validate = dask.config.get("distributed.scheduler.validate") self.proc = psutil.Process() - self.delete_interval = parse_timedelta(delete_interval, default="ms") - self.synchronize_worker_interval = parse_timedelta( - synchronize_worker_interval, default="ms" - ) self.service_specs = services or {} self.service_kwargs = service_kwargs or {} self.services = {} diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index c3c4051929..02b065ceda 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -2660,25 +2660,6 @@ async def test_futures_of_cancelled_raises(c, s, a, b): await c.gather(c.map(add, [1], y=x)) -@pytest.mark.skip -@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True) -async def test_dont_delete_recomputed_results(c, s, w): - x = c.submit(inc, 1) # compute first time - await wait([x]) - x.__del__() # trigger garbage collection - await asyncio.sleep(0) - xx = c.submit(inc, 1) # compute second time - - start = time() - while xx.key not in w.data: # data shows up - await asyncio.sleep(0.01) - assert time() < start + 1 - - while time() < start + (s.delete_interval + 100) / 1000: # and stays - assert xx.key in w.data - await asyncio.sleep(0.01) - - @pytest.mark.skip(reason="Use fast random selection now") @gen_cluster(client=True) async def test_balance_tasks_by_stacks(c, s, a, b): From c44ad2293b1f2b857418c3dcaf6e7c8c249e23de Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 30 Jul 2024 18:39:47 +0200 Subject: [PATCH 087/138] Reduce noise from GC-related logging (#8804) --- distributed/config.py | 1 + distributed/distributed.yaml | 1 + distributed/{utils_perf.py => gc.py} | 33 +++++++++++-------- distributed/http/scheduler/prometheus/core.py | 2 +- distributed/http/worker/prometheus/core.py | 2 +- distributed/scheduler.py | 2 +- .../tests/{test_utils_perf.py => test_gc.py} | 13 ++++---- distributed/worker.py | 2 +- distributed/worker_memory.py | 2 +- 9 files changed, 33 insertions(+), 25 deletions(-) rename distributed/{utils_perf.py => gc.py} (90%) rename distributed/tests/{test_utils_perf.py => test_gc.py} (91%) diff --git a/distributed/config.py b/distributed/config.py index b6e5f3d0d3..659f147545 100644 --- a/distributed/config.py +++ b/distributed/config.py @@ -97,6 +97,7 @@ def _initialize_logging_old_style(config: dict[Any, Any]) -> None: loggers: dict[str, str | int] = { # default values "distributed": "info", "distributed.client": "warning", + "distributed.gc": "warning", "bokeh": "error", "tornado": "critical", "tornado.application": "error", diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index a7ec037e27..26b88322d0 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -3,6 +3,7 @@ distributed: # logging: # distributed: info # distributed.client: warning + # distributed.gc: warning # bokeh: error # # http://stackoverflow.com/questions/21234772/python-tornado-disable-logging-to-stderr # tornado: critical diff --git a/distributed/utils_perf.py b/distributed/gc.py similarity index 90% rename from distributed/utils_perf.py rename to distributed/gc.py index ad406d6135..5a02411ae3 100644 --- a/distributed/utils_perf.py +++ b/distributed/gc.py @@ -11,8 +11,10 @@ from dask.utils import format_bytes from distributed.metrics import thread_time +from distributed.utils import RateLimiterFilter logger = _logger = logging.getLogger(__name__) +logger.addFilter(RateLimiterFilter("full garbage collections took", rate="60s")) class ThrottledGC: @@ -24,7 +26,7 @@ class ThrottledGC: collect() does nothing when repeated calls are so costly and so frequent that the thread would spend more than max_in_gc_frac doing GC. - warn_if_longer is a duration in seconds (10s by default) that can be used + warn_if_longer is a duration in seconds (1s by default) that can be used to log a warning level message whenever an actual call to gc.collect() lasts too long. """ @@ -160,8 +162,8 @@ class GCDiagnosis: N_SAMPLES = 30 - def __init__(self, warn_over_frac=0.1, info_over_rss_win=10 * 1e6): - self._warn_over_frac = warn_over_frac + def __init__(self, info_over_frac=0.1, info_over_rss_win=10 * 1e6): + self._info_over_frac = info_over_frac self._info_over_rss_win = info_over_rss_win self._enabled = False self._fractional_timer = None @@ -206,22 +208,25 @@ def _gc_callback(self, phase, info): assert phase == "stop" self._fractional_timer.stop_timing() frac = self._fractional_timer.running_fraction - if frac is not None and frac >= self._warn_over_frac: - logger.warning( + if frac is not None: + level = logging.INFO if frac >= self._info_over_frac else logging.DEBUG + logger.log( + level, "full garbage collections took %d%% CPU time " "recently (threshold: %d%%)", 100 * frac, - 100 * self._warn_over_frac, + 100 * self._info_over_frac, ) rss_saved = self._gc_rss_before - rss - if rss_saved >= self._info_over_rss_win: - logger.info( - "full garbage collection released %s " - "from %d reference cycles (threshold: %s)", - format_bytes(rss_saved), - info["collected"], - format_bytes(self._info_over_rss_win), - ) + level = logging.INFO if rss_saved >= self._info_over_rss_win else logging.DEBUG + logger.log( + level, + "full garbage collection released %s " + "from %d reference cycles (threshold: %s)", + format_bytes(rss_saved), + info["collected"], + format_bytes(self._info_over_rss_win), + ) if info["uncollectable"] > 0: # This should ideally never happen on Python 3, but who knows? logger.warning( diff --git a/distributed/http/scheduler/prometheus/core.py b/distributed/http/scheduler/prometheus/core.py index 9240c98d7c..173c0c6200 100644 --- a/distributed/http/scheduler/prometheus/core.py +++ b/distributed/http/scheduler/prometheus/core.py @@ -8,12 +8,12 @@ from prometheus_client.core import CounterMetricFamily, GaugeMetricFamily from distributed.core import Status +from distributed.gc import gc_collect_duration from distributed.http.prometheus import PrometheusCollector from distributed.http.scheduler.prometheus.semaphore import SemaphoreMetricCollector from distributed.http.scheduler.prometheus.stealing import WorkStealingMetricCollector from distributed.http.utils import RequestHandler from distributed.scheduler import ALL_TASK_STATES, Scheduler -from distributed.utils_perf import gc_collect_duration class SchedulerMetricCollector(PrometheusCollector): diff --git a/distributed/http/worker/prometheus/core.py b/distributed/http/worker/prometheus/core.py index 4f0c2cb100..b6d7172c72 100644 --- a/distributed/http/worker/prometheus/core.py +++ b/distributed/http/worker/prometheus/core.py @@ -8,9 +8,9 @@ import prometheus_client from prometheus_client.core import CounterMetricFamily, GaugeMetricFamily, Metric +from distributed.gc import gc_collect_duration from distributed.http.prometheus import PrometheusCollector from distributed.http.utils import RequestHandler -from distributed.utils_perf import gc_collect_duration from distributed.worker import Worker logger = logging.getLogger("distributed.prometheus.worker") diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4f31ed44aa..7dc30b8893 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -101,6 +101,7 @@ from distributed.diagnostics.memory_sampler import MemorySamplerExtension from distributed.diagnostics.plugin import SchedulerPlugin, _get_plugin_name from distributed.event import EventExtension +from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis from distributed.http import get_handlers from distributed.metrics import time from distributed.multi_lock import MultiLockExtension @@ -136,7 +137,6 @@ scatter_to_workers, unpack_remotedata, ) -from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis from distributed.variable import VariableExtension from distributed.worker import _normalize_task diff --git a/distributed/tests/test_utils_perf.py b/distributed/tests/test_gc.py similarity index 91% rename from distributed/tests/test_utils_perf.py rename to distributed/tests/test_gc.py index 2e66141b2a..3d8b4b7264 100644 --- a/distributed/tests/test_utils_perf.py +++ b/distributed/tests/test_gc.py @@ -8,8 +8,8 @@ import pytest +from distributed.gc import FractionalTimer, GCDiagnosis, disable_gc_diagnosis from distributed.metrics import thread_time -from distributed.utils_perf import FractionalTimer, GCDiagnosis, disable_gc_diagnosis from distributed.utils_test import captured_logger, run_for @@ -78,7 +78,7 @@ def enable_gc_diagnosis_and_log(diag, level="INFO"): if gc.callbacks: print("Unexpected gc.callbacks", gc.callbacks) - with captured_logger("distributed.utils_perf", level=level, propagate=False) as sio: + with captured_logger("distributed.gc", level=level, propagate=False) as sio: gc.disable() gc.collect() # drain any leftover from previous tests diag.enable() @@ -89,17 +89,18 @@ def enable_gc_diagnosis_and_log(diag, level="INFO"): gc.enable() -@pytest.mark.slow +# @pytest.mark.slow def test_gc_diagnosis_cpu_time(): - diag = GCDiagnosis(warn_over_frac=0.75) + diag = GCDiagnosis(info_over_frac=0.75) diag.N_SAMPLES = 3 # shorten tests - with enable_gc_diagnosis_and_log(diag, level="WARN") as sio: + with enable_gc_diagnosis_and_log(diag, level="INFO") as sio: # Spend some CPU time doing only full GCs for _ in range(diag.N_SAMPLES): gc.collect() assert not sio.getvalue() gc.collect() + gc.collect() lines = sio.getvalue().splitlines() assert len(lines) == 1 # Between 80% and 100% @@ -108,7 +109,7 @@ def test_gc_diagnosis_cpu_time(): lines[0], ) - with enable_gc_diagnosis_and_log(diag, level="WARN") as sio: + with enable_gc_diagnosis_and_log(diag, level="INFO") as sio: # Spend half the CPU time doing full GCs for _ in range(diag.N_SAMPLES + 1): t1 = thread_time() diff --git a/distributed/worker.py b/distributed/worker.py index ac4231a2a6..5778d60dac 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -83,6 +83,7 @@ from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name from distributed.diskutils import WorkSpace from distributed.exceptions import Reschedule +from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis from distributed.http import get_handlers from distributed.metrics import context_meter, thread_time, time from distributed.node import ServerNode @@ -114,7 +115,6 @@ wait_for, ) from distributed.utils_comm import gather_from_workers, pack_data, retry_operation -from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis from distributed.versions import get_versions from distributed.worker_memory import ( DeprecatedMemoryManagerAttribute, diff --git a/distributed/worker_memory.py b/distributed/worker_memory.py index 7484263a3d..ac2f3c8188 100644 --- a/distributed/worker_memory.py +++ b/distributed/worker_memory.py @@ -40,10 +40,10 @@ from distributed import system from distributed.compatibility import WINDOWS, PeriodicCallback from distributed.core import Status +from distributed.gc import ThrottledGC from distributed.metrics import context_meter, monotonic from distributed.spill import ManualEvictProto, SpillBuffer from distributed.utils import RateLimiterFilter, has_arg, log_errors -from distributed.utils_perf import ThrottledGC if TYPE_CHECKING: # TODO import from typing (requires Python >=3.10) From 564f28b5637b56fd2e6870c4489d299f8984576a Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 31 Jul 2024 20:14:10 +0200 Subject: [PATCH 088/138] Ensure that adaptive only stops once (#8807) --- distributed/deploy/adaptive.py | 2 - distributed/deploy/adaptive_core.py | 55 ++++++++++++++----- .../deploy/tests/test_adaptive_core.py | 19 ++++++- 3 files changed, 58 insertions(+), 18 deletions(-) diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 3f61964100..1638659db4 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -109,8 +109,6 @@ def __init__( self.target_duration = parse_timedelta(target_duration) - logger.info("Adaptive scaling started: minimum=%s maximum=%s", minimum, maximum) - super().__init__( minimum=minimum, maximum=maximum, wait_count=wait_count, interval=interval ) diff --git a/distributed/deploy/adaptive_core.py b/distributed/deploy/adaptive_core.py index 0543e117b6..ccb81008cf 100644 --- a/distributed/deploy/adaptive_core.py +++ b/distributed/deploy/adaptive_core.py @@ -5,7 +5,7 @@ from collections import defaultdict, deque from collections.abc import Iterable from datetime import timedelta -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Literal, cast import tlz as toolz from tornado.ioloop import IOLoop @@ -17,12 +17,21 @@ from distributed.metrics import time if TYPE_CHECKING: - from distributed.scheduler import WorkerState + from typing_extensions import TypeAlias + from distributed.scheduler import WorkerState logger = logging.getLogger(__name__) +AdaptiveStateState: TypeAlias = Literal[ + "starting", + "running", + "stopped", + "inactive", +] + + class AdaptiveCore: """ The core logic for adaptive deployments, with none of the cluster details @@ -89,6 +98,8 @@ class AdaptiveCore: observed: set[WorkerState] close_counts: defaultdict[WorkerState, int] _adapting: bool + #: Whether this adaptive strategy is periodically adapting + _state: AdaptiveStateState log: deque[tuple[float, dict]] def __init__( @@ -107,12 +118,6 @@ def __init__( self.interval = parse_timedelta(interval, "seconds") self.periodic_callback = None - def f(): - try: - self.periodic_callback.start() - except AttributeError: - pass - if self.interval: import weakref @@ -124,8 +129,10 @@ async def _adapt(): await core.adapt() self.periodic_callback = PeriodicCallback(_adapt, self.interval * 1000) - self.loop.add_callback(f) - + self._state = "starting" + self.loop.add_callback(self._start) + else: + self._state = "inactive" try: self.plan = set() self.requested = set() @@ -140,12 +147,34 @@ async def _adapt(): maxlen=dask.config.get("distributed.admin.low-level-log-length") ) + def _start(self) -> None: + if self._state != "starting": + return + + assert self.periodic_callback is not None + self.periodic_callback.start() + self._state = "running" + logger.info( + "Adaptive scaling started: minimum=%s maximum=%s", + self.minimum, + self.maximum, + ) + def stop(self) -> None: - logger.info("Adaptive stop") + if self._state in ("inactive", "stopped"): + return - if self.periodic_callback: + if self._state == "running": + assert self.periodic_callback is not None self.periodic_callback.stop() - self.periodic_callback = None + logger.info( + "Adaptive scaling stopped: minimum=%s maximum=%s", + self.minimum, + self.maximum, + ) + + self.periodic_callback = None + self._state = "stopped" async def target(self) -> int: """The target number of workers that should exist""" diff --git a/distributed/deploy/tests/test_adaptive_core.py b/distributed/deploy/tests/test_adaptive_core.py index 69a74f1c5d..b5cfd734ab 100644 --- a/distributed/deploy/tests/test_adaptive_core.py +++ b/distributed/deploy/tests/test_adaptive_core.py @@ -104,11 +104,12 @@ def safe_target(self): raise OSError() with captured_logger("distributed.deploy.adaptive_core") as log: - adapt = BadAdaptive(minimum=1, maximum=4) - await adapt.adapt() + adapt = BadAdaptive(minimum=1, maximum=4, interval="10ms") + while adapt._state != "stopped": + await asyncio.sleep(0.01) text = log.getvalue() assert "Adaptive stopping due to error" in text - assert "Adaptive stop" in text + assert "Adaptive scaling stopped" in text assert not adapt._adapting assert not adapt.periodic_callback @@ -147,6 +148,18 @@ async def scale_down(self, workers=None): adapt.stop() +@gen_test() +async def test_adaptive_logs_stopping_once(): + with captured_logger("distributed.deploy.adaptive_core") as log: + adapt = MyAdaptive(interval="100ms") + while not adapt.periodic_callback.is_running(): + await asyncio.sleep(0.01) + adapt.stop() + adapt.stop() + lines = log.getvalue().splitlines() + assert sum("Adaptive scaling stopped" in line for line in lines) == 1 + + @gen_test() async def test_adapt_stop_del(): adapt = MyAdaptive(interval="100ms") From 3d47d3a349bcdcc51dde8d323866cf711a7fc897 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Thu, 1 Aug 2024 14:34:29 +0200 Subject: [PATCH 089/138] Fix if-else for send_recv_from_rpc (#8809) --- distributed/core.py | 4 ++-- distributed/tests/test_core.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 28467594b5..d5eacb7d3c 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -1174,11 +1174,11 @@ async def send_recv_from_rpc(**kwargs): except (RPCClosed, CommClosedError) as e: if comm: raise type(e)( - f"Exception while trying to call remote method {key!r} before comm was established." + f"Exception while trying to call remote method {key!r} using comm {comm!r}." ) from e else: raise type(e)( - f"Exception while trying to call remote method {key!r} using comm {comm!r}." + f"Exception while trying to call remote method {key!r} before comm was established." ) from e self.comms[comm] = True # mark as open diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index af25353bf2..0e94dc96a5 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -22,6 +22,7 @@ from distributed.comm.tcp import TCPBackend, TCPListener from distributed.core import ( ConnectionPool, + RPCClosed, Server, Status, _expects_comm, @@ -923,6 +924,20 @@ async def test_rpc_serialization(): assert result == {"result": inc} +@gen_test() +async def test_rpc_closed_exception(): + async with Server({"echo": echo_serialize}) as server: + await server.listen("tcp://") + + async with rpc(server.address, serializers=["msgpack"]) as r: + r.status = Status.closed + with pytest.raises( + RPCClosed, + match="Exception while trying to call remote method .* before comm was established.", + ): + await r.__getattr__("foo")() + + @gen_cluster() async def test_thread_id(s, a, b): assert s.thread_id == a.thread_id == b.thread_id == threading.get_ident() From 21739997c48965cae504d372dcbc71ff68004bfa Mon Sep 17 00:00:00 2001 From: alex-rakowski Date: Fri, 2 Aug 2024 10:18:26 +0100 Subject: [PATCH 090/138] typo fix (#8812) --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 7dc30b8893..4b5791db80 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2100,7 +2100,7 @@ def _transition( "key": key, "start": start, "finish": finish, - "transistion_log": list(self.transition_log), + "transition_log": list(self.transition_log), }, ) if LOG_PDB: From 798183deaa7de3ae663314af584905a634c59e55 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 2 Aug 2024 14:59:38 +0200 Subject: [PATCH 091/138] Fix exception handling for ``WorkerPlugin.setup`` and ``WorkerPlugin.teardown`` (#8810) --- distributed/client.py | 3 +- .../diagnostics/tests/test_worker_plugin.py | 58 ++- distributed/shuffle/tests/test_shuffle.py | 343 +++++++++--------- distributed/worker.py | 13 +- 4 files changed, 228 insertions(+), 189 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index ad283a352a..0601b0db5f 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -5409,8 +5409,7 @@ async def _unregister_worker_plugin(self, name, nanny=None): for response in responses.values(): if response["status"] == "error": - exc = response["exception"] - tb = response["traceback"] + _, exc, tb = clean_exception(**response) raise exc.with_traceback(tb) return responses diff --git a/distributed/diagnostics/tests/test_worker_plugin.py b/distributed/diagnostics/tests/test_worker_plugin.py index 0f206512b8..001576afe3 100644 --- a/distributed/diagnostics/tests/test_worker_plugin.py +++ b/distributed/diagnostics/tests/test_worker_plugin.py @@ -1,13 +1,14 @@ from __future__ import annotations import asyncio +import logging import warnings import pytest from distributed import Worker, WorkerPlugin from distributed.protocol.pickle import dumps -from distributed.utils_test import async_poll_for, gen_cluster, inc +from distributed.utils_test import async_poll_for, captured_logger, gen_cluster, inc class MyPlugin(WorkerPlugin): @@ -423,3 +424,58 @@ def setup(self, worker): await c.register_plugin(second, idempotent=True) assert "idempotentplugin" in a.plugins assert a.plugins["idempotentplugin"].instance == "first" + + +class BrokenSetupPlugin(WorkerPlugin): + def setup(self, worker): + raise RuntimeError("test error") + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_register_plugin_with_broken_setup_to_existing_workers_raises(c, s, a): + with pytest.raises(RuntimeError, match="test error"): + with captured_logger("distributed.worker", level=logging.ERROR) as caplog: + await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1") + logs = caplog.getvalue() + assert "TestPlugin1 failed to setup" in logs + assert "test error" in logs + + +@gen_cluster(client=True, nthreads=[]) +async def test_plugin_with_broken_setup_on_new_worker_logs(c, s): + await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1") + + with captured_logger("distributed.worker", level=logging.ERROR) as caplog: + async with Worker(s.address): + pass + logs = caplog.getvalue() + assert "TestPlugin1 failed to setup" in logs + assert "test error" in logs + + +class BrokenTeardownPlugin(WorkerPlugin): + def teardown(self, worker): + raise RuntimeError("test error") + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_unregister_worker_plugin_with_broken_teardown_raises(c, s, a): + await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1") + with pytest.raises(RuntimeError, match="test error"): + with captured_logger("distributed.worker", level=logging.ERROR) as caplog: + await c.unregister_worker_plugin("TestPlugin1") + logs = caplog.getvalue() + assert "TestPlugin1 failed to teardown" in logs + assert "test error" in logs + + +@gen_cluster(client=True, nthreads=[]) +async def test_plugin_with_broken_teardown_logs_on_close(c, s): + await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1") + + with captured_logger("distributed.worker", level=logging.ERROR) as caplog: + async with Worker(s.address): + pass + logs = caplog.getvalue() + assert "TestPlugin1 failed to teardown" in logs + assert "test error" in logs diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 39f53528fd..443595b0bb 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -87,29 +87,32 @@ def lose_annotations(request): return request.param -async def check_worker_cleanup( +async def assert_worker_cleanup( worker: Worker, - closed: bool = False, + close: bool = False, interval: float = 0.01, timeout: int | None = 5, ) -> None: """Assert that the worker has no shuffle state""" - deadline = Deadline.after(timeout) plugin = worker.plugins["shuffle"] assert isinstance(plugin, ShuffleWorkerPlugin) - while plugin.shuffle_runs._runs and not deadline.expired: - await asyncio.sleep(interval) - assert not plugin.shuffle_runs._runs - if closed: + deadline = Deadline.after(timeout) + if close: + await worker.close() + assert "shuffle" not in worker.plugins assert plugin.closed + else: + while plugin.shuffle_runs._runs and not deadline.expired: + await asyncio.sleep(interval) + assert not plugin.shuffle_runs._runs for dirpath, dirnames, filenames in os.walk(worker.local_directory): assert "shuffle" not in dirpath for fn in dirnames + filenames: assert "shuffle" not in fn -async def check_scheduler_cleanup( +async def assert_scheduler_cleanup( scheduler: Scheduler, interval: float = 0.01, timeout: int | None = 5 ) -> None: """Assert that the scheduler has no shuffle state""" @@ -175,9 +178,9 @@ async def test_basic_cudf_support(c, s, a, b): result, expected = await c.compute([shuffled, df], sync=True) dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) def get_active_shuffle_run(shuffle_id: ShuffleId, worker: Worker) -> ShuffleRun: @@ -213,9 +216,9 @@ async def test_basic_integration(c, s, a, b, npartitions, disk): result, expected = await c.compute([shuffled, df], sync=True) dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.parametrize("processes", [True, False]) @@ -260,9 +263,9 @@ async def test_shuffle_with_array_conversion(c, s, a, b, npartitions): else: await c.compute(out) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) def test_shuffle_before_categorize(loop_in_thread): @@ -295,9 +298,9 @@ async def test_concurrent(c, s, a, b): dd.assert_eq(x, df, check_index=False) dd.assert_eq(y, df, check_index=False) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -323,9 +326,9 @@ async def test_bad_disk(c, s, a, b): out = await c.compute(out) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) async def wait_until_worker_has_tasks( @@ -401,15 +404,14 @@ async def test_closed_worker_during_transfer(c, s, a, b): shuffled = df.shuffle("x") fut = c.compute([shuffled, df], sync=True) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) result, expected = await fut dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -428,16 +430,15 @@ async def test_restarting_during_transfer_raises_killed_worker(c, s, a, b): out = df.shuffle("x") out = c.compute(out.x.size) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) with pytest.raises(KilledWorker): await out assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1 await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -491,14 +492,13 @@ async def test_restarting_does_not_log_p2p_failed(c, s, a, b): out = df.shuffle("x") out = c.compute(out.x.size) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) await out assert not s.get_events("p2p") await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) class BlockedGetOrCreateShuffleRunManager(_ShuffleRunManager): @@ -538,7 +538,7 @@ async def test_get_or_create_from_dangling_transfer(c, s, a, b): shuffle_extB.shuffle_runs.block_get_or_create.set() await shuffle_extA.shuffle_runs.in_get_or_create.wait() - await b.close() + await assert_worker_cleanup(b, close=True) await async_poll_for( lambda: not any(ws.processing for ws in s.workers.values()), timeout=5 ) @@ -552,10 +552,9 @@ async def test_get_or_create_from_dangling_transfer(c, s, a, b): await async_poll_for(lambda: not a.state.tasks, timeout=10) assert not s.plugins["shuffle"].active_shuffles - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) + await assert_worker_cleanup(a) await c.close() - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -581,8 +580,8 @@ async def test_crashed_worker_during_transfer(c, s, a): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -648,15 +647,14 @@ def mock_get_worker_for_range_sharding( shuffled = df.shuffle("x") fut = c.compute([shuffled, df], sync=True) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b, 0.001) - await b.close() + await assert_worker_cleanup(b, close=True) result, expected = await fut dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -691,8 +689,8 @@ def mock_mock_get_worker_for_range_sharding( dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) # @pytest.mark.slow @@ -714,15 +712,14 @@ async def test_closed_bystanding_worker_during_shuffle(c, s, w1, w2, w3): ) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, w1) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, w2) - await w3.close() + await assert_worker_cleanup(w3, close=True) result, expected = await fut dd.assert_eq(result, expected) - await check_worker_cleanup(w1) - await check_worker_cleanup(w2) - await check_worker_cleanup(w3, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(w1) + await assert_worker_cleanup(w2) + await assert_scheduler_cleanup(s) class RaiseOnCloseShuffleRun(DataFrameShuffleRun): @@ -749,9 +746,8 @@ async def test_exception_on_close_cleans_up(c, s, caplog): with dask.config.set({"dataframe.shuffle.method": "p2p"}): shuffled = df.shuffle("x") await c.compute([shuffled, df], sync=True) - + await assert_worker_cleanup(w, close=True) assert any("test-exception-on-close" in record.message for record in caplog.records) - await check_worker_cleanup(w, closed=True) class BlockedInputsDoneShuffle(DataFrameShuffleRun): @@ -798,7 +794,7 @@ async def test_closed_worker_during_barrier(c, s, a, b): else: close_worker, alive_worker = b, a alive_shuffle = shuffleA - await close_worker.close() + await assert_worker_cleanup(close_worker, close=True) alive_shuffle.block_inputs_done.set() alive_shuffles = get_active_shuffle_runs(alive_worker) @@ -820,9 +816,8 @@ def shuffle_restarted(): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(close_worker, closed=True) - await check_worker_cleanup(alive_worker) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(alive_worker) + await assert_scheduler_cleanup(s) @mock.patch( @@ -861,7 +856,7 @@ async def test_restarting_during_barrier_raises_killed_worker(c, s, a, b): else: close_worker, alive_worker = b, a alive_shuffle = shuffleA - await close_worker.close() + await assert_worker_cleanup(close_worker, close=True) with pytest.raises(KilledWorker): await out @@ -870,9 +865,8 @@ async def test_restarting_during_barrier_raises_killed_worker(c, s, a, b): alive_shuffle.block_inputs_done.set() await c.close() - await check_worker_cleanup(close_worker, closed=True) - await check_worker_cleanup(alive_worker) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(alive_worker) + await assert_scheduler_cleanup(s) @mock.patch( @@ -909,7 +903,7 @@ async def test_closed_other_worker_during_barrier(c, s, a, b): else: close_worker, alive_worker = a, b alive_shuffle = shuffleB - await close_worker.close() + await assert_worker_cleanup(close_worker, close=True) alive_shuffle.block_inputs_done.set() alive_shuffles = get_active_shuffle_runs(alive_worker) @@ -931,9 +925,8 @@ def shuffle_restarted(): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(close_worker, closed=True) - await check_worker_cleanup(alive_worker) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(alive_worker) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -981,8 +974,8 @@ def shuffle_restarted(): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)] * 2) @@ -997,15 +990,14 @@ async def test_closed_worker_during_unpack(c, s, a, b): shuffled = df.shuffle("x") fut = c.compute([shuffled, df], sync=True) await wait_for_tasks_in_state(UNPACK_PREFIX, "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) result, expected = await fut dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -1024,16 +1016,15 @@ async def test_restarting_during_unpack_raises_killed_worker(c, s, a, b): out = df.shuffle("x") out = c.compute(out.x.size) await wait_for_tasks_in_state(UNPACK_PREFIX, "memory", 1, b) - await b.close() + await assert_worker_cleanup(b, close=True) with pytest.raises(KilledWorker): await out assert sum(event["action"] == "p2p-failed" for _, event in s.get_events("p2p")) == 1 await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -1059,14 +1050,14 @@ async def test_crashed_worker_during_unpack(c, s, a): dd.assert_eq(result, expected) await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) async def test_heartbeat(c, s, a, b): await a.heartbeat() - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) df = dask.datasets.timeseries( start="2000-01-01", end="2000-01-10", @@ -1084,10 +1075,10 @@ async def test_heartbeat(c, s, a, b): assert s.plugins["shuffle"].heartbeats.values() await out - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) del out - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @pytest.mark.skipif("not pa", reason="Requires PyArrow") @@ -1292,10 +1283,10 @@ async def test_head(c, s, a, b): assert list(os.walk(a.local_directory)) == a_files # cleaned up files? assert list(os.walk(b.local_directory)) == b_files - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) del out - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) def test_split_by_worker(): @@ -1399,9 +1390,9 @@ async def test_clean_after_forgotten_early(c, s, a, b): await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, a) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) del out - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -1424,9 +1415,9 @@ async def test_tail(c, s, a, b): assert len(s.tasks) < ntasks_full del partial - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.parametrize("wait_until_forgotten", [True, False]) @@ -1454,9 +1445,9 @@ async def test_repeat_shuffle_instance(c, s, a, b, wait_until_forgotten): await c.compute(out) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.parametrize("wait_until_forgotten", [True, False]) @@ -1485,9 +1476,9 @@ async def test_repeat_shuffle_operation(c, s, a, b, wait_until_forgotten): with dask.config.set({"dataframe.shuffle.method": "p2p"}): await c.compute(df.shuffle("x")) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -1532,8 +1523,8 @@ def block(df, in_event, block_event): assert result == expected await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -1561,8 +1552,8 @@ async def test_crashed_worker_after_shuffle_persisted(c, s, a): assert result == expected await c.close() - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)] * 3) @@ -1578,25 +1569,22 @@ async def test_closed_worker_between_repeats(c, s, w1, w2, w3): out = df.shuffle("x") await c.compute(out.head(compute=False)) - await check_worker_cleanup(w1) - await check_worker_cleanup(w2) - await check_worker_cleanup(w3) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(w1) + await assert_worker_cleanup(w2) + await assert_worker_cleanup(w3) + await assert_scheduler_cleanup(s) - await w3.close() + await assert_worker_cleanup(w3, close=True) await c.compute(out.tail(compute=False)) - await check_worker_cleanup(w1) - await check_worker_cleanup(w2) - await check_worker_cleanup(w3, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(w1) + await assert_worker_cleanup(w2) + await assert_scheduler_cleanup(s) - await w2.close() + await assert_worker_cleanup(w2, close=True) await c.compute(out.head(compute=False)) - await check_worker_cleanup(w1) - await check_worker_cleanup(w2, closed=True) - await check_worker_cleanup(w3, closed=True) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(w1) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -1616,11 +1604,11 @@ async def test_new_worker(c, s, a, b): async with Worker(s.address) as w: await c.compute(persisted) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_worker_cleanup(w) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_worker_cleanup(w) del persisted - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -1644,9 +1632,9 @@ async def test_multi(c, s, a, b): out = await c.compute(out.size) assert out - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.skipif( @@ -1694,10 +1682,10 @@ async def test_delete_some_results(c, s, a, b): x = x.partitions[: x.npartitions // 2] x = await c.compute(x.size) - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) del x - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -1719,11 +1707,11 @@ async def test_add_some_results(c, s, a, b): await c.compute(x.size) - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) del x del y - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @pytest.mark.slow @@ -1743,12 +1731,11 @@ async def test_clean_after_close(c, s, a, b): await wait_for_tasks_in_state("shuffle-transfer", "executing", 1, a) await wait_for_tasks_in_state("shuffle-transfer", "memory", 1, b) - await a.close() - await check_worker_cleanup(a, closed=True) + await assert_worker_cleanup(a, close=True) del out - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) class DataFrameShuffleTestPool(AbstractShuffleTestPool): @@ -2115,9 +2102,9 @@ async def test_deduplicate_stale_transfer(c, s, a, b, wait_until_forgotten): expected = await c.compute(df) dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) class BlockedBarrierShuffleWorkerPlugin(ShuffleWorkerPlugin): @@ -2172,9 +2159,9 @@ async def test_handle_stale_barrier(c, s, a, b, wait_until_forgotten): result, expected = await fut dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -2270,8 +2257,8 @@ async def test_shuffle_run_consistency(c, s, a): await out del out - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -2317,8 +2304,8 @@ async def test_fail_fetch_race(c, s, a): worker_plugin.block_barrier.set() del out - await check_worker_cleanup(a) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_scheduler_cleanup(s) class BlockedShuffleAccessAndFailShuffleRunManager(_ShuffleRunManager): @@ -2393,7 +2380,7 @@ async def test_replace_stale_shuffle(c, s, a, b): await asyncio.sleep(0) # A is cleaned - await check_worker_cleanup(a) + await assert_worker_cleanup(a) # B is not cleaned assert shuffle_id in get_active_shuffle_runs(b) @@ -2424,9 +2411,9 @@ async def test_replace_stale_shuffle(c, s, a, b): await out del out - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2444,9 +2431,9 @@ async def test_handle_null_partitions(c, s, a, b): result = await c.compute(ddf) dd.assert_eq(result, df) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2467,9 +2454,9 @@ def make_partition(i): expected = await expected dd.assert_eq(result, expected) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2496,9 +2483,9 @@ async def test_handle_object_columns(c, s, a, b): result = await c.compute(shuffled) dd.assert_eq(result, df) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2529,9 +2516,9 @@ def make_partition(i): await c.close() del out - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2555,9 +2542,9 @@ def make_partition(i): await c.compute(out) await c.close() - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2582,9 +2569,9 @@ async def test_handle_categorical_data(c, s, a, b): result, expected = await c.compute([shuffled, df], sync=True) dd.assert_eq(result, expected, check_categorical=False) - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @gen_cluster(client=True) @@ -2620,8 +2607,8 @@ async def test_set_index(c, s, *workers): dd.assert_eq(result, df.set_index("a")) await c.close() - await asyncio.gather(*[check_worker_cleanup(w) for w in workers]) - await check_scheduler_cleanup(s) + await asyncio.gather(*[assert_worker_cleanup(w) for w in workers]) + await assert_scheduler_cleanup(s) def test_shuffle_with_existing_index(client): @@ -2741,9 +2728,9 @@ async def test_unpack_is_non_rootish(c, s, a, b): scheduler_plugin.block_barrier.set() result = await result - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) class FlakyConnectionPool(ConnectionPool): @@ -2791,10 +2778,10 @@ async def test_flaky_connect_fails_without_retry(c, s, a, b): ): await c.compute(x) - await check_worker_cleanup(a) - await check_worker_cleanup(b) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) await c.close() - await check_scheduler_cleanup(s) + await assert_scheduler_cleanup(s) @gen_cluster( @@ -2823,9 +2810,9 @@ async def test_flaky_connect_recover_with_retry(c, s, a, b): assert len(line) < 250 assert not line or line.startswith("Retrying") - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) class BlockedAfterGatherDep(Worker): @@ -2900,9 +2887,9 @@ def make_partition(partition_id, size): for _, group in result.groupby("b"): assert group["a"].is_monotonic_increasing - await check_worker_cleanup(a) - await check_worker_cleanup(b) - await check_scheduler_cleanup(s) + await assert_worker_cleanup(a) + await assert_worker_cleanup(b) + await assert_scheduler_cleanup(s) @pytest.mark.parametrize("disk", [True, False]) diff --git a/distributed/worker.py b/distributed/worker.py index 5778d60dac..18ef0aca86 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1229,7 +1229,7 @@ async def _register_with_scheduler(self) -> None: *( self.plugin_add(name=name, plugin=plugin) for name, plugin in response["worker-plugins"].items() - ) + ), ) logger.info(" Registered to: %26s", self.scheduler.address) @@ -1560,12 +1560,7 @@ async def close( # type: ignore # Cancel async instructions await BaseWorker.close(self, timeout=timeout) - teardowns = [ - plugin.teardown(self) - for plugin in self.plugins.values() - if hasattr(plugin, "teardown") - ] - await asyncio.gather(*(td for td in teardowns if isawaitable(td))) + await asyncio.gather(*(self.plugin_remove(name) for name in self.plugins)) for extension in self.extensions.values(): if hasattr(extension, "close"): @@ -1870,13 +1865,14 @@ async def plugin_add( self.plugins[name] = plugin - logger.info("Starting Worker plugin %s" % name) + logger.info("Starting Worker plugin %s", name) if hasattr(plugin, "setup"): try: result = plugin.setup(worker=self) if isawaitable(result): result = await result except Exception as e: + logger.exception("Worker plugin %s failed to setup", name) if not catch_errors: raise return error_message(e) @@ -1893,6 +1889,7 @@ async def plugin_remove(self, name: str) -> ErrorMessage | OKMessage: if isawaitable(result): result = await result except Exception as e: + logger.exception("Worker plugin %s failed to teardown", name) return error_message(e) return {"status": "OK"} From fea5515030e3b79475e3555fec84e309177132f8 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 2 Aug 2024 15:02:45 +0200 Subject: [PATCH 092/138] Fix exception handling for ``NannyPlugin.setup`` and ``NannyPlugin.teardown`` (#8811) --- .../diagnostics/tests/test_nanny_plugin.py | 59 ++++++++++++++++++- distributed/nanny.py | 12 ++-- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/distributed/diagnostics/tests/test_nanny_plugin.py b/distributed/diagnostics/tests/test_nanny_plugin.py index db17fe70b5..3c481dce26 100644 --- a/distributed/diagnostics/tests/test_nanny_plugin.py +++ b/distributed/diagnostics/tests/test_nanny_plugin.py @@ -1,10 +1,12 @@ from __future__ import annotations +import logging + import pytest from distributed import Nanny, NannyPlugin from distributed.protocol.pickle import dumps -from distributed.utils_test import gen_cluster +from distributed.utils_test import captured_logger, gen_cluster @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) @@ -160,3 +162,58 @@ def setup(self, nanny): await c.register_plugin(second, idempotent=True) assert "idempotentplugin" in a.plugins assert a.plugins["idempotentplugin"].instance == "first" + + +class BrokenSetupPlugin(NannyPlugin): + def setup(self, nanny): + raise RuntimeError("test error") + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_register_plugin_with_broken_setup_to_existing_nannies_raises(c, s, a): + with pytest.raises(RuntimeError, match="test error"): + with captured_logger("distributed.nanny", level=logging.ERROR) as caplog: + await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1") + logs = caplog.getvalue() + assert "TestPlugin1 failed to setup" in logs + assert "test error" in logs + + +@gen_cluster(client=True, nthreads=[]) +async def test_plugin_with_broken_setup_on_new_nanny_logs(c, s): + await c.register_plugin(BrokenSetupPlugin(), name="TestPlugin1") + + with captured_logger("distributed.nanny", level=logging.ERROR) as caplog: + async with Nanny(s.address): + pass + logs = caplog.getvalue() + assert "TestPlugin1 failed to setup" in logs + assert "test error" in logs + + +class BrokenTeardownPlugin(NannyPlugin): + def teardown(self, nanny): + raise RuntimeError("test error") + + +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_unregister_nanny_plugin_with_broken_teardown_raises(c, s, a): + await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1") + with pytest.raises(RuntimeError, match="test error"): + with captured_logger("distributed.nanny", level=logging.ERROR) as caplog: + await c.unregister_worker_plugin("TestPlugin1", nanny=True) + logs = caplog.getvalue() + assert "TestPlugin1 failed to teardown" in logs + assert "test error" in logs + + +@gen_cluster(client=True, nthreads=[]) +async def test_nanny_plugin_with_broken_teardown_logs_on_close(c, s): + await c.register_plugin(BrokenTeardownPlugin(), name="TestPlugin1") + + with captured_logger("distributed.nanny", level=logging.ERROR) as caplog: + async with Nanny(s.address): + pass + logs = caplog.getvalue() + assert "TestPlugin1 failed to teardown" in logs + assert "test error" in logs diff --git a/distributed/nanny.py b/distributed/nanny.py index af0d9a62ad..7a14ee6576 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -477,13 +477,14 @@ async def plugin_add( self.plugins[name] = plugin - logger.info("Starting Nanny plugin %s" % name) + logger.info("Starting Nanny plugin %s", name) if hasattr(plugin, "setup"): try: result = plugin.setup(nanny=self) if isawaitable(result): result = await result except Exception as e: + logger.exception("Nanny plugin %s failed to setup", name) return error_message(e) if getattr(plugin, "restart", False): await self.restart(reason=f"nanny-plugin-{name}-restart") @@ -500,6 +501,7 @@ async def plugin_remove(self, name: str) -> ErrorMessage | OKMessage: if isawaitable(result): result = await result except Exception as e: + logger.exception("Nanny plugin %s failed to teardown", name) msg = error_message(e) return msg @@ -610,13 +612,7 @@ async def close( # type:ignore[override] await self.preloads.teardown() - teardowns = [ - plugin.teardown(self) - for plugin in self.plugins.values() - if hasattr(plugin, "teardown") - ] - - await asyncio.gather(*(td for td in teardowns if isawaitable(td))) + await asyncio.gather(*(self.plugin_remove(name) for name in self.plugins)) self.stop() if self.process is not None: From 879e5924ccb4f1bb45e0090f6ca9e6f7eaf343ce Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 2 Aug 2024 15:55:56 +0200 Subject: [PATCH 093/138] Fail tasks exceeding `no-workers-timeout` (#8806) --- distributed/distributed-schema.yaml | 11 +- distributed/distributed.yaml | 2 +- distributed/scheduler.py | 365 +++++++++++++++++++++--- distributed/tests/test_scheduler.py | 71 +++-- distributed/tests/test_worker_memory.py | 4 +- 5 files changed, 383 insertions(+), 70 deletions(-) diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 45534e9be8..f7e452383c 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -81,16 +81,15 @@ properties: - string - "null" description: | - Shut down the scheduler after this duration if there are pending tasks, - but no workers that can process them. This can either mean that there are - no workers running at all, or that there are idle workers but they've been - excluded through worker or resource restrictions. + Timeout for tasks in an unrunnable state. + + If task remains unrunnable for longer than this, it fails. A task is considered unrunnable IFF + it has no pending dependencies, and the task has restrictions that are not satisfied by + any available worker or no workers are running at all. In adaptive clusters, this timeout must be set to be safely higher than the time it takes for workers to spin up. - Works in conjunction with idle-timeout. - work-stealing: type: boolean description: | diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 26b88322d0..250af10f7d 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -19,7 +19,7 @@ distributed: # after they have been removed from the scheduler events-cleanup-delay: 1h idle-timeout: null # Shut down after this duration, like "1h" or "30 minutes" - no-workers-timeout: null # Shut down if there are tasks but no workers to process them + no-workers-timeout: null # If a task remains unrunnable for longer than this, it fails. work-stealing: True # workers should steal tasks from each other work-stealing-interval: 100ms # Callback time for work stealing worker-saturation: 1.1 # Send this fraction of nthreads root tasks to workers diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4b5791db80..adf99113b9 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -103,7 +103,7 @@ from distributed.event import EventExtension from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis from distributed.http import get_handlers -from distributed.metrics import time +from distributed.metrics import monotonic, time from distributed.multi_lock import MultiLockExtension from distributed.node import ServerNode from distributed.proctitle import setproctitle @@ -1681,8 +1681,8 @@ class SchedulerState: #: Tasks in the "queued" state, ordered by priority queued: HeapSet[TaskState] - #: Tasks in the "no-worker" state - unrunnable: set[TaskState] + #: Tasks in the "no-worker" state with the (monotonic) time when they became unrunnable + unrunnable: dict[TaskState, float] #: Subset of tasks that exist in memory on more than one worker replicated_tasks: set[TaskState] @@ -1755,7 +1755,7 @@ def __init__( host_info: dict[str, dict[str, Any]], resources: dict[str, dict[str, float]], tasks: dict[Key, TaskState], - unrunnable: set[TaskState], + unrunnable: dict[TaskState, float], queued: HeapSet[TaskState], validate: bool, plugins: Iterable[SchedulerPlugin] = (), @@ -2193,12 +2193,74 @@ def _transition_no_worker_processing(self, key: Key, stimulus_id: str) -> RecsMs assert ts in self.unrunnable if ws := self.decide_worker_non_rootish(ts): - self.unrunnable.discard(ts) + self.unrunnable.pop(ts, None) return self._add_to_processing(ts, ws, stimulus_id=stimulus_id) # If no worker, task just stays in `no-worker` return {}, {}, {} + def _transition_no_worker_erred( + self, + key: Key, + stimulus_id: str, + *, + # TODO: Which ones can actually be None? + cause: Key | None = None, + exception: Serialized | None = None, + traceback: Serialized | None = None, + exception_text: str | None = None, + traceback_text: str | None = None, + ) -> RecsMsgs: + ts = self.tasks[key] + + if self.validate: + assert not ts.actor, f"Actors can't be in `no-worker`: {ts}" + assert cause + assert ts in self.unrunnable + assert not ts.processing_on + + self.unrunnable.pop(ts) + + return self._propagate_erred( + ts, + cause=cause, + exception=exception, + traceback=traceback, + exception_text=exception_text, + traceback_text=traceback_text, + ) + + def _transition_queued_erred( + self, + key: Key, + stimulus_id: str, + *, + # TODO: Which ones can actually be None? + cause: Key | None = None, + exception: Serialized | None = None, + traceback: Serialized | None = None, + exception_text: str | None = None, + traceback_text: str | None = None, + ) -> RecsMsgs: + ts = self.tasks[key] + + if self.validate: + assert not ts.actor, f"Actors can't be in `no-worker`: {ts}" + assert cause + assert ts in self.queued + assert not ts.processing_on + + self.queued.remove(ts) + + return self._propagate_erred( + ts, + cause=cause, + exception=exception, + traceback=traceback, + exception_text=exception_text, + traceback_text=traceback_text, + ) + def decide_worker_rootish_queuing_disabled( self, ts: TaskState ) -> WorkerState | None: @@ -2730,8 +2792,6 @@ def _transition_processing_erred( Recommendations, client messages and worker messages to process """ ts = self.tasks[key] - recommendations: Recs = {} - client_msgs: Msgs = {} if self.validate: assert cause or ts.exception_blame @@ -2746,9 +2806,41 @@ def _transition_processing_erred( self._exit_processing_common(ts) + if self.validate: + assert not ts.processing_on + + return self._propagate_erred( + ts, + worker=worker, + cause=cause, + exception=exception, + traceback=traceback, + exception_text=exception_text, + traceback_text=traceback_text, + ) + + def _propagate_erred( + self, + ts: TaskState, + *, + worker: str | None = None, + cause: Key | None = None, + exception: Serialized | None = None, + traceback: Serialized | None = None, + exception_text: str | None = None, + traceback_text: str | None = None, + ) -> RecsMsgs: + recommendations: Recs = {} + client_msgs: Msgs = {} + + ts.state = "erred" + key = ts.key + if not ts.erred_on: ts.erred_on = set() - ts.erred_on.add(worker) + if worker is not None: + ts.erred_on.add(worker) + if exception is not None: ts.exception = exception ts.exception_text = exception_text @@ -2783,8 +2875,6 @@ def _transition_processing_erred( ts.waiters = None - ts.state = "erred" - report_msg = { "op": "task-erred", "key": key, @@ -2802,9 +2892,6 @@ def _transition_processing_erred( recommendations=recommendations, ) - if self.validate: - assert not ts.processing_on - return recommendations, client_msgs, {} def _transition_no_worker_released(self, key: Key, stimulus_id: str) -> RecsMsgs: @@ -2815,7 +2902,7 @@ def _transition_no_worker_released(self, key: Key, stimulus_id: str) -> RecsMsgs assert not ts.who_has assert not ts.waiting_on - self.unrunnable.remove(ts) + self.unrunnable.pop(ts) recommendations: Recs = {} self._propagate_released(ts, recommendations) @@ -2838,9 +2925,13 @@ def _transition_waiting_no_worker(self, key: Key, stimulus_id: str) -> RecsMsgs: if self.validate: self._validate_ready(ts) + assert ts not in self.unrunnable ts.state = "no-worker" - self.unrunnable.add(ts) + self.unrunnable[ts] = monotonic() + + if self.validate: + validate_unrunnable(self.unrunnable) return {}, {}, {} @@ -2873,7 +2964,7 @@ def _transition_queued_processing(self, key: Key, stimulus_id: str) -> RecsMsgs: def _remove_key(self, key: Key) -> None: ts = self.tasks.pop(key) assert ts.state == "forgotten" - self.unrunnable.discard(ts) + self.unrunnable.pop(ts, None) for cs in ts.who_wants or (): cs.wants_what.remove(ts) ts.who_wants = None @@ -2963,11 +3054,13 @@ def _transition_released_forgotten(self, key: Key, stimulus_id: str) -> RecsMsgs ("waiting", "memory"): _transition_waiting_memory, ("queued", "released"): _transition_queued_released, ("queued", "processing"): _transition_queued_processing, + ("queued", "erred"): _transition_queued_erred, ("processing", "released"): _transition_processing_released, ("processing", "memory"): _transition_processing_memory, ("processing", "erred"): _transition_processing_erred, ("no-worker", "released"): _transition_no_worker_released, ("no-worker", "processing"): _transition_no_worker_processing, + ("no-worker", "erred"): _transition_no_worker_erred, ("released", "forgotten"): _transition_released_forgotten, ("memory", "forgotten"): _transition_memory_forgotten, ("erred", "released"): _transition_erred_released, @@ -3597,7 +3690,6 @@ class Scheduler(SchedulerState, ServerNode): idle_timeout: float | None _no_workers_since: float | None # Note: not None iff there are pending tasks no_workers_timeout: float | None - _client_connections_added_total: int _client_connections_removed_total: int _workers_added_total: int @@ -3794,7 +3886,7 @@ async def post(self): self.generation = 0 self._last_client = None self._last_time = 0 - unrunnable = set() + unrunnable = {} queued = HeapSet(key=operator.attrgetter("priority")) self.datasets = {} @@ -4434,6 +4526,7 @@ async def add_worker( self._workers_added_total += 1 if ws.status == Status.running: self.running.add(ws) + self._refresh_no_workers_since() dh = self.host_info.get(host) if dh is None: @@ -5320,6 +5413,7 @@ async def remove_worker( recommendations: Recs = {} + timestamp = monotonic() processing_keys = {ts.key for ts in ws.processing} for ts in list(ws.processing): k = ts.key @@ -5390,6 +5484,9 @@ async def remove_worker( self.log_event("all", event_msg) self.transitions(recommendations, stimulus_id=stimulus_id) + # Make sure that the timestamp has been collected before tasks were transitioned to no-worker + # to ensure a meaningful error message. + self._refresh_no_workers_since(timestamp=timestamp) awaitables = [] for plugin in list(self.plugins.values()): @@ -5632,6 +5729,7 @@ def validate_key(self, key: Key, ts: TaskState | None = None) -> None: def validate_state(self, allow_overlap: bool = False) -> None: validate_state(self.tasks, self.workers, self.clients) + validate_unrunnable(self.unrunnable) if not (set(self.workers) == set(self.stream_comms)): raise ValueError("Workers not the same in all collections") @@ -5971,6 +6069,7 @@ def handle_worker_status_change( self.idle.pop(ws.address, None) self.idle_task_count.discard(ws) self.saturated.discard(ws) + self._refresh_no_workers_since() def handle_request_refresh_who_has( self, keys: Iterable[Key], worker: str, stimulus_id: str @@ -8552,39 +8651,134 @@ def check_idle(self) -> float | None: return self.idle_since def _check_no_workers(self) -> None: - """Shut down the scheduler if there have been tasks ready to run which have - nowhere to run for `distributed.scheduler.no-workers-timeout`, and there - aren't other tasks running. - """ - if self.status in (Status.closing, Status.closed): - return # pragma: nocover - if ( - (not self.queued and not self.unrunnable) - or (self.queued and self.workers) - or any(ws.processing for ws in self.workers.values()) + self.status in (Status.closing, Status.closed) + or self.no_workers_timeout is None ): - self._no_workers_since = None return - # 1. There are queued or unrunnable tasks and no workers at all - # 2. There are unrunnable tasks and no workers satisfy their restrictions - # (Only rootish tasks can be queued, and rootish tasks can't have restrictions) + now = monotonic() + stimulus_id = f"check-no-workers-timeout-{time()}" - if not self._no_workers_since: - self._no_workers_since = time() - return + recommendations: Recs = {} - if ( - self.no_workers_timeout - and time() > self._no_workers_since + self.no_workers_timeout - ): - logger.info( - "Tasks have been without any workers to run them for %s; " - "shutting scheduler down", - format_time(self.no_workers_timeout), + self._refresh_no_workers_since(now) + + affected = self._check_unrunnable_task_timeouts( + now, recommendations=recommendations, stimulus_id=stimulus_id + ) + + affected.update( + self._check_queued_task_timeouts( + now, recommendations=recommendations, stimulus_id=stimulus_id ) - self._ongoing_background_tasks.call_soon(self.close) + ) + self.transitions(recommendations, stimulus_id=stimulus_id) + if affected: + self.log_event( + "scheduler", + {"action": "no-workers-timeout-exceeded", "keys": affected}, + ) + + def _check_unrunnable_task_timeouts( + self, timestamp: float, recommendations: Recs, stimulus_id: str + ) -> set[Key]: + assert self.no_workers_timeout + unsatisfied = [] + no_workers = [] + for ts, unrunnable_since in self.unrunnable.items(): + if timestamp <= unrunnable_since + self.no_workers_timeout: + # unrunnable is insertion-ordered, which means that unrunnable_since will + # be monotonically increasing in this loop. + break + if ( + self._no_workers_since is None + or self._no_workers_since >= unrunnable_since + ): + unsatisfied.append(ts) + else: + no_workers.append(ts) + if not unsatisfied and not no_workers: + return set() + + for ts in unsatisfied: + e = pickle.dumps( + NoValidWorkerError( + task=ts.key, + host_restrictions=(ts.host_restrictions or set()).copy(), + worker_restrictions=(ts.worker_restrictions or set()).copy(), + resource_restrictions=(ts.resource_restrictions or {}).copy(), + timeout=self.no_workers_timeout, + ), + ) + r = self.transition( + ts.key, + "erred", + exception=e, + cause=ts.key, + stimulus_id=stimulus_id, + ) + recommendations.update(r) + logger.error( + "Task %s marked as failed because it timed out waiting " + "for its restrictions to become satisfied.", + ts.key, + ) + self._fail_tasks_after_no_workers_timeout( + no_workers, recommendations, stimulus_id + ) + return {ts.key for ts in concat([unsatisfied, no_workers])} + + def _check_queued_task_timeouts( + self, timestamp: float, recommendations: Recs, stimulus_id: str + ) -> set[Key]: + assert self.no_workers_timeout + + if self._no_workers_since is None: + return set() + + if timestamp <= self._no_workers_since + self.no_workers_timeout: + return set() + affected = list(self.queued) + self._fail_tasks_after_no_workers_timeout( + affected, recommendations, stimulus_id + ) + return {ts.key for ts in affected} + + def _fail_tasks_after_no_workers_timeout( + self, timed_out: Iterable[TaskState], recommendations: Recs, stimulus_id: str + ) -> None: + assert self.no_workers_timeout + + for ts in timed_out: + e = pickle.dumps( + NoWorkerError( + task=ts.key, + timeout=self.no_workers_timeout, + ), + ) + r = self.transition( + ts.key, + "erred", + exception=e, + cause=ts.key, + stimulus_id=stimulus_id, + ) + recommendations.update(r) + logger.error( + "Task %s marked as failed because it timed out waiting " + "without any running workers.", + ts.key, + ) + + def _refresh_no_workers_since(self, timestamp: float | None = None) -> None: + if self.running or not (self.queued or self.unrunnable): + self._no_workers_since = None + return + + if not self._no_workers_since: + self._no_workers_since = timestamp or monotonic() + return def adaptive_target(self, target_duration=None): """Desired number of workers based on the current workload @@ -8607,7 +8801,7 @@ def adaptive_target(self, target_duration=None): target_duration = parse_timedelta(target_duration) # CPU - queued = take(100, concat([self.queued, self.unrunnable])) + queued = take(100, concat([self.queued, self.unrunnable.keys()])) queued_occupancy = 0 for ts in queued: if ts.prefix.duration_average == -1: @@ -8894,6 +9088,25 @@ def validate_task_state(ts: TaskState) -> None: assert ts.state != "queued" +def validate_unrunnable(unrunnable: dict[TaskState, float]) -> None: + prev_unrunnable_since: float | None = None + prev_ts: TaskState | None = None + for ts, unrunnable_since in unrunnable.items(): + assert ts.state == "no-worker" + if prev_ts is not None: + assert prev_unrunnable_since is not None + # Ensure that unrunnable_since is monotonically increasing when iterating over unrunnable. + # _check_no_workers relies on this. + assert prev_unrunnable_since <= unrunnable_since, ( + prev_ts, + ts, + prev_unrunnable_since, + unrunnable_since, + ) + prev_ts = ts + prev_unrunnable_since = unrunnable_since + + def validate_worker_state(ws: WorkerState) -> None: for ts in ws.has_what or (): assert ts.who_has @@ -8988,6 +9201,68 @@ def __str__(self) -> str: ) +class NoValidWorkerError(Exception): + def __init__( + self, + task: Key, + host_restrictions: set[str], + worker_restrictions: set[str], + resource_restrictions: dict[str, float], + timeout: float, + ): + super().__init__( + task, host_restrictions, worker_restrictions, resource_restrictions, timeout + ) + + @property + def task(self) -> Key: + return self.args[0] + + @property + def host_restrictions(self) -> Any: + return self.args[1] + + @property + def worker_restrictions(self) -> Any: + return self.args[2] + + @property + def resource_restrictions(self) -> Any: + return self.args[3] + + @property + def timeout(self) -> float: + return self.args[4] + + def __str__(self) -> str: + return ( + f"Attempted to run task {self.task!r} but timed out after {format_time(self.timeout)} " + "waiting for a valid worker matching all restrictions.\n\nRestrictions:\n" + "host_restrictions={self.host_restrictions!s}\n" + "worker_restrictions={self.worker_restrictions!s}\n" + "resource_restrictions={self.resource_restrictions!s}\n" + ) + + +class NoWorkerError(Exception): + def __init__(self, task: Key, timeout: float): + super().__init__(task, timeout) + + @property + def task(self) -> Key: + return self.args[0] + + @property + def timeout(self) -> float: + return self.args[1] + + def __str__(self) -> str: + return ( + f"Attempted to run task {self.task!r} but timed out after {format_time(self.timeout)} " + "waiting without any running workers." + ) + + class WorkerStatusPlugin(SchedulerPlugin): """A plugin to share worker status with a remote observer diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 0a63557106..a3c3316999 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -47,7 +47,14 @@ from distributed.protocol import serialize from distributed.protocol.pickle import dumps, loads from distributed.protocol.serialize import ToPickle -from distributed.scheduler import KilledWorker, MemoryState, Scheduler, WorkerState +from distributed.scheduler import ( + KilledWorker, + MemoryState, + NoValidWorkerError, + NoWorkerError, + Scheduler, + WorkerState, +) from distributed.utils import TimeoutError, wait_for from distributed.utils_test import ( NO_AMM, @@ -2108,7 +2115,7 @@ def g(_, ev1, ev2): await ev2.set() -# @pytest.mark.slow +@pytest.mark.slow @gen_cluster( client=True, Worker=Nanny, clean_kwargs={"processes": False, "threads": False} ) @@ -2445,7 +2452,7 @@ async def test_idle_timeout_no_workers(c, s): nthreads=[], config={"distributed.scheduler.no-workers-timeout": None}, ) -async def test_no_workers_timeout_disabled(c, s, a, b): +async def test_no_workers_timeout_disabled(c, s): """no-workers-timeout has been disabled""" future = c.submit(inc, 1, key="x") await wait_for_state("x", ("queued", "no-worker"), s) @@ -2455,7 +2462,13 @@ async def test_no_workers_timeout_disabled(c, s, a, b): s._check_no_workers() await asyncio.sleep(0.2) - assert s.status == Status.running + async with Worker(s.address): + await future + + assert all( + event["action"] != "no-workers-timeout-exceeded" + for _, event in s.get_events("scheduler") + ) @pytest.mark.slow @@ -2466,17 +2479,23 @@ async def test_no_workers_timeout_disabled(c, s, a, b): ) async def test_no_workers_timeout_without_workers(c, s): """Trip no-workers-timeout when there are no workers available""" - # Don't trip scheduler shutdown when there are no tasks + future = c.submit(inc, 1, key="x") + await wait_for_state("x", ("queued", "no-worker"), s) s._check_no_workers() await asyncio.sleep(0.2) s._check_no_workers() await asyncio.sleep(0.2) - assert s.status == Status.running + with pytest.raises(NoWorkerError if QUEUING_ON_BY_DEFAULT else NoValidWorkerError): + await future - future = c.submit(inc, 1) - while s.status != Status.closed: - await asyncio.sleep(0.01) + events = [ + event + for _, event in s.get_events("scheduler") + if event["action"] == "no-workers-timeout-exceeded" + ] + assert len(events) == 1 + assert events[0]["keys"] == {"x"} @pytest.mark.slow @@ -2489,8 +2508,16 @@ async def test_no_workers_timeout_bad_restrictions(c, s, a, b): task restrictions """ future = c.submit(inc, 1, key="x", workers=["127.0.0.2:1234"]) - while s.status != Status.closed: - await asyncio.sleep(0.01) + with pytest.raises(NoValidWorkerError): + await future + + events = [ + event + for _, event in s.get_events("scheduler") + if event["action"] == "no-workers-timeout-exceeded" + ] + assert len(events) == 1 + assert events[0]["keys"] == {"x"} @gen_cluster( @@ -2510,8 +2537,13 @@ async def test_no_workers_timeout_queued(c, s, a): s._check_no_workers() await asyncio.sleep(0.2) - assert s.status == Status.running await ev.set() + await c.gather(futures) + + assert all( + event["action"] != "no-workers-timeout-exceeded" + for _, event in s.get_events("scheduler") + ) @pytest.mark.slow @@ -2532,14 +2564,21 @@ async def test_no_workers_timeout_processing(c, s, a, b): await asyncio.sleep(0.2) s._check_no_workers() await asyncio.sleep(0.2) - assert s.status == Status.running + + with pytest.raises(NoValidWorkerError): + await y + + events = [ + event + for _, event in s.get_events("scheduler") + if event["action"] == "no-workers-timeout-exceeded" + ] + assert len(events) == 1 + assert events[0]["keys"] == {"y"} await ev.set() await x - while s.status != Status.closed: - await asyncio.sleep(0.01) - @gen_cluster(client=True, config={"distributed.scheduler.bandwidth": "100 GB"}) async def test_bandwidth(c, s, a, b): diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index fc473ba8d1..0994816d2e 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -538,7 +538,7 @@ def f(ev): z = c.submit(inc, 2, key="z") while "z" not in s.tasks or s.tasks["z"].state != "no-worker": await asyncio.sleep(0.01) - assert s.unrunnable == {s.tasks["z"]} + assert s.unrunnable.keys() == {s.tasks["z"]} # Test that a task that already started when the worker paused can complete # and its output can be retrieved. Also test that the now free slot won't be @@ -605,7 +605,7 @@ def f(ev): z = c.submit(inc, 2, key="z") while "z" not in s.tasks or s.tasks["z"].state != "no-worker": await asyncio.sleep(0.01) - assert s.unrunnable == {s.tasks["z"]} + assert s.unrunnable.keys() == {s.tasks["z"]} # Test that a task that already started when the worker paused can complete # and its output can be retrieved. Also test that the now free slot won't be From e0ce0fe455231bea4adfd881b07bbf28d85ed1a1 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Mon, 5 Aug 2024 12:43:01 +0200 Subject: [PATCH 094/138] Update large graph size warning to remove scatter recommendation (#8815) --- distributed/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/distributed/client.py b/distributed/client.py index 0601b0db5f..2ece27802a 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3362,7 +3362,9 @@ def _graph_to_futures( warnings.warn( f"Sending large graph of size {format_bytes(pickled_size)}.\n" "This may cause some slowdown.\n" - "Consider scattering data ahead of time and using futures." + "Consider loading the data with Dask directly\n or using futures or " + "delayed objects to embed the data into the graph without repetition.\n" + "See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information." ) computations = self._get_computation_code( From 833831922d4f070003dff9eac1975e999bc22a6e Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Tue, 6 Aug 2024 12:32:59 +0200 Subject: [PATCH 095/138] Run graph normalisation after dask order (#8818) --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index adf99113b9..a7cd91a765 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4884,6 +4884,7 @@ async def update_graph( internal_priority = await offload( dask.order.order, dsk=dsk, dependencies=stripped_deps ) + dsk = valmap(_normalize_task, dsk) self._create_taskstate_from_graph( dsk=dsk, @@ -9383,5 +9384,4 @@ def _materialize_graph( deps.discard(k) dependencies[k] = deps - dsk = valmap(_normalize_task, dsk) return dsk, dependencies, annotations_by_type From 92fc0e24846d3edcbabc5861a3b9a812f3fc76d8 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 6 Aug 2024 15:18:25 -0500 Subject: [PATCH 096/138] bump version to 2024.8.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 44a79310f5..e7e729642f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ requires-python = ">=3.9" dependencies = [ "click >= 8.0", "cloudpickle >= 1.5.0", - "dask == 2024.7.1", + "dask == 2024.8.0", "jinja2 >= 2.10.3", "locket >= 1.0.0", "msgpack >= 1.0.0", From ad5f98c2da7b15214e58fa90b0eb67d061a2d845 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Thu, 8 Aug 2024 05:16:03 +0200 Subject: [PATCH 097/138] Drop support for Python 3.9 (#8793) Co-authored-by: James Bourbeau --- .github/workflows/ci-pre-commit.yml | 2 +- .github/workflows/tests.yaml | 12 ++-- continuous_integration/environment-3.10.yaml | 7 +++ continuous_integration/environment-3.9.yaml | 60 ------------------- .../environment-mindeps.yaml | 17 +++--- continuous_integration/gpuci/axis.yaml | 1 - continuous_integration/recipes/dask/meta.yaml | 8 +-- .../recipes/distributed/meta.yaml | 18 +++--- distributed/comm/registry.py | 2 +- distributed/protocol/pickle.py | 6 +- distributed/protocol/tests/test_pickle.py | 10 +--- pyproject.toml | 19 +++--- 12 files changed, 46 insertions(+), 116 deletions(-) delete mode 100644 continuous_integration/environment-3.9.yaml diff --git a/.github/workflows/ci-pre-commit.yml b/.github/workflows/ci-pre-commit.yml index d19a45eba5..ba64f9ebf7 100644 --- a/.github/workflows/ci-pre-commit.yml +++ b/.github/workflows/ci-pre-commit.yml @@ -14,5 +14,5 @@ jobs: - uses: actions/checkout@v4.1.3 - uses: actions/setup-python@v5 with: - python-version: '3.9' + python-version: '3.12' - uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 49c4531eb2..4ded58473c 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -26,7 +26,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - environment: [mindeps, "3.9", "3.10", "3.11", "3.12"] + environment: [mindeps, "3.10", "3.11", "3.12"] label: [default] extra_packages: [null] # Cherry-pick test modules to split the overall runtime roughly in half @@ -36,8 +36,6 @@ jobs: # MacOS CI does not have any hosts available; run it on 3.12 only - os: macos-latest environment: mindeps - - os: macos-latest - environment: "3.9" - os: macos-latest environment: "3.10" - os: macos-latest @@ -49,21 +47,21 @@ jobs: include: # Set distributed.scheduler.worker-saturation: .inf - os: ubuntu-latest - environment: "3.9" + environment: "3.10" label: no_queue partition: "ci1" - os: ubuntu-latest - environment: "3.9" + environment: "3.10" label: no_queue partition: "not ci1" # Set dataframe.query-planning: false - os: ubuntu-latest - environment: "3.9" + environment: "3.10" label: no_expr partition: "ci1" - os: ubuntu-latest - environment: "3.9" + environment: "3.10" label: no_expr partition: "not ci1" diff --git a/continuous_integration/environment-3.10.yaml b/continuous_integration/environment-3.10.yaml index be2da72095..a3209963dc 100644 --- a/continuous_integration/environment-3.10.yaml +++ b/continuous_integration/environment-3.10.yaml @@ -11,6 +11,7 @@ dependencies: - click - cloudpickle - coverage + - cython # Only tested here; also a dependency of crick - dask # overridden by git tip below - filesystem-spec # overridden by git tip below - gilknocker @@ -21,6 +22,7 @@ dependencies: - jupyter-server-proxy - jupyterlab - locket + - lz4 # Only tested here - msgpack-python - netcdf4 - paramiko @@ -28,12 +30,15 @@ dependencies: - prometheus_client - psutil - pyarrow + - pynvml # Only tested here - pytest - pytest-cov - pytest-faulthandler - pytest-repeat - pytest-rerunfailures - pytest-timeout + - python-snappy # Only tested here + - pytorch # Only tested here - requests - s3fs # overridden by git tip below - scikit-learn @@ -41,6 +46,7 @@ dependencies: - sortedcollections - tblib - toolz + - torchvision # Only tested here - tornado - zict # overridden by git tip below - zstandard @@ -50,6 +56,7 @@ dependencies: - git+https://github.com/dask/dask - git+https://github.com/dask-contrib/dask-expr - git+https://github.com/dask/zict + - git+https://github.com/dask/crick # Only tested here # Revert after https://github.com/dask/distributed/issues/8614 is fixed # - git+https://github.com/dask/s3fs # - git+https://github.com/fsspec/filesystem_spec diff --git a/continuous_integration/environment-3.9.yaml b/continuous_integration/environment-3.9.yaml deleted file mode 100644 index 47502ee3c2..0000000000 --- a/continuous_integration/environment-3.9.yaml +++ /dev/null @@ -1,60 +0,0 @@ -name: dask-distributed -channels: - - conda-forge - - defaults - - pytorch -dependencies: - - python=3.9 - - packaging - - pip - - asyncssh - - bokeh - - click - - cloudpickle - - coverage - - cython # Only tested here; also a dependency of crick - - dask # overridden by git tip below - - filesystem-spec - - gilknocker - - h5py - - ipykernel - - ipywidgets - - jinja2 - - jupyter-server-proxy - - jupyterlab - - locket - - lz4 # Only tested here - - msgpack-python - - netcdf4 - - paramiko - - pre-commit - - prometheus_client - - psutil - - pyarrow - - pynvml # Only tested here - - pytest - - pytest-cov - - pytest-faulthandler - - pytest-repeat - - pytest-rerunfailures - - pytest-timeout - - python-snappy # Only tested here - - pytorch # Only tested here - - requests - - s3fs - - scikit-learn - - scipy - - sortedcollections - - tblib - - toolz - - torchvision # Only tested here - - tornado - - zict - - zstandard - # Temporary fix for https://github.com/pypa/setuptools/issues/4496 - - setuptools < 71 - - pip: - - git+https://github.com/dask/dask - - git+https://github.com/dask-contrib/dask-expr - - git+https://github.com/dask/crick # Only tested here - - keras diff --git a/continuous_integration/environment-mindeps.yaml b/continuous_integration/environment-mindeps.yaml index ce2c2c980c..9403ca64b8 100644 --- a/continuous_integration/environment-mindeps.yaml +++ b/continuous_integration/environment-mindeps.yaml @@ -1,23 +1,22 @@ name: dask-distributed channels: - conda-forge - - defaults dependencies: - - python=3.9 + - python=3.10 - click=8.0 - - cloudpickle=1.5.0 - - cytoolz=0.10.1 + - cloudpickle=2.0.0 + - cytoolz=0.11.2 - jinja2=2.10.3 - locket=1.0.0 - - msgpack-python=1.0.0 + - msgpack-python=1.0.2 - packaging=20.0 - - psutil=5.7.2 - - pyyaml=5.3.1 + - psutil=5.8.0 + - pyyaml=5.4.1 - sortedcontainers=2.0.5 - tblib=1.6.0 - toolz=0.10.0 - - tornado=6.0.4 - - urllib3=1.24.3 + - tornado=6.2.0 + - urllib3=1.26.5 - zict=3.0.0 # Temporary fix for https://github.com/pypa/setuptools/issues/4496 - setuptools < 71 diff --git a/continuous_integration/gpuci/axis.yaml b/continuous_integration/gpuci/axis.yaml index 77e9b9b45c..25b4c78ef9 100644 --- a/continuous_integration/gpuci/axis.yaml +++ b/continuous_integration/gpuci/axis.yaml @@ -1,5 +1,4 @@ PYTHON_VER: -- "3.9" - "3.10" - "3.11" diff --git a/continuous_integration/recipes/dask/meta.yaml b/continuous_integration/recipes/dask/meta.yaml index 0567c565b0..721e5e345f 100644 --- a/continuous_integration/recipes/dask/meta.yaml +++ b/continuous_integration/recipes/dask/meta.yaml @@ -19,19 +19,19 @@ build: requirements: host: - - python >=3.9 + - python >=3.10 - dask-core {{ dask_version }} - dask-expr {{ dask_expr_version }} - distributed {{ version }} run: - - python >=3.9 + - python >=3.10 - {{ pin_compatible('dask-core', max_pin='x.x.x.x') }} - {{ pin_compatible('dask-expr', max_pin='x.x.x.x') }} - {{ pin_compatible('distributed', exact=True) }} - - cytoolz >=0.8.2 + - cytoolz >=0.11.2 - lz4 >=4.3.2 - numpy >=1.21 - - pandas >=1.3 + - pandas >=2 - bokeh >=2.4.2,!=3.0.* - jinja2 >=2.10.3 - pyarrow >=7.0 diff --git a/continuous_integration/recipes/distributed/meta.yaml b/continuous_integration/recipes/distributed/meta.yaml index 2fb01a1837..e6279ccda0 100644 --- a/continuous_integration/recipes/distributed/meta.yaml +++ b/continuous_integration/recipes/distributed/meta.yaml @@ -24,28 +24,28 @@ build: requirements: host: - - python >=3.9 + - python >=3.10 - pip - dask-core {{ dask_version }} - versioneer =0.29 - tomli # [py<311] run: - - python >=3.9 + - python >=3.10 - click >=8.0 - - cloudpickle >=1.5.0 - - cytoolz >=0.10.1 + - cloudpickle >=2.0.0 + - cytoolz >=0.11.2 - {{ pin_compatible('dask-core', max_pin='x.x.x.x') }} - jinja2 >=2.10.3 - locket >=1.0.0 - - msgpack-python >=1.0.0 + - msgpack-python >=1.0.2 - packaging >=20.0 - - psutil >=5.7.2 - - pyyaml >=5.3.1 + - psutil >=5.8.0 + - pyyaml >=5.4.1 - sortedcontainers >=2.0.5 - tblib >=1.6.0 - toolz >=0.10.0 - - tornado >=6.0.4 - - urllib3 >=1.24.3 + - tornado >=6.2.0 + - urllib3 >=1.26.5 - zict >=3.0.0 run_constrained: - openssl !=1.1.1e diff --git a/distributed/comm/registry.py b/distributed/comm/registry.py index bcdc7aa361..47ba730a7d 100644 --- a/distributed/comm/registry.py +++ b/distributed/comm/registry.py @@ -15,7 +15,7 @@ def __call__(self, **kwargs: str) -> Iterable[importlib.metadata.EntryPoint]: if sys.version_info >= (3, 10): # py3.10 importlib.metadata type annotations are not in mypy yet # https://github.com/python/typeshed/pull/7331 - _entry_points: _EntryPoints = importlib.metadata.entry_points + _entry_points: _EntryPoints = importlib.metadata.entry_points # type: ignore[assignment] else: def _entry_points( diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index 8b4b7328e5..cb724af012 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -6,12 +6,9 @@ import pickle import cloudpickle -from packaging.version import parse as parse_version from distributed.protocol.serialize import dask_deserialize, dask_serialize -CLOUDPICKLE_GE_20 = parse_version(cloudpickle.__version__) >= parse_version("2.0.0") - HIGHEST_PROTOCOL = pickle.HIGHEST_PROTOCOL logger = logging.getLogger(__name__) @@ -68,8 +65,7 @@ def dumps(x, *, buffer_callback=None, protocol=HIGHEST_PROTOCOL): pickler.dump(x) result = f.getvalue() if b"__main__" in result or ( - CLOUDPICKLE_GE_20 - and getattr(inspect.getmodule(x), "__name__", None) + getattr(inspect.getmodule(x), "__name__", None) in cloudpickle.list_registry_pickle_by_value() ): if len(result) < 1000 or not _always_use_pickle_for(x): diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index 7f34dd66e3..3d9ef38de7 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -13,12 +13,7 @@ from distributed import profile from distributed.protocol import deserialize, serialize -from distributed.protocol.pickle import ( - CLOUDPICKLE_GE_20, - HIGHEST_PROTOCOL, - dumps, - loads, -) +from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads from distributed.protocol.serialize import dask_deserialize, dask_serialize from distributed.utils_test import popen, save_sys_modules @@ -200,9 +195,6 @@ def funcs(): assert wr3() is None -@pytest.mark.skipif( - not CLOUDPICKLE_GE_20, reason="Pickle by value registration not supported" -) def test_pickle_by_value_when_registered(): with save_sys_modules(): with tmpdir() as d: diff --git a/pyproject.toml b/pyproject.toml index e7e729642f..b12a53e567 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -24,22 +23,22 @@ classifiers = [ "Topic :: System :: Distributed Computing", ] readme = "README.rst" -requires-python = ">=3.9" +requires-python = ">=3.10" dependencies = [ "click >= 8.0", - "cloudpickle >= 1.5.0", + "cloudpickle >= 2.0.0", "dask == 2024.8.0", "jinja2 >= 2.10.3", "locket >= 1.0.0", - "msgpack >= 1.0.0", + "msgpack >= 1.0.2", "packaging >= 20.0", - "psutil >= 5.7.2", - "pyyaml >= 5.3.1", + "psutil >= 5.8.0", + "pyyaml >= 5.4.1", "sortedcontainers >= 2.0.5", "tblib >= 1.6.0", - "toolz >= 0.10.0", - "tornado >= 6.0.4", - "urllib3 >= 1.24.3", + "toolz >= 0.11.2", + "tornado >= 6.2.0", + "urllib3 >= 1.26.5", "zict >= 3.0.0", ] dynamic = ["version"] @@ -178,7 +177,7 @@ timeout_method = "thread" timeout = 300 [tool.mypy] -python_version = "3.9" +python_version = "3.10" # See https://github.com/python/mypy/issues/12286 for automatic multi-platform support platform = "linux" # platform = win32 From fa00237dd3044ca57ba2ede9d35a523850377c07 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 8 Aug 2024 14:56:22 -0500 Subject: [PATCH 098/138] Log ``worker_client`` event (#8819) --- distributed/tests/test_worker.py | 3 +- distributed/tests/test_worker_client.py | 53 +++++++++++++++++++++++++ distributed/worker.py | 8 ++++ distributed/worker_client.py | 7 ++++ 4 files changed, 70 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 2c02b79bfb..da68e61a98 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3003,7 +3003,7 @@ async def test_log_remove_worker(c, s, a, b): events = {topic: [ev for _, ev in evs] for topic, evs in s.get_events().items()} for evs in events.values(): for ev in evs: - if ev["action"] == "retire-workers": + if ev.get("action", None) == "retire-workers": for k in ("retired", "could-not-retire"): ev[k] = {addr: "snip" for addr in ev[k]} if "stimulus_id" in ev: # Strip timestamp @@ -3083,6 +3083,7 @@ async def test_log_remove_worker(c, s, a, b): "worker": b.address, }, ], + "worker-get-client": [{"client": c.id, "timeout": 5, "worker": b.address}], } diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index f5e72225c3..a7f8c236b3 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -363,3 +363,56 @@ def long_running(): assert len(res) == 2 assert res[a.address] > 25 assert res[b.address] > 25 + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_log_event(c, s, a): + # Run a task that spawns a worker client + def f(x): + with worker_client(timeout=10, separate_thread=True) as wc: + x = wc.submit(inc, x) + y = wc.submit(double, x) + result = x.result() + y.result() + return result + + future = c.submit(f, 1) + result = await future + assert result == 6 + + # Ensure a corresponding event is logged + for topic in ["worker-get-client", "worker-client"]: + events = [msg for t, msg in s.get_events().items() if t == topic] + assert len(events) == 1 + assert events[0][0][1] == { + "worker": a.address, + "timeout": 10, + "client": c.id, + } + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_log_event_implicit(c, s, a): + # Run a task that spawns a worker client + def f(x): + x = delayed(inc)(x) + y = delayed(double)(x) + result = x.compute() + y.compute() + return result + + future = c.submit(f, 1) + result = await future + assert result == 6 + + # Ensure a corresponding event is logged + events = [ + msg for topic, msg in s.get_events().items() if topic == "worker-get-client" + ] + assert len(events) == 1 + assert events[0][0][1] == { + "worker": a.address, + "timeout": 5, + "client": c.id, + } + # Do not log a `worker-client` since this client was created implicitly + events = [msg for topic, msg in s.get_events().items() if topic == "worker-client"] + assert len(events) == 0 diff --git a/distributed/worker.py b/distributed/worker.py index 18ef0aca86..bb1c063774 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2609,6 +2609,14 @@ def _get_client(self, timeout: float | None = None) -> Client: if not asynchronous: assert self._client.status == "running" + self.log_event( + "worker-get-client", + { + "client": self._client.id, + "timeout": timeout, + }, + ) + return self._client def get_current_task(self) -> Key: diff --git a/distributed/worker_client.py b/distributed/worker_client.py index 355156206d..86965d4cf7 100644 --- a/distributed/worker_client.py +++ b/distributed/worker_client.py @@ -53,6 +53,13 @@ def worker_client(timeout=None, separate_thread=True): worker = get_worker() client = get_client(timeout=timeout) + worker.log_event( + "worker-client", + { + "client": client.id, + "timeout": timeout, + }, + ) with contextlib.ExitStack() as stack: if separate_thread: try: From 845c07a4f55fe1c299736995634adea922e6c844 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Thu, 8 Aug 2024 21:58:58 +0200 Subject: [PATCH 099/138] Avoid key validation if validation is disabled (#8822) --- distributed/scheduler.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a7cd91a765..95487d32eb 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4869,6 +4869,7 @@ async def update_graph( _materialize_graph, graph=graph, global_annotations=annotations or {}, + validate=self.validate, ) del graph if not internal_priority: @@ -9337,11 +9338,12 @@ def transition( def _materialize_graph( - graph: HighLevelGraph, global_annotations: dict[str, Any] + graph: HighLevelGraph, global_annotations: dict[str, Any], validate: bool ) -> tuple[dict[Key, T_runspec], dict[Key, set[Key]], dict[str, dict[Key, Any]]]: dsk = ensure_dict(graph) - for k in dsk: - validate_key(k) + if validate: + for k in dsk: + validate_key(k) annotations_by_type: defaultdict[str, dict[Key, Any]] = defaultdict(dict) for annotations_type, value in global_annotations.items(): annotations_by_type[annotations_type].update( From e364f42a50bf7ba18a02d3cdb582b593daf94833 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Thu, 8 Aug 2024 22:00:10 +0200 Subject: [PATCH 100/138] avoid excessive attribute access overhead for remove_from_task_prefix_count (#8821) --- distributed/scheduler.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 95487d32eb..363b33bb41 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -775,17 +775,20 @@ def remove_from_processing(self, ts: TaskState) -> None: self._dec_needs_replica(dts) def _remove_from_task_prefix_count(self, ts: TaskState) -> None: - count = self.task_prefix_count[ts.prefix.name] - 1 + prefix_name = ts.prefix.name + count = self.task_prefix_count[prefix_name] - 1 + tp_count = self.task_prefix_count + tp_count_global = self.scheduler._task_prefix_count_global if count: - self.task_prefix_count[ts.prefix.name] = count + tp_count[prefix_name] = count else: - del self.task_prefix_count[ts.prefix.name] + del tp_count[prefix_name] - count = self.scheduler._task_prefix_count_global[ts.prefix.name] - 1 + count = tp_count_global[prefix_name] - 1 if count: - self.scheduler._task_prefix_count_global[ts.prefix.name] = count + tp_count_global[prefix_name] = count else: - del self.scheduler._task_prefix_count_global[ts.prefix.name] + del tp_count_global[prefix_name] def remove_replica(self, ts: TaskState) -> None: """The worker no longer has a task in memory""" From 86dc83c7421a5d660b2e465ca56d5ec938962239 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 12 Aug 2024 15:43:36 +0200 Subject: [PATCH 101/138] MINOR: Extract truncation logic out of partial concatenation in P2P rechunking (#8826) --- distributed/shuffle/_rechunk.py | 62 ++++++++++++++++----------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 1ff0daa97c..99f9d39061 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -399,20 +399,21 @@ def _construct_graph(self) -> _T_LowLevelGraph: chunked_shape = tuple(len(axis) for axis in self.chunks) for ndpartial in _split_partials(_old_to_new, chunked_shape): - output_count = np.sum(self.keepmap[ndpartial.new]) + partial_keepmap = self.keepmap[ndpartial.new] + output_count = np.sum(partial_keepmap) if output_count == 0: continue elif output_count == 1: # Single output chunk - # TODO: Create new partial that contains ONLY the relevant chunk + ndindex = np.argwhere(partial_keepmap)[0] + ndpartial = _truncate_partial(ndindex, ndpartial, _old_to_new) + dsk.update( partial_concatenate( input_name=self.name_input, input_chunks=self.chunks_input, ndpartial=ndpartial, token=self.token, - keepmap=self.keepmap, - old_to_new=_old_to_new, ) ) else: @@ -516,8 +517,6 @@ def partial_concatenate( input_chunks: ChunkedAxes, ndpartial: _NDPartial, token: str, - keepmap: np.ndarray, - old_to_new: list[Any], ) -> dict[Key, Any]: import numpy as np @@ -528,31 +527,6 @@ def partial_concatenate( slice_group = f"rechunk-slice-{token}" - partial_keepmap = keepmap[ndpartial.new] - assert np.sum(partial_keepmap) == 1 - - ndindex = np.argwhere(partial_keepmap)[0] - - partial_per_axis = [] - for axis_index, index in enumerate(ndindex): - slc = slice( - ndpartial.new[axis_index].start + index, - ndpartial.new[axis_index].start + index + 1, - ) - first_old_chunk, first_old_slice = old_to_new[axis_index][slc.start][0] - last_old_chunk, last_old_slice = old_to_new[axis_index][slc.stop - 1][-1] - partial_per_axis.append( - _Partial( - old=slice(first_old_chunk, last_old_chunk + 1), - new=slc, - left_start=first_old_slice.start, - right_stop=last_old_slice.stop, - ) - ) - - old, new, left_starts, right_stops = zip(*partial_per_axis) - ndpartial = _NDPartial(old, new, left_starts, right_stops, ndpartial.ix) - old_offset = tuple(slice_.start for slice_ in ndpartial.old) shape = tuple(slice_.stop - slice_.start for slice_ in ndpartial.old) @@ -588,6 +562,32 @@ def partial_concatenate( return dsk +def _truncate_partial( + ndindex: NDIndex, + ndpartial: _NDPartial, + old_to_new: list[Any], +) -> _NDPartial: + partial_per_axis = [] + for axis_index, index in enumerate(ndindex): + slc = slice( + ndpartial.new[axis_index].start + index, + ndpartial.new[axis_index].start + index + 1, + ) + first_old_chunk, first_old_slice = old_to_new[axis_index][slc.start][0] + last_old_chunk, last_old_slice = old_to_new[axis_index][slc.stop - 1][-1] + partial_per_axis.append( + _Partial( + old=slice(first_old_chunk, last_old_chunk + 1), + new=slc, + left_start=first_old_slice.start, + right_stop=last_old_slice.stop, + ) + ) + + old, new, left_starts, right_stops = zip(*partial_per_axis) + return _NDPartial(old, new, left_starts, right_stops, ndpartial.ix) + + def _compute_partial_old_chunks( partial: _NDPartial, chunks: ChunkedAxes ) -> ChunkedAxes: From fd92ab83bf96a6b0090d64942e1d6e16a726d263 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 13 Aug 2024 15:11:33 +0200 Subject: [PATCH 102/138] Improve concurrent close for scheduler (#8829) * Improve concurrent close for scheduler * Fix test --- distributed/scheduler.py | 16 +++++++++------- distributed/tests/test_jupyter.py | 4 ++-- distributed/tests/test_scheduler.py | 27 +++++++++++++++++++++++---- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 363b33bb41..a989aa7fe3 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3849,7 +3849,7 @@ async def post(self): """Shut down the server.""" self.log.info("Shutting down on /api/shutdown request.") - await scheduler.close(reason="shutdown requested via Jupyter") + await scheduler.close(reason="jupyter-requested-shutdown") j = ServerApp.instance( config=Config( @@ -4274,7 +4274,7 @@ def del_scheduler_file() -> None: setproctitle(f"dask scheduler [{self.address}]") return self - async def close(self, fast=None, close_workers=None, reason=""): + async def close(self, fast=None, close_workers=None, reason="unknown"): """Send cleanup signal to all coroutines then wait until finished See Also @@ -4291,6 +4291,10 @@ async def close(self, fast=None, close_workers=None, reason=""): await self.finished() return + self.status = Status.closing + logger.info("Closing scheduler. Reason: %s", reason) + setproctitle("dask scheduler [closing]") + async def log_errors(func): try: await func() @@ -4301,10 +4305,6 @@ async def log_errors(func): *[log_errors(plugin.before_close) for plugin in list(self.plugins.values())] ) - self.status = Status.closing - logger.info("Scheduler closing due to %s...", reason or "unknown reason") - setproctitle("dask scheduler [closing]") - await self.preloads.teardown() await asyncio.gather( @@ -8652,7 +8652,9 @@ def check_idle(self) -> float | None: "Scheduler closing after being idle for %s", format_time(self.idle_timeout), ) - self._ongoing_background_tasks.call_soon(self.close) + self._ongoing_background_tasks.call_soon( + self.close, reason="idle-timeout-exceeded" + ) return self.idle_since def _check_no_workers(self) -> None: diff --git a/distributed/tests/test_jupyter.py b/distributed/tests/test_jupyter.py index 5280d11093..3f45678f81 100644 --- a/distributed/tests/test_jupyter.py +++ b/distributed/tests/test_jupyter.py @@ -146,8 +146,8 @@ def test_shutsdown_cleanly(requires_default_ports): stderr = subprocess_fut.result().stderr assert "Traceback" not in stderr assert ( - "distributed.scheduler - INFO - Scheduler closing due to shutdown " - "requested via Jupyter...\n" in stderr + "distributed.scheduler - INFO - Closing scheduler. Reason: jupyter-requested-shutdown" + in stderr ) assert "Shutting down on /api/shutdown request.\n" in stderr assert ( diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index a3c3316999..d492b769eb 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2394,7 +2394,7 @@ async def test_idle_timeout(c, s, a, b): _idle_since = s.check_idle() assert _idle_since == s.idle_since - with captured_logger("distributed.scheduler") as logs: + with captured_logger("distributed.scheduler") as caplog: start = time() while s.status != Status.closed: await asyncio.sleep(0.01) @@ -2405,9 +2405,11 @@ async def test_idle_timeout(c, s, a, b): await asyncio.sleep(0.01) assert time() < start + 1 - assert "idle" in logs.getvalue() - assert "500" in logs.getvalue() - assert "ms" in logs.getvalue() + logs = caplog.getvalue() + assert "idle" in logs + assert "500" in logs + assert "ms" in logs + assert "idle-timeout-exceeded" in logs assert s.idle_since > beginning pc.stop() @@ -5270,3 +5272,20 @@ async def test_stimulus_from_erred_task(c, s, a): logger.getvalue() == "Task f marked as failed because 1 workers died while trying to run it\n" ) + + +@gen_cluster(client=True) +async def test_concurrent_close_requests(c, s, *workers): + class BeforeCloseCounterPlugin(SchedulerPlugin): + async def start(self, scheduler): + self.call_count = 0 + + async def before_close(self): + self.call_count += 1 + + await c.register_plugin(BeforeCloseCounterPlugin(), name="before_close") + with captured_logger("distributed.scheduler", level=logging.INFO) as caplog: + await asyncio.gather(*[s.close(reason="test-reason") for _ in range(5)]) + assert s.plugins["before_close"].call_count == 1 + lines = caplog.getvalue().split("\n") + assert sum("Closing scheduler" in line for line in lines) == 1 From 3db8c467a13e9835cf0af03f0dbd4c6b4b062b59 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 13 Aug 2024 19:52:11 +0200 Subject: [PATCH 103/138] Avoid `RuntimeError: dictionary changed size during iteration` in `Server._shift_counters()` (#8828) --- distributed/core.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index d5eacb7d3c..a4bb031c12 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -424,10 +424,12 @@ def func(data): return {"status": "OK", "nbytes": len(data)} def _shift_counters(self): - for counter in self.counters.values(): + # Copy counters before iterating to avoid concurrent modification + for counter in list(self.counters.values()): counter.shift() if self.digests is not None: - for digest in self.digests.values(): + # Copy digests before iterating to avoid concurrent modification + for digest in list(self.digests.values()): digest.shift() @property From f12cc4f6805bc53f4e46a0ff6de5da3907ec6b62 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 14 Aug 2024 14:56:40 -0500 Subject: [PATCH 104/138] Update gpuCI `RAPIDS_VER` to `24.10` (#8786) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- continuous_integration/gpuci/axis.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/continuous_integration/gpuci/axis.yaml b/continuous_integration/gpuci/axis.yaml index 25b4c78ef9..7b30d50651 100644 --- a/continuous_integration/gpuci/axis.yaml +++ b/continuous_integration/gpuci/axis.yaml @@ -9,6 +9,6 @@ LINUX_VER: - ubuntu20.04 RAPIDS_VER: -- "24.08" +- "24.10" excludes: From 0c54e9d950288199e52fd0ed349a1982fe8038ed Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 16 Aug 2024 12:38:28 +0200 Subject: [PATCH 105/138] Reduce frequency of unmanaged memory use warning (#8834) --- distributed/worker_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/worker_memory.py b/distributed/worker_memory.py index ac2f3c8188..5465e94a8e 100644 --- a/distributed/worker_memory.py +++ b/distributed/worker_memory.py @@ -67,7 +67,7 @@ ] worker_logger = logging.getLogger("distributed.worker.memory") -worker_logger.addFilter(RateLimiterFilter(r"Unmanaged memory use is high")) +worker_logger.addFilter(RateLimiterFilter(r"Unmanaged memory use is high", rate="300s")) nanny_logger = logging.getLogger("distributed.nanny.memory") From cdad3cbf5c3a61d16770a1b54e6b7ccb4e153ecc Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 16 Aug 2024 16:53:45 -0500 Subject: [PATCH 106/138] bump version to 2024.8.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b12a53e567..75dd7c320c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ requires-python = ">=3.10" dependencies = [ "click >= 8.0", "cloudpickle >= 2.0.0", - "dask == 2024.8.0", + "dask == 2024.8.1", "jinja2 >= 2.10.3", "locket >= 1.0.0", "msgpack >= 1.0.2", From 3075b0886d32cf85c7eadd2863be32706b20e312 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 19 Aug 2024 03:05:56 -0500 Subject: [PATCH 107/138] Bump minimum cloudpickle to 3 (#8836) --- continuous_integration/environment-mindeps.yaml | 2 +- continuous_integration/recipes/distributed/meta.yaml | 2 +- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/continuous_integration/environment-mindeps.yaml b/continuous_integration/environment-mindeps.yaml index 9403ca64b8..9bc4996515 100644 --- a/continuous_integration/environment-mindeps.yaml +++ b/continuous_integration/environment-mindeps.yaml @@ -4,7 +4,7 @@ channels: dependencies: - python=3.10 - click=8.0 - - cloudpickle=2.0.0 + - cloudpickle=3.0.0 - cytoolz=0.11.2 - jinja2=2.10.3 - locket=1.0.0 diff --git a/continuous_integration/recipes/distributed/meta.yaml b/continuous_integration/recipes/distributed/meta.yaml index e6279ccda0..9b2fd9d5c1 100644 --- a/continuous_integration/recipes/distributed/meta.yaml +++ b/continuous_integration/recipes/distributed/meta.yaml @@ -32,7 +32,7 @@ requirements: run: - python >=3.10 - click >=8.0 - - cloudpickle >=2.0.0 + - cloudpickle >=3.0.0 - cytoolz >=0.11.2 - {{ pin_compatible('dask-core', max_pin='x.x.x.x') }} - jinja2 >=2.10.3 diff --git a/pyproject.toml b/pyproject.toml index 75dd7c320c..8ac18b1c5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ readme = "README.rst" requires-python = ">=3.10" dependencies = [ "click >= 8.0", - "cloudpickle >= 2.0.0", + "cloudpickle >= 3.0.0", "dask == 2024.8.1", "jinja2 >= 2.10.3", "locket >= 1.0.0", From fe79a362a45b2d2b2960f44feec471c6c5920ff3 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Tue, 20 Aug 2024 13:29:22 +0200 Subject: [PATCH 108/138] Ensure client_desires_keys does not corrupt Scheduler state (#8827) --- distributed/actor.py | 2 +- distributed/client.py | 54 ++++++----- distributed/queues.py | 13 ++- distributed/scheduler.py | 12 ++- distributed/tests/test_client.py | 135 +++++++++++----------------- distributed/tests/test_queues.py | 11 +++ distributed/tests/test_scheduler.py | 20 ----- distributed/tests/test_spans.py | 46 +--------- distributed/tests/test_variable.py | 9 ++ distributed/variable.py | 36 +++++--- 10 files changed, 140 insertions(+), 198 deletions(-) diff --git a/distributed/actor.py b/distributed/actor.py index 1fdbf5dae4..d2dea1848e 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -77,7 +77,7 @@ def _try_bind_worker_client(self): if not self._client: try: self._client = get_client() - self._future = Future(self._key, inform=False) + self._future = Future(self._key, self._client) # ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable. except ValueError: self._client = None diff --git a/distributed/client.py b/distributed/client.py index 2ece27802a..ebe4299d1a 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -163,6 +163,9 @@ def __str__(self) -> str: result = "\n".join([result, self.msg]) return result + def __reduce__(self): + return self.__class__, (self.key, self.reason, self.msg) + class FuturesCancelledError(CancelledError): error_groups: list[CancelledFuturesGroup] @@ -297,13 +300,12 @@ class Future(WrappedKey): # Make sure this stays unique even across multiple processes or hosts _uid = uuid.uuid4().hex - def __init__(self, key, client=None, inform=True, state=None, _id=None): + def __init__(self, key, client=None, state=None, _id=None): self.key = key self._cleared = False self._client = client self._id = _id or (Future._uid, next(Future._counter)) self._input_state = state - self._inform = inform self._state = None self._bind_late() @@ -312,13 +314,11 @@ def client(self): self._bind_late() return self._client + def bind_client(self, client): + self._client = client + self._bind_late() + def _bind_late(self): - if not self._client: - try: - client = get_client() - except ValueError: - client = None - self._client = client if self._client and not self._state: self._client._inc_ref(self.key) self._generation = self._client.generation @@ -328,15 +328,6 @@ def _bind_late(self): else: self._state = self._client.futures[self.key] = FutureState(self.key) - if self._inform: - self._client._send_to_scheduler( - { - "op": "client-desires-keys", - "keys": [self.key], - "client": self._client.id, - } - ) - if self._input_state is not None: try: handler = self._client._state_handlers[self._input_state] @@ -588,13 +579,8 @@ def release(self): except TypeError: # pragma: no cover pass # Shutting down, add_callback may be None - @staticmethod - def make_future(key, id): - # Can't use kwargs in pickle __reduce__ methods - return Future(key=key, _id=id) - def __reduce__(self) -> str | tuple[Any, ...]: - return Future.make_future, (self.key, self._id) + return Future, (self.key,) def __dask_tokenize__(self): return (type(self).__name__, self.key, self._id) @@ -2161,7 +2147,7 @@ def submit( with self._refcount_lock: if key in self.futures: - return Future(key, self, inform=False) + return Future(key, self) if allow_other_workers and workers is None: raise ValueError("Only use allow_other_workers= if using workers=") @@ -2661,7 +2647,7 @@ async def _scatter( timeout=timeout, ) - out = {k: Future(k, self, inform=False) for k in data} + out = {k: Future(k, self) for k in data} for key, typ in types.items(): self.futures[key].finish(type=typ) @@ -2969,12 +2955,14 @@ def list_datasets(self, **kwargs): async def _get_dataset(self, name, default=no_default): with self.as_current(): out = await self.scheduler.publish_get(name=name, client=self.id) - if out is None: if default is no_default: raise KeyError(f"Dataset '{name}' not found") else: return default + for fut in futures_of(out["data"]): + fut.bind_client(self) + self._inform_scheduler_of_futures() return out["data"] def get_dataset(self, name, default=no_default, **kwargs): @@ -3300,6 +3288,14 @@ def _get_computation_code( return tuple(reversed(code)) + def _inform_scheduler_of_futures(self): + self._send_to_scheduler( + { + "op": "client-desires-keys", + "keys": list(self.refcount), + } + ) + def _graph_to_futures( self, dsk, @@ -3348,7 +3344,7 @@ def _graph_to_futures( validate_key(key) # Create futures before sending graph (helps avoid contention) - futures = {key: Future(key, self, inform=False) for key in keyset} + futures = {key: Future(key, self) for key in keyset} # Circular import from distributed.protocol import serialize from distributed.protocol.serialize import ToPickle @@ -3507,7 +3503,7 @@ def _optimize_insert_futures(self, dsk, keys): if not changed: changed = True dsk = ensure_dict(dsk) - dsk[key] = Future(key, self, inform=False) + dsk[key] = Future(key, self) if changed: dsk, _ = dask.optimization.cull(dsk, keys) @@ -6092,7 +6088,7 @@ def futures_of(o, client=None): stack.extend(x.values()) elif type(x) is SubgraphCallable: stack.extend(x.dsk.values()) - elif isinstance(x, Future): + elif isinstance(x, WrappedKey): if x not in seen: seen.add(x) futures.append(x) diff --git a/distributed/queues.py b/distributed/queues.py index c48616459d..e2b0ee6e4d 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -8,7 +8,7 @@ from dask.utils import parse_timedelta from distributed.client import Future -from distributed.utils import wait_for +from distributed.utils import Deadline, wait_for from distributed.worker import get_client logger = logging.getLogger(__name__) @@ -67,15 +67,22 @@ def release(self, name=None, client=None): self.scheduler.client_releases_keys(keys=keys, client="queue-%s" % name) async def put(self, name=None, key=None, data=None, client=None, timeout=None): + deadline = Deadline.after(timeout) if key is not None: + while key not in self.scheduler.tasks: + await asyncio.sleep(0.01) + if deadline.expired: + raise TimeoutError(f"Task {key} unknown to scheduler.") + record = {"type": "Future", "value": key} self.future_refcount[name, key] += 1 self.scheduler.client_desires_keys(keys=[key], client="queue-%s" % name) else: record = {"type": "msgpack", "value": data} - await wait_for(self.queues[name].put(record), timeout=timeout) + await wait_for(self.queues[name].put(record), timeout=deadline.remaining) def future_release(self, name=None, key=None, client=None): + self.scheduler.client_desires_keys(keys=[key], client=client) self.future_refcount[name, key] -= 1 if self.future_refcount[name, key] == 0: self.scheduler.client_releases_keys(keys=[key], client="queue-%s" % name) @@ -265,7 +272,7 @@ async def _get(self, timeout=None, batch=False): def process(d): if d["type"] == "Future": - value = Future(d["value"], self.client, inform=True, state=d["state"]) + value = Future(d["value"], self.client, state=d["state"]) if d["state"] == "erred": value._state.set_error(d["exception"], d["traceback"]) self.client._send_to_scheduler( diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a989aa7fe3..57d0eca46b 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -670,9 +670,7 @@ def clean(self) -> WorkerState: ) ws._occupancy_cache = self.occupancy - ws.executing = { - ts.key: duration for ts, duration in self.executing.items() # type: ignore - } + ws.executing = {ts.key: duration for ts, duration in self.executing.items()} # type: ignore return ws def __repr__(self) -> str: @@ -4634,7 +4632,7 @@ def _match_graph_with_tasks( ): # bad key lost_keys.add(k) logger.info("User asked for computation on lost data, %s", k) - del dsk[k] + dsk.pop(k, None) del dependencies[k] if k in keys: keys.remove(k) @@ -5595,8 +5593,8 @@ def client_desires_keys(self, keys: Collection[Key], client: str) -> None: for k in keys: ts = self.tasks.get(k) if ts is None: - # For publish, queues etc. - ts = self.new_task(k, None, "released") + warnings.warn(f"Client desires key {k!r} but key is unknown.") + continue if ts.who_wants is None: ts.who_wants = set() ts.who_wants.add(cs) @@ -9345,7 +9343,7 @@ def transition( def _materialize_graph( graph: HighLevelGraph, global_annotations: dict[str, Any], validate: bool ) -> tuple[dict[Key, T_runspec], dict[Key, set[Key]], dict[str, dict[Key, Any]]]: - dsk = ensure_dict(graph) + dsk: dict = ensure_dict(graph) if validate: for k in dsk: validate_key(k) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 02b065ceda..5c2f023e79 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -102,7 +102,6 @@ dec, div, double, - ensure_no_new_clients, gen_cluster, gen_test, get_cert, @@ -2639,24 +2638,32 @@ def test_futures_of_class(): @gen_cluster(client=True) async def test_futures_of_cancelled_raises(c, s, a, b): x = c.submit(inc, 1) - await c.cancel([x]) - - with pytest.raises(CancelledError): + while x.key not in s.tasks: + await asyncio.sleep(0.01) + await c.cancel([x], reason="testreason") + + # Note: The scheduler currently doesn't remember the reason but rather + # forgets the task immediately. The reason is currently. only raised if the + # client checks on it. Therefore, we expect an unknown reason and definitely + # not a scheduler disconnected which would otherwise indicate a bug, e.g. an + # AssertionError during transitioning. + with pytest.raises(CancelledError, match="(reason: unknown|testreason)"): await x while x.key in s.tasks: await asyncio.sleep(0.01) - with pytest.raises(CancelledError): + + with pytest.raises(CancelledError, match="(reason: unknown|testreason)"): get_obj = c.get({"x": (inc, x), "y": (inc, 2)}, ["x", "y"], sync=False) gather_obj = c.gather(get_obj) await gather_obj - with pytest.raises(CancelledError): + with pytest.raises(CancelledError, match="(reason: unknown|testreason)"): await c.submit(inc, x) - with pytest.raises(CancelledError): + with pytest.raises(CancelledError, match="(reason: unknown|testreason)"): await c.submit(add, 1, y=x) - with pytest.raises(CancelledError): + with pytest.raises(CancelledError, match="(reason: unknown|testreason)"): await c.gather(c.map(add, [1], y=x)) @@ -3027,14 +3034,6 @@ async def test_rebalance_unprepared(c, s, a, b): s.validate_state() -@gen_cluster(client=True, config=NO_AMM) -async def test_rebalance_raises_on_explicit_missing_data(c, s, a, b): - """rebalance() raises KeyError if explicitly listed futures disappear""" - f = Future("x", client=c, state="memory") - with pytest.raises(KeyError, match="Could not rebalance keys:"): - await c.rebalance(futures=[f]) - - @gen_cluster(client=True) async def test_receive_lost_key(c, s, a, b): x = c.submit(inc, 1, workers=[a.address]) @@ -4141,51 +4140,6 @@ async def test_scatter_compute_store_lose_processing(c, s, a, b): assert z.status == "cancelled" -@gen_cluster() -async def test_serialize_future(s, a, b): - async with ( - Client(s.address, asynchronous=True) as c1, - Client(s.address, asynchronous=True) as c2, - ): - future = c1.submit(lambda: 1) - result = await future - - for ci in (c1, c2): - with ensure_no_new_clients(): - with ci.as_current(): - future2 = pickle.loads(pickle.dumps(future)) - assert future2.client is ci - assert future2.key in ci.futures - result2 = await future2 - assert result == result2 - with temp_default_client(ci): - future2 = pickle.loads(pickle.dumps(future)) - - -@gen_cluster() -async def test_serialize_future_without_client(s, a, b): - # Do not use a ctx manager to avoid having this being set as a current and/or default client - c1 = await Client(s.address, asynchronous=True, set_as_default=False) - try: - with ensure_no_new_clients(): - - def do_stuff(): - return 1 - - future = c1.submit(do_stuff) - pickled = pickle.dumps(future) - unpickled_fut = pickle.loads(pickled) - - with pytest.raises(RuntimeError): - await unpickled_fut - - with c1.as_current(): - unpickled_fut_ctx = pickle.loads(pickled) - assert await unpickled_fut_ctx == 1 - finally: - await c1.close() - - @gen_cluster() async def test_temp_default_client(s, a, b): async with ( @@ -5827,27 +5781,6 @@ async def test_client_with_name(s, a, b): assert "foo" in text -@gen_cluster(client=True) -async def test_future_defaults_to_default_client(c, s, a, b): - x = c.submit(inc, 1) - await wait(x) - - future = Future(x.key) - assert future.client is c - - -@gen_cluster(client=True) -async def test_future_auto_inform(c, s, a, b): - x = c.submit(inc, 1) - await wait(x) - - async with Client(s.address, asynchronous=True) as client: - future = Future(x.key, client) - - while future.status != "finished": - await asyncio.sleep(0.01) - - def test_client_async_before_loop_starts(cleanup): with pytest.raises( RuntimeError, @@ -8504,3 +8437,41 @@ def test_release_persisted_collection_sync(c): # submitting to the scheduler is different to what we are in # client.compute arr.compute() + + +@pytest.mark.slow() +@pytest.mark.parametrize("do_wait", [True, False]) +def test_worker_clients_do_not_claim_ownership_of_serialize_futures(c, do_wait): + # Note: sending collections like this should be considered an anti-pattern + # but it is possible. As long as the user ensures the futures stay alive + # this is fine but the cluster will not take over this responsibility. The + # client will not unpack the collection when using submit and will therefore + # not handle the dependencies in any way. + # See also https://github.com/dask/distributed/issues/7498 + da = pytest.importorskip("dask.array", exc_type=ImportError) + x = da.arange(10, chunks=(5,)).persist() + if do_wait: + wait(x) + + def f(x): + assert isinstance(x, da.Array) + return x.sum().compute() + + future = c.submit(f, x) + result = future.result() + assert result == sum(range(10)) + del x, future, result + + # Now we delete the persisted collection before computing the result + y = da.arange(10, chunks=(4,)).persist() + if do_wait: + wait(y) + future = c.submit(f, y) + del y + with pytest.raises(FutureCancelledError): + future.result() + del future + + future = c.submit(f, da.arange(10, chunks=(4,)).persist()) + with pytest.raises(FutureCancelledError): + future.result() diff --git a/distributed/tests/test_queues.py b/distributed/tests/test_queues.py index 8a0dd96ca9..11e0f379a8 100644 --- a/distributed/tests/test_queues.py +++ b/distributed/tests/test_queues.py @@ -332,3 +332,14 @@ async def test_unpickle_without_client(s): q3 = pickle.loads(pickled) await q3.put(1) assert await q3.get() == 1 + + +@gen_cluster(client=True, nthreads=[]) +async def test_set_cancelled_future(c, s): + x = c.submit(inc, 1) + await x.cancel() + q = Queue("x") + # FIXME: This is a TimeoutError but pytest doesn't appear to recognize it as + # such + with pytest.raises(Exception, match="unknown to scheduler"): + await q.put(x, timeout="100ms") diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index d492b769eb..a62fcd977e 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -32,7 +32,6 @@ CancelledError, Client, Event, - Future, Lock, Nanny, SchedulerPlugin, @@ -4851,25 +4850,6 @@ async def gather_dep(self, worker, *args, **kwargs): assert time() < start + connect_timeout -@gen_cluster(client=True) -async def test_client_desires_keys_creates_ts(c, s, a, b): - """A TaskState object is created by client_desires_keys, and - is only later submitted with submit/compute by a different client - - See also - -------- - test_scheduler.py::test_scatter_creates_ts - test_spans.py::test_client_desires_keys_creates_ts - """ - x = Future(key="x") - await wait_for_state("x", "released", s) - assert s.tasks["x"].run_spec is None - async with Client(s.address, asynchronous=True) as c2: - c2.submit(inc, 1, key="x") - assert await x == 2 - assert s.tasks["x"].run_spec is not None - - @gen_cluster(client=True) async def test_scatter_creates_ts(c, s, a, b): """A TaskState object is created by scatter, and only later becomes runnable diff --git a/distributed/tests/test_spans.py b/distributed/tests/test_spans.py index 2f54b4d3bf..63db4ccf4c 100644 --- a/distributed/tests/test_spans.py +++ b/distributed/tests/test_spans.py @@ -8,7 +8,7 @@ from dask import delayed import distributed -from distributed import Client, Event, Future, Worker, span, wait +from distributed import Client, Event, Worker, span, wait from distributed.diagnostics.plugin import SchedulerPlugin from distributed.utils_test import ( NoSchedulerDelayWorker, @@ -386,46 +386,6 @@ def test_no_tags(): pass -@gen_cluster(client=True) -async def test_client_desires_keys_creates_ts(c, s, a, b): - """A TaskState object is created by client_desires_keys, and - is only later submitted with submit/compute by a different client - - See also - -------- - test_scheduler.py::test_client_desires_keys_creates_ts - test_spans.py::test_client_desires_keys_creates_tg - test_spans.py::test_scatter_creates_ts - test_spans.py::test_scatter_creates_tg - """ - x = Future(key="x") - await wait_for_state("x", "released", s) - assert s.tasks["x"].group.span_id is None - async with Client(s.address, asynchronous=True) as c2: - c2.submit(inc, 1, key="x") - assert await x == 2 - assert s.tasks["x"].group.span_id is not None - - -@gen_cluster(client=True) -async def test_client_desires_keys_creates_tg(c, s, a, b): - """A TaskGroup object is created by client_desires_keys, and - only later gains runnable tasks - - See also - -------- - test_spans.py::test_client_desires_keys_creates_ts - test_spans.py::test_scatter_creates_ts - test_spans.py::test_scatter_creates_tg - """ - x0 = Future(key="x-0") - await wait_for_state("x-0", "released", s) - assert s.tasks["x-0"].group.span_id is None - x1 = c.submit(inc, 1, key="x-1") - assert await x1 == 2 - assert s.tasks["x-0"].group.span_id is not None - - @gen_cluster(client=True) async def test_scatter_creates_ts(c, s, a, b): """A TaskState object is created by scatter, and only later becomes runnable @@ -433,8 +393,6 @@ async def test_scatter_creates_ts(c, s, a, b): See also -------- test_scheduler.py::test_scatter_creates_ts - test_spans.py::test_client_desires_keys_creates_ts - test_spans.py::test_client_desires_keys_creates_tg test_spans.py::test_scatter_creates_tg """ x1 = (await c.scatter({"x": 1}, workers=[a.address]))["x"] @@ -454,8 +412,6 @@ async def test_scatter_creates_tg(c, s, a, b): See also -------- - test_spans.py::test_client_desires_keys_creates_ts - test_spans.py::test_client_desires_keys_creates_tg test_spans.py::test_scatter_creates_ts """ x0 = (await c.scatter({"x-0": 1}))["x-0"] diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 00618da70d..c0ee24416a 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -325,3 +325,12 @@ async def test_unpickle_without_client(s): obj3 = pickle.loads(pickled) await obj3.set(42) assert await obj3.get() == 42 + + +@gen_cluster(client=True, nthreads=[]) +async def test_set_cancelled_future(c, s): + x = c.submit(inc, 1) + await x.cancel() + v = Variable("x") + with pytest.raises(TimeoutError): + await v.set(x, timeout="5ms") diff --git a/distributed/variable.py b/distributed/variable.py index 3df28ff359..befec99484 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -12,7 +12,7 @@ from distributed.client import Future from distributed.metrics import time -from distributed.utils import TimeoutError, log_errors, wait_for +from distributed.utils import Deadline, TimeoutError, log_errors, wait_for from distributed.worker import get_client logger = logging.getLogger(__name__) @@ -39,12 +39,19 @@ def __init__(self, scheduler): {"variable_set": self.set, "variable_get": self.get} ) - self.scheduler.stream_handlers["variable-future-release"] = self.future_release + self.scheduler.stream_handlers[ + "variable-future-received-confirm" + ] = self.future_received_confirm self.scheduler.stream_handlers["variable_delete"] = self.delete - async def set(self, name=None, key=None, data=None, client=None): + async def set(self, name=None, key=None, data=None, client=None, timeout=None): + deadline = Deadline.after(parse_timedelta(timeout)) if key is not None: record = {"type": "Future", "value": key} + while key not in self.scheduler.tasks: + await asyncio.sleep(0.01) + if deadline.expired: + raise TimeoutError(f"Task {key} unknown to scheduler.") self.scheduler.client_desires_keys(keys=[key], client="variable-%s" % name) else: record = {"type": "msgpack", "value": data} @@ -68,7 +75,10 @@ async def release(self, key, name): self.scheduler.client_releases_keys(keys=[key], client="variable-%s" % name) del self.waiting[key, name] - async def future_release(self, name=None, key=None, token=None, client=None): + async def future_received_confirm( + self, name=None, key=None, token=None, client=None + ): + self.scheduler.client_desires_keys([key], client) self.waiting[key, name].remove(token) if not self.waiting[key, name]: async with self.waiting_conditions[name]: @@ -182,13 +192,17 @@ def _verify_running(self): " a Client or Worker." ) - async def _set(self, value): + async def _set(self, value, timeout): if isinstance(value, Future): - await self.client.scheduler.variable_set(key=value.key, name=self.name) + await self.client.scheduler.variable_set( + key=value.key, name=self.name, timeout=timeout + ) else: - await self.client.scheduler.variable_set(data=value, name=self.name) + await self.client.scheduler.variable_set( + data=value, name=self.name, timeout=timeout + ) - def set(self, value, **kwargs): + def set(self, value, timeout="30 s", **kwargs): """Set the value of this variable Parameters @@ -197,19 +211,19 @@ def set(self, value, **kwargs): Must be either a Future or a msgpack-encodable value """ self._verify_running() - return self.client.sync(self._set, value, **kwargs) + return self.client.sync(self._set, value, timeout=timeout, **kwargs) async def _get(self, timeout=None): d = await self.client.scheduler.variable_get( timeout=timeout, name=self.name, client=self.client.id ) if d["type"] == "Future": - value = Future(d["value"], self.client, inform=True, state=d["state"]) + value = Future(d["value"], self.client, state=d["state"]) if d["state"] == "erred": value._state.set_error(d["exception"], d["traceback"]) self.client._send_to_scheduler( { - "op": "variable-future-release", + "op": "variable-future-received-confirm", "name": self.name, "key": d["value"], "token": d["token"], From e9d8233804897201169777f06feef0a1bb556b36 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 20 Aug 2024 15:35:27 +0200 Subject: [PATCH 109/138] Use task-based rechunking to prechunk along partial boundaries (#8831) --- distributed/shuffle/_rechunk.py | 271 ++++++++++------------ distributed/shuffle/tests/test_rechunk.py | 77 +++--- 2 files changed, 177 insertions(+), 171 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 99f9d39061..962b57fb32 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -153,6 +153,8 @@ ChunkedAxis: TypeAlias = tuple[float, ...] # chunks must either be an int or NaN ChunkedAxes: TypeAlias = tuple[ChunkedAxis, ...] NDSlice: TypeAlias = tuple[slice, ...] +SlicedAxis: TypeAlias = tuple[slice, ...] +SlicedAxes: TypeAlias = tuple[SlicedAxis, ...] def rechunk_transfer( @@ -187,14 +189,6 @@ class _Partial(NamedTuple): old: slice #: Slice of the new chunks along this axis that belong to the partial new: slice - #: Index of the first value of the left-most old chunk along this axis - #: to include in this partial. Everything left to this index belongs to - #: the previous partial. - left_start: int - #: Index of the first value of the right-most old chunk along this axis - #: to exclude from this partial. - #: This corresponds to `left_start` of the subsequent partial. - right_stop: int class _NDPartial(NamedTuple): @@ -204,17 +198,6 @@ class _NDPartial(NamedTuple): old: NDSlice #: n-dimensional slice of the new chunks along each axis that belong to the partial new: NDSlice - #: Indices of the first value of the left-most old chunk along each axis - #: to include in this partial. Everything left to this index belongs to - #: the previous partial. - left_starts: NDIndex - #: Indices of the first value of the right-most old chunk along each axis - #: to exclude from this partial. - #: This corresponds to `left_start` of the subsequent partial. - right_stops: NDIndex - #: Index of the partial among all partials. - #: This corresponds to the position of the partial in the n-dimensional grid of - #: partials representing the full rechunk. ix: NDIndex @@ -222,7 +205,14 @@ def rechunk_name(token: str) -> str: return f"rechunk-p2p-{token}" -def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array: +def rechunk_p2p( + x: da.Array, + chunks: ChunkedAxes, + *, + threshold: int | None = None, + block_size_limit: int | None = None, + balance: bool = False, +) -> da.Array: import dask.array as da if x.size == 0: @@ -230,6 +220,19 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array: return da.empty(x.shape, chunks=chunks, dtype=x.dtype) from dask.array.core import new_da_object + prechunked = _calculate_prechunking(x.chunks, chunks) + if prechunked != x.chunks: + x = cast( + "da.Array", + x.rechunk( + chunks=prechunked, + threshold=threshold, + block_size_limit=block_size_limit, + balance=balance, + method="tasks", + ), + ) + token = tokenize(x, chunks) name = rechunk_name(token) disk: bool = dask.config.get("distributed.p2p.disk") @@ -396,24 +399,22 @@ def _construct_graph(self) -> _T_LowLevelGraph: dsk: _T_LowLevelGraph = {} _old_to_new = old_to_new(self.chunks_input, self.chunks) - chunked_shape = tuple(len(axis) for axis in self.chunks) - for ndpartial in _split_partials(_old_to_new, chunked_shape): + for ndpartial in _split_partials(_old_to_new): partial_keepmap = self.keepmap[ndpartial.new] output_count = np.sum(partial_keepmap) if output_count == 0: continue elif output_count == 1: # Single output chunk - ndindex = np.argwhere(partial_keepmap)[0] - ndpartial = _truncate_partial(ndindex, ndpartial, _old_to_new) - dsk.update( partial_concatenate( input_name=self.name_input, input_chunks=self.chunks_input, ndpartial=ndpartial, token=self.token, + keepmap=self.keepmap, + old_to_new=_old_to_new, ) ) else: @@ -431,71 +432,120 @@ def _construct_graph(self) -> _T_LowLevelGraph: return dsk +def _calculate_prechunking( + old_chunks: ChunkedAxes, new_chunks: ChunkedAxes +) -> ChunkedAxes: + from dask.array.rechunk import old_to_new + + _old_to_new = old_to_new(old_chunks, new_chunks) + + partials = _slice_new_chunks_into_partials(_old_to_new) + + split_axes = [] + + for axis_index, slices in enumerate(partials): + old_to_new_axis = _old_to_new[axis_index] + old_axis = old_chunks[axis_index] + split_axis = [] + for slice_ in slices: + first_new_chunk = slice_.start + first_old_chunk, first_old_slice = old_to_new_axis[first_new_chunk][0] + last_new_chunk = slice_.stop - 1 + last_old_chunk, last_old_slice = old_to_new_axis[last_new_chunk][-1] + + first_chunk_size = old_axis[first_old_chunk] + last_chunk_size = old_axis[last_old_chunk] + + if first_old_chunk == last_old_chunk: + chunk_size = first_chunk_size + if ( + last_old_slice.stop is not None + and last_old_slice.stop != last_chunk_size + ): + chunk_size = last_old_slice.stop + if first_old_slice.start != 0: + chunk_size -= first_old_slice.start + split_axis.append(chunk_size) + continue + + split_axis.append(first_chunk_size - first_old_slice.start) + + split_axis.extend(old_axis[first_old_chunk + 1 : last_old_chunk]) + + if last_old_slice.stop is not None: + chunk_size = last_old_slice.stop + else: + chunk_size = last_chunk_size + + split_axis.append(chunk_size) + + split_axes.append(split_axis) + return tuple(tuple(axis) for axis in split_axes) + + def _split_partials( old_to_new: list[Any], - chunked_shape: tuple[int, ...], ) -> Generator[_NDPartial, None, None]: """Split the rechunking into partials that can be performed separately""" - partials_per_axis = _split_partials_per_axis(old_to_new, chunked_shape) + partials_per_axis = _split_partials_per_axis(old_to_new) indices_per_axis = (range(len(partials)) for partials in partials_per_axis) for nindex, partial_per_axis in zip( product(*indices_per_axis), product(*partials_per_axis) ): - old, new, left_starts, right_stops = zip(*partial_per_axis) - yield _NDPartial(old, new, left_starts, right_stops, nindex) + old, new = zip(*partial_per_axis) + yield _NDPartial(old, new, nindex) -def _split_partials_per_axis( - old_to_new: list[Any], chunked_shape: tuple[int, ...] -) -> tuple[tuple[_Partial, ...], ...]: +def _split_partials_per_axis(old_to_new: list[Any]) -> tuple[tuple[_Partial, ...], ...]: """Split the rechunking into partials that can be performed separately on each axis""" - sliced_axes = _partial_slices(old_to_new, chunked_shape) + sliced_axes = _slice_new_chunks_into_partials(old_to_new) partial_axes = [] for axis_index, slices in enumerate(sliced_axes): partials = [] for slice_ in slices: last_old_chunk: int - first_old_chunk, first_old_slice = old_to_new[axis_index][slice_.start][0] - last_old_chunk, last_old_slice = old_to_new[axis_index][slice_.stop - 1][-1] + first_old_chunk, _ = old_to_new[axis_index][slice_.start][0] + last_old_chunk, _ = old_to_new[axis_index][slice_.stop - 1][-1] partials.append( _Partial( old=slice(first_old_chunk, last_old_chunk + 1), new=slice_, - left_start=first_old_slice.start, - right_stop=last_old_slice.stop, ) ) partial_axes.append(tuple(partials)) return tuple(partial_axes) -def _partial_slices( - old_to_new: list[list[list[tuple[int, slice]]]], chunked_shape: NDIndex -) -> tuple[tuple[slice, ...], ...]: - """Compute the slices of the new chunks that can be computed separately""" +def _slice_new_chunks_into_partials( + old_to_new: list[list[list[tuple[int, slice]]]] +) -> SlicedAxes: + """Slice the new chunks into partials that can be computed separately""" sliced_axes = [] + chunk_shape = tuple(len(axis) for axis in old_to_new) + for axis_index, old_to_new_axis in enumerate(old_to_new): # Two consecutive output chunks A and B belong to the same partial rechunk - # if B is fully included in the right-most input chunk of A, i.e., - # separating A and B would not allow us to cull more input tasks. + # if A and B share the same input chunks, i.e., separating A and B would not + # allow us to cull more input tasks. # Index of the last input chunk of this partial rechunk - last_old_chunk: int | None = None + first_old_chunk: int | None = None partial_splits = [0] recipe: list[tuple[int, slice]] for new_chunk_index, recipe in enumerate(old_to_new_axis): if len(recipe) == 0: continue - current_last_old_chunk, old_slice = recipe[-1] - if last_old_chunk is None: - last_old_chunk = current_last_old_chunk - elif last_old_chunk != current_last_old_chunk: + current_first_old_chunk, _ = recipe[0] + current_last_old_chunk, _ = recipe[-1] + if first_old_chunk is None: + first_old_chunk = current_first_old_chunk + elif first_old_chunk != current_last_old_chunk: partial_splits.append(new_chunk_index) - last_old_chunk = current_last_old_chunk - partial_splits.append(chunked_shape[axis_index]) + first_old_chunk = current_first_old_chunk + partial_splits.append(chunk_shape[axis_index]) sliced_axes.append( tuple(slice(a, b) for a, b in toolz.sliding_window(2, partial_splits)) ) @@ -517,6 +567,8 @@ def partial_concatenate( input_chunks: ChunkedAxes, ndpartial: _NDPartial, token: str, + keepmap: np.ndarray, + old_to_new: list[Any], ) -> dict[Key, Any]: import numpy as np @@ -527,80 +579,45 @@ def partial_concatenate( slice_group = f"rechunk-slice-{token}" - old_offset = tuple(slice_.start for slice_ in ndpartial.old) + partial_keepmap = keepmap[ndpartial.new] + assert np.sum(partial_keepmap) == 1 + partial_new_index = np.argwhere(partial_keepmap)[0] - shape = tuple(slice_.stop - slice_.start for slice_ in ndpartial.old) - rec_cat_arg = np.empty(shape, dtype="O") + global_new_index = tuple( + int(ix) + slc.start for ix, slc in zip(partial_new_index, ndpartial.new) + ) - partial_old = _compute_partial_old_chunks(ndpartial, input_chunks) + inputs = tuple( + old_to_new_axis[ix] for ix, old_to_new_axis in zip(global_new_index, old_to_new) + ) + shape = tuple(len(axis) for axis in inputs) + rec_cat_arg = np.empty(shape, dtype="O") - for old_partial_index in _partial_ndindex(ndpartial.old): - old_global_index = _global_index(old_partial_index, old_offset) - # TODO: Precompute slicing to avoid duplicate work - ndslice = ndslice_for( - old_partial_index, partial_old, ndpartial.left_starts, ndpartial.right_stops + for old_partial_index in np.ndindex(shape): + old_global_index, old_slice = zip( + *(input_axis[index] for index, input_axis in zip(old_partial_index, inputs)) ) - original_shape = tuple( - axis[index] for index, axis in zip(old_global_index, input_chunks) + old_axis[index] for index, old_axis in zip(old_global_index, input_chunks) ) - if _slicing_is_necessary(ndslice, original_shape): # type: ignore + if _slicing_is_necessary(old_slice, original_shape): key = (slice_group,) + ndpartial.ix + old_global_index rec_cat_arg[old_partial_index] = key dsk[key] = ( getitem, (input_name,) + old_global_index, - ndslice, + old_slice, ) else: rec_cat_arg[old_partial_index] = (input_name,) + old_global_index - global_index = tuple(int(slice_.start) for slice_ in ndpartial.new) - dsk[(rechunk_name(token),) + global_index] = ( + + dsk[(rechunk_name(token),) + global_new_index] = ( concatenate3, rec_cat_arg.tolist(), ) return dsk -def _truncate_partial( - ndindex: NDIndex, - ndpartial: _NDPartial, - old_to_new: list[Any], -) -> _NDPartial: - partial_per_axis = [] - for axis_index, index in enumerate(ndindex): - slc = slice( - ndpartial.new[axis_index].start + index, - ndpartial.new[axis_index].start + index + 1, - ) - first_old_chunk, first_old_slice = old_to_new[axis_index][slc.start][0] - last_old_chunk, last_old_slice = old_to_new[axis_index][slc.stop - 1][-1] - partial_per_axis.append( - _Partial( - old=slice(first_old_chunk, last_old_chunk + 1), - new=slc, - left_start=first_old_slice.start, - right_stop=last_old_slice.stop, - ) - ) - - old, new, left_starts, right_stops = zip(*partial_per_axis) - return _NDPartial(old, new, left_starts, right_stops, ndpartial.ix) - - -def _compute_partial_old_chunks( - partial: _NDPartial, chunks: ChunkedAxes -) -> ChunkedAxes: - _partial_old = [] - for axis_index in range(len(partial.old)): - c = list(chunks[axis_index][partial.old[axis_index]]) - c[0] = c[0] - partial.left_starts[axis_index] - if (stop := partial.right_stops[axis_index]) is not None: - c[-1] = stop - _partial_old.append(tuple(c)) - return tuple(_partial_old) - - def _slicing_is_necessary(slice: NDSlice, shape: tuple[int | None, ...]) -> bool: """Return True if applying the slice alters the shape, False otherwise.""" return not all( @@ -618,15 +635,12 @@ def partial_rechunk( disk: bool, keepmap: np.ndarray, ) -> dict[Key, Any]: - from dask.array.chunk import getitem - dsk: dict[Key, Any] = {} old_partial_offset = tuple(slice_.start for slice_ in ndpartial.old) partial_token = tokenize(token, ndpartial.ix) # Use `token` to generate a canonical group for the entire rechunk - slice_group = f"rechunk-slice-{token}" transfer_group = f"rechunk-transfer-{token}" unpack_group = rechunk_name(token) # We can use `partial_token` here because the barrier task share their @@ -636,32 +650,19 @@ def partial_rechunk( ndim = len(input_chunks) - partial_old = _compute_partial_old_chunks(ndpartial, input_chunks) + partial_old = tuple( + chunk_axis[partial_axis] + for partial_axis, chunk_axis in zip(ndpartial.old, input_chunks) + ) partial_new: ChunkedAxes = tuple( chunks[axis_index][ndpartial.new[axis_index]] for axis_index in range(ndim) ) transfer_keys = [] for partial_index in _partial_ndindex(ndpartial.old): - # FIXME: Do not shuffle data for output chunks that we culled - ndslice = ndslice_for( - partial_index, partial_old, ndpartial.left_starts, ndpartial.right_stops - ) - global_index = _global_index(partial_index, old_partial_offset) - original_shape = tuple( - axis[index] for index, axis in zip(global_index, input_chunks) - ) - if _slicing_is_necessary(ndslice, original_shape): # type: ignore - input_key = (slice_group,) + ndpartial.ix + global_index - dsk[input_key] = ( - getitem, - (input_name,) + global_index, - ndslice, - ) - else: - input_key = (input_name,) + global_index + input_key = (input_name,) + global_index key = (transfer_group,) + ndpartial.ix + global_index transfer_keys.append(key) @@ -690,26 +691,6 @@ def partial_rechunk( return dsk -def ndslice_for( - partial_index: NDIndex, - chunks: ChunkedAxes, - left_starts: NDIndex, - right_stops: NDIndex, -) -> NDSlice: - slices = [] - shape = tuple(len(axis) for axis in chunks) - for axis_index, chunked_axis in enumerate(chunks): - chunk_index = partial_index[axis_index] - start = left_starts[axis_index] if chunk_index == 0 else 0 - stop = ( - right_stops[axis_index] - if chunk_index == shape[axis_index] - 1 - else chunked_axis[chunk_index] + start - ) - slices.append(slice(start, stop)) - return tuple(slices) - - class Split(NamedTuple): """Slice of a chunk that is concatenated with other splits to create a new chunk diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 69438b4473..7fa608bef3 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -31,6 +31,7 @@ ArrayRechunkRun, ArrayRechunkSpec, Split, + _calculate_prechunking, split_axes, ) from distributed.shuffle.tests.utils import AbstractShuffleTestPool @@ -188,17 +189,19 @@ async def test_rechunk_configuration(c, s, *ws, config_value, keyword): -------- dask.array.tests.test_rechunk.test_rechunk_1d """ - a = np.random.default_rng().uniform(0, 1, 30) - x = da.from_array(a, chunks=((10,) * 3,)) - new = ((6,) * 5,) + a = np.random.default_rng().uniform(0, 1, 100).reshape((10, 10)) + x = da.from_array(a, chunks=(10, 1)) + new = ((1,) * 10, (10,)) config = {"array.rechunk.method": config_value} if config_value is not None else {} with dask.config.set(config): x2 = rechunk(x, chunks=new, method=keyword) expected_algorithm = keyword if keyword is not None else config_value if expected_algorithm == "p2p": - assert all(key[0].startswith("rechunk-p2p") for key in x2.__dask_keys__()) + assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()) else: - assert not any(key[0].startswith("rechunk-p2p") for key in x2.__dask_keys__()) + assert not any( + key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__() + ) assert x2.chunks == new assert np.all(await c.compute(x2) == a) @@ -1315,30 +1318,17 @@ async def test_partial_rechunk_homogeneous_distribution(c, s, *workers): async def test_partial_rechunk_taskgroups(c, s): """Regression test for https://github.com/dask/distributed/issues/8656""" arr = da.random.random( - (10, 10, 10), + (10, 10), chunks=( - ( - 2, - 2, - 2, - 2, - 2, - ), - ) - * 3, + (1,) * 10, + (2,) * 5, + ), ) arr = arr.rechunk( ( - ( - 1, - 2, - 2, - 2, - 2, - 1, - ), - ) - * 3, + (2,) * 5, + (1,) * 10, + ), method="p2p", ) @@ -1350,4 +1340,39 @@ async def test_partial_rechunk_taskgroups(c, s): ), timeout=5, ) - assert len(s.task_groups) < 7 + assert len(s.task_groups) < 6 + + +@pytest.mark.parametrize( + ["old", "new", "expected"], + [ + [((2, 2),), ((2, 2),), ((2, 2),)], + [((2, 2),), ((4,),), ((2, 2),)], + [((2, 2),), ((1, 1, 1, 1),), ((2, 2),)], + [((2, 2, 2),), ((1, 2, 2, 1),), ((1, 1, 1, 1, 1, 1),)], + [((1, np.nan),), ((1, np.nan),), ((1, np.nan),)], + ], +) +def test_calculate_prechunking_1d(old, new, expected): + actual = _calculate_prechunking(old, new) + assert actual == expected + + +@pytest.mark.parametrize( + ["old", "new", "expected"], + [ + [((2, 2), (3, 3)), ((2, 2), (3, 3)), ((2, 2), (3, 3))], + [((2, 2), (3, 3)), ((4,), (3, 3)), ((2, 2), (3, 3))], + [((2, 2), (3, 3)), ((1, 1, 1, 1), (3, 3)), ((2, 2), (3, 3))], + [ + ((2, 2, 2), (3, 3, 3)), + ((1, 2, 2, 1), (2, 3, 4)), + ((1, 1, 1, 1, 1, 1), (2, 1, 2, 1, 3)), + ], + [((1, np.nan), (3, 3)), ((1, np.nan), (2, 2, 2)), ((1, np.nan), (2, 1, 1, 2))], + [((4,), (1, 1, 1)), ((1, 1, 1, 1), (3,)), ((4,), (1, 1, 1))], + ], +) +def test_calculate_prechunking_2d(old, new, expected): + actual = _calculate_prechunking(old, new) + assert actual == expected From f5c30e852b32acd2834f2cf5319bc8cab8556e6a Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 20 Aug 2024 10:42:49 -0500 Subject: [PATCH 110/138] Remove more Python 3.10 compatibility code (#8824) --- distributed/comm/registry.py | 14 +-------- distributed/security.py | 3 +- distributed/tests/test_client.py | 19 ------------ .../tests/test_worker_state_machine.py | 4 --- distributed/utils.py | 31 +------------------ distributed/worker_state_machine.py | 8 +---- 6 files changed, 4 insertions(+), 75 deletions(-) diff --git a/distributed/comm/registry.py b/distributed/comm/registry.py index 47ba730a7d..db9c0baa29 100644 --- a/distributed/comm/registry.py +++ b/distributed/comm/registry.py @@ -1,7 +1,6 @@ from __future__ import annotations import importlib.metadata -import sys from abc import ABC, abstractmethod from collections.abc import Iterable from typing import Protocol @@ -12,18 +11,7 @@ def __call__(self, **kwargs: str) -> Iterable[importlib.metadata.EntryPoint]: ... -if sys.version_info >= (3, 10): - # py3.10 importlib.metadata type annotations are not in mypy yet - # https://github.com/python/typeshed/pull/7331 - _entry_points: _EntryPoints = importlib.metadata.entry_points # type: ignore[assignment] -else: - - def _entry_points( - *, group: str, name: str - ) -> Iterable[importlib.metadata.EntryPoint]: - for ep in importlib.metadata.entry_points().get(group, []): - if ep.name == name: - yield ep +_entry_points: _EntryPoints = importlib.metadata.entry_points # type: ignore[assignment] class Backend(ABC): diff --git a/distributed/security.py b/distributed/security.py index ae8288bec4..e3a3df906b 100644 --- a/distributed/security.py +++ b/distributed/security.py @@ -3,7 +3,6 @@ import datetime import os import ssl -import sys import tempfile import warnings @@ -13,7 +12,7 @@ __all__ = ("Security",) -if sys.version_info >= (3, 10) or ssl.OPENSSL_VERSION_INFO >= (1, 1, 0, 7): +if ssl.OPENSSL_VERSION_INFO >= (1, 1, 0, 7): # The OP_NO_SSL* and OP_NO_TLS* become deprecated in favor of # 'SSLContext.minimum_version' from Python 3.7 onwards, however # this attribute is not available unless the ssl module is compiled diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 5c2f023e79..e3a596855c 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6823,16 +6823,6 @@ async def f(stacklevel, mode=None): assert "cdn.bokeh.org" in data -@pytest.mark.skipif( - sys.version_info >= (3, 10), - reason="On Py3.10+ semaphore._loop is not bound until .acquire() blocks", -) -@gen_cluster(nthreads=[]) -async def test_client_gather_semaphore_loop(s): - async with Client(s.address, asynchronous=True) as c: - assert c._gather_semaphore._loop is c.loop.asyncio_loop - - @gen_cluster(client=True) async def test_as_completed_condition_loop(c, s, a, b): seq = c.map(inc, range(5)) @@ -6843,15 +6833,6 @@ async def test_as_completed_condition_loop(c, s, a, b): assert ac.condition._loop == c.loop.asyncio_loop -@pytest.mark.skipif( - sys.version_info >= (3, 10), - reason="On Py3.10+ semaphore._loop is not bound until .acquire() blocks", -) -def test_client_connectionpool_semaphore_loop(s, a, b, loop): - with Client(s["address"], loop=loop) as c: - assert c.rpc.semaphore._loop is loop.asyncio_loop - - @pytest.mark.slow @gen_cluster(client=True, nthreads=[], config={"distributed.comm.compression": None}) @pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost") diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 5508355929..1772ca3a94 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -3,7 +3,6 @@ import asyncio import gc import pickle -import sys from collections import defaultdict from collections.abc import Iterator from time import sleep @@ -287,9 +286,6 @@ def traverse_subclasses(cls: type) -> Iterator[type]: [ pytest.param( TaskState, - marks=pytest.mark.skipif( - sys.version_info < (3, 10), reason="Requires @dataclass(slots=True)" - ), ), *traverse_subclasses(Instruction), *traverse_subclasses(StateMachineEvent), diff --git a/distributed/utils.py b/distributed/utils.py index 6f23799ebd..0c84dd1737 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -20,6 +20,7 @@ import warnings import weakref import xml.etree.ElementTree +from asyncio import Event as LateLoopEvent from asyncio import TimeoutError from collections import deque from collections.abc import ( @@ -441,36 +442,6 @@ def wait(timeout: float | None) -> bool: return result -if sys.version_info >= (3, 10): - from asyncio import Event as LateLoopEvent -else: - # In python 3.10 asyncio.Lock and other primitives no longer support - # passing a loop kwarg to bind to a loop running in another thread - # e.g. calling from Client(asynchronous=False). Instead the loop is bound - # as late as possible: when calling any methods that wait on or wake - # Future instances. See: https://bugs.python.org/issue42392 - class LateLoopEvent: - _event: asyncio.Event | None - - def __init__(self) -> None: - self._event = None - - def set(self) -> None: - if self._event is None: - self._event = asyncio.Event() - - self._event.set() - - def is_set(self) -> bool: - return self._event is not None and self._event.is_set() - - async def wait(self) -> bool: - if self._event is None: - self._event = asyncio.Event() - - return await self._event.wait() - - class _CollectErrorThread: def __init__(self, target: Callable[[], None], daemon: bool, name: str): self._exception: BaseException | None = None diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index e88a026b48..caf6ed2ac7 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -7,7 +7,6 @@ import math import operator import random -import sys import warnings import weakref from collections import Counter, defaultdict, deque @@ -199,12 +198,7 @@ def _default_data_size() -> int: return parse_bytes(dask.config.get("distributed.scheduler.default-data-size")) -# Note: can't specify __slots__ manually to enable slots in Python <3.10 in a @dataclass -# that defines any default values -DC_SLOTS = {"slots": True} if sys.version_info >= (3, 10) else {} - - -@dataclass(repr=False, eq=False, **DC_SLOTS) +@dataclass(repr=False, eq=False, slots=True) class TaskState: """Holds volatile state relating to an individual Dask task. From 30e01fb6683c225223121e0f39190fab01c7f164 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 20 Aug 2024 18:21:13 +0200 Subject: [PATCH 111/138] Fix PipInstall plugin on Worker (#8839) --- distributed/diagnostics/plugin.py | 17 ++++++++++++++--- .../diagnostics/tests/test_install_plugin.py | 15 +++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 7e2c8eea75..669f915292 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -442,6 +442,9 @@ def __init__( self.name = f"{self.__class__.__name__}-{uuid.uuid4()}" async def start(self, scheduler: Scheduler) -> None: + from distributed.core import clean_exception + from distributed.protocol.serialize import Serialized, deserialize + self._scheduler = scheduler if InstallPlugin._lock is None: @@ -452,7 +455,7 @@ async def start(self, scheduler: Scheduler) -> None: if self.restart_workers: nanny_plugin = _InstallNannyPlugin(self._install_fn, self.name) - await scheduler.register_nanny_plugin( + responses = await scheduler.register_nanny_plugin( comm=None, plugin=dumps(nanny_plugin), name=self.name, @@ -460,12 +463,21 @@ async def start(self, scheduler: Scheduler) -> None: ) else: worker_plugin = _InstallWorkerPlugin(self._install_fn, self.name) - await scheduler.register_worker_plugin( + responses = await scheduler.register_worker_plugin( comm=None, plugin=dumps(worker_plugin), name=self.name, idempotent=True, ) + for response in responses.values(): + if response["status"] == "error": + response = { # type: ignore[unreachable] + k: deserialize(v.header, v.frames) + for k, v in response.items() + if isinstance(v, Serialized) + } + _, exc, tb = clean_exception(**response) + raise exc.with_traceback(tb) async def close(self) -> None: assert InstallPlugin._lock is not None @@ -563,7 +575,6 @@ async def setup(self, worker): await Semaphore( max_leases=1, name=socket.gethostname(), - register=True, scheduler_rpc=worker.scheduler, loop=worker.loop, ) diff --git a/distributed/diagnostics/tests/test_install_plugin.py b/distributed/diagnostics/tests/test_install_plugin.py index 26e3a6f937..3ebdfefdef 100644 --- a/distributed/diagnostics/tests/test_install_plugin.py +++ b/distributed/diagnostics/tests/test_install_plugin.py @@ -140,6 +140,21 @@ async def test_conda_install_fails_on_returncode(c, s, a, b): assert "install failed" in logs +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_package_install_on_worker(c, s, a): + (addr,) = s.workers + + await c.register_plugin(InstallPlugin(lambda: None, restart_workers=False)) + + +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_package_install_on_nanny(c, s, a): + (addr,) = s.workers + await c.register_plugin(InstallPlugin(lambda: None, restart_workers=False)) + + @pytest.mark.slow @gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) async def test_package_install_restarts_on_nanny(c, s, a): From 5bbceb7283357bd0c7772a4bdd79cc0b1d41ff30 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 21 Aug 2024 05:46:22 -0500 Subject: [PATCH 112/138] Bump `numpy>=1.24` and `pyarrow>=14.0.1` minimum versions (#8837) --- .github/workflows/tests.yaml | 8 ++--- distributed/shuffle/_arrow.py | 28 ++-------------- distributed/shuffle/tests/test_rechunk.py | 5 +-- distributed/shuffle/tests/test_shuffle.py | 39 ++++++----------------- 4 files changed, 17 insertions(+), 63 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 4ded58473c..0e5f82c641 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -69,24 +69,24 @@ jobs: - os: ubuntu-latest environment: mindeps label: numpy - extra_packages: [numpy=1.21] + extra_packages: [numpy=1.24] partition: "ci1" - os: ubuntu-latest environment: mindeps label: numpy - extra_packages: [numpy=1.21] + extra_packages: [numpy=1.24] partition: "not ci1" # dask.dataframe P2P shuffle - os: ubuntu-latest environment: mindeps label: pandas - extra_packages: [numpy=1.21, pandas=2.0, pyarrow=7, pyarrow-hotfix] + extra_packages: [numpy=1.24, pandas=2.0, pyarrow=14.0.1] partition: "ci1" - os: ubuntu-latest environment: mindeps label: pandas - extra_packages: [numpy=1.21, pandas=2.0, pyarrow=7, pyarrow-hotfix] + extra_packages: [numpy=1.24, pandas=2.0, pyarrow=14.0.1] partition: "not ci1" - os: ubuntu-latest diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 71317021ac..973e3414be 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -38,7 +38,7 @@ def check_minimal_arrow_version() -> None: Raises a ModuleNotFoundError if pyarrow is not installed or an ImportError if the installed version is not recent enough. """ - minversion = "7.0.0" + minversion = "14.0.1" try: import pyarrow as pa except ModuleNotFoundError: @@ -52,14 +52,7 @@ def check_minimal_arrow_version() -> None: def concat_tables(tables: Iterable[pa.Table]) -> pa.Table: import pyarrow as pa - if parse(pa.__version__) >= parse("14.0.0"): - return pa.concat_tables(tables, promote_options="permissive") - try: - return pa.concat_tables(tables, promote=True) - except pa.ArrowNotImplementedError as e: - if parse(pa.__version__) >= parse("12.0.0"): - raise e - raise + return pa.concat_tables(tables, promote_options="permissive") def convert_shards( @@ -179,23 +172,8 @@ def read_from_disk(path: Path) -> tuple[list[pa.Table], int]: return shards, size -def concat_arrays(arrays: Iterable[pa.Array]) -> pa.Array: - import pyarrow as pa - - try: - return pa.concat_arrays(arrays) - except pa.ArrowNotImplementedError as e: - if parse(pa.__version__) >= parse("12.0.0"): - raise - if e.args[0].startswith("concatenation of extension"): - raise RuntimeError( - "P2P shuffling requires pyarrow>=12.0.0 to support extension types." - ) from e - raise - - def _copy_table(table: pa.Table) -> pa.Table: import pyarrow as pa - arrs = [concat_arrays(column.chunks) for column in table.columns] + arrs = [pa.concat_arrays(column.chunks) for column in table.columns] return pa.table(data=arrs, schema=table.schema) diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 7fa608bef3..33948c6dce 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -6,7 +6,6 @@ import warnings import pytest -from packaging.version import parse as parse_version np = pytest.importorskip("numpy") da = pytest.importorskip("dask.array") @@ -37,8 +36,6 @@ from distributed.shuffle.tests.utils import AbstractShuffleTestPool from distributed.utils_test import async_poll_for, gen_cluster, gen_test -NUMPY_GE_124 = parse_version(np.__version__) >= parse_version("1.24") - class ArrayRechunkTestPool(AbstractShuffleTestPool): def __init__(self, *args, **kwargs): @@ -175,7 +172,7 @@ async def test_lowlevel_rechunk(tmp_path, n_workers, barrier_first_worker, disk) np.testing.assert_array_equal( concatenate3(old_cs.tolist()), concatenate3(all_chunks.tolist()), - **({"strict": True} if NUMPY_GE_124 else {}), + strict=True, ) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 443595b0bb..0efff85f68 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -15,7 +15,6 @@ from unittest import mock import pytest -from packaging.version import parse as parse_version from tornado.ioloop import IOLoop import dask @@ -73,13 +72,8 @@ try: import pyarrow as pa - - PYARROW_GE_12 = parse_version(pa.__version__).release >= (12,) - PYARROW_GE_14 = parse_version(pa.__version__).release >= (14,) except ImportError: pa = None - PYARROW_GE_12 = False - PYARROW_GE_14 = False @pytest.fixture(params=[0, 0.3, 1], ids=["none", "some", "all"]) @@ -1145,6 +1139,9 @@ def __init__(self, value: int) -> None: ), f"col{next(counter)}": pd.array(["x", "y"] * 50, dtype="category"), f"col{next(counter)}": pd.array(["lorem ipsum"] * 100, dtype="string"), + # Extension types + f"col{next(counter)}": pd.period_range("2022-01-01", periods=100, freq="D"), + f"col{next(counter)}": pd.interval_range(start=0, end=100, freq=1), # FIXME: PyArrow does not support sparse data: # https://issues.apache.org/jira/browse/ARROW-8679 # f"col{next(counter)}": pd.array( @@ -1158,17 +1155,6 @@ def __init__(self, value: int) -> None: # ), } - if PYARROW_GE_12: - columns.update( - { - # Extension types - f"col{next(counter)}": pd.period_range( - "2022-01-01", periods=100, freq="D" - ), - f"col{next(counter)}": pd.interval_range(start=0, end=100, freq=1), - } - ) - columns.update( { # PyArrow dtypes @@ -2502,18 +2488,11 @@ def make_partition(i): with dask.config.set({"dataframe.shuffle.method": "p2p"}): out = ddf.shuffle(on="a", ignore_index=True) - if PYARROW_GE_14: - result, expected = c.compute([ddf, out]) - result = await result - expected = await expected - dd.assert_eq(result, expected) - del result - else: - with raises_with_cause( - RuntimeError, r"shuffling \w+ failed", pa.ArrowInvalid, "incompatible types" - ): - await c.compute(out) - await c.close() + result, expected = c.compute([ddf, out]) + result = await result + expected = await expected + dd.assert_eq(result, expected) + del result del out await assert_worker_cleanup(a) @@ -2536,7 +2515,7 @@ def make_partition(i): with raises_with_cause( RuntimeError, r"(shuffling \w*|shuffle_barrier) failed", - pa.ArrowTypeError if PYARROW_GE_14 else pa.ArrowInvalid, + pa.ArrowTypeError, "incompatible types", ): await c.compute(out) From c073797f8278744891983ac910ca0f5ab40f7df6 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 23 Aug 2024 11:06:42 +0200 Subject: [PATCH 113/138] Remove dump cluster from gen_cluster (#8823) --- distributed/tests/test_utils_test.py | 95 ---------------------------- distributed/utils_test.py | 65 ++----------------- 2 files changed, 6 insertions(+), 154 deletions(-) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index c750e3faa5..6718060f7c 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -2,7 +2,6 @@ import asyncio import logging -import os import pathlib import signal import socket @@ -16,7 +15,6 @@ from unittest import mock import pytest -import yaml from tornado import gen import dask.config @@ -41,7 +39,6 @@ check_process_leak, check_thread_leak, cluster, - dump_cluster_state, ensure_no_new_clients, freeze_batched_send, gen_cluster, @@ -439,40 +436,6 @@ async def ping_pong(): assert await fut == "pong" -@pytest.mark.slow() -def test_dump_cluster_state_timeout(tmp_path): - sleep_time = 30 - - async def inner_test(c, s, a, b): - await asyncio.sleep(sleep_time) - - # This timeout includes cluster startup and teardown which sometimes can - # take a significant amount of time. For this particular test we would like - # to keep the _test timeout_ small because we intend to trigger it but the - # overall timeout large. - test = gen_cluster(client=True, timeout=5, cluster_dump_directory=tmp_path)( - inner_test - ) - try: - with pytest.raises(asyncio.TimeoutError) as exc: - test() - assert "inner_test" in str(exc) - assert "await asyncio.sleep(sleep_time)" in str(exc) - except gen.TimeoutError: - pytest.xfail("Cluster startup or teardown took too long") - - _, dirs, files = next(os.walk(tmp_path)) - assert not dirs - assert files == [inner_test.__name__ + ".yaml"] - import yaml - - with open(tmp_path / files[0], "rb") as fd: - state = yaml.load(fd, Loader=yaml.Loader) - - assert "scheduler" in state - assert "workers" in state - - def test_assert_story(): now = time() story = [ @@ -558,64 +521,6 @@ async def test_assert_story_identity(c, s, a, strict): assert_story(worker_story, scheduler_story, strict=strict) -@gen_cluster() -async def test_dump_cluster_state(s, a, b, tmp_path): - await dump_cluster_state(s, [a, b], str(tmp_path), "dump") - with open(f"{tmp_path}/dump.yaml") as fh: - out = yaml.safe_load(fh) - - assert out.keys() == {"scheduler", "workers", "versions"} - assert out["workers"].keys() == {a.address, b.address} - - -@gen_cluster(nthreads=[]) -async def test_dump_cluster_state_no_workers(s, tmp_path): - await dump_cluster_state(s, [], str(tmp_path), "dump") - with open(f"{tmp_path}/dump.yaml") as fh: - out = yaml.safe_load(fh) - - assert out.keys() == {"scheduler", "workers", "versions"} - assert out["workers"] == {} - - -@gen_cluster(Worker=Nanny) -async def test_dump_cluster_state_nannies(s, a, b, tmp_path): - await dump_cluster_state(s, [a, b], str(tmp_path), "dump") - with open(f"{tmp_path}/dump.yaml") as fh: - out = yaml.safe_load(fh) - - assert out.keys() == {"scheduler", "workers", "versions"} - assert out["workers"].keys() == s.workers.keys() - - -@gen_cluster() -async def test_dump_cluster_state_unresponsive_local_worker(s, a, b, tmp_path): - a.stop() - await dump_cluster_state(s, [a, b], str(tmp_path), "dump") - with open(f"{tmp_path}/dump.yaml") as fh: - out = yaml.safe_load(fh) - - assert out.keys() == {"scheduler", "workers", "versions"} - assert isinstance(out["workers"][a.address], dict) - assert isinstance(out["workers"][b.address], dict) - - -@pytest.mark.slow -@gen_cluster(client=True, Worker=Nanny) -async def test_dump_cluster_unresponsive_remote_worker(c, s, a, b, tmp_path): - await c.run(lambda dask_worker: dask_worker.stop(), workers=[a.worker_address]) - - await dump_cluster_state(s, [a, b], str(tmp_path), "dump") - with open(f"{tmp_path}/dump.yaml") as fh: - out = yaml.safe_load(fh) - - assert out.keys() == {"scheduler", "workers", "versions"} - assert isinstance(out["workers"][b.worker_address], dict) - assert out["workers"][a.worker_address].startswith( - "OSError('Timed out trying to connect to" - ) - - # Note: WINDOWS constant doesn't work with `mypy --platform win32` if sys.platform == "win32": TERM_SIGNALS = (signal.SIGTERM, signal.SIGINT) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 1fd59b5525..05cd9382b0 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -38,7 +38,6 @@ from dask.typing import Key from distributed import Event, Scheduler, system -from distributed import versions as version_module from distributed.batched import BatchedSend from distributed.client import Client, _global_clients, default_client from distributed.comm import Comm @@ -878,7 +877,7 @@ def gen_cluster( clean_kwargs: dict[str, Any] | None = None, # FIXME: distributed#8054 allow_unclosed: bool = True, - cluster_dump_directory: str | Literal[False] = "test_cluster_dump", + cluster_dump_directory: str | Literal[False] = False, ) -> Callable[[Callable], Callable]: from distributed import Client @@ -901,6 +900,11 @@ async def test_foo(scheduler, worker1, worker2, pytest_fixture_a, pytest_fixture start end """ + if cluster_dump_directory: + warnings.warn( + "The `cluster_dump_directory` argument is being ignored and will be removed in a future version.", + DeprecationWarning, + ) if nthreads is None: nthreads = [ ("127.0.0.1", 1), @@ -1019,14 +1023,6 @@ async def async_fn(): # This stack indicates where the coro/test is suspended task.print_stack(file=buffer) - if cluster_dump_directory: - await dump_cluster_state( - s=s, - ws=workers, - output_dir=cluster_dump_directory, - func_name=func.__name__, - ) - task.cancel() while not task.cancelled(): await asyncio.sleep(0.01) @@ -1048,18 +1044,6 @@ async def async_fn(): except pytest.xfail.Exception: raise - except Exception: - if cluster_dump_directory and not has_pytestmark( - test_func, "xfail" - ): - await dump_cluster_state( - s=s, - ws=workers, - output_dir=cluster_dump_directory, - func_name=func.__name__, - ) - raise - try: c = default_client() except ValueError: @@ -1122,41 +1106,6 @@ async def async_fn_outer(): return _ -async def dump_cluster_state( - s: Scheduler, ws: list[ServerNode], output_dir: str, func_name: str -) -> None: - """A variant of Client.dump_cluster_state, which does not rely on any of the below - to work: - - - Having a client at all - - Client->Scheduler comms - - Scheduler->Worker comms (unless using Nannies) - """ - scheduler_info = s._to_dict() - workers_info: dict[str, Any] - versions_info = version_module.get_versions() - - if not ws or isinstance(ws[0], Worker): - workers_info = {w.address: w._to_dict() for w in ws} - else: - workers_info = await s.broadcast(msg={"op": "dump_state"}, on_error="return") - workers_info = { - k: repr(v) if isinstance(v, Exception) else v - for k, v in workers_info.items() - } - - state = { - "scheduler": scheduler_info, - "workers": workers_info, - "versions": versions_info, - } - os.makedirs(output_dir, exist_ok=True) - fname = os.path.join(output_dir, func_name) + ".yaml" - with open(fname, "w") as fh: - yaml.safe_dump(state, fh) # Automatically convert tuples to lists - print(f"Dumped cluster state to {fname}") - - def validate_state(*servers: Scheduler | Worker | Nanny) -> None: """Run validate_state() on the Scheduler and all the Workers of the cluster. Excludes workers wrapped by Nannies and workers manually started by the test. @@ -1505,8 +1454,6 @@ def new_config_file(c: dict[str, Any]) -> Iterator[None]: """ Temporarily change configuration file to match dictionary *c*. """ - import yaml - old_file = os.environ.get("DASK_CONFIG") fd, path = tempfile.mkstemp(prefix="dask-config") with os.fdopen(fd, "w") as f: From ea7d35c25ee05b03962fb1719a1fa881dbdc09f2 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 23 Aug 2024 14:49:37 +0200 Subject: [PATCH 114/138] Concatenate small input chunks before P2P rechunking (#8832) Co-authored-by: Patrick Hoefler <61934744+phofl@users.noreply.github.com> --- distributed/shuffle/_rechunk.py | 173 ++++++++++++++++++++-- distributed/shuffle/tests/test_rechunk.py | 97 +++++++++++- 2 files changed, 248 insertions(+), 22 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 962b57fb32..0f0dfbf21f 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -96,6 +96,7 @@ from __future__ import annotations +import math import mmap import os from collections import defaultdict @@ -111,7 +112,7 @@ ) from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from itertools import product +from itertools import chain, product from pathlib import Path from typing import TYPE_CHECKING, Any, NamedTuple, cast @@ -124,6 +125,7 @@ from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer from dask.typing import Key +from dask.utils import parse_bytes from distributed.core import PooledRPCCall from distributed.metrics import context_meter @@ -220,7 +222,7 @@ def rechunk_p2p( return da.empty(x.shape, chunks=chunks, dtype=x.dtype) from dask.array.core import new_da_object - prechunked = _calculate_prechunking(x.chunks, chunks) + prechunked = _calculate_prechunking(x.chunks, chunks, x.dtype, block_size_limit) if prechunked != x.chunks: x = cast( "da.Array", @@ -433,8 +435,140 @@ def _construct_graph(self) -> _T_LowLevelGraph: def _calculate_prechunking( - old_chunks: ChunkedAxes, new_chunks: ChunkedAxes + old_chunks: ChunkedAxes, + new_chunks: ChunkedAxes, + dtype: np.dtype, + block_size_limit: int | None, +) -> ChunkedAxes: + """Calculate how to perform the pre-rechunking step + + During the pre-rechunking step, we + 1. Split input chunks along partial boundaries to make partials completely independent of one another + 2. Merge small chunks within partials to reduce the number of transfer tasks and corresponding overhead + """ + split_axes = _split_chunks_along_partial_boundaries(old_chunks, new_chunks) + + # We can only determine how to concatenate chunks if we can calculate block sizes. + has_nans = (any(math.isnan(y) for y in x) for x in old_chunks) + + if len(new_chunks) <= 1 or not all(new_chunks) or any(has_nans): + return tuple(tuple(chain(*axis)) for axis in split_axes) + + if dtype is None or dtype.hasobject or dtype.itemsize == 0: + return tuple(tuple(chain(*axis)) for axis in split_axes) + + # We made sure that there are no NaNs in split_axes above + return _concatenate_small_chunks( + split_axes, old_chunks, new_chunks, dtype, block_size_limit # type: ignore[arg-type] + ) + + +def _concatenate_small_chunks( + split_axes: list[list[list[int]]], + old_chunks: ChunkedAxes, + new_chunks: ChunkedAxes, + dtype: np.dtype, + block_size_limit: int | None, ) -> ChunkedAxes: + """Concatenate small chunks within partials. + + By concatenating chunks within partials, we reduce the number of P2P transfer tasks and their + corresponding overhead. + + The algorithm used in this function is very similar to :func:`dask.array.rechunk.find_merge_rechunk`, + the main difference is that we have to make sure only to merge chunks within partials. + """ + import numpy as np + + block_size_limit = block_size_limit or dask.config.get("array.chunk-size") + + if isinstance(block_size_limit, str): + block_size_limit = parse_bytes(block_size_limit) + + # Make it a number of elements + block_size_limit //= dtype.itemsize + + # We verified earlier that we do not have any NaNs + largest_old_block = _largest_block_size(old_chunks) # type: ignore[arg-type] + largest_new_block = _largest_block_size(new_chunks) # type: ignore[arg-type] + block_size_limit = max([block_size_limit, largest_old_block, largest_new_block]) + + old_largest_width = [max(chain(*axis)) for axis in split_axes] + new_largest_width = [max(c) for c in new_chunks] + + # This represents how much each dimension increases (>1) or reduces (<1) + # the graph size during rechunking + graph_size_effect = { + dim: len(new_axis) / sum(map(len, split_axis)) + for dim, (split_axis, new_axis) in enumerate(zip(split_axes, new_chunks)) + } + + ndim = len(old_chunks) + + # This represents how much each dimension increases (>1) or reduces (<1) the + # largest block size during rechunking + block_size_effect = { + dim: new_largest_width[dim] / (old_largest_width[dim] or 1) + for dim in range(ndim) + } + + # Our goal is to reduce the number of nodes in the rechunk graph + # by concatenating some adjacent chunks, so consider dimensions where we can + # reduce the # of chunks + candidates = [dim for dim in range(ndim) if graph_size_effect[dim] <= 1.0] + + # Concatenating along each dimension reduces the graph size by a certain factor + # and increases memory largest block size by a certain factor. + # We want to optimize the graph size while staying below the given + # block_size_limit. This is in effect a knapsack problem, except with + # multiplicative values and weights. Just use a greedy algorithm + # by trying dimensions in decreasing value / weight order. + def key(k: int) -> float: + gse = graph_size_effect[k] + bse = block_size_effect[k] + if bse == 1: + bse = 1 + 1e-9 + return (np.log(gse) / np.log(bse)) if bse > 0 else 0 + + sorted_candidates = sorted(candidates, key=key) + + concatenated_axes: list[list[int]] = [[] for i in range(ndim)] + + # Sim all the axes that are no candidates + for i in range(ndim): + if i in candidates: + continue + concatenated_axes[i] = list(chain(*split_axes[i])) + + # We want to concatenate chunks + for axis_index in sorted_candidates: + concatenated_axis = concatenated_axes[axis_index] + multiplier = math.prod( + old_largest_width[:axis_index] + old_largest_width[axis_index + 1 :] + ) + axis_limit = block_size_limit // multiplier + + for partial in split_axes[axis_index]: + current = partial[0] + for chunk in partial[1:]: + if (current + chunk) > axis_limit: + concatenated_axis.append(current) + current = chunk + else: + current += chunk + concatenated_axis.append(current) + old_largest_width[axis_index] = max(concatenated_axis) + return tuple(tuple(axis) for axis in concatenated_axes) + + +def _split_chunks_along_partial_boundaries( + old_chunks: ChunkedAxes, new_chunks: ChunkedAxes +) -> list[list[list[float]]]: + """Split the old chunks along the boundaries of partials, i.e., groups of new chunks that share the same inputs. + + By splitting along the boundaries before rechunkin their input tasks become disjunct and each partial conceptually + operates on an independent sub-array. + """ from dask.array.rechunk import old_to_new _old_to_new = old_to_new(old_chunks, new_chunks) @@ -443,10 +577,13 @@ def _calculate_prechunking( split_axes = [] + # Along each axis, we want to figure out how we have to split input chunks in order to make + # partials disjunct. We then group the resulting input chunks per partial before returning. for axis_index, slices in enumerate(partials): old_to_new_axis = _old_to_new[axis_index] old_axis = old_chunks[axis_index] split_axis = [] + partial_chunks = [] for slice_ in slices: first_new_chunk = slice_.start first_old_chunk, first_old_slice = old_to_new_axis[first_new_chunk][0] @@ -465,22 +602,28 @@ def _calculate_prechunking( chunk_size = last_old_slice.stop if first_old_slice.start != 0: chunk_size -= first_old_slice.start - split_axis.append(chunk_size) - continue - - split_axis.append(first_chunk_size - first_old_slice.start) - - split_axis.extend(old_axis[first_old_chunk + 1 : last_old_chunk]) - - if last_old_slice.stop is not None: - chunk_size = last_old_slice.stop + partial_chunks.append(chunk_size) else: - chunk_size = last_chunk_size + partial_chunks.append(first_chunk_size - first_old_slice.start) - split_axis.append(chunk_size) + partial_chunks.extend(old_axis[first_old_chunk + 1 : last_old_chunk]) + if last_old_slice.stop is not None: + chunk_size = last_old_slice.stop + else: + chunk_size = last_chunk_size + + partial_chunks.append(chunk_size) + split_axis.append(partial_chunks) + partial_chunks = [] + if partial_chunks: + split_axis.append(partial_chunks) split_axes.append(split_axis) - return tuple(tuple(axis) for axis in split_axes) + return split_axes + + +def _largest_block_size(chunks: tuple[tuple[int, ...], ...]) -> int: + return math.prod(map(max, chunks)) def _split_partials( diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 33948c6dce..6d46ff7228 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -847,7 +847,8 @@ async def test_rechunk_avoid_needless_chunking(c, s, *ws): x = da.ones(16, chunks=2) y = x.rechunk(8, method="p2p") dsk = y.__dask_graph__() - assert len(dsk) <= 8 + 2 + # 8 inputs, 2 concatenations of small inputs, 2 outputs + assert len(dsk) <= 8 + 2 + 2 @pytest.mark.parametrize( @@ -1337,7 +1338,7 @@ async def test_partial_rechunk_taskgroups(c, s): ), timeout=5, ) - assert len(s.task_groups) < 6 + assert len(s.task_groups) < 7 @pytest.mark.parametrize( @@ -1351,7 +1352,7 @@ async def test_partial_rechunk_taskgroups(c, s): ], ) def test_calculate_prechunking_1d(old, new, expected): - actual = _calculate_prechunking(old, new) + actual = _calculate_prechunking(old, new, np.dtype, None) assert actual == expected @@ -1359,17 +1360,99 @@ def test_calculate_prechunking_1d(old, new, expected): ["old", "new", "expected"], [ [((2, 2), (3, 3)), ((2, 2), (3, 3)), ((2, 2), (3, 3))], - [((2, 2), (3, 3)), ((4,), (3, 3)), ((2, 2), (3, 3))], + [((2, 2), (3, 3)), ((4,), (3, 3)), ((4,), (3, 3))], [((2, 2), (3, 3)), ((1, 1, 1, 1), (3, 3)), ((2, 2), (3, 3))], [ ((2, 2, 2), (3, 3, 3)), ((1, 2, 2, 1), (2, 3, 4)), - ((1, 1, 1, 1, 1, 1), (2, 1, 2, 1, 3)), + ((1, 2, 2, 1), (2, 3, 4)), ], [((1, np.nan), (3, 3)), ((1, np.nan), (2, 2, 2)), ((1, np.nan), (2, 1, 1, 2))], - [((4,), (1, 1, 1)), ((1, 1, 1, 1), (3,)), ((4,), (1, 1, 1))], + [((4,), (1, 1, 1)), ((1, 1, 1, 1), (3,)), ((4,), (3,))], ], ) def test_calculate_prechunking_2d(old, new, expected): - actual = _calculate_prechunking(old, new) + actual = _calculate_prechunking(old, new, np.dtype(np.int16), None) + assert actual == expected + + +@pytest.mark.parametrize( + ["old", "new", "expected"], + [ + ( + ((2, 2), (1, 1, 1, 1), (1, 1, 1, 1)), + ((1, 1, 1, 1), (4,), (2, 2)), + ((2, 2), (4,), (1, 1, 1, 1)), + ), + ( + ((2, 2), (1, 1, 1, 1), (1, 1, 1, 1)), + ((1, 1, 1, 1), (2, 2), (2, 2)), + ((2, 2), (2, 2), (2, 2)), + ), + ( + ((2, 2), (1, 1, 1, 1), (1, 1, 1, 1)), + ((1, 1, 1, 1), (2, 2), (4,)), + ((2, 2), (2, 2), (2, 2)), + ), + ( + ((1, 1, 1, 1), (1, 1, 1, 1), (2, 2)), + ((2, 2), (4,), (1, 1, 1, 1)), + ((2, 2), (2, 2), (2, 2)), + ), + ], +) +def test_calculate_prechunking_3d(old, new, expected): + with dask.config.set({"array.chunk-size": "16 B"}): + actual = _calculate_prechunking(old, new, np.dtype(np.int16), None) + assert actual == expected + + +@pytest.mark.parametrize( + ["chunk_size", "expected"], + [ + ("1 B", ((10,), (1,) * 10)), + ("20 B", ((10,), (1,) * 10)), + ("40 B", ((10,), (2, 2, 1, 2, 2, 1))), + ("100 B", ((10,), (5, 5))), + ], +) +def test_calculate_prechunking_concatenation(chunk_size, expected): + old = ((10,), (1,) * 10) + new = ((2,) * 5, (5, 5)) + with dask.config.set({"array.chunk-size": chunk_size}): + actual = _calculate_prechunking(old, new, np.dtype(np.int16), None) + assert actual == expected + + +def test_calculate_prechunking_does_not_concatenate_object_type(): + old = ((10,), (1,) * 10) + new = ((2,) * 5, (5, 5)) + + # Ensure that int dtypes get concatenated + new = ((2,) * 5, (5, 5)) + with dask.config.set({"array.chunk-size": "100 B"}): + actual = _calculate_prechunking(old, new, np.dtype(np.int16), None) + assert actual == ((10,), (5, 5)) + + # Ensure object dtype chunks do not get concatenated + with dask.config.set({"array.chunk-size": "100 B"}): + actual = _calculate_prechunking(old, new, np.dtype(object), None) + assert actual == old + + +@pytest.mark.parametrize( + ["old", "new", "expected"], + [ + [((2, 2), (3, 3)), ((4,), (3, 3)), ((2, 2), (3, 3))], + [ + ((2, 2, 2), (3, 3, 3)), + ((1, 2, 2, 1), (2, 3, 4)), + ((1, 1, 1, 1, 1, 1), (2, 1, 2, 1, 3)), + ], + [((4,), (1, 1, 1)), ((1, 1, 1, 1), (3,)), ((4,), (1, 1, 1))], + ], +) +def test_calculate_prechunking_splitting(old, new, expected): + # _calculate_prechunking does not concatenate on object + actual = _calculate_prechunking(old, new, np.dtype(object), None) assert actual == expected From 277b1f9c866eefe9e16215cf49a72bded584d85c Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Mon, 26 Aug 2024 17:29:20 +0200 Subject: [PATCH 115/138] Bump test_pause_while_idle timeout (#8844) --- distributed/tests/test_worker_memory.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index 0994816d2e..4a222861d8 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -1097,12 +1097,12 @@ async def test_pause_while_idle(s, a, b): assert sa in s.running a.monitor.get_process_memory = lambda: 2**40 - await async_poll_for(lambda: sa.status == Status.paused, timeout=2) + await async_poll_for(lambda: sa.status == Status.paused, timeout=5) assert a.address not in s.idle assert sa not in s.running a.monitor.get_process_memory = lambda: 0 - await async_poll_for(lambda: sa.status == Status.running, timeout=2) + await async_poll_for(lambda: sa.status == Status.running, timeout=5) assert a.address in s.idle assert sa in s.running @@ -1112,17 +1112,17 @@ async def test_pause_while_saturated(c, s, a, b): sa = s.workers[a.address] ev = Event() futs = c.map(lambda i, ev: ev.wait(), range(3), ev=ev, workers=[a.address]) - await async_poll_for(lambda: len(a.state.tasks) == 3, timeout=2) + await async_poll_for(lambda: len(a.state.tasks) == 3, timeout=5) assert sa in s.saturated assert sa in s.running a.monitor.get_process_memory = lambda: 2**40 - await async_poll_for(lambda: sa.status == Status.paused, timeout=2) + await async_poll_for(lambda: sa.status == Status.paused, timeout=5) assert sa not in s.saturated assert sa not in s.running a.monitor.get_process_memory = lambda: 0 - await async_poll_for(lambda: sa.status == Status.running, timeout=2) + await async_poll_for(lambda: sa.status == Status.running, timeout=5) assert sa in s.saturated assert sa in s.running From ccdf9ea93fdfe8df489674133ba59b34682f3d16 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Tue, 27 Aug 2024 16:32:31 -0400 Subject: [PATCH 116/138] Increase visibility of GPU CI updates (#8841) --- .github/workflows/update-gpuci.yaml | 2 +- CODEOWNERS | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/update-gpuci.yaml b/.github/workflows/update-gpuci.yaml index 36912c1d8c..90c25c16a5 100644 --- a/.github/workflows/update-gpuci.yaml +++ b/.github/workflows/update-gpuci.yaml @@ -58,7 +58,7 @@ jobs: if: ${{ env.UCX_PY_VER != env.NEW_UCX_PY_VER }} # make sure new ucx-py nightlies are available with: token: ${{ secrets.GITHUB_TOKEN }} - draft: true + draft: false commit-message: "Update gpuCI `RAPIDS_VER` to `${{ env.NEW_RAPIDS_VER }}`" title: "Update gpuCI `RAPIDS_VER` to `${{ env.NEW_RAPIDS_VER }}`" author: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> diff --git a/CODEOWNERS b/CODEOWNERS index 38c8a480a8..08ebf86238 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -12,3 +12,4 @@ distributed/widgets/* @jacobtomlinson # GPU Support distributed/diagnostics/nvml.py @jacobtomlinson @quasiben +continuous_integration/gpuci/* @jacobtomlinson @quasiben From 4aeed40a783813a342a236fdd2a19bff2369c19c Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 28 Aug 2024 12:21:11 +0200 Subject: [PATCH 117/138] Add tests for choosing default rechunking method (#8843) * Add tests for choosing default method * Improve test case * Trigger CI --- distributed/shuffle/tests/test_rechunk.py | 32 +++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 6d46ff7228..803a33fe6c 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -193,6 +193,38 @@ async def test_rechunk_configuration(c, s, *ws, config_value, keyword): with dask.config.set(config): x2 = rechunk(x, chunks=new, method=keyword) expected_algorithm = keyword if keyword is not None else config_value + if expected_algorithm == "p2p": + assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()) + elif expected_algorithm == "tasks": + assert not any( + key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__() + ) + # Neither is specified, so we choose the best one (see test_rechunk_heuristic for a full test of the heuristic) + else: + assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()) + + assert x2.chunks == new + assert np.all(await c.compute(x2) == a) + + +@pytest.mark.parametrize( + ["new", "expected_algorithm"], + [ + # All-to-all rechunking defaults to P2P + (((1,) * 100, (100,)), "p2p"), + # Localized rechunking defaults to tasks + (((50, 50), (2,) * 50), "tasks"), + # Less local rechunking first defaults to tasks, + (((25, 25, 25, 25), (4,) * 25), "tasks"), + # then switches to p2p + (((10,) * 10, (10,) * 10), "p2p"), + ], +) +@gen_cluster(client=True) +async def test_rechunk_heuristic(c, s, a, b, new, expected_algorithm): + a = np.random.default_rng().uniform(0, 1, 10000).reshape((100, 100)) + x = da.from_array(a, chunks=(100, 1)) + x2 = rechunk(x, chunks=new) if expected_algorithm == "p2p": assert all(key[0][0].startswith("rechunk-p2p") for key in x2.__dask_keys__()) else: From e19c6306af41ccca8c4a934d6240943a52a1940b Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 30 Aug 2024 10:56:16 +0200 Subject: [PATCH 118/138] Reduce memory footprint of culling P2P rechunking (#8845) --- distributed/shuffle/_rechunk.py | 55 +++++++++++++++++---------------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 0f0dfbf21f..efae7d80b7 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -370,28 +370,31 @@ def cull( indices_to_keep = self._keys_to_indices(keys) _old_to_new = old_to_new(self.chunks_input, self.chunks) - culled_deps: defaultdict[Key, set[Key]] = defaultdict(set) - for nindex in indices_to_keep: - old_indices_per_axis = [] - keepmap[nindex] = True - for index, new_axis in zip(nindex, _old_to_new): - old_indices_per_axis.append( - [old_chunk_index for old_chunk_index, _ in new_axis[index]] - ) - for old_nindex in product(*old_indices_per_axis): - culled_deps[(self.name,) + nindex].add((self.name_input,) + old_nindex) + for ndindex in indices_to_keep: + keepmap[ndindex] = True - # Protect against mutations later on with frozenset - frozen_deps = { - output_task: frozenset(input_tasks) - for output_task, input_tasks in culled_deps.items() - } + culled_deps = {} + # Identify the individual partial rechunks + for ndpartial in _split_partials(_old_to_new): + # Cull partials for which we do not keep any output tasks + if not np.any(keepmap[ndpartial.new]): + continue + + # Within partials, we have all-to-all communication. + # Thus, all output tasks share the same input tasks. + deps = frozenset( + (self.name_input,) + ndindex + for ndindex in _ndindices_of_slice(ndpartial.old) + ) + + for ndindex in _ndindices_of_slice(ndpartial.new): + culled_deps[(self.name,) + ndindex] = deps if np.array_equal(keepmap, self.keepmap): - return self, frozen_deps + return self, culled_deps else: culled_layer = self._cull(keepmap) - return culled_layer, frozen_deps + return culled_layer, culled_deps def _construct_graph(self) -> _T_LowLevelGraph: import numpy as np @@ -695,14 +698,12 @@ def _slice_new_chunks_into_partials( return tuple(sliced_axes) -def _partial_ndindex(ndslice: NDSlice) -> np.ndindex: - import numpy as np - - return np.ndindex(tuple(slice.stop - slice.start for slice in ndslice)) +def _ndindices_of_slice(ndslice: NDSlice) -> Iterator[NDIndex]: + return product(*(range(slc.start, slc.stop) for slc in ndslice)) -def _global_index(partial_index: NDIndex, partial_offset: NDIndex) -> NDIndex: - return tuple(index + offset for index, offset in zip(partial_index, partial_offset)) +def _partial_index(global_index: NDIndex, partial_offset: NDIndex) -> NDIndex: + return tuple(index - offset for index, offset in zip(global_index, partial_offset)) def partial_concatenate( @@ -802,8 +803,8 @@ def partial_rechunk( ) transfer_keys = [] - for partial_index in _partial_ndindex(ndpartial.old): - global_index = _global_index(partial_index, old_partial_offset) + for global_index in _ndindices_of_slice(ndpartial.old): + partial_index = _partial_index(global_index, old_partial_offset) input_key = (input_name,) + global_index @@ -822,8 +823,8 @@ def partial_rechunk( dsk[_barrier_key] = (shuffle_barrier, partial_token, transfer_keys) new_partial_offset = tuple(axis.start for axis in ndpartial.new) - for partial_index in _partial_ndindex(ndpartial.new): - global_index = _global_index(partial_index, new_partial_offset) + for global_index in _ndindices_of_slice(ndpartial.new): + partial_index = _partial_index(global_index, new_partial_offset) if keepmap[global_index]: dsk[(unpack_group,) + global_index] = ( rechunk_unpack, From ac7a24e8e821cee3cb77a726b8f39e6bf1f926c1 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 30 Aug 2024 17:21:10 +0200 Subject: [PATCH 119/138] Avoid capturing code of xdist (#8846) --- distributed/distributed.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 250af10f7d..1e7505a116 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -290,6 +290,8 @@ distributed: - xarray - xgboost - xdist + - __channelexec__ # more xdist + - execnet # more xdist ignore-files: - runpy\.py # `python -m pytest` (or other module) shell command - pytest # `pytest` shell command From 4b65be04518a4d166c386ffd1cc127a08275c5d4 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 30 Aug 2024 15:39:05 -0500 Subject: [PATCH 120/138] bump version to 2024.8.2 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8ac18b1c5b..15b878554d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ requires-python = ">=3.10" dependencies = [ "click >= 8.0", "cloudpickle >= 3.0.0", - "dask == 2024.8.1", + "dask == 2024.8.2", "jinja2 >= 2.10.3", "locket >= 1.0.0", "msgpack >= 1.0.2", From 2e61816b68296e823b69a83b6845e512b80bc524 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Mon, 2 Sep 2024 16:00:29 +0200 Subject: [PATCH 121/138] Update precommit (#8852) --- .flake8 | 2 + .pre-commit-config.yaml | 12 ++-- continuous_integration/scripts/host_info.py | 3 +- .../scripts/parse_stdout.py | 1 + distributed/_concurrent_futures_thread.py | 6 +- distributed/active_memory_manager.py | 13 ++-- distributed/actor.py | 6 +- distributed/broker.py | 6 +- distributed/cfexecutor.py | 2 +- distributed/cli/dask_worker.py | 8 ++- distributed/cli/tests/test_dask_worker.py | 5 +- distributed/client.py | 21 +++--- distributed/comm/registry.py | 3 +- distributed/comm/ucx.py | 5 +- distributed/dashboard/components/scheduler.py | 8 ++- distributed/deploy/ssh.py | 30 +++++---- distributed/deploy/tests/test_adaptive.py | 36 +++++----- distributed/deploy/tests/test_cluster.py | 19 ++++-- distributed/deploy/tests/test_local.py | 66 +++++++++++-------- distributed/deploy/tests/test_spec_cluster.py | 2 +- distributed/deploy/tests/test_subprocess.py | 7 +- distributed/diagnostics/memray.py | 14 ++-- distributed/diagnostics/plugin.py | 24 +++---- distributed/diagnostics/tests/test_nvml.py | 7 +- distributed/metrics.py | 4 +- distributed/objects.py | 1 + distributed/profile.py | 1 + distributed/protocol/compression.py | 1 + distributed/protocol/cupy.py | 1 + distributed/protocol/scipy.py | 1 + distributed/protocol/utils.py | 6 +- distributed/pytest_resourceleaks.py | 1 + distributed/scheduler.py | 42 ++++++------ distributed/semaphore.py | 6 +- distributed/shuffle/_buffer.py | 3 +- distributed/shuffle/_core.py | 2 +- distributed/shuffle/_disk.py | 4 +- distributed/shuffle/_rechunk.py | 4 +- distributed/shuffle/_shuffle.py | 8 +-- distributed/shuffle/_worker_plugin.py | 6 +- distributed/shuffle/tests/test_graph.py | 30 +++++---- distributed/shuffle/tests/test_merge.py | 7 +- distributed/tests/test_client.py | 5 +- distributed/tests/test_computations.py | 1 + distributed/tests/test_core.py | 33 ++++------ distributed/tests/test_event_logging.py | 14 ++-- distributed/tests/test_failed_workers.py | 7 +- distributed/tests/test_preload.py | 37 ++++++----- distributed/tests/test_publish.py | 21 +++--- distributed/tests/test_reschedule.py | 1 + distributed/tests/test_scheduler.py | 20 +++--- distributed/tests/test_tls_functional.py | 1 + distributed/tests/test_worker.py | 19 +++--- distributed/threadpoolexecutor.py | 2 + distributed/utils.py | 10 ++- distributed/utils_comm.py | 6 +- distributed/utils_test.py | 13 ++-- distributed/variable.py | 6 +- distributed/versions.py | 13 ++-- distributed/worker.py | 8 +-- distributed/worker_memory.py | 27 +++++--- docs/source/scheduling-policies.rst | 2 +- 62 files changed, 374 insertions(+), 306 deletions(-) diff --git a/.flake8 b/.flake8 index ec3c89d42d..c4750707d5 100644 --- a/.flake8 +++ b/.flake8 @@ -22,6 +22,8 @@ ignore = B028 # do not compare types, for exact checks use `is` / `is not`, for instance checks use `isinstance()` E721 + # multiple statements on one line; required for black compat + E701, E704 per-file-ignores = **/tests/*: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6ef75712a2..dcc5c01a3d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,29 +12,29 @@ repos: - id: isort language_version: python3 - repo: https://github.com/asottile/pyupgrade - rev: v3.15.0 + rev: v3.17.0 hooks: - id: pyupgrade args: - --py39-plus - repo: https://github.com/psf/black - rev: 23.12.1 + rev: 24.8.0 hooks: - id: black language_version: python3 args: - --target-version=py39 - repo: https://github.com/pycqa/flake8 - rev: 7.0.0 + rev: 7.1.1 hooks: - id: flake8 language_version: python3 additional_dependencies: # NOTE: autoupdate does not pick up flake8-bugbear since it is a transitive # dependency. Make sure to update flake8-bugbear manually on a regular basis. - - flake8-bugbear==23.12.2 + - flake8-bugbear==24.8.19 - repo: https://github.com/codespell-project/codespell - rev: v2.2.6 + rev: v2.3.0 hooks: - id: codespell additional_dependencies: @@ -42,7 +42,7 @@ repos: types_or: [rst, markdown] files: docs - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 + rev: v1.11.2 hooks: - id: mypy # Override default --ignore-missing-imports diff --git a/continuous_integration/scripts/host_info.py b/continuous_integration/scripts/host_info.py index e49c3f864d..f876d28ef7 100644 --- a/continuous_integration/scripts/host_info.py +++ b/continuous_integration/scripts/host_info.py @@ -38,8 +38,7 @@ def main() -> None: else: print("CPU frequency:") for freq in freqs: - # FIXME types-psutil - print(f" - current={freq.current}, min={freq.min}, max={freq.max}") # type: ignore + print(f" - current={freq.current}, min={freq.min}, max={freq.max}") mem = psutil.virtual_memory() print("Memory:") diff --git a/continuous_integration/scripts/parse_stdout.py b/continuous_integration/scripts/parse_stdout.py index e9b53b193c..59becdf5df 100644 --- a/continuous_integration/scripts/parse_stdout.py +++ b/continuous_integration/scripts/parse_stdout.py @@ -1,6 +1,7 @@ """On Windows, pytest-timeout kills off the whole test suite, leaving no junit report behind. Parse the stdout of pytest to generate one. """ + from __future__ import annotations import html diff --git a/distributed/_concurrent_futures_thread.py b/distributed/_concurrent_futures_thread.py index 86de8a2a42..d3c05251b9 100644 --- a/distributed/_concurrent_futures_thread.py +++ b/distributed/_concurrent_futures_thread.py @@ -31,9 +31,9 @@ # workers to exit when their work queues are empty and then waits until the # threads finish. -_threads_queues: weakref.WeakKeyDictionary[ - threading.Thread, queue.Queue -] = weakref.WeakKeyDictionary() +_threads_queues: weakref.WeakKeyDictionary[threading.Thread, queue.Queue] = ( + weakref.WeakKeyDictionary() +) _shutdown = False diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 724bfc1892..4f7c472127 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -4,6 +4,7 @@ See also :mod:`distributed.worker_memory` and :mod:`distributed.spill`, which implement spill/pause/terminate mechanics on the Worker side. """ + from __future__ import annotations import abc @@ -392,12 +393,12 @@ def _enact_suggestions(self) -> None: logger.debug("Enacting suggestions for %d tasks:", len(self.pending)) validate = self.scheduler.validate - drop_by_worker: ( - defaultdict[scheduler_module.WorkerState, list[Key]] - ) = defaultdict(list) - repl_by_worker: ( - defaultdict[scheduler_module.WorkerState, list[Key]] - ) = defaultdict(list) + drop_by_worker: defaultdict[scheduler_module.WorkerState, list[Key]] = ( + defaultdict(list) + ) + repl_by_worker: defaultdict[scheduler_module.WorkerState, list[Key]] = ( + defaultdict(list) + ) for ts, (pending_repl, pending_drop) in self.pending.items(): if not ts.who_has: diff --git a/distributed/actor.py b/distributed/actor.py index d2dea1848e..0af83daf63 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -245,12 +245,10 @@ class BaseActorFuture(abc.ABC, Awaitable[_T]): """ @abc.abstractmethod - def result(self, timeout: str | timedelta | float | None = None) -> _T: - ... + def result(self, timeout: str | timedelta | float | None = None) -> _T: ... @abc.abstractmethod - def done(self) -> bool: - ... + def done(self) -> bool: ... def __repr__(self) -> Literal[""]: return "" diff --git a/distributed/broker.py b/distributed/broker.py index 298225eb85..b8df96620a 100644 --- a/distributed/broker.py +++ b/distributed/broker.py @@ -84,14 +84,12 @@ def _send_to_subscribers(self, topic: str, event: Any) -> None: self._scheduler.send_all(client_msgs, worker_msgs={}) @overload - def get_events(self, topic: str) -> tuple[tuple[float, Any], ...]: - ... + def get_events(self, topic: str) -> tuple[tuple[float, Any], ...]: ... @overload def get_events( self, topic: None = None - ) -> dict[str, tuple[tuple[float, Any], ...]]: - ... + ) -> dict[str, tuple[tuple[float, Any], ...]]: ... def get_events( self, topic: str | None = None diff --git a/distributed/cfexecutor.py b/distributed/cfexecutor.py index 1eadfcac19..13998708cc 100644 --- a/distributed/cfexecutor.py +++ b/distributed/cfexecutor.py @@ -30,7 +30,7 @@ def _cascade_future(future, cf_future): try: typ, exc, tb = result raise exc.with_traceback(tb) - except BaseException as exc: + except BaseException as exc: # noqa: B036 cf_future.set_exception(exc) diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index 8ee87d0fee..bf0f96420d 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -406,9 +406,11 @@ async def run(): host=host, dashboard=dashboard, dashboard_address=dashboard_address, - name=name - if n_workers == 1 or name is None or name == "" - else str(name) + "-" + str(i), + name=( + name + if n_workers == 1 or name is None or name == "" + else str(name) + "-" + str(i) + ), **kwargs, **port_kwargs_i, ) diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index c0e86dd67c..847c28ea4b 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -610,8 +610,9 @@ async def test_set_lifetime_stagger_via_env_var(c, s): env = os.environ.copy() env["DASK_DISTRIBUTED__WORKER__LIFETIME__DURATION"] = "10 seconds" env["DASK_DISTRIBUTED__WORKER__LIFETIME__STAGGER"] = "2 seconds" - with popen(["dask", "worker", s.address], env=env), popen( - ["dask", "worker", s.address], env=env + with ( + popen(["dask", "worker", s.address], env=env), + popen(["dask", "worker", s.address], env=env), ): await c.wait_for_workers(2) [lifetime1, lifetime2] = ( diff --git a/distributed/client.py b/distributed/client.py index ebe4299d1a..152540e9bf 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -133,9 +133,9 @@ logger = logging.getLogger(__name__) -_global_clients: weakref.WeakValueDictionary[ - int, Client -] = weakref.WeakValueDictionary() +_global_clients: weakref.WeakValueDictionary[int, Client] = ( + weakref.WeakValueDictionary() +) _global_client_index = [0] _current_client: ContextVar[Client | None] = ContextVar("_current_client", default=None) @@ -483,6 +483,7 @@ def execute_callback(fut): fn(fut) except BaseException: logger.exception("Error in callback %s of %s:", fn, fut) + raise self.client.loop.add_callback( done_callback, self, partial(cls._cb_executor.submit, execute_callback) @@ -3873,13 +3874,13 @@ async def _restart_workers( name_to_addr = {meta["name"]: addr for addr, meta in info["workers"].items()} worker_addrs = [name_to_addr.get(w, w) for w in workers] - out: dict[ - str, Literal["OK", "removed", "timed out"] - ] = await self.scheduler.restart_workers( - workers=worker_addrs, - timeout=timeout, - on_error="raise" if raise_for_error else "return", - stimulus_id=f"client-restart-workers-{time()}", + out: dict[str, Literal["OK", "removed", "timed out"]] = ( + await self.scheduler.restart_workers( + workers=worker_addrs, + timeout=timeout, + on_error="raise" if raise_for_error else "return", + stimulus_id=f"client-restart-workers-{time()}", + ) ) # Map keys back to original `workers` input names/addresses out = {w: out[w_addr] for w, w_addr in zip(workers, worker_addrs)} diff --git a/distributed/comm/registry.py b/distributed/comm/registry.py index db9c0baa29..9cdc29a87f 100644 --- a/distributed/comm/registry.py +++ b/distributed/comm/registry.py @@ -7,8 +7,7 @@ class _EntryPoints(Protocol): - def __call__(self, **kwargs: str) -> Iterable[importlib.metadata.EntryPoint]: - ... + def __call__(self, **kwargs: str) -> Iterable[importlib.metadata.EntryPoint]: ... _entry_points: _EntryPoints = importlib.metadata.entry_points # type: ignore[assignment] diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index ecb6eb081f..54b14fec44 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -5,6 +5,7 @@ .. _UCX: https://github.com/openucx/ucx """ + from __future__ import annotations import functools @@ -360,7 +361,7 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): await self.ep.recv(header) header = struct.unpack(header_fmt, header) cuda_frames, sizes = header[:nframes], header[nframes:] - except BaseException as e: + except BaseException as e: # noqa: B036 # In addition to UCX exceptions, may be CancelledError or another # "low-level" exception. The only safe thing to do is to abort. # (See also https://github.com/dask/distributed/pull/6574). @@ -390,7 +391,7 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): try: for each_frame in recv_frames: await self.ep.recv(each_frame) - except BaseException as e: + except BaseException as e: # noqa: B036 # In addition to UCX exceptions, may be CancelledError or another # "low-level" exception. The only safe thing to do is to abort. # (See also https://github.com/dask/distributed/pull/6574). diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 982d234c9a..d994e17acd 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -3923,9 +3923,11 @@ def update(self): # Format event loop as time and GIL (if configured) as % self.data["text"] = [ - f"{x * 100:.1f}%" - if i % 2 and s.monitor.monitor_gil_contention - else format_time(x) + ( + f"{x * 100:.1f}%" + if i % 2 and s.monitor.monitor_gil_contention + else format_time(x) + ) for i, x in enumerate(self.data["values"]) ] update(self.source, self.data) diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index 2df143412e..481745d139 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -434,13 +434,15 @@ def SSHCluster( "cls": Scheduler, "options": { "address": hosts[0], - "connect_options": connect_options - if isinstance(connect_options, dict) - else connect_options[0], + "connect_options": ( + connect_options + if isinstance(connect_options, dict) + else connect_options[0] + ), "kwargs": scheduler_options, - "remote_python": remote_python[0] - if isinstance(remote_python, list) - else remote_python, + "remote_python": ( + remote_python[0] if isinstance(remote_python, list) else remote_python + ), }, } workers = { @@ -448,14 +450,18 @@ def SSHCluster( "cls": Worker, "options": { "address": host, - "connect_options": connect_options - if isinstance(connect_options, dict) - else connect_options[i + 1], + "connect_options": ( + connect_options + if isinstance(connect_options, dict) + else connect_options[i + 1] + ), "kwargs": worker_options, "worker_class": worker_class, - "remote_python": remote_python[i + 1] - if isinstance(remote_python, list) - else remote_python, + "remote_python": ( + remote_python[i + 1] + if isinstance(remote_python, list) + else remote_python + ), }, } for i, host in enumerate(hosts[1:]) diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index 16576a857a..a71fdfb298 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -189,14 +189,17 @@ async def test_adapt_quickly(): Instead we want to wait a few beats before removing a worker in case the user is taking a brief pause between work """ - async with LocalCluster( - n_workers=0, - asynchronous=True, - processes=False, - silence_logs=False, - dashboard_address=":0", - threads_per_worker=1, - ) as cluster, Client(cluster, asynchronous=True) as client: + async with ( + LocalCluster( + n_workers=0, + asynchronous=True, + processes=False, + silence_logs=False, + dashboard_address=":0", + threads_per_worker=1, + ) as cluster, + Client(cluster, asynchronous=True) as client, + ): adapt = cluster.adapt(interval="20 ms", wait_count=5, maximum=10) future = client.submit(slowinc, 1, delay=0.100) await wait(future) @@ -240,13 +243,16 @@ async def test_adapt_quickly(): @gen_test() async def test_adapt_down(): """Ensure that redefining adapt with a lower maximum removes workers""" - async with LocalCluster( - n_workers=0, - asynchronous=True, - processes=False, - silence_logs=False, - dashboard_address=":0", - ) as cluster, Client(cluster, asynchronous=True) as client: + async with ( + LocalCluster( + n_workers=0, + asynchronous=True, + processes=False, + silence_logs=False, + dashboard_address=":0", + ) as cluster, + Client(cluster, asynchronous=True) as client, + ): cluster.adapt(interval="20ms", maximum=5) futures = client.map(slowinc, range(1000), delay=0.1) diff --git a/distributed/deploy/tests/test_cluster.py b/distributed/deploy/tests/test_cluster.py index 5c2394a80e..b40a15f2c4 100644 --- a/distributed/deploy/tests/test_cluster.py +++ b/distributed/deploy/tests/test_cluster.py @@ -10,9 +10,11 @@ @gen_test() async def test_eq(): - async with Cluster(asynchronous=True, name="A") as clusterA, Cluster( - asynchronous=True, name="A2" - ) as clusterA2, Cluster(asynchronous=True, name="B") as clusterB: + async with ( + Cluster(asynchronous=True, name="A") as clusterA, + Cluster(asynchronous=True, name="A2") as clusterA2, + Cluster(asynchronous=True, name="B") as clusterB, + ): assert clusterA != "A" assert not (clusterA == "A") assert clusterA == clusterA @@ -75,8 +77,11 @@ def test_exponential_backoff(): @gen_test() async def test_sync_context_manager_used_with_async_cluster(): async with Cluster(asynchronous=True, name="A") as cluster: - with pytest.raises( - TypeError, - match=r"Used 'with' with asynchronous class; please use 'async with'", - ), cluster: + with ( + pytest.raises( + TypeError, + match=r"Used 'with' with asynchronous class; please use 'async with'", + ), + cluster, + ): pass diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 1594bc4db5..59abc15573 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -184,35 +184,44 @@ def test_transports_tcp_port(loop): def test_cores(loop): - with LocalCluster( - n_workers=2, - scheduler_port=0, - silence_logs=False, - dashboard_address=":0", - processes=False, - loop=loop, - ) as cluster, Client(cluster.scheduler_address, loop=loop) as client: + with ( + LocalCluster( + n_workers=2, + scheduler_port=0, + silence_logs=False, + dashboard_address=":0", + processes=False, + loop=loop, + ) as cluster, + Client(cluster.scheduler_address, loop=loop) as client, + ): client.scheduler_info() assert len(client.nthreads()) == 2 def test_submit(loop): - with LocalCluster( - n_workers=2, - scheduler_port=0, - silence_logs=False, - dashboard_address=":0", - processes=False, - loop=loop, - ) as cluster, Client(cluster.scheduler_address, loop=loop) as client: + with ( + LocalCluster( + n_workers=2, + scheduler_port=0, + silence_logs=False, + dashboard_address=":0", + processes=False, + loop=loop, + ) as cluster, + Client(cluster.scheduler_address, loop=loop) as client, + ): future = client.submit(lambda x: x + 1, 1) assert future.result() == 2 def test_context_manager(loop): - with LocalCluster( - silence_logs=False, dashboard_address=":0", processes=False, loop=loop - ) as c, Client(c) as e: + with ( + LocalCluster( + silence_logs=False, dashboard_address=":0", processes=False, loop=loop + ) as c, + Client(c) as e, + ): assert e.nthreads() @@ -829,14 +838,17 @@ class MyCluster(LocalCluster): def scale_down(self, *args, **kwargs): pass - async with MyCluster( - n_workers=0, - processes=False, - silence_logs=False, - dashboard_address=":0", - loop=None, - asynchronous=True, - ) as cluster, Client(cluster, asynchronous=True) as c: + async with ( + MyCluster( + n_workers=0, + processes=False, + silence_logs=False, + dashboard_address=":0", + loop=None, + asynchronous=True, + ) as cluster, + Client(cluster, asynchronous=True) as c, + ): assert not cluster.workers await cluster.scale(2) diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index 0a30f01508..878e778cc4 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -415,7 +415,7 @@ def __str__(self): __repr__ = __str__ - async def start(self): + async def start_unsafe(self): await asyncio.gather(*self.workers) async def close(self): diff --git a/distributed/deploy/tests/test_subprocess.py b/distributed/deploy/tests/test_subprocess.py index 9033e5e22e..fa878f08c1 100644 --- a/distributed/deploy/tests/test_subprocess.py +++ b/distributed/deploy/tests/test_subprocess.py @@ -81,9 +81,10 @@ async def test_subprocess_cluster_does_not_depend_on_logging(): with new_config_file( {"distributed": {"logging": {"distributed": logging.CRITICAL + 1}}} ): - async with SubprocessCluster( - asynchronous=True, dashboard_address=":0" - ) as cluster, Client(cluster, asynchronous=True) as client: + async with ( + SubprocessCluster(asynchronous=True, dashboard_address=":0") as cluster, + Client(cluster, asynchronous=True) as client, + ): result = await client.submit(lambda x: x + 1, 10) assert result == 11 diff --git a/distributed/diagnostics/memray.py b/distributed/diagnostics/memray.py index 82b6f389a2..7dcee2f231 100644 --- a/distributed/diagnostics/memray.py +++ b/distributed/diagnostics/memray.py @@ -71,8 +71,11 @@ def _fetch_memray_profile( def memray_workers( directory: str | pathlib.Path = "memray-profiles", workers: int | None | list[str] = None, - report_args: Sequence[str] - | Literal[False] = ("flamegraph", "--temporal", "--leaks"), + report_args: Sequence[str] | Literal[False] = ( + "flamegraph", + "--temporal", + "--leaks", + ), fetch_reports_parallel: bool | int = True, **memray_kwargs: Any, ) -> Iterator[None]: @@ -183,8 +186,11 @@ def memray_workers( @contextlib.contextmanager def memray_scheduler( directory: str | pathlib.Path = "memray-profiles", - report_args: Sequence[str] - | Literal[False] = ("flamegraph", "--temporal", "--leaks"), + report_args: Sequence[str] | Literal[False] = ( + "flamegraph", + "--temporal", + "--leaks", + ), **memray_kwargs: Any, ) -> Iterator[None]: """Generate a Memray profile on the Scheduler and download the generated report. diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 669f915292..71c6cc1ed9 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -526,13 +526,11 @@ def __init__(self, _install_fn: Callable[[], None], name: str): async def setup(self, nanny): from distributed.semaphore import Semaphore - async with ( - await Semaphore( - max_leases=1, - name=socket.gethostname(), - scheduler_rpc=nanny.scheduler, - loop=nanny.loop, - ) + async with await Semaphore( + max_leases=1, + name=socket.gethostname(), + scheduler_rpc=nanny.scheduler, + loop=nanny.loop, ): self._install_fn() @@ -571,13 +569,11 @@ def __init__(self, install_fn: Callable[[], None], name: str): async def setup(self, worker): from distributed.semaphore import Semaphore - async with ( - await Semaphore( - max_leases=1, - name=socket.gethostname(), - scheduler_rpc=worker.scheduler, - loop=worker.loop, - ) + async with await Semaphore( + max_leases=1, + name=socket.gethostname(), + scheduler_rpc=worker.scheduler, + loop=worker.loop, ): self._install_fn() diff --git a/distributed/diagnostics/tests/test_nvml.py b/distributed/diagnostics/tests/test_nvml.py index d9a95486d4..6ddfbc67ed 100644 --- a/distributed/diagnostics/tests/test_nvml.py +++ b/distributed/diagnostics/tests/test_nvml.py @@ -162,9 +162,10 @@ def test_visible_devices_bad_uuid(): if nvml.device_get_count() < 1: pytest.skip("No GPUs available") - with mock.patch.dict( - os.environ, {"CUDA_VISIBLE_DEVICES": "NOT-A-GPU-UUID"} - ), pytest.raises(ValueError, match="Devices in CUDA_VISIBLE_DEVICES"): + with ( + mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "NOT-A-GPU-UUID"}), + pytest.raises(ValueError, match="Devices in CUDA_VISIBLE_DEVICES"), + ): nvml._pynvml_handles() diff --git a/distributed/metrics.py b/distributed/metrics.py index 3013a447c6..529b47fd3f 100755 --- a/distributed/metrics.py +++ b/distributed/metrics.py @@ -195,7 +195,9 @@ class ContextMeter: _callbacks: ContextVar[dict[Hashable, Callable[[Hashable, float, str], None]]] def __init__(self): - self._callbacks = ContextVar(f"MetricHook<{id(self)}>._callbacks", default={}) + self._callbacks = ContextVar( + f"MetricHook<{id(self)}>._callbacks", default={} # noqa: B039 + ) def __reduce__(self): assert self is context_meter, "Found copy of singleton" diff --git a/distributed/objects.py b/distributed/objects.py index 76bad0be8d..53b7b91565 100644 --- a/distributed/objects.py +++ b/distributed/objects.py @@ -1,5 +1,6 @@ """This file contains custom objects. These are mostly regular objects with more useful _repr_ and _repr_html_ methods.""" + from __future__ import annotations from urllib.parse import urlparse diff --git a/distributed/profile.py b/distributed/profile.py index 1ba4e2994d..a194a53180 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -24,6 +24,7 @@ 'children': {...}}} } """ + from __future__ import annotations import bisect diff --git a/distributed/protocol/compression.py b/distributed/protocol/compression.py index 6b92048455..5b8eac78fb 100644 --- a/distributed/protocol/compression.py +++ b/distributed/protocol/compression.py @@ -3,6 +3,7 @@ Includes utilities for determining whether or not to compress """ + from __future__ import annotations import zlib diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py index 2cc2169b07..fb98490ce6 100644 --- a/distributed/protocol/cupy.py +++ b/distributed/protocol/cupy.py @@ -1,6 +1,7 @@ """ Efficient serialization GPU arrays. """ + from __future__ import annotations import copyreg diff --git a/distributed/protocol/scipy.py b/distributed/protocol/scipy.py index 89e42bfe05..90a7910693 100644 --- a/distributed/protocol/scipy.py +++ b/distributed/protocol/scipy.py @@ -1,6 +1,7 @@ """ Efficient serialization of SciPy sparse matrices. """ + from __future__ import annotations import scipy diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index 0777b8f9e0..09226e00be 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -95,8 +95,7 @@ def unpack_frames( *, remainder: bool = False, partial: Literal[False] = False, -) -> list[memoryview]: - ... +) -> list[memoryview]: ... @overload @@ -105,8 +104,7 @@ def unpack_frames( *, remainder: bool = False, partial: Literal[True], -) -> tuple[list[memoryview], list[int]]: - ... +) -> tuple[list[memoryview], list[int]]: ... def unpack_frames(b, *, remainder=False, partial=False): diff --git a/distributed/pytest_resourceleaks.py b/distributed/pytest_resourceleaks.py index 0521648e0a..e647f68048 100644 --- a/distributed/pytest_resourceleaks.py +++ b/distributed/pytest_resourceleaks.py @@ -41,6 +41,7 @@ def test1(): unreliable. On Linux, this can be improved by reducing the MALLOC_TRIM glibc setting (see distributed.yaml). """ + from __future__ import annotations import gc diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 57d0eca46b..fa965c6910 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1847,8 +1847,7 @@ def __init__( ) @abstractmethod - def log_event(self, topic: str | Collection[str], msg: Any) -> None: - ... + def log_event(self, topic: str | Collection[str], msg: Any) -> None: ... @property def memory(self) -> MemoryState: @@ -1918,7 +1917,7 @@ def _clear_task_state(self) -> None: self.unknown_durations, self.replicated_tasks, ): - collection.clear() # type: ignore + collection.clear() @property def is_idle(self) -> bool: @@ -7399,8 +7398,7 @@ async def retire_workers( close_workers: bool = False, remove: bool = True, stimulus_id: str | None = None, - ) -> dict[str, Any]: - ... + ) -> dict[str, Any]: ... @overload async def retire_workers( @@ -7410,8 +7408,7 @@ async def retire_workers( close_workers: bool = False, remove: bool = True, stimulus_id: str | None = None, - ) -> dict[str, Any]: - ... + ) -> dict[str, Any]: ... @overload async def retire_workers( @@ -7427,8 +7424,7 @@ async def retire_workers( minimum: int | None = None, target: int | None = None, attribute: str = "address", - ) -> dict[str, Any]: - ... + ) -> dict[str, Any]: ... @log_errors async def retire_workers( @@ -7710,12 +7706,10 @@ def update_data( self.client_desires_keys(keys=list(who_has), client=client) @overload - def report_on_key(self, key: Key, *, client: str | None = None) -> None: - ... + def report_on_key(self, key: Key, *, client: str | None = None) -> None: ... @overload - def report_on_key(self, *, ts: TaskState, client: str | None = None) -> None: - ... + def report_on_key(self, *, ts: TaskState, client: str | None = None) -> None: ... def report_on_key(self, key=None, *, ts=None, client=None): if (ts is None) == (key is None): @@ -7806,9 +7800,11 @@ def get_processing( def get_who_has(self, keys: Iterable[Key] | None = None) -> dict[Key, list[str]]: if keys is not None: return { - key: [ws.address for ws in self.tasks[key].who_has or ()] - if key in self.tasks - else [] + key: ( + [ws.address for ws in self.tasks[key].who_has or ()] + if key in self.tasks + else [] + ) for key in keys } else: @@ -7823,9 +7819,11 @@ def get_has_what( if workers is not None: workers = map(self.coerce_address, workers) return { - w: [ts.key for ts in self.workers[w].has_what] - if w in self.workers - else [] + w: ( + [ts.key for ts in self.workers[w].has_what] + if w in self.workers + else [] + ) for w in workers } else: @@ -8555,12 +8553,10 @@ def unsubscribe_topic(self, topic: str, client: str) -> None: self._broker.unsubscribe(topic, client) @overload - def get_events(self, topic: str) -> tuple[tuple[float, Any], ...]: - ... + def get_events(self, topic: str) -> tuple[tuple[float, Any], ...]: ... @overload - def get_events(self) -> dict[str, tuple[tuple[float, Any], ...]]: - ... + def get_events(self) -> dict[str, tuple[tuple[float, Any], ...]]: ... def get_events( self, topic: str | None = None diff --git a/distributed/semaphore.py b/distributed/semaphore.py index b4d8d900c0..f650f4f815 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -64,10 +64,8 @@ def __init__(self, scheduler): dask.config.get("distributed.scheduler.locks.lease-validation-interval"), default="s", ) - self.scheduler.periodic_callbacks[ - "semaphore-lease-timeout" - ] = pc = PeriodicCallback( - self._check_lease_timeout, validation_callback_time * 1000 + self.scheduler.periodic_callbacks["semaphore-lease-timeout"] = pc = ( + PeriodicCallback(self._check_lease_timeout, validation_callback_time * 1000) ) pc.start() diff --git a/distributed/shuffle/_buffer.py b/distributed/shuffle/_buffer.py index 60cc0a86c4..f402c51b7c 100644 --- a/distributed/shuffle/_buffer.py +++ b/distributed/shuffle/_buffer.py @@ -126,8 +126,7 @@ async def process(self, id: str, shards: list[ShardType], size: int) -> None: self.bytes_memory -= size @abc.abstractmethod - async def _process(self, id: str, shards: list[ShardType]) -> None: - ... + async def _process(self, id: str, shards: list[ShardType]) -> None: ... def read(self, id: str) -> ShardType: raise NotImplementedError() # pragma: nocover diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index b0c4fc17e1..c8ae80a75e 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -447,7 +447,7 @@ class ShuffleSpec(abc.ABC, Generic[_T_partition_id]): @property @abc.abstractmethod - def output_partitions(self) -> Generator[_T_partition_id, None, None]: + def output_partitions(self) -> Generator[_T_partition_id]: """Output partitions""" @abc.abstractmethod diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 327c06f3df..bed442705c 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -78,7 +78,7 @@ def release_read(self) -> None: self._condition.notify_all() @contextmanager - def write(self) -> Generator[None, None, None]: + def write(self) -> Generator[None]: self.acquire_write() try: yield @@ -86,7 +86,7 @@ def write(self) -> Generator[None, None, None]: self.release_write() @contextmanager - def read(self) -> Generator[None, None, None]: + def read(self) -> Generator[None]: self.acquire_read() try: yield diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index efae7d80b7..20b3715085 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -631,7 +631,7 @@ def _largest_block_size(chunks: tuple[tuple[int, ...], ...]) -> int: def _split_partials( old_to_new: list[Any], -) -> Generator[_NDPartial, None, None]: +) -> Generator[_NDPartial]: """Split the rechunking into partials that can be performed separately""" partials_per_axis = _split_partials_per_axis(old_to_new) indices_per_axis = (range(len(partials)) for partials in partials_per_axis) @@ -1102,7 +1102,7 @@ class ArrayRechunkSpec(ShuffleSpec[NDIndex]): old: ChunkedAxes @property - def output_partitions(self) -> Generator[NDIndex, None, None]: + def output_partitions(self) -> Generator[NDIndex]: yield from product(*(range(len(c)) for c in self.new)) def pick_worker(self, partition: NDIndex, workers: Sequence[str]) -> str: diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 508e2f4823..4a21909fa0 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -567,7 +567,7 @@ class DataFrameShuffleSpec(ShuffleSpec[int]): drop_column: bool @property - def output_partitions(self) -> Generator[int, None, None]: + def output_partitions(self) -> Generator[int]: yield from self.parts_out def pick_worker(self, partition: int, workers: Sequence[str]) -> str: @@ -600,9 +600,9 @@ def create_run_on_worker( rpc=plugin.worker.rpc, digest_metric=plugin.worker.digest_metric, scheduler=plugin.worker.scheduler, - memory_limiter_disk=plugin.memory_limiter_disk - if self.disk - else ResourceLimiter(None), + memory_limiter_disk=( + plugin.memory_limiter_disk if self.disk else ResourceLimiter(None) + ), memory_limiter_comms=plugin.memory_limiter_comms, disk=self.disk, drop_column=self.drop_column, diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 57d2cfe369..e4033d4310 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -210,8 +210,7 @@ async def _fetch( async def _refresh( self, shuffle_id: ShuffleId, - ) -> ShuffleRun: - ... + ) -> ShuffleRun: ... @overload async def _refresh( @@ -219,8 +218,7 @@ async def _refresh( shuffle_id: ShuffleId, spec: ShuffleSpec, key: Key, - ) -> ShuffleRun: - ... + ) -> ShuffleRun: ... async def _refresh( self, diff --git a/distributed/shuffle/tests/test_graph.py b/distributed/shuffle/tests/test_graph.py index 7fa4d5b2ce..33c20b01f8 100644 --- a/distributed/shuffle/tests/test_graph.py +++ b/distributed/shuffle/tests/test_graph.py @@ -32,9 +32,12 @@ def test_raise_on_complex_numbers(dtype): df = dd.from_pandas( pd.DataFrame({"x": pd.array(range(10), dtype=dtype)}), npartitions=5 ) - with pytest.raises( - TypeError, match=f"p2p does not support data of type '{df.x.dtype}'" - ), dask.config.set({"dataframe.shuffle.method": "p2p"}): + with ( + pytest.raises( + TypeError, match=f"p2p does not support data of type '{df.x.dtype}'" + ), + dask.config.set({"dataframe.shuffle.method": "p2p"}), + ): df.shuffle("x") @@ -50,9 +53,10 @@ def __init__(self, value: int) -> None: pd.DataFrame({"x": pd.array([Stub(i) for i in range(10)], dtype="object")}), npartitions=5, ) - with pytest.raises( - TypeError, match="p2p does not support custom objects" - ), dask.config.set({"dataframe.shuffle.method": "p2p"}): + with ( + pytest.raises(TypeError, match="p2p does not support custom objects"), + dask.config.set({"dataframe.shuffle.method": "p2p"}), + ): df.shuffle("x") @@ -60,17 +64,19 @@ def test_raise_on_sparse_data(): df = dd.from_pandas( pd.DataFrame({"x": pd.array(range(10), dtype="Sparse[float64]")}), npartitions=5 ) - with pytest.raises( - TypeError, match="p2p does not support sparse data" - ), dask.config.set({"dataframe.shuffle.method": "p2p"}): + with ( + pytest.raises(TypeError, match="p2p does not support sparse data"), + dask.config.set({"dataframe.shuffle.method": "p2p"}), + ): df.shuffle("x") def test_raise_on_non_string_column_name(): df = dd.from_pandas(pd.DataFrame({"a": range(10), 1: range(10)}), npartitions=5) - with pytest.raises( - TypeError, match="p2p requires all column names to be str" - ), dask.config.set({"dataframe.shuffle.method": "p2p"}): + with ( + pytest.raises(TypeError, match="p2p requires all column names to be str"), + dask.config.set({"dataframe.shuffle.method": "p2p"}), + ): df.shuffle("a") diff --git a/distributed/shuffle/tests/test_merge.py b/distributed/shuffle/tests/test_merge.py index 8c196af3c0..112fcd038d 100644 --- a/distributed/shuffle/tests/test_merge.py +++ b/distributed/shuffle/tests/test_merge.py @@ -61,9 +61,10 @@ async def test_minimal_version(c, s, a, b): B = pd.DataFrame({"y": [1, 3, 4, 4, 5, 6], "z": [6, 5, 4, 3, 2, 1]}) b = dd.repartition(B, [0, 2, 5]) - with pytest.raises( - ModuleNotFoundError, match="requires pyarrow" - ), dask.config.set({"dataframe.shuffle.method": "p2p"}): + with ( + pytest.raises(ModuleNotFoundError, match="requires pyarrow"), + dask.config.set({"dataframe.shuffle.method": "p2p"}), + ): await c.compute(dd.merge(a, b, left_on="x", right_on="z")) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index e3a596855c..ab676b51a7 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3811,7 +3811,7 @@ class UnhandledExceptions(Exception): @contextmanager -def catch_unhandled_exceptions() -> Generator[None, None, None]: +def catch_unhandled_exceptions() -> Generator[None]: loop = asyncio.get_running_loop() ctxs: list[dict[str, Any]] = [] @@ -6939,8 +6939,7 @@ async def test_get_task_metadata_multiple(c, s, a, b): @gen_cluster(client=True) async def test_register_worker_plugin_instance_required(c, s, a, b): - class MyPlugin(WorkerPlugin): - ... + class MyPlugin(WorkerPlugin): ... with pytest.raises(TypeError, match="instance"): await c.register_plugin(MyPlugin) diff --git a/distributed/tests/test_computations.py b/distributed/tests/test_computations.py index 70fd7633e8..e820a6fc4b 100644 --- a/distributed/tests/test_computations.py +++ b/distributed/tests/test_computations.py @@ -1,4 +1,5 @@ """Tests for distributed.scheduler.Computation objects""" + from __future__ import annotations import pytest diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 0e94dc96a5..6f4836ae08 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -1199,38 +1199,27 @@ async def long_handler(comm, delay=10): def test_expects_comm(): class A: - def empty(self): - ... + def empty(self): ... - def one_arg(self, arg): - ... + def one_arg(self, arg): ... - def comm_arg(self, comm): - ... + def comm_arg(self, comm): ... - def stream_arg(self, stream): - ... + def stream_arg(self, stream): ... - def two_arg(self, arg, other): - ... + def two_arg(self, arg, other): ... - def comm_arg_other(self, comm, other): - ... + def comm_arg_other(self, comm, other): ... - def stream_arg_other(self, stream, other): - ... + def stream_arg_other(self, stream, other): ... - def arg_kwarg(self, arg, other=None): - ... + def arg_kwarg(self, arg, other=None): ... - def comm_posarg_only(self, comm, /, other): - ... + def comm_posarg_only(self, comm, /, other): ... - def comm_not_leading_position(self, other, comm): - ... + def comm_not_leading_position(self, other, comm): ... - def stream_not_leading_position(self, other, stream): - ... + def stream_not_leading_position(self, other, stream): ... expected_warning = "first argument of a RPC handler `stream` is deprecated" diff --git a/distributed/tests/test_event_logging.py b/distributed/tests/test_event_logging.py index 43effcb12f..e97cc68dea 100644 --- a/distributed/tests/test_event_logging.py +++ b/distributed/tests/test_event_logging.py @@ -75,9 +75,10 @@ def log_scheduler(dask_scheduler): @gen_cluster(client=True, nthreads=[]) async def test_log_event_multiple_clients(c, s): - async with Client(s.address, asynchronous=True) as c2, Client( - s.address, asynchronous=True - ) as c3: + async with ( + Client(s.address, asynchronous=True) as c2, + Client(s.address, asynchronous=True) as c3, + ): received_events = [] def get_event_handler(handler_id): @@ -185,9 +186,10 @@ async def user_event_handler(event): @gen_cluster(nthreads=[]) async def test_topic_subscribe_unsubscribe(s): - async with Client(s.address, asynchronous=True) as c1, Client( - s.address, asynchronous=True - ) as c2: + async with ( + Client(s.address, asynchronous=True) as c1, + Client(s.address, asynchronous=True) as c2, + ): def event_handler(recorded_events, event): _, msg = event diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 35401f7f5f..85cbd3a5e1 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -216,9 +216,10 @@ def test_worker_doesnt_await_task_completion(loop): @gen_cluster(Worker=Nanny, timeout=60) async def test_multiple_clients_restart(s, a, b): - async with Client(s.address, asynchronous=True) as c1, Client( - s.address, asynchronous=True - ) as c2: + async with ( + Client(s.address, asynchronous=True) as c1, + Client(s.address, asynchronous=True) as c2, + ): x = c1.submit(inc, 1) y = c2.submit(inc, 2) xx = await x diff --git a/distributed/tests/test_preload.py b/distributed/tests/test_preload.py index a8fb86bf8e..5a003f9fa8 100644 --- a/distributed/tests/test_preload.py +++ b/distributed/tests/test_preload.py @@ -40,9 +40,10 @@ def check_worker(): with open(path, "w") as f: f.write(PRELOAD_TEXT) - with cluster(worker_kwargs={"preload": [path]}) as (s, workers), Client( - s["address"], loop=loop - ) as c: + with ( + cluster(worker_kwargs={"preload": [path]}) as (s, workers), + Client(s["address"], loop=loop) as c, + ): assert c.run(check_worker) == { worker["address"]: worker["address"] for worker in workers } @@ -108,10 +109,13 @@ def check_worker(): with open(path, "w") as f: f.write(PRELOAD_TEXT) - with cluster(worker_kwargs={"preload": ["worker_info"]}) as ( - s, - workers, - ), Client(s["address"], loop=loop) as c: + with ( + cluster(worker_kwargs={"preload": ["worker_info"]}) as ( + s, + workers, + ), + Client(s["address"], loop=loop) as c, + ): assert c.run(check_worker) == { worker["address"]: worker["address"] for worker in workers } @@ -170,14 +174,17 @@ async def test_preload_import_time(): @gen_test() async def test_web_preload(): - with mock.patch( - "urllib3.PoolManager.request", - **{ - "return_value.data": b"def dask_setup(dask_server):" - b"\n dask_server.foo = 1" - b"\n" - }, - ) as request, captured_logger("distributed.preloading") as log: + with ( + mock.patch( + "urllib3.PoolManager.request", + **{ + "return_value.data": b"def dask_setup(dask_server):" + b"\n dask_server.foo = 1" + b"\n" + }, + ) as request, + captured_logger("distributed.preloading") as log, + ): async with Scheduler( host="localhost", preload=["http://example.com/preload"] ) as s: diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index 3669030cca..4c59da9358 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -16,9 +16,10 @@ @gen_cluster() async def test_publish_simple(s, a, b): - async with Client(s.address, asynchronous=True) as c, Client( - s.address, asynchronous=True - ) as f: + async with ( + Client(s.address, asynchronous=True) as c, + Client(s.address, asynchronous=True) as f, + ): data = await c.scatter(range(3)) await c.publish_dataset(data=data) assert "data" in s.extensions["publish"].datasets @@ -54,9 +55,10 @@ async def test_publish_non_string_key(s, a, b): @gen_cluster() async def test_publish_roundtrip(s, a, b): - async with Client(s.address, asynchronous=True) as c, Client( - s.address, asynchronous=True - ) as f: + async with ( + Client(s.address, asynchronous=True) as c, + Client(s.address, asynchronous=True) as f, + ): data = await c.scatter([0, 1, 2]) await c.publish_dataset(data=data) @@ -149,9 +151,10 @@ def test_unpublish_multiple_datasets_sync(client): @gen_cluster() async def test_publish_bag(s, a, b): db = pytest.importorskip("dask.bag") - async with Client(s.address, asynchronous=True) as c, Client( - s.address, asynchronous=True - ) as f: + async with ( + Client(s.address, asynchronous=True) as c, + Client(s.address, asynchronous=True) as f, + ): bag = db.from_sequence([0, 1, 2]) bagp = c.persist(bag) diff --git a/distributed/tests/test_reschedule.py b/distributed/tests/test_reschedule.py index fe6ef45644..feb3d64e18 100644 --- a/distributed/tests/test_reschedule.py +++ b/distributed/tests/test_reschedule.py @@ -3,6 +3,7 @@ Note that this functionality is also used by work stealing; see test_steal.py for additional tests. """ + from __future__ import annotations import asyncio diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index a62fcd977e..cfb6fdefdb 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -3177,11 +3177,13 @@ async def connect(self, *args, **kwargs): async def test_gather_failing_can_recover(c, s, a, b): x = await c.scatter({"x": 1}, workers=a.address) rpc = await FlakyConnectionPool(failing_connections=1) - with mock.patch.object(s, "rpc", rpc), dask.config.set( - {"distributed.comm.retry.count": 1} - ), captured_handler( - logging.getLogger("distributed").handlers[0] - ) as distributed_log: + with ( + mock.patch.object(s, "rpc", rpc), + dask.config.set({"distributed.comm.retry.count": 1}), + captured_handler( + logging.getLogger("distributed").handlers[0] + ) as distributed_log, + ): res = await s.gather(keys=["x"]) assert re.match( r"\A\d+-\d+-\d+ \d+:\d+:\d+,\d+ - distributed.utils_comm - INFO - " @@ -4697,9 +4699,11 @@ def block_on_event(input, block, executing): return input # Manually spin up cluster to avoid state validation on cluster shutdown in gen_cluster - async with Scheduler(dashboard_address=":0") as s, Worker(s.address) as w, Client( - s.address, asynchronous=True - ) as c: + async with ( + Scheduler(dashboard_address=":0") as s, + Worker(s.address) as w, + Client(s.address, asynchronous=True) as c, + ): block = Event() executing = Event() diff --git a/distributed/tests/test_tls_functional.py b/distributed/tests/test_tls_functional.py index ac7a0a17ab..bbddc550e0 100644 --- a/distributed/tests/test_tls_functional.py +++ b/distributed/tests/test_tls_functional.py @@ -2,6 +2,7 @@ Various functional tests for TLS networking. Most are taken from other test files and adapted. """ + from __future__ import annotations import asyncio diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index da68e61a98..8bfcd8347b 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -450,6 +450,7 @@ async def raiser(): # Prevent test failure from killing the whole pytest process traceback.print_exc() pytest.fail(f"BaseException propagated back to test: {e!r}. See stdout.") + raise # Nanny restarts it await c.wait_for_workers(1) @@ -663,9 +664,12 @@ async def test_close_on_disconnect(s, w): @gen_cluster(nthreads=[]) async def test_memory_limit_auto(s): - async with Worker(s.address, nthreads=1) as a, Worker( - s.address, nthreads=2 - ) as b, Worker(s.address, nthreads=100) as c, Worker(s.address, nthreads=200) as d: + async with ( + Worker(s.address, nthreads=1) as a, + Worker(s.address, nthreads=2) as b, + Worker(s.address, nthreads=100) as c, + Worker(s.address, nthreads=200) as d, + ): assert isinstance(a.memory_manager.memory_limit, Number) assert isinstance(b.memory_manager.memory_limit, Number) @@ -1768,11 +1772,10 @@ async def test_heartbeat_missing_real_cluster(s, a): assumption_msg = "Test assumptions have changed. Race condition may have been fixed; this test may be removable." - with captured_logger( - "distributed.worker", level=logging.WARNING - ) as wlogger, captured_logger( - "distributed.scheduler", level=logging.WARNING - ) as slogger: + with ( + captured_logger("distributed.worker", level=logging.WARNING) as wlogger, + captured_logger("distributed.scheduler", level=logging.WARNING) as slogger, + ): with freeze_batched_send(s.stream_comms[a.address]): await s.remove_worker(a.address, stimulus_id="foo") assert not s.workers diff --git a/distributed/threadpoolexecutor.py b/distributed/threadpoolexecutor.py index 60209144d1..0d7bf609a3 100644 --- a/distributed/threadpoolexecutor.py +++ b/distributed/threadpoolexecutor.py @@ -20,6 +20,7 @@ Copyright 2001-2016 Python Software Foundation; All Rights Reserved """ + from __future__ import annotations import itertools @@ -62,6 +63,7 @@ def _worker(executor, work_queue): del executor except BaseException: logger.critical("Exception in worker", exc_info=True) + raise finally: del thread_state.proceed del thread_state.executor diff --git a/distributed/utils.py b/distributed/utils.py index 0c84dd1737..e360667bc9 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -449,7 +449,7 @@ def __init__(self, target: Callable[[], None], daemon: bool, name: str): def wrapper() -> None: try: target() - except BaseException as e: + except BaseException as e: # noqa: B036 self._exception = e self._thread = thread = threading.Thread( @@ -723,13 +723,11 @@ def key_split_group(x: object) -> str: @overload -def log_errors(func: Callable[P, T], /) -> Callable[P, T]: - ... +def log_errors(func: Callable[P, T], /) -> Callable[P, T]: ... @overload -def log_errors(*, pdb: bool = False, unroll_stack: int = 1) -> _LogErrors: - ... +def log_errors(*, pdb: bool = False, unroll_stack: int = 1) -> _LogErrors: ... def log_errors(func=None, /, *, pdb=False, unroll_stack=0): @@ -865,7 +863,7 @@ def silence_logging(level, root="distributed"): @contextlib.contextmanager def silence_logging_cmgr( level: str | int, root: str = "distributed" -) -> Generator[None, None, None]: +) -> Generator[None]: """ Temporarily change all StreamHandlers for the given logger to the given level """ diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index 7c10c25635..d37a808363 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -385,8 +385,10 @@ async def retry( delay_min: float, delay_max: float, jitter_fraction: float = 0.1, - retry_on_exceptions: type[BaseException] - | tuple[type[BaseException], ...] = (EnvironmentError, IOError), + retry_on_exceptions: type[BaseException] | tuple[type[BaseException], ...] = ( + EnvironmentError, + IOError, + ), operation: str | None = None, ) -> T: """ diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 05cd9382b0..334f1eced4 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -131,9 +131,12 @@ def loop(loop_in_thread): @pytest.fixture def loop_in_thread(cleanup): loop_started = concurrent.futures.Future() - with concurrent.futures.ThreadPoolExecutor( - 1, thread_name_prefix="test IOLoop" - ) as tpe, config_for_cluster_tests(): + with ( + concurrent.futures.ThreadPoolExecutor( + 1, thread_name_prefix="test IOLoop" + ) as tpe, + config_for_cluster_tests(), + ): async def run(): io_loop = IOLoop.current() @@ -1415,7 +1418,7 @@ def captured_handler(handler): @contextmanager -def captured_context_meter() -> Generator[defaultdict[tuple, float], None, None]: +def captured_context_meter() -> Generator[defaultdict[tuple, float]]: """Capture distributed.metrics.context_meter metrics into a local defaultdict""" # Don't cast int metrics to float metrics: defaultdict[tuple, float] = defaultdict(int) @@ -2060,7 +2063,7 @@ def raises_with_cause( expected_cause: type[BaseException] | tuple[type[BaseException], ...], match_cause: str | None, *more_causes: type[BaseException] | tuple[type[BaseException], ...] | str | None, -) -> Generator[None, None, None]: +) -> Generator[None]: """Contextmanager to assert that a certain exception with cause was raised. It can travel the causes recursively by adding more expected, match pairs at the end. diff --git a/distributed/variable.py b/distributed/variable.py index befec99484..29ab379c7f 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -39,9 +39,9 @@ def __init__(self, scheduler): {"variable_set": self.set, "variable_get": self.get} ) - self.scheduler.stream_handlers[ - "variable-future-received-confirm" - ] = self.future_received_confirm + self.scheduler.stream_handlers["variable-future-received-confirm"] = ( + self.future_received_confirm + ) self.scheduler.stream_handlers["variable_delete"] = self.delete async def set(self, name=None, key=None, data=None, client=None, timeout=None): diff --git a/distributed/versions.py b/distributed/versions.py index 05325dae21..a85ee517b1 100644 --- a/distributed/versions.py +++ b/distributed/versions.py @@ -44,8 +44,9 @@ def get_versions( - packages: Iterable[str | tuple[str, Callable[[ModuleType], str | None]]] - | None = None + packages: ( + Iterable[str | tuple[str, Callable[[ModuleType], str | None]]] | None + ) = None ) -> dict[str, dict[str, Any]]: """Return basic information on our software installation, and our installed versions of packages @@ -139,9 +140,11 @@ def error_message(scheduler, workers, source, source_name="Client"): versions.add(source_version) worker_versions = { - workers[w].get(pkg, "MISSING") - if isinstance(workers[w], dict) - else workers[w] + ( + workers[w].get(pkg, "MISSING") + if isinstance(workers[w], dict) + else workers[w] + ) for w in workers } versions |= worker_versions diff --git a/distributed/worker.py b/distributed/worker.py index bb1c063774..7e3fecb9b2 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1870,7 +1870,7 @@ async def plugin_add( try: result = plugin.setup(worker=self) if isawaitable(result): - result = await result + await result except Exception as e: logger.exception("Worker plugin %s failed to setup", name) if not catch_errors: @@ -1887,7 +1887,7 @@ async def plugin_remove(self, name: str) -> ErrorMessage | OKMessage: if hasattr(plugin, "teardown"): result = plugin.teardown(worker=self) if isawaitable(result): - result = await result + await result except Exception as e: logger.exception("Worker plugin %s failed to teardown", name) return error_message(e) @@ -3011,7 +3011,7 @@ def apply_function_simple( # Any other `BaseException` types would ultimately be ignored by asyncio if # raised here, after messing up the worker state machine along their way. raise - except BaseException as e: + except BaseException as e: # noqa: B036 # Users _shouldn't_ use `BaseException`s, but if they do, we can assume they # aren't a reason to shut down the whole system (since we allow the # system-shutting-down `SystemExit` and `KeyboardInterrupt` to pass through) @@ -3056,7 +3056,7 @@ async def apply_function_async( # Any other `BaseException` types would ultimately be ignored by asyncio if # raised here, after messing up the worker state machine along their way. raise - except BaseException as e: + except BaseException as e: # noqa: B036 # NOTE: this includes `CancelledError`! Since it's a user task, that's _not_ # a reason to shut down the worker. # Users _shouldn't_ use `BaseException`s, but if they do, we can assume they diff --git a/distributed/worker_memory.py b/distributed/worker_memory.py index 5465e94a8e..2ac4bcf4e7 100644 --- a/distributed/worker_memory.py +++ b/distributed/worker_memory.py @@ -18,6 +18,7 @@ Worker. - :mod:`distributed.active_memory_manager`, which runs on the scheduler side """ + from __future__ import annotations import asyncio @@ -229,9 +230,11 @@ def _maybe_pause_or_unpause(self, worker: Worker, memory: int) -> None: "Process memory: %s -- Worker memory limit: %s", int(frac * 100), format_bytes(memory), - format_bytes(self.memory_limit) - if self.memory_limit is not None - else "None", + ( + format_bytes(self.memory_limit) + if self.memory_limit is not None + else "None" + ), ) worker.status = Status.paused elif worker.status == Status.paused: @@ -240,9 +243,11 @@ def _maybe_pause_or_unpause(self, worker: Worker, memory: int) -> None: "Process memory: %s -- Worker memory limit: %s", int(frac * 100), format_bytes(memory), - format_bytes(self.memory_limit) - if self.memory_limit is not None - else "None", + ( + format_bytes(self.memory_limit) + if self.memory_limit is not None + else "None" + ), ) worker.status = Status.running @@ -449,9 +454,11 @@ def memory_monitor(self, nanny: Nanny) -> None: nanny_logger.warning( f"Worker {nanny.worker_address} (pid={process.pid}) is slow to %s", # On Windows, kill() is an alias to terminate() - "terminate; trying again" - if WINDOWS - else "accept SIGTERM; sending SIGKILL", + ( + "terminate; trying again" + if WINDOWS + else "accept SIGTERM; sending SIGKILL" + ), ) process.kill() @@ -540,4 +547,4 @@ def __get__(self, instance: Nanny | Worker | None, owner: type) -> Any: # This is triggered by Sphinx return None # pragma: nocover _warn_deprecated(instance, "memory_monitor") - return partial(instance.memory_manager.memory_monitor, instance) + return partial(instance.memory_manager.memory_monitor, instance) # type: ignore diff --git a/docs/source/scheduling-policies.rst b/docs/source/scheduling-policies.rst index 278e5064f9..fd60e06a71 100644 --- a/docs/source/scheduling-policies.rst +++ b/docs/source/scheduling-policies.rst @@ -242,7 +242,7 @@ There are two downsides to this queueing: 2. For embarrassingly-parallel workloads like a ``client.map``, there can be a minor increase in overhead per task, because each time a task finishes, a scheduler<->worker roundtrip message is required before the next task starts. In most - cases, this overhead is not even measureable and not something to worry about. + cases, this overhead is not even measurable and not something to worry about. This will only matter if you have very fast tasks, or a very slow network—that is, if your task runtime is the same order of magnitude as your network latency. For From c473680a88db4f46f8e2f7807fc9343f730e3053 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Tue, 3 Sep 2024 17:13:07 +0200 Subject: [PATCH 122/138] Speed up ``Client.map`` by computing ``token`` only once for ``func`` and ``kwargs`` (#8855) --- distributed/client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/client.py b/distributed/client.py index 152540e9bf..b13fac566a 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -877,8 +877,9 @@ def _keys(self) -> Iterable[Key]: else: if self.pure: + tok = tokenize(self.func, self.kwargs) keys = [ - self.key + "-" + tokenize(self.func, self.kwargs, args) # type: ignore + self.key + "-" + tokenize(tok, args) # type: ignore for args in zip(*self.iterables) ] else: From 56392168a67abe5951eae6ad48a8fa1a0f556388 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 3 Sep 2024 10:22:14 -0500 Subject: [PATCH 123/138] Bump JamesIves/github-pages-deploy-action from 4.5.0 to 4.6.4 (#8853) --- .github/workflows/test-report.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-report.yaml b/.github/workflows/test-report.yaml index c0b65f1d31..490f091a4f 100644 --- a/.github/workflows/test-report.yaml +++ b/.github/workflows/test-report.yaml @@ -54,7 +54,7 @@ jobs: mv test_report.html test_short_report.html deploy/ - name: Deploy 🚀 - uses: JamesIves/github-pages-deploy-action@v4.5.0 + uses: JamesIves/github-pages-deploy-action@v4.6.4 with: branch: gh-pages folder: deploy From d728052ea1be136df86cbfb53568a37cd90db875 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Tue, 3 Sep 2024 17:36:06 +0200 Subject: [PATCH 124/138] Fix test nanny timeout (#8847) --- distributed/nanny.py | 47 +++++++++++++---------- distributed/tests/test_nanny.py | 68 ++++++++++++++++++++++++++++----- 2 files changed, 86 insertions(+), 29 deletions(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index 7a14ee6576..859b9f22dc 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -516,7 +516,7 @@ async def _(): await self.instantiate() try: - await wait_for(_(), timeout) + await wait_for(asyncio.shield(_()), timeout) except asyncio.TimeoutError: logger.error( f"Restart timed out after {timeout}s; returning before finished" @@ -745,26 +745,30 @@ async def start(self) -> Status: os.environ.update(self.pre_spawn_env) try: - await self.process.start() - except OSError: - logger.exception("Nanny failed to start process", exc_info=True) - # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent - await self.process.terminate() - self.status = Status.failed - try: - msg = await self._wait_until_connected(uid) - except Exception: - # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent - await self.process.terminate() - self.status = Status.failed - raise + try: + await self.process.start() + except OSError: + # This can only happen if the actual process creation failed, e.g. + # multiprocessing.Process.start failed. This is not tested! + logger.exception("Nanny failed to start process", exc_info=True) + # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent + await self.process.terminate() + self.status = Status.failed + try: + msg = await self._wait_until_connected(uid) + except Exception: + # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent + await self.process.terminate() + self.status = Status.failed + raise + finally: + self.running.set() if not msg: return self.status self.worker_address = msg["address"] self.worker_dir = msg["dir"] assert self.worker_address self.status = Status.running - self.running.set() return self.status @@ -799,6 +803,7 @@ def mark_stopped(self): msg = self._death_message(self.process.pid, r) logger.info(msg) self.status = Status.stopped + self.running.clear() self.stopped.set() # Release resources self.process.close() @@ -830,11 +835,6 @@ async def kill( """ deadline = time() + timeout - if self.status == Status.stopped: - return - if self.status == Status.stopping: - await self.stopped.wait() - return # If the process is not properly up it will not watch the closing queue # and we may end up leaking this process # Therefore wait for it to be properly started before killing it @@ -842,10 +842,17 @@ async def kill( await self.running.wait() assert self.status in ( + Status.stopping, + Status.stopped, Status.running, Status.failed, # process failed to start, but hasn't been joined yet Status.closing_gracefully, ), self.status + if self.status == Status.stopped: + return + if self.status == Status.stopping: + await self.stopped.wait() + return self.status = Status.stopping logger.info("Nanny asking worker to close. Reason: %s", reason) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 52073c13d2..b05b7dc90c 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -208,25 +208,47 @@ async def test_scheduler_file(): s.stop() -@pytest.mark.xfail( - os.environ.get("MINDEPS") == "true", - reason="Timeout errors with mindeps environment", -) -@gen_cluster(client=True, Worker=Nanny, nthreads=[("127.0.0.1", 2)]) -async def test_nanny_timeout(c, s, a): +@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)]) +async def test_nanny_restart(c, s, a): + x = await c.scatter(123) + assert await c.submit(lambda: 1) == 1 + + await a.restart() + + while x.status != "cancelled": + await asyncio.sleep(0.1) + + assert await c.submit(lambda: 1) == 1 + + +@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)]) +async def test_nanny_restart_timeout(c, s, a): x = await c.scatter(123) with captured_logger( logging.getLogger("distributed.nanny"), level=logging.ERROR ) as logger: - await a.restart(timeout=0.1) + await a.restart(timeout=0) out = logger.getvalue() assert "timed out" in out.lower() - start = time() while x.status != "cancelled": await asyncio.sleep(0.1) - assert time() < start + 7 + + assert await c.submit(lambda: 1) == 1 + + +@gen_cluster(client=True, Worker=Nanny, nthreads=[("", 1)]) +async def test_nanny_restart_timeout_stress(c, s, a): + x = await c.scatter(123) + restarts = [a.restart(timeout=random.random()) for _ in range(100)] + await asyncio.gather(*restarts) + + while x.status != "cancelled": + await asyncio.sleep(0.1) + + assert await c.submit(lambda: 1) == 1 + assert len(s.workers) == 1 @gen_cluster( @@ -582,6 +604,34 @@ async def test_worker_start_exception(s): assert logs.getvalue().count("ValueError: broken") == 1, logs.getvalue() +@gen_cluster(nthreads=[]) +async def test_worker_start_exception_while_killing(s): + nanny = Nanny(s.address, worker_class=BrokenWorker) + + async def try_to_kill_nanny(): + while not nanny.process or nanny.process.status != Status.starting: + await asyncio.sleep(0) + await nanny.kill() + + kill_task = asyncio.create_task(try_to_kill_nanny()) + with captured_logger(logger="distributed.nanny", level=logging.WARNING) as logs: + with raises_with_cause( + RuntimeError, + "Nanny failed to start", + RuntimeError, + "BrokenWorker failed to start", + ): + async with nanny: + pass + await kill_task + assert nanny.status == Status.failed + # ^ NOTE: `Nanny.close` sets it to `closed`, then `Server.start._close_on_failure` sets it to `failed` + assert nanny.process is None + assert "Restarting worker" not in logs.getvalue() + # Avoid excessive spewing. (It's also printed once extra within the subprocess, which is okay.) + assert logs.getvalue().count("ValueError: broken") == 1, logs.getvalue() + + @gen_cluster(nthreads=[]) async def test_failure_during_worker_initialization(s): with captured_logger(logger="distributed.nanny", level=logging.WARNING) as logs: From 7c2134e2de9e4e0beae89f289440c2ebf0da20b3 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 4 Sep 2024 02:55:30 -0500 Subject: [PATCH 125/138] Point to user code with idempotent plugin warning (#8856) --- distributed/client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/client.py b/distributed/client.py index b13fac566a..5c3268f3cf 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -5084,6 +5084,7 @@ def register_plugin( "future version. Please mark your plugin as idempotent by setting its " "`.idempotent` attribute to `True`.", FutureWarning, + stacklevel=2, ) else: idempotent = getattr(plugin, "idempotent", False) From 50169e991aaeb2f3b5af3d91e4f8b4324aef4982 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 4 Sep 2024 11:38:14 -0500 Subject: [PATCH 126/138] Use new ``tokenize`` module (#8858) --- distributed/client.py | 3 ++- distributed/diagnostics/progress.py | 2 +- distributed/protocol/serialize.py | 4 ++-- distributed/scheduler.py | 2 +- distributed/shuffle/_merge.py | 3 ++- distributed/shuffle/_rechunk.py | 2 +- distributed/shuffle/_shuffle.py | 2 +- distributed/tests/test_client.py | 4 ++-- 8 files changed, 12 insertions(+), 10 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 5c3268f3cf..b038c0e46f 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -44,11 +44,12 @@ from tlz import first, groupby, merge, partition_all, valmap import dask -from dask.base import collections_to_dsk, tokenize +from dask.base import collections_to_dsk from dask.core import flatten, validate_key from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer from dask.optimization import SubgraphCallable +from dask.tokenize import tokenize from dask.typing import Key, NoDefault, no_default from dask.utils import ( apply, diff --git a/distributed/diagnostics/progress.py b/distributed/diagnostics/progress.py index 3fd7bbbd07..1712cc5df3 100644 --- a/distributed/diagnostics/progress.py +++ b/distributed/diagnostics/progress.py @@ -9,7 +9,7 @@ from tlz import groupby, valmap -from dask.base import tokenize +from dask.tokenize import tokenize from dask.utils import key_split from distributed.diagnostics.plugin import SchedulerPlugin diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index a2e593e06a..4ad5c3f98a 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -14,8 +14,8 @@ import msgpack import dask -from dask.base import normalize_token from dask.sizeof import sizeof +from dask.tokenize import normalize_token from dask.utils import typename from distributed.metrics import context_meter @@ -776,7 +776,7 @@ def register_serialization_lazy(toplevel, func): @partial(normalize_token.register, Serialized) def normalize_Serialized(o): - return [o.header] + o.frames # for dask.base.tokenize + return [o.header] + o.frames # for dask.tokenize.tokenize # Teach serialize how to handle bytes diff --git a/distributed/scheduler.py b/distributed/scheduler.py index fa965c6910..647f808ed8 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -54,8 +54,8 @@ import dask import dask.utils -from dask.base import TokenizationError, normalize_token, tokenize from dask.core import get_deps, iskey, validate_key +from dask.tokenize import TokenizationError, normalize_token, tokenize from dask.typing import Key, no_default from dask.utils import ( _deprecated, diff --git a/distributed/shuffle/_merge.py b/distributed/shuffle/_merge.py index eb248b4e22..0fb403a3f8 100644 --- a/distributed/shuffle/_merge.py +++ b/distributed/shuffle/_merge.py @@ -5,9 +5,10 @@ from typing import TYPE_CHECKING, Any import dask -from dask.base import is_dask_collection, tokenize +from dask.base import is_dask_collection from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer +from dask.tokenize import tokenize from distributed.shuffle._arrow import check_minimal_arrow_version from distributed.shuffle._core import ShuffleId, barrier_key, get_worker_plugin diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 20b3715085..15e3ea1d78 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -121,9 +121,9 @@ import dask import dask.config -from dask.base import tokenize from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer +from dask.tokenize import tokenize from dask.typing import Key from dask.utils import parse_bytes diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 4a21909fa0..154f25c4ba 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -21,9 +21,9 @@ from tornado.ioloop import IOLoop import dask -from dask.base import tokenize from dask.highlevelgraph import HighLevelGraph from dask.layers import Layer +from dask.tokenize import tokenize from dask.typing import Key from distributed.core import PooledRPCCall diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index ab676b51a7..0016294e70 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -43,6 +43,7 @@ import dask.bag as db from dask import delayed from dask.optimization import SubgraphCallable +from dask.tokenize import tokenize from dask.utils import get_default_shuffle_method, parse_timedelta, tmpfile from distributed import ( @@ -73,7 +74,6 @@ futures_of, get_task_metadata, temp_default_client, - tokenize, wait, ) from distributed.cluster_dump import load_cluster_dump @@ -1127,7 +1127,7 @@ async def test_scatter_non_list(c, s, a, b): @gen_cluster(client=True) async def test_scatter_tokenize_local(c, s, a, b): - from dask.base import normalize_token + from dask.tokenize import normalize_token class MyObj: pass From b28822bb8449bfe6a2799098a7c8cc88fb9ff2c0 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Thu, 5 Sep 2024 15:22:22 -0500 Subject: [PATCH 127/138] Bump ``bokeh`` minimum version to 3.1.0 (#8861) --- continuous_integration/recipes/dask/meta.yaml | 2 +- distributed/dashboard/components/scheduler.py | 23 ++++--------------- distributed/dashboard/core.py | 6 ----- distributed/dashboard/utils.py | 21 ++++++++--------- distributed/scheduler.py | 10 +++----- distributed/versions.py | 2 +- 6 files changed, 18 insertions(+), 46 deletions(-) diff --git a/continuous_integration/recipes/dask/meta.yaml b/continuous_integration/recipes/dask/meta.yaml index 721e5e345f..bd8f8be7da 100644 --- a/continuous_integration/recipes/dask/meta.yaml +++ b/continuous_integration/recipes/dask/meta.yaml @@ -32,7 +32,7 @@ requirements: - lz4 >=4.3.2 - numpy >=1.21 - pandas >=2 - - bokeh >=2.4.2,!=3.0.* + - bokeh >=3.1.0 - jinja2 >=2.10.3 - pyarrow >=7.0 diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index d994e17acd..3605945abd 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -38,6 +38,7 @@ Range1d, ResetTool, Select, + TabPanel, Tabs, TapTool, Title, @@ -73,10 +74,8 @@ ProfileTimePlot, SystemMonitor, ) -from distributed.dashboard.core import TabPanel from distributed.dashboard.utils import ( _DATATABLE_STYLESHEETS_KWARGS, - BOKEH_VERSION, PROFILING, transpose, update, @@ -2271,18 +2270,8 @@ def __init__(self, scheduler, **kwargs): self.edge_source = ColumnDataSource({"x": [], "y": [], "visible": []}) filter = GroupFilter(column_name="visible", group="True") - if BOKEH_VERSION.major < 3: - filter_kwargs = {"filters": [filter]} - else: - filter_kwargs = {"filter": filter} - node_view = CDSView(**filter_kwargs) - edge_view = CDSView(**filter_kwargs) - - # Bokeh >= 3.0 automatically infers the source to use - if BOKEH_VERSION.major < 3: - node_view.source = self.node_source - edge_view.source = self.edge_source - + node_view = CDSView(filter=filter) + edge_view = CDSView(filter=filter) node_colors = factor_cmap( "state", factors=["waiting", "queued", "processing", "memory", "released", "erred"], @@ -4515,10 +4504,6 @@ def update(self): "box-shadow": "inset 1px 0 8px 0 lightgray", "overflow": "auto", } -if BOKEH_VERSION.major < 3: - _BOKEH_STYLES_KWARGS = {"style": _STYLES} -else: - _BOKEH_STYLES_KWARGS = {"styles": _STYLES} class SchedulerLogs: @@ -4538,7 +4523,7 @@ def __init__(self, scheduler, start=None): ) )._repr_html_() - self.root = Div(text=logs_html, **_BOKEH_STYLES_KWARGS) + self.root = Div(text=logs_html, styles=_STYLES) @log_errors diff --git a/distributed/dashboard/core.py b/distributed/dashboard/core.py index 96211bb2ea..77d0ff1ba6 100644 --- a/distributed/dashboard/core.py +++ b/distributed/dashboard/core.py @@ -26,12 +26,6 @@ ) -if BOKEH_VERSION.major < 3: - from bokeh.models import Panel as TabPanel # noqa: F401 -else: - from bokeh.models import TabPanel # noqa: F401 - - if BOKEH_VERSION < parse_version("3.3.0"): from bokeh.server.server import BokehTornado as DaskBokehTornado else: diff --git a/distributed/dashboard/utils.py b/distributed/dashboard/utils.py index ce36a17ff6..03d5271e2a 100644 --- a/distributed/dashboard/utils.py +++ b/distributed/dashboard/utils.py @@ -18,18 +18,15 @@ PROFILING = False -if BOKEH_VERSION.major < 3: - _DATATABLE_STYLESHEETS_KWARGS = {} -else: - _DATATABLE_STYLESHEETS_KWARGS = { - "stylesheets": [ - """ - .bk-data-table { - z-index: 0; - } - """ - ] - } +_DATATABLE_STYLESHEETS_KWARGS = { + "stylesheets": [ + """ +.bk-data-table { +z-index: 0; +} +""" + ] +} def transpose(lod): diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 647f808ed8..64ad6d964f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -8421,17 +8421,13 @@ def profile_to_figure(state): sysmon.update() # Scheduler logs - from distributed.dashboard.components.scheduler import ( - _BOKEH_STYLES_KWARGS, - SchedulerLogs, - ) + from distributed.dashboard.components.scheduler import _STYLES, SchedulerLogs logs = SchedulerLogs(self, start=start) - from bokeh.models import Div, Tabs + from bokeh.models import Div, TabPanel, Tabs import distributed - from distributed.dashboard.core import TabPanel # HTML html = """ @@ -8472,7 +8468,7 @@ def profile_to_figure(state): dask_version=dask.__version__, distributed_version=distributed.__version__, ) - html = Div(text=html, **_BOKEH_STYLES_KWARGS) + html = Div(text=html, styles=_STYLES) html = TabPanel(child=html, title="Summary") compute = TabPanel(child=compute, title="Worker Profile (compute)") diff --git a/distributed/versions.py b/distributed/versions.py index a85ee517b1..7a3a7a5b05 100644 --- a/distributed/versions.py +++ b/distributed/versions.py @@ -14,7 +14,7 @@ from packaging.requirements import Requirement -BOKEH_REQUIREMENT = Requirement("bokeh>=2.4.2,!=3.0.*") +BOKEH_REQUIREMENT = Requirement("bokeh>=3.1.0") required_packages = [ ("dask", lambda p: p.__version__), From ec3f4eccd1b8f8541dfa3b72e363464192823a4a Mon Sep 17 00:00:00 2001 From: Mario Linker <15095261+maldag@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:48:13 +0200 Subject: [PATCH 128/138] Work/fix firewall for localhost (#8868) Co-authored-by: Hendrik Makait --- distributed/deploy/local.py | 3 +++ distributed/node.py | 9 +++++++-- distributed/tests/test_scheduler.py | 2 +- distributed/tests/test_worker.py | 2 +- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 69f5d8af35..b7f736e1ed 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -62,6 +62,9 @@ class LocalCluster(SpecCluster): 'localhost:8787' or '0.0.0.0:8787'. Defaults to ':8787'. Set to ``None`` to disable the dashboard. Use ':0' for a random port. + When specifying only a port like ':8787', the dashboard will bind to the given interface from the ``host`` parameter. + If ``host`` is empty, binding will occur on all interfaces '0.0.0.0'. + To avoid firewall issues when deploying locally, set ``host`` to 'localhost'. worker_dashboard_address: str Address on which to listen for the Bokeh worker diagnostics server like 'localhost:8787' or '0.0.0.0:8787'. Defaults to None which disables the dashboard. diff --git a/distributed/node.py b/distributed/node.py index d41a340985..56d0a1882c 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -143,9 +143,14 @@ def start_http_server( self.http_server = HTTPServer(self.http_application, ssl_options=ssl_options) http_addresses = clean_dashboard_address(dashboard_address or default_port) - for http_address in http_addresses: - if http_address["address"] is None: + # Handle default case for dashboard address + # In case dashboard_address is given, e.g. ":8787" + # the address is empty and it is intended to listen to all interfaces + if dashboard_address is not None and http_address["address"] == "": + http_address["address"] = "0.0.0.0" + + if http_address["address"] is None or http_address["address"] == "": address = self._start_address if isinstance(address, (list, tuple)): address = address[0] diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index cfb6fdefdb..dada3e3fc9 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1969,7 +1969,7 @@ async def test_scheduler_file(): @pytest.mark.parametrize( "dashboard_address,expect", [ - (None, ("::", "0.0.0.0")), + (None, ("::", "0.0.0.0", "127.0.0.1")), ("127.0.0.1:0", ("127.0.0.1",)), ], ) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 8bfcd8347b..0a2fed3219 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1115,7 +1115,7 @@ async def test_service_hosts_match_worker(s): async with Worker(s.address, host="tcp://127.0.0.1") as w: sock = first(w.http_server._sockets.values()) - assert sock.getsockname()[0] in ("::", "0.0.0.0") + assert sock.getsockname()[0] in ("::", "127.0.0.1") # See what happens with e.g. `dask worker --listen-address tcp://:8811` async with Worker(s.address, host="") as w: From 4f3ac263919bd442f666eacaa8388b3d9ebe7781 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 13 Sep 2024 22:51:58 +0200 Subject: [PATCH 129/138] Homogeneously schedule P2P's unpack tasks (#8873) --- distributed/scheduler.py | 18 ++++-------- distributed/shuffle/_scheduler_plugin.py | 2 +- distributed/shuffle/tests/test_rechunk.py | 36 +++++++++++++++++++++++ distributed/shuffle/tests/test_shuffle.py | 27 ----------------- distributed/tests/test_scheduler.py | 24 --------------- 5 files changed, 42 insertions(+), 65 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 64ad6d964f..ab5b28fffc 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1425,14 +1425,8 @@ class TaskState: #: be rejected. run_id: int | None - #: Whether to consider this task rootish in the context of task queueing - #: True - #: Always consider this task rootish - #: False - #: Never consider this task rootish - #: None - #: Use a heuristic to determine whether this task should be considered rootish - _rootish: bool | None + #: Whether to allow queueing this task if it is rootish + _queueable: bool #: Cached hash of :attr:`~TaskState.client_key` _hash: int @@ -1489,7 +1483,7 @@ def __init__( self.metadata = None self.annotations = None self.erred_on = None - self._rootish = None + self._queueable = True self.run_id = None self.group = group group.add(self) @@ -2286,7 +2280,7 @@ def decide_worker_rootish_queuing_disabled( """ if self.validate: # See root-ish-ness note below in `decide_worker_rootish_queuing_enabled` - assert math.isinf(self.WORKER_SATURATION) + assert math.isinf(self.WORKER_SATURATION) or not ts._queueable pool = self.idle.values() if self.idle else self.running if not pool: @@ -2452,7 +2446,7 @@ def _transition_waiting_processing(self, key: Key, stimulus_id: str) -> RecsMsgs # removed, there should only be one, which combines co-assignment and # queuing. Eventually, special-casing root tasks might be removed entirely, # with better heuristics. - if math.isinf(self.WORKER_SATURATION): + if math.isinf(self.WORKER_SATURATION) or not ts._queueable: if not (ws := self.decide_worker_rootish_queuing_disabled(ts)): return {ts.key: "no-worker"}, {}, {} else: @@ -3090,8 +3084,6 @@ def is_rootish(self, ts: TaskState) -> bool: and have few or no dependencies. Tasks may also be explicitly marked as rootish to override this heuristic. """ - if ts._rootish is not None: - return ts._rootish if ts.resource_restrictions or ts.worker_restrictions or ts.host_restrictions: return False tg = ts.group diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index 5f474c0cfd..0cf1c3338e 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -300,7 +300,7 @@ def _ensure_output_tasks_are_non_rootish(self, spec: ShuffleSpec) -> None: """ barrier = self.scheduler.tasks[barrier_key(spec.id)] for dependent in barrier.dependents: - dependent._rootish = False + dependent._queueable = False @log_errors() def _set_restriction(self, ts: TaskState, worker: str) -> None: diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index 803a33fe6c..bf55b45457 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -4,9 +4,12 @@ import math import random import warnings +from collections import defaultdict import pytest +from distributed.diagnostics.plugin import SchedulerPlugin + np = pytest.importorskip("numpy") da = pytest.importorskip("dask.array") @@ -1488,3 +1491,36 @@ def test_calculate_prechunking_splitting(old, new, expected): # _calculate_prechunking does not concatenate on object actual = _calculate_prechunking(old, new, np.dtype(object), None) assert actual == expected + + +@gen_cluster(client=True, nthreads=[("", 1)] * 4, config={"array.chunk-size": "1 B"}) +async def test_homogeneously_schedule_unpack(c, s, *ws): + class SchedulingTrackerPlugin(SchedulerPlugin): + async def start(self, scheduler): + self.scheduler = scheduler + self.counts = defaultdict(int) + self.seen = set() + + def transition(self, key, start, finish, *args, stimulus_id, **kwargs): + if key in self.seen: + return + + if not isinstance(key, tuple) or not isinstance(key[0], str): + return + + if not key[0].startswith("rechunk-p2p"): + return + + if start != "waiting" or finish != "processing": + return + + self.seen.add(key) + self.counts[self.scheduler.tasks[key].processing_on.address] += 1 + + await c.register_plugin(SchedulingTrackerPlugin(), name="tracker") + res = da.random.random((100, 100), chunks=(1, -1)).rechunk((-1, 1)) + await c.compute(res) + counts = s.plugins["tracker"].counts + min_count = min(counts.values()) + max_count = max(counts.values()) + assert min_count >= max_count, counts diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 0efff85f68..23e61eef2d 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -2685,33 +2685,6 @@ async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None: return await super().barrier(id, run_id, consistent) -@gen_cluster(client=True) -async def test_unpack_is_non_rootish(c, s, a, b): - with pytest.warns(UserWarning): - scheduler_plugin = BlockedBarrierShuffleSchedulerPlugin(s) - df = dask.datasets.timeseries( - start="2000-01-01", - end="2000-01-21", - dtypes={"x": float, "y": float}, - freq="10 s", - ) - df = df.shuffle("x") - result = c.compute(df) - - await scheduler_plugin.in_barrier.wait() - - unpack_tss = [ts for key, ts in s.tasks.items() if key_split(key) == UNPACK_PREFIX] - assert len(unpack_tss) == 20 - assert not any(s.is_rootish(ts) for ts in unpack_tss) - del unpack_tss - scheduler_plugin.block_barrier.set() - result = await result - - await assert_worker_cleanup(a) - await assert_worker_cleanup(b) - await assert_scheduler_cleanup(s) - - class FlakyConnectionPool(ConnectionPool): def __init__(self, *args, failing_connects=0, **kwargs): self.attempts = 0 diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index dada3e3fc9..a5c3e22b63 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -284,30 +284,6 @@ def random(**kwargs): test_decide_worker_coschedule_order_neighbors_() -@gen_cluster( - client=True, - nthreads=[], -) -async def test_override_is_rootish(c, s): - x = c.submit(lambda x: x + 1, 1, key="x") - await async_poll_for(lambda: "x" in s.tasks, timeout=5) - ts_x = s.tasks["x"] - assert ts_x._rootish is None - assert s.is_rootish(ts_x) - - ts_x._rootish = False - assert not s.is_rootish(ts_x) - - y = c.submit(lambda y: y + 1, 1, key="y", workers=["not-existing"]) - await async_poll_for(lambda: "y" in s.tasks, timeout=5) - ts_y = s.tasks["y"] - assert ts_y._rootish is None - assert not s.is_rootish(ts_y) - - ts_y._rootish = True - assert s.is_rootish(ts_y) - - @pytest.mark.skipif( QUEUING_ON_BY_DEFAULT, reason="Not relevant with queuing on; see https://github.com/dask/distributed/issues/7204", From 8bafad5450592d9a647080589af1ca3072cc5341 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 13 Sep 2024 16:21:33 -0500 Subject: [PATCH 130/138] bump version to 2024.9.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 15b878554d..e5d68b635c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ requires-python = ">=3.10" dependencies = [ "click >= 8.0", "cloudpickle >= 3.0.0", - "dask == 2024.8.2", + "dask == 2024.9.0", "jinja2 >= 2.10.3", "locket >= 1.0.0", "msgpack >= 1.0.2", From 80b3af58a2e9c279221a570651c252221a8fc78c Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 17 Sep 2024 12:32:55 -0500 Subject: [PATCH 131/138] Support P2P rechunking datetime arrays (#8875) --- distributed/shuffle/_core.py | 7 +++++++ distributed/shuffle/tests/test_rechunk.py | 12 ++++++++++++ 2 files changed, 19 insertions(+) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index c8ae80a75e..4eaea179df 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -526,6 +526,12 @@ def handle_unpack_errors(id: ShuffleId) -> Iterator[None]: raise RuntimeError(f"P2P shuffling {id} failed during unpack phase") from e +def _handle_datetime(buf: Any) -> Any: + if hasattr(buf, "dtype") and buf.dtype.kind in "Mm": + return buf.view("u8") + return buf + + def _mean_shard_size(shards: Iterable) -> int: """Return estimated mean size in bytes of each shard""" size = 0 @@ -534,6 +540,7 @@ def _mean_shard_size(shards: Iterable) -> int: if not isinstance(shard, int): # This also asserts that shard is a Buffer and that we didn't forget # a container or metadata type above + shard = _handle_datetime(shard) size += memoryview(shard).nbytes count += 1 if count == 10: diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index bf55b45457..f2cd8564cc 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -1524,3 +1524,15 @@ def transition(self, key, start, finish, *args, stimulus_id, **kwargs): min_count = min(counts.values()) max_count = max(counts.values()) assert min_count >= max_count, counts + + +@pytest.mark.parametrize("method", ["tasks", "p2p"]) +@gen_cluster(client=True) +async def test_rechunk_datetime(c, s, *ws, method): + pd = pytest.importorskip("pandas") + + x = pd.date_range("2005-01-01", "2005-01-10").to_numpy(dtype="datetime64[ns]") + dx = da.from_array(x, chunks=10) + result = dx.rechunk(2, method=method) + result = await c.compute(result) + np.testing.assert_array_equal(x, result) From 42168f7b64bba893647d171f633735501d87fa8b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Sep 2024 16:30:15 -0500 Subject: [PATCH 132/138] Bump jacobtomlinson/gha-anaconda-package-version from 0.1.3 to 0.1.4 (#8878) --- .github/workflows/update-gpuci.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/update-gpuci.yaml b/.github/workflows/update-gpuci.yaml index 90c25c16a5..d88ae20be1 100644 --- a/.github/workflows/update-gpuci.yaml +++ b/.github/workflows/update-gpuci.yaml @@ -21,7 +21,7 @@ jobs: - name: Get latest cuDF nightly version id: cudf_latest - uses: jacobtomlinson/gha-anaconda-package-version@0.1.3 + uses: jacobtomlinson/gha-anaconda-package-version@0.1.4 with: org: "rapidsai-nightly" package: "cudf" @@ -29,7 +29,7 @@ jobs: - name: Get latest UCX-Py nightly version id: ucx_py_latest - uses: jacobtomlinson/gha-anaconda-package-version@0.1.3 + uses: jacobtomlinson/gha-anaconda-package-version@0.1.4 with: org: "rapidsai-nightly" package: "ucx-py" From f87032513f27cf61a75baf82f0254ed1e7416164 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 24 Sep 2024 20:45:36 +0200 Subject: [PATCH 133/138] Don't consider scheduler idle while executing ``Scheduler.update_graph`` (#8877) Co-authored-by: James Bourbeau --- distributed/scheduler.py | 14 ++++++++++++-- distributed/tests/test_scheduler.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ab5b28fffc..d1f18d7c2f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3686,6 +3686,7 @@ class Scheduler(SchedulerState, ServerNode): _client_connections_removed_total: int _workers_added_total: int _workers_removed_total: int + _active_graph_updates: int def __init__( self, @@ -4049,6 +4050,7 @@ async def post(self): self._client_connections_removed_total = 0 self._workers_added_total = 0 self._workers_removed_total = 0 + self._active_graph_updates = 0 ################## # Administration # @@ -4841,6 +4843,7 @@ async def update_graph( stimulus_id: str | None = None, ) -> None: start = time() + self._active_graph_updates += 1 try: try: graph = deserialize(graph_header, graph_frames).data @@ -4913,8 +4916,11 @@ async def update_graph( # (which may not have been added to who_wants yet) client=client, ) - end = time() - self.digest_metric("update-graph-duration", end - start) + finally: + self._active_graph_updates -= 1 + assert self._active_graph_updates >= 0 + end = time() + self.digest_metric("update-graph-duration", end - start) def _generate_taskstates( self, @@ -8607,6 +8613,10 @@ def check_idle(self) -> float | None: self.idle_since = None return None + if self._active_graph_updates > 0: + self.idle_since = None + return None + if ( self.queued or self.unrunnable diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index a5c3e22b63..cf3b8b8d4f 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2389,6 +2389,34 @@ async def test_idle_timeout(c, s, a, b): pc.stop() +@gen_cluster(client=True) +async def test_idle_during_update_graph(c, s, a, b): + class UpdateGraphTrackerPlugin(SchedulerPlugin): + def start(self, scheduler): + self.scheduler = scheduler + self.idle_during_update_graph = None + + def update_graph(self, *args, **kwargs): + self.idle_during_update_graph = self.scheduler.check_idle() is not None + + await c.register_plugin(UpdateGraphTrackerPlugin(), name="tracker") + plugin = s.plugins["tracker"] + # The cluster is idle because no work ever existed + assert s.check_idle() is not None + beginning = time() + assert s.idle_since < beginning + await c.submit(lambda x: x, 1) + # The cluster may be considered not idle because of the unit of work + s.check_idle() + # Now the cluster must be idle + assert s.check_idle() is not None + end = time() + assert beginning <= s.idle_since + assert s.idle_since <= end + # Ensure the cluster isn't idle while `Scheduler.update_graph` was being run + assert plugin.idle_during_update_graph is False + + @gen_cluster(client=True, nthreads=[]) async def test_idle_timeout_no_workers(c, s): """Test that idle-timeout is not triggered if there are no workers available From f7b7f176786857b04df3384fc3b5f25cbbf65fa5 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 26 Sep 2024 11:35:24 -0400 Subject: [PATCH 134/138] Update gpuCI `RAPIDS_VER` to `24.12` (#8879) --- continuous_integration/gpuci/axis.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/continuous_integration/gpuci/axis.yaml b/continuous_integration/gpuci/axis.yaml index 7b30d50651..41fca92d9f 100644 --- a/continuous_integration/gpuci/axis.yaml +++ b/continuous_integration/gpuci/axis.yaml @@ -9,6 +9,6 @@ LINUX_VER: - ubuntu20.04 RAPIDS_VER: -- "24.10" +- "24.12" excludes: From e52da4609e88e5b558582a6761efc32da205a452 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Fri, 27 Sep 2024 13:00:43 +0200 Subject: [PATCH 135/138] Don't stop Adaptive on error (#8871) --- distributed/deploy/adaptive.py | 115 ++++++- distributed/deploy/adaptive_core.py | 116 ++----- distributed/deploy/tests/test_adaptive.py | 293 +++++++++++++++--- .../deploy/tests/test_adaptive_core.py | 136 ++------ 4 files changed, 387 insertions(+), 273 deletions(-) diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 1638659db4..dd4e411f5f 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -1,20 +1,39 @@ from __future__ import annotations import logging +from collections.abc import Hashable +from datetime import timedelta from inspect import isawaitable +from typing import TYPE_CHECKING, Any, Callable, Literal, cast from tornado.ioloop import IOLoop import dask.config from dask.utils import parse_timedelta +from distributed.compatibility import PeriodicCallback +from distributed.core import Status from distributed.deploy.adaptive_core import AdaptiveCore from distributed.protocol import pickle from distributed.utils import log_errors +if TYPE_CHECKING: + from typing_extensions import TypeAlias + + from distributed.deploy.cluster import Cluster + from distributed.scheduler import WorkerState + logger = logging.getLogger(__name__) +AdaptiveStateState: TypeAlias = Literal[ + "starting", + "running", + "stopped", + "inactive", +] + + class Adaptive(AdaptiveCore): ''' Adaptively allocate workers based on scheduler load. A superclass. @@ -81,16 +100,21 @@ class Adaptive(AdaptiveCore): specified in the dask config under the distributed.adaptive key. ''' + interval: float | None + periodic_callback: PeriodicCallback | None + #: Whether this adaptive strategy is periodically adapting + state: AdaptiveStateState + def __init__( self, - cluster=None, - interval=None, - minimum=None, - maximum=None, - wait_count=None, - target_duration=None, - worker_key=None, - **kwargs, + cluster: Cluster, + interval: str | float | timedelta | None = None, + minimum: int | None = None, + maximum: int | float | None = None, + wait_count: int | None = None, + target_duration: str | float | timedelta | None = None, + worker_key: Callable[[WorkerState], Hashable] | None = None, + **kwargs: Any, ): self.cluster = cluster self.worker_key = worker_key @@ -99,20 +123,78 @@ def __init__( if interval is None: interval = dask.config.get("distributed.adaptive.interval") if minimum is None: - minimum = dask.config.get("distributed.adaptive.minimum") + minimum = cast(int, dask.config.get("distributed.adaptive.minimum")) if maximum is None: - maximum = dask.config.get("distributed.adaptive.maximum") + maximum = cast(float, dask.config.get("distributed.adaptive.maximum")) if wait_count is None: - wait_count = dask.config.get("distributed.adaptive.wait-count") + wait_count = cast(int, dask.config.get("distributed.adaptive.wait-count")) if target_duration is None: - target_duration = dask.config.get("distributed.adaptive.target-duration") + target_duration = cast( + str, dask.config.get("distributed.adaptive.target-duration") + ) + + self.interval = parse_timedelta(interval, "seconds") + self.periodic_callback = None + + if self.interval and self.cluster: + import weakref + + self_ref = weakref.ref(self) + + async def _adapt(): + adaptive = self_ref() + if not adaptive or adaptive.state != "running": + return + if adaptive.cluster.status != Status.running: + adaptive.stop(reason="cluster-not-running") + return + try: + await adaptive.adapt() + except Exception: + logger.warning( + "Adaptive encountered an error while adapting", exc_info=True + ) + + self.periodic_callback = PeriodicCallback(_adapt, self.interval * 1000) + self.state = "starting" + self.loop.add_callback(self._start) + else: + self.state = "inactive" self.target_duration = parse_timedelta(target_duration) - super().__init__( - minimum=minimum, maximum=maximum, wait_count=wait_count, interval=interval + super().__init__(minimum=minimum, maximum=maximum, wait_count=wait_count) + + def _start(self) -> None: + if self.state != "starting": + return + + assert self.periodic_callback is not None + self.periodic_callback.start() + self.state = "running" + logger.info( + "Adaptive scaling started: minimum=%s maximum=%s", + self.minimum, + self.maximum, ) + def stop(self, reason: str = "unknown") -> None: + if self.state in ("inactive", "stopped"): + return + + if self.state == "running": + assert self.periodic_callback is not None + self.periodic_callback.stop() + logger.info( + "Adaptive scaling stopped: minimum=%s maximum=%s. Reason: %s", + self.minimum, + self.maximum, + reason, + ) + + self.periodic_callback = None + self.state = "stopped" + @property def scheduler(self): return self.cluster.scheduler_comm @@ -210,6 +292,9 @@ async def scale_up(self, n): def loop(self) -> IOLoop: """Override Adaptive.loop""" if self.cluster: - return self.cluster.loop + return self.cluster.loop # type: ignore[return-value] else: return IOLoop.current() + + def __del__(self): + self.stop(reason="adaptive-deleted") diff --git a/distributed/deploy/adaptive_core.py b/distributed/deploy/adaptive_core.py index ccb81008cf..128353d9cf 100644 --- a/distributed/deploy/adaptive_core.py +++ b/distributed/deploy/adaptive_core.py @@ -2,37 +2,24 @@ import logging import math +from abc import ABC, abstractmethod from collections import defaultdict, deque from collections.abc import Iterable -from datetime import timedelta -from typing import TYPE_CHECKING, Literal, cast +from typing import TYPE_CHECKING, cast import tlz as toolz -from tornado.ioloop import IOLoop import dask.config -from dask.utils import parse_timedelta -from distributed.compatibility import PeriodicCallback from distributed.metrics import time if TYPE_CHECKING: - from typing_extensions import TypeAlias - from distributed.scheduler import WorkerState logger = logging.getLogger(__name__) -AdaptiveStateState: TypeAlias = Literal[ - "starting", - "running", - "stopped", - "inactive", -] - - -class AdaptiveCore: +class AdaptiveCore(ABC): """ The core logic for adaptive deployments, with none of the cluster details @@ -91,54 +78,22 @@ class AdaptiveCore: minimum: int maximum: int | float wait_count: int - interval: int | float - periodic_callback: PeriodicCallback | None - plan: set[WorkerState] - requested: set[WorkerState] - observed: set[WorkerState] close_counts: defaultdict[WorkerState, int] - _adapting: bool - #: Whether this adaptive strategy is periodically adapting - _state: AdaptiveStateState log: deque[tuple[float, dict]] + _adapting: bool def __init__( self, minimum: int = 0, maximum: int | float = math.inf, wait_count: int = 3, - interval: str | int | float | timedelta = "1s", ): if not isinstance(maximum, int) and not math.isinf(maximum): - raise TypeError(f"maximum must be int or inf; got {maximum}") + raise ValueError(f"maximum must be int or inf; got {maximum}") self.minimum = minimum self.maximum = maximum self.wait_count = wait_count - self.interval = parse_timedelta(interval, "seconds") - self.periodic_callback = None - - if self.interval: - import weakref - - self_ref = weakref.ref(self) - - async def _adapt(): - core = self_ref() - if core: - await core.adapt() - - self.periodic_callback = PeriodicCallback(_adapt, self.interval * 1000) - self._state = "starting" - self.loop.add_callback(self._start) - else: - self._state = "inactive" - try: - self.plan = set() - self.requested = set() - self.observed = set() - except Exception: - pass # internal state self.close_counts = defaultdict(int) @@ -147,38 +102,22 @@ async def _adapt(): maxlen=dask.config.get("distributed.admin.low-level-log-length") ) - def _start(self) -> None: - if self._state != "starting": - return - - assert self.periodic_callback is not None - self.periodic_callback.start() - self._state = "running" - logger.info( - "Adaptive scaling started: minimum=%s maximum=%s", - self.minimum, - self.maximum, - ) - - def stop(self) -> None: - if self._state in ("inactive", "stopped"): - return + @property + @abstractmethod + def plan(self) -> set[WorkerState]: ... - if self._state == "running": - assert self.periodic_callback is not None - self.periodic_callback.stop() - logger.info( - "Adaptive scaling stopped: minimum=%s maximum=%s", - self.minimum, - self.maximum, - ) + @property + @abstractmethod + def requested(self) -> set[WorkerState]: ... - self.periodic_callback = None - self._state = "stopped" + @property + @abstractmethod + def observed(self) -> set[WorkerState]: ... + @abstractmethod async def target(self) -> int: """The target number of workers that should exist""" - raise NotImplementedError() + ... async def workers_to_close(self, target: int) -> list: """ @@ -198,11 +137,11 @@ async def safe_target(self) -> int: return n - async def scale_down(self, n: int) -> None: - raise NotImplementedError() + @abstractmethod + async def scale_down(self, n: int) -> None: ... - async def scale_up(self, workers: Iterable) -> None: - raise NotImplementedError() + @abstractmethod + async def scale_up(self, workers: Iterable) -> None: ... async def recommendations(self, target: int) -> dict: """ @@ -270,20 +209,5 @@ async def adapt(self) -> None: await self.scale_up(**recommendations) if status == "down": await self.scale_down(**recommendations) - except OSError: - if status != "down": - logger.error("Adaptive stopping due to error", exc_info=True) - self.stop() - else: - logger.error( - "Error during adaptive downscaling. Ignoring.", exc_info=True - ) finally: self._adapting = False - - def __del__(self): - self.stop() - - @property - def loop(self) -> IOLoop: - return IOLoop.current() diff --git a/distributed/deploy/tests/test_adaptive.py b/distributed/deploy/tests/test_adaptive.py index a71fdfb298..441c1c609d 100644 --- a/distributed/deploy/tests/test_adaptive.py +++ b/distributed/deploy/tests/test_adaptive.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging import math from time import sleep @@ -17,8 +18,16 @@ Worker, wait, ) +from distributed.core import Status +from distributed.deploy.cluster import Cluster from distributed.metrics import time -from distributed.utils_test import async_poll_for, gen_cluster, gen_test, slowinc +from distributed.utils_test import ( + async_poll_for, + captured_logger, + gen_cluster, + gen_test, + slowinc, +) def test_adaptive_local_cluster(loop): @@ -80,39 +89,6 @@ async def test_adaptive_local_cluster_multi_workers(): await c.gather(futures) -@pytest.mark.xfail(reason="changed API") -@gen_test() -async def test_adaptive_scale_down_override(): - class TestAdaptive(Adaptive): - def __init__(self, *args, **kwargs): - self.min_size = kwargs.pop("min_size", 0) - super().__init__(*args, **kwargs) - - async def workers_to_close(self, **kwargs): - num_workers = len(self.cluster.workers) - to_close = await self.scheduler.workers_to_close(**kwargs) - if num_workers - len(to_close) < self.min_size: - to_close = to_close[: num_workers - self.min_size] - - return to_close - - class TestCluster(LocalCluster): - def scale_up(self, n, **kwargs): - assert False - - async with TestCluster( - n_workers=10, processes=False, asynchronous=True, dashboard_address=":0" - ) as cluster: - ta = cluster.adapt( - min_size=2, interval=0.1, scale_factor=2, Adaptive=TestAdaptive - ) - await asyncio.sleep(0.3) - - # Assert that adaptive cycle does not reduce cluster below minimum size - # as determined via override. - assert len(cluster.scheduler.workers) == 2 - - @gen_test() async def test_min_max(): async with LocalCluster( @@ -400,17 +376,23 @@ async def test_adapt_cores_memory(): @gen_test() async def test_adaptive_config(): - with dask.config.set( - {"distributed.adaptive.minimum": 10, "distributed.adaptive.wait-count": 8} - ): - try: - adapt = Adaptive(interval="5s") - assert adapt.minimum == 10 - assert adapt.maximum == math.inf - assert adapt.interval == 5 - assert adapt.wait_count == 8 - finally: - adapt.stop() + async with LocalCluster( + n_workers=0, + asynchronous=True, + silence_logs=False, + dashboard_address=":0", + ) as cluster: + with dask.config.set( + {"distributed.adaptive.minimum": 10, "distributed.adaptive.wait-count": 8} + ): + try: + adapt = Adaptive(cluster, interval="5s") + assert adapt.minimum == 10 + assert adapt.maximum == math.inf + assert adapt.interval == 5 + assert adapt.wait_count == 8 + finally: + adapt.stop() @gen_test() @@ -427,6 +409,8 @@ async def test_update_adaptive(): first = cluster.adapt(maximum=1) second = cluster.adapt(maximum=2) await asyncio.sleep(0.2) + assert first.state == "stopped" + assert second.state == "running" assert first.periodic_callback is None assert second.periodic_callback.is_running() @@ -454,6 +438,19 @@ async def test_adaptive_no_memory_limit(): ) +@gen_test() +async def test_adapt_gets_stopped_on_cluster_close(): + class MyCluster(Cluster): + pass + + async with MyCluster(asynchronous=True) as cluster: + adapt = cluster.adapt(minimum=1, maximum=10, interval="10ms") + while adapt.state != "running": + await asyncio.sleep(0.01) + await cluster.close() + assert adapt.state == "stopped" + + @gen_test() async def test_scale_needs_to_be_awaited(): """ @@ -495,13 +492,12 @@ async def test_adaptive_stopped(): n_workers=0, asynchronous=True, dashboard_address=":0" ) as cluster: instance = cluster.adapt(interval="10ms") + await async_poll_for(lambda: instance.state == "running", timeout=5) assert instance.periodic_callback is not None - - await async_poll_for(lambda: instance.periodic_callback.is_running(), timeout=5) - + assert instance.periodic_callback.is_running() pc = instance.periodic_callback - - await async_poll_for(lambda: not pc.is_running(), timeout=5) + await async_poll_for(lambda: instance.state == "stopped", timeout=5) + assert not pc.is_running() @pytest.mark.parametrize("saturation", [1, float("inf")]) @@ -544,3 +540,200 @@ async def test_respect_average_nthreads(c, s, w): await asyncio.sleep(0.001) assert s.adaptive_target() == 40 + + +class MyAdaptive(Adaptive): + def __init__(self, *args, interval=None, **kwargs): + super().__init__(*args, interval=interval, **kwargs) + self._target = 0 + self._log = [] + self._observed = set() + self._plan = set() + self._requested = set() + + @property + def observed(self): + return self._observed + + @property + def plan(self): + return self._plan + + @property + def requested(self): + return self._requested + + async def target(self): + return self._target + + async def scale_up(self, n=0): + self._plan = self._requested = set(range(n)) + + async def scale_down(self, workers=()): + for collection in [self.plan, self.requested, self.observed]: + for w in workers: + collection.discard(w) + + +@gen_test() +async def test_adaptive_stops_on_cluster_status_change(): + async with LocalCluster( + n_workers=0, + asynchronous=True, + silence_logs=False, + dashboard_address=":0", + ) as cluster: + adapt = Adaptive(cluster, interval="100 ms") + assert adapt.state == "starting" + await async_poll_for(lambda: adapt.state == "running", timeout=5) + + assert adapt.periodic_callback + assert adapt.periodic_callback.is_running() + + try: + cluster.status = Status.closing + + await async_poll_for(lambda: adapt.state != "running", timeout=5) + assert adapt.state == "stopped" + assert not adapt.periodic_callback + finally: + # Set back to running to let normal shutdown do its thing + cluster.status = Status.running + + +@gen_test() +async def test_interval(): + async with LocalCluster( + n_workers=0, + asynchronous=True, + silence_logs=False, + dashboard_address=":0", + ) as cluster: + adapt = MyAdaptive(cluster=cluster, interval="100 ms") + assert not adapt.plan + + for i in [0, 3, 1]: + start = time() + adapt._target = i + while len(adapt.plan) != i: + await asyncio.sleep(0.01) + assert time() < start + 2 + + adapt.stop() + await asyncio.sleep(0.05) + + adapt._target = 10 + await asyncio.sleep(0.02) + assert len(adapt.plan) == 1 # last value from before, unchanged + + +@gen_test() +async def test_adapt_logs_error_in_safe_target(): + class BadAdaptive(MyAdaptive): + """Adaptive subclass which raises an OSError when attempting to adapt + + We use this to check that error handling works properly + """ + + def safe_target(self): + raise OSError() + + async with LocalCluster( + n_workers=0, + asynchronous=True, + silence_logs=False, + dashboard_address=":0", + ) as cluster: + with captured_logger( + "distributed.deploy.adaptive", level=logging.WARNING + ) as log: + adapt = cluster.adapt( + Adaptive=BadAdaptive, minimum=1, maximum=4, interval="10ms" + ) + while "encountered an error" not in log.getvalue(): + await asyncio.sleep(0.01) + assert "stop" not in log.getvalue() + assert adapt.state == "running" + assert adapt.periodic_callback + assert adapt.periodic_callback.is_running() + + +@gen_test() +async def test_adapt_callback_logs_error_in_scale_down(): + class BadAdaptive(MyAdaptive): + async def scale_down(self, workers=None): + raise OSError() + + async with LocalCluster( + n_workers=0, + asynchronous=True, + silence_logs=False, + dashboard_address=":0", + ) as cluster: + adapt = cluster.adapt( + Adaptive=BadAdaptive, minimum=1, maximum=4, wait_count=0, interval="10ms" + ) + adapt._target = 2 + await async_poll_for(lambda: adapt.state == "running", timeout=5) + assert adapt.periodic_callback.is_running() + await adapt.adapt() + assert len(adapt.plan) == 2 + assert len(adapt.requested) == 2 + with captured_logger( + "distributed.deploy.adaptive", level=logging.WARNING + ) as log: + adapt._target = 0 + while "encountered an error" not in log.getvalue(): + await asyncio.sleep(0.01) + assert "stop" not in log.getvalue() + assert not adapt._adapting + assert adapt.periodic_callback + assert adapt.periodic_callback.is_running() + + +@pytest.mark.parametrize("wait_until_running", [True, False]) +@gen_test() +async def test_adaptive_logs_stopping_once(wait_until_running): + async with LocalCluster( + n_workers=0, + asynchronous=True, + silence_logs=False, + dashboard_address=":0", + ) as cluster: + with captured_logger("distributed.deploy.adaptive") as log: + adapt = cluster.adapt(Adaptive=MyAdaptive, interval="100ms") + if wait_until_running: + await async_poll_for(lambda: adapt.state == "running", timeout=5) + assert adapt.periodic_callback + assert adapt.periodic_callback.is_running() + pc = adapt.periodic_callback + else: + assert adapt.periodic_callback + assert not adapt.periodic_callback.is_running() + pc = adapt.periodic_callback + + adapt.stop() + adapt.stop() + assert adapt.state == "stopped" + assert not adapt.periodic_callback + assert not pc.is_running() + lines = log.getvalue().splitlines() + assert sum("Adaptive scaling stopped" in line for line in lines) == 1 + + +@gen_test() +async def test_adapt_stop_del(): + async with LocalCluster( + n_workers=0, + asynchronous=True, + silence_logs=False, + dashboard_address=":0", + ) as cluster: + adapt = cluster.adapt(Adaptive=MyAdaptive, interval="100ms") + pc = adapt.periodic_callback + await async_poll_for(lambda: adapt.state == "running", timeout=5) # noqa: F821 + + # Remove reference of adaptive object from cluster + cluster._adaptive = None + del adapt + await async_poll_for(lambda: not pc.is_running(), timeout=5) diff --git a/distributed/deploy/tests/test_adaptive_core.py b/distributed/deploy/tests/test_adaptive_core.py index b5cfd734ab..cc2336e76a 100644 --- a/distributed/deploy/tests/test_adaptive_core.py +++ b/distributed/deploy/tests/test_adaptive_core.py @@ -1,23 +1,35 @@ from __future__ import annotations -import asyncio - from distributed.deploy.adaptive_core import AdaptiveCore -from distributed.metrics import time -from distributed.utils_test import captured_logger, gen_test +from distributed.utils_test import gen_test -class MyAdaptive(AdaptiveCore): - def __init__(self, *args, interval=None, **kwargs): - super().__init__(*args, interval=interval, **kwargs) +class MyAdaptiveCore(AdaptiveCore): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._observed = set() + self._plan = set() + self._requested = set() self._target = 0 self._log = [] + @property + def observed(self): + return self._observed + + @property + def plan(self): + return self._plan + + @property + def requested(self): + return self._requested + async def target(self): return self._target async def scale_up(self, n=0): - self.plan = self.requested = set(range(n)) + self._plan = self._requested = set(range(n)) async def scale_down(self, workers=()): for collection in [self.plan, self.requested, self.observed]: @@ -27,7 +39,7 @@ async def scale_down(self, workers=()): @gen_test() async def test_safe_target(): - adapt = MyAdaptive(minimum=1, maximum=4) + adapt = MyAdaptiveCore(minimum=1, maximum=4) assert await adapt.safe_target() == 1 adapt._target = 10 assert await adapt.safe_target() == 4 @@ -35,7 +47,7 @@ async def test_safe_target(): @gen_test() async def test_scale_up(): - adapt = MyAdaptive(minimum=1, maximum=4) + adapt = MyAdaptiveCore(minimum=1, maximum=4) await adapt.adapt() assert adapt.log[-1][1] == {"status": "up", "n": 1} assert adapt.plan == {0} @@ -48,12 +60,12 @@ async def test_scale_up(): @gen_test() async def test_scale_down(): - adapt = MyAdaptive(minimum=1, maximum=4, wait_count=2) + adapt = MyAdaptiveCore(minimum=1, maximum=4, wait_count=2) adapt._target = 10 await adapt.adapt() assert len(adapt.log) == 1 - adapt.observed = {0, 1, 3} # all but 2 have arrived + adapt._observed = {0, 1, 3} # all but 2 have arrived adapt._target = 2 await adapt.adapt() @@ -70,103 +82,3 @@ async def test_scale_down(): await adapt.adapt() await adapt.adapt() assert list(adapt.log) == old - - -@gen_test() -async def test_interval(): - adapt = MyAdaptive(interval="5 ms") - assert not adapt.plan - - for i in [0, 3, 1]: - start = time() - adapt._target = i - while len(adapt.plan) != i: - await asyncio.sleep(0.001) - assert time() < start + 2 - - adapt.stop() - await asyncio.sleep(0.05) - - adapt._target = 10 - await asyncio.sleep(0.02) - assert len(adapt.plan) == 1 # last value from before, unchanged - - -@gen_test() -async def test_adapt_oserror_safe_target(): - class BadAdaptive(MyAdaptive): - """AdaptiveCore subclass which raises an OSError when attempting to adapt - - We use this to check that error handling works properly - """ - - def safe_target(self): - raise OSError() - - with captured_logger("distributed.deploy.adaptive_core") as log: - adapt = BadAdaptive(minimum=1, maximum=4, interval="10ms") - while adapt._state != "stopped": - await asyncio.sleep(0.01) - text = log.getvalue() - assert "Adaptive stopping due to error" in text - assert "Adaptive scaling stopped" in text - assert not adapt._adapting - assert not adapt.periodic_callback - - -@gen_test() -async def test_adapt_oserror_scale(): - """ - FIXME: - If we encounter an OSError during scale down, we continue as before. It is - not entirely clear if this is the correct behaviour but defines the current - state. - This was probably introduced to protect against comm failures during - shutdown but the scale down command should be robust call to the scheduler - which is never scaled down. - """ - - class BadAdaptive(MyAdaptive): - async def scale_down(self, workers=None): - raise OSError() - - adapt = BadAdaptive(minimum=1, maximum=4, wait_count=0, interval="10ms") - adapt._target = 2 - while not adapt.periodic_callback.is_running(): - await asyncio.sleep(0.01) - await adapt.adapt() - assert len(adapt.plan) == 2 - assert len(adapt.requested) == 2 - with captured_logger("distributed.deploy.adaptive_core") as log: - adapt._target = 0 - await adapt.adapt() - text = log.getvalue() - assert "Error during adaptive downscaling" in text - assert not adapt._adapting - assert adapt.periodic_callback - assert adapt.periodic_callback.is_running() - adapt.stop() - - -@gen_test() -async def test_adaptive_logs_stopping_once(): - with captured_logger("distributed.deploy.adaptive_core") as log: - adapt = MyAdaptive(interval="100ms") - while not adapt.periodic_callback.is_running(): - await asyncio.sleep(0.01) - adapt.stop() - adapt.stop() - lines = log.getvalue().splitlines() - assert sum("Adaptive scaling stopped" in line for line in lines) == 1 - - -@gen_test() -async def test_adapt_stop_del(): - adapt = MyAdaptive(interval="100ms") - pc = adapt.periodic_callback - while not adapt.periodic_callback.is_running(): - await asyncio.sleep(0.01) - - del adapt - while pc.is_running(): - await asyncio.sleep(0.01) From 7ab72493b6df194bf6cb86ec1c4ac7c27cf1b3a1 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 27 Sep 2024 19:54:24 -0500 Subject: [PATCH 136/138] bump version to 2024.9.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e5d68b635c..4ff037513d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ requires-python = ">=3.10" dependencies = [ "click >= 8.0", "cloudpickle >= 3.0.0", - "dask == 2024.9.0", + "dask == 2024.9.1", "jinja2 >= 2.10.3", "locket >= 1.0.0", "msgpack >= 1.0.2", From b8dd8e782151721e86ecd20b9958d9a6581e53bc Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 1 Oct 2024 11:06:51 -0500 Subject: [PATCH 137/138] Switch from mambaforge to miniforge in CI (#8881) --- .github/workflows/conda.yml | 2 +- .github/workflows/test-report.yaml | 1 - .github/workflows/tests.yaml | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml index 9dae4fba59..81d9299d9e 100644 --- a/.github/workflows/conda.yml +++ b/.github/workflows/conda.yml @@ -32,7 +32,7 @@ jobs: - name: Set up Python uses: conda-incubator/setup-miniconda@v3.0.3 with: - miniforge-variant: Mambaforge + miniforge-version: latest use-mamba: true python-version: 3.9 channel-priority: strict diff --git a/.github/workflows/test-report.yaml b/.github/workflows/test-report.yaml index 490f091a4f..4853372fa0 100644 --- a/.github/workflows/test-report.yaml +++ b/.github/workflows/test-report.yaml @@ -23,7 +23,6 @@ jobs: - name: Setup Conda Environment uses: conda-incubator/setup-miniconda@v3.0.3 with: - miniforge-variant: Mambaforge miniforge-version: latest condarc-file: continuous_integration/condarc use-mamba: true diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 0e5f82c641..2deffff9d6 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -125,7 +125,6 @@ jobs: - name: Setup Conda Environment uses: conda-incubator/setup-miniconda@v3.0.3 with: - miniforge-variant: Mambaforge miniforge-version: latest condarc-file: continuous_integration/condarc use-mamba: true From 36020d6abe4506af08bae2807d78477b7916c8aa Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 2 Oct 2024 09:37:26 -0500 Subject: [PATCH 138/138] Bump JamesIves/github-pages-deploy-action from 4.6.4 to 4.6.8 (#8880) --- .github/workflows/test-report.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-report.yaml b/.github/workflows/test-report.yaml index 4853372fa0..120cf23817 100644 --- a/.github/workflows/test-report.yaml +++ b/.github/workflows/test-report.yaml @@ -53,7 +53,7 @@ jobs: mv test_report.html test_short_report.html deploy/ - name: Deploy 🚀 - uses: JamesIves/github-pages-deploy-action@v4.6.4 + uses: JamesIves/github-pages-deploy-action@v4.6.8 with: branch: gh-pages folder: deploy